Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/ert/dark_storage/endpoints/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ def get_parameter_std_dev(
return Response(content=buffer.getvalue(), media_type="application/octet-stream")


@router.get("/ensembles/{ensemble_id}/parameters/{key}/mean")
def get_parameter_mean(
*, storage: Storage = DEFAULT_STORAGE, ensemble_id: UUID, key: str, z: int
) -> Response:
key = unquote(key)
with reraise_as_http_errors(logger):
ensemble = storage.get_ensemble(ensemble_id)
da = ensemble.calculate_mean_for_parameter_group(key)

if z >= int(da.shape[2]):
logger.error("invalid z index")
raise HTTPException(status_code=500, detail="Internal server error")

data_2d = da[:, :, z]

buffer = io.BytesIO()
np.save(buffer, data_2d)

return Response(content=buffer.getvalue(), media_type="application/octet-stream")


def data_for_parameter(ensemble: Ensemble, key: str) -> pd.DataFrame:
param_info = ensemble.experiment.parameter_info.get(key)

Expand Down
20 changes: 20 additions & 0 deletions src/ert/gui/tools/plot/plot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,23 @@ def std_dev_for_parameter(
return np.load(io.BytesIO(http_response.content))
else:
return np.array([])

def mean_for_parameter(
self, key: str, ensemble_id: str, z: int
) -> npt.NDArray[np.float32]:
ensemble = self._get_ensemble_by_id(ensemble_id)
if not ensemble:
return np.array([])

with create_ertserver_client(self.ens_path) as client:
http_response = client.get(
f"/ensembles/{ensemble.id}/parameters/{PlotApi.escape(key)}/mean",
params={"z": z},
timeout=self._timeout,
)

if http_response.status_code == 200:
# Deserialize the numpy array
return np.load(io.BytesIO(http_response.content))
else:
return np.array([])
10 changes: 5 additions & 5 deletions src/ert/gui/tools/plot/plot_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
NavigationToolbar2QT,
)
from matplotlib.figure import Figure
from PyQt6.QtCore import QStringListModel, Qt, pyqtBoundSignal
from PyQt6.QtCore import QStringListModel, Qt
from PyQt6.QtCore import pyqtSignal as Signal
from PyQt6.QtCore import pyqtSlot as Slot
from PyQt6.QtGui import QAction
Expand All @@ -37,6 +37,7 @@
from .plottery.plots.ensemble import EnsemblePlot
from .plottery.plots.gaussian_kde import GaussianKDEPlot
from .plottery.plots.histogram import HistogramPlot
from .plottery.plots.mean import MeanPlot
from .plottery.plots.misfits import MisfitsPlot
from .plottery.plots.statistics import StatisticsPlot
from .plottery.plots.std_dev import StdDevPlot
Expand Down Expand Up @@ -114,6 +115,7 @@ def updateLayerWidget(self, layers: int) -> None:
class PlotWidget(QWidget):
customizationTriggered = Signal()
layerIndexChanged = Signal(int)
plotUpdateRequested = Signal()
updateLayerWidget = Signal(int)
resetLayerWidget = Signal()
showLayerWidget = Signal(bool)
Expand All @@ -129,6 +131,7 @@ def __init__(
"DistributionPlot",
"CrossEnsembleStatisticsPlot",
"StdDevPlot",
"MeanPlot",
"ValuesOverIterationsPlot",
"MisfitsPlot",
"EverestGradientsPlot",
Expand Down Expand Up @@ -162,6 +165,7 @@ def __init__(
self._log_checkbox.setVisible(False)
self._log_checkbox.setToolTip("Toggle data domain to log scale and back")
self._log_checkbox.clicked.connect(self.logLogScaleButtonUsage)
self._log_checkbox.toggled.connect(lambda _: self.plotUpdateRequested.emit())

log_checkbox_row = QHBoxLayout()
log_checkbox_row.addWidget(self._log_checkbox)
Expand All @@ -177,10 +181,6 @@ def __init__(
self._active = False
self.resetPlot()

@property
def plotUpdateRequested(self) -> pyqtBoundSignal:
return self._log_checkbox.toggled

def resetPlot(self) -> None:
self._figure.clear()

Expand Down
16 changes: 13 additions & 3 deletions src/ert/gui/tools/plot/plot_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
EverestGradientsPlot,
GaussianKDEPlot,
HistogramPlot,
MeanPlot,
MisfitsPlot,
StatisticsPlot,
StdDevPlot,
Expand All @@ -57,6 +58,7 @@
HISTOGRAM = "Histogram"
STATISTICS = "Statistics"
STD_DEV = "Std Dev"
MEAN = "Mean"
MISFITS = "Misfits"
EVEREST_RESPONSES_PLOT = "Batch responses"
EVEREST_CONTROLS_PLOT = "Batch controls"
Expand All @@ -65,6 +67,7 @@
RESPONSE_DEFAULT = 0
GEN_KW_DEFAULT = 3
STD_DEV_DEFAULT = 7
MEAN_DEFAULT = 8


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -214,6 +217,7 @@ def __init__(
CROSS_ENSEMBLE_STATISTICS, CrossEnsembleStatisticsPlot()
)
self.addPlotWidget(STD_DEV, StdDevPlot())
self.addPlotWidget(MEAN, MeanPlot())
else:
self.addPlotWidget(ENSEMBLE, EnsemblePlot())
self.addPlotWidget(EVEREST_CONTROLS_PLOT, ValuesOverIterationsPlot())
Expand Down Expand Up @@ -411,9 +415,14 @@ def fetch_data(

for ensemble in selected_ensembles:
try:
std_dev_images[ensemble.name] = self._api.std_dev_for_parameter(
key, ensemble.id, layer
)
if isinstance(plot_widget._plotter, MeanPlot):
std_dev_images[ensemble.name] = (
self._api.mean_for_parameter(key, ensemble.id, layer)
)
else:
std_dev_images[ensemble.name] = (
self._api.std_dev_for_parameter(key, ensemble.id, layer)
)
except BaseException as e:
handle_exception(e)
else:
Expand Down Expand Up @@ -507,6 +516,7 @@ def addPlotWidget(
| DistributionPlot
| CrossEnsembleStatisticsPlot
| StdDevPlot
| MeanPlot
| MisfitsPlot
| ValuesOverIterationsPlot
| EverestGradientsPlot,
Expand Down
2 changes: 2 additions & 0 deletions src/ert/gui/tools/plot/plottery/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .everest_gradients_plot import EverestGradientsPlot
from .gaussian_kde import GaussianKDEPlot
from .histogram import HistogramPlot
from .mean import MeanPlot
from .misfits import MisfitsPlot
from .statistics import StatisticsPlot
from .std_dev import StdDevPlot
Expand All @@ -16,6 +17,7 @@
"EverestGradientsPlot",
"GaussianKDEPlot",
"HistogramPlot",
"MeanPlot",
"MisfitsPlot",
"StatisticsPlot",
"StdDevPlot",
Expand Down
125 changes: 125 additions & 0 deletions src/ert/gui/tools/plot/plottery/plots/mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable

if TYPE_CHECKING:
from ert.gui.tools.plot.plot_api import EnsembleObject, PlotApiKeyDefinition
from ert.gui.tools.plot.plottery import PlotContext


class MeanPlot:
def __init__(self) -> None:
self.dimensionality = 3
self.requires_observations = False

def plot(
self,
figure: Figure,
plot_context: PlotContext,
ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame],
observation_data: pd.DataFrame,
mean_data: dict[str, npt.NDArray[np.float32]],
key_def: PlotApiKeyDefinition | None = None,
) -> None:
ensemble_count = len(plot_context.ensembles())
layer = plot_context.layer
if layer is not None:
vmin: float = np.inf
vmax: float = -np.inf
heatmaps = []
boxplot_axes = []

figure.set_layout_engine("constrained")
gridspec = figure.add_gridspec(2, ensemble_count, hspace=0.2)

for i, ensemble in enumerate(reversed(plot_context.ensembles()), start=1):
ax_heat = figure.add_subplot(gridspec[0, i - 1])
ax_box = figure.add_subplot(gridspec[1, i - 1])
data = mean_data[ensemble.name]
if data.size == 0:
ax_heat.set_axis_off()
ax_box.set_axis_off()
ax_heat.text(
0.5,
0.5,
f"No data for {ensemble.experiment_name} : {ensemble.name}",
ha="center",
va="center",
)
else:
vmin = min(vmin, float(np.min(data)))
vmax = max(vmax, float(np.max(data)))

im = ax_heat.imshow(data, cmap="plasma", aspect="equal")
heatmaps.append(im)

ax_box.boxplot(data.flatten(), orientation="vertical", widths=0.5)
boxplot_axes.append(ax_box)

min_value = np.min(data)
mean_value = np.mean(data)
max_value = np.max(data)

ax_box.annotate(
f"Min: {min_value:.2f}\nMean: {mean_value:.2f}\n"
f"Max: {max_value:.2f}",
xy=(1, 1),
xycoords="axes fraction",
ha="right",
va="top",
fontsize=8,
fontweight="bold",
bbox={
"facecolor": "white",
"edgecolor": "black",
"boxstyle": "round,pad=0.2",
},
)

ax_box.spines["top"].set_visible(False)
ax_box.spines["right"].set_visible(False)
ax_box.spines["bottom"].set_visible(False)
ax_box.spines["left"].set_visible(True)

ax_box.set_xticks([])
ax_box.set_xticklabels([])

ax_heat.set_ylabel("")
ax_box.set_ylabel("Mean", fontsize=8)

self._colorbar(im)

ax_heat.set_title(
f"{ensemble.experiment_name} : {ensemble.name} layer={layer}",
wrap=True,
fontsize=10,
)

norm = plt.Normalize(vmin, vmax)
for im in heatmaps:
im.set_norm(norm)

padding = 0.05 * (vmax - vmin)
if padding > 0.0:
for ax_box in boxplot_axes:
ax_box.set_ylim(vmin - padding, vmax + padding)

@staticmethod
def _colorbar(mappable: Any) -> Any:
last_axes = plt.gca()
ax = mappable.axes
assert ax is not None
fig = ax.figure
assert fig is not None
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = fig.colorbar(mappable, cax=cax)
plt.sca(last_axes)
return cbar
19 changes: 18 additions & 1 deletion src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ParameterConfig,
SummaryConfig,
)
from ert.config.field import Field, field_transform
from ert.exceptions import StorageError
from ert.substitutions import substitute_runpath_name

Expand Down Expand Up @@ -892,7 +893,23 @@
data = self.load_parameters(parameter_group)
if isinstance(data, pl.DataFrame):
return data.drop("realization").std().to_numpy().reshape(-1)
return data.std("realizations")["values"].to_numpy()
param_config = self.experiment.parameter_configuration.get(parameter_group)
values = data["values"]
if isinstance(param_config, Field) and param_config.output_transformation:
values = field_transform(values, param_config.output_transformation)

Check failure on line 899 in src/ert/storage/local_ensemble.py

View workflow job for this annotation

GitHub Actions / check-typing / type-checking (3.12)

Incompatible types in assignment (expression has type "ndarray[tuple[Any, ...], dtype[floating[_32Bit]]] | DataArray", variable has type "DataArray")
return values.std("realizations").to_numpy()

def calculate_mean_for_parameter_group(
self, parameter_group: str
) -> npt.NDArray[np.float64]:
data = self.load_parameters(parameter_group)
if isinstance(data, pl.DataFrame):
return data.drop("realization", strict=False).mean().to_numpy().reshape(-1)
param_config = self.experiment.parameter_configuration.get(parameter_group)
values = data["values"]
if isinstance(param_config, Field) and param_config.output_transformation:
values = field_transform(values, param_config.output_transformation)

Check failure on line 911 in src/ert/storage/local_ensemble.py

View workflow job for this annotation

GitHub Actions / check-typing / type-checking (3.12)

Incompatible types in assignment (expression has type "ndarray[tuple[Any, ...], dtype[floating[_32Bit]]] | DataArray", variable has type "DataArray")
return values.mean("realizations").to_numpy()

def get_parameter_state(
self, realization: int
Expand Down
2 changes: 1 addition & 1 deletion test-data/ert/heat_equation/config.ert
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ GRID CASE.EGRID
OBS_CONFIG observations_loc.txt
ECLBASE HEAT

FIELD COND PARAMETER cond.bgrdecl INIT_FILES:cond.bgrdecl FORWARD_INIT:True
FIELD COND PARAMETER cond.bgrdecl INIT_FILES:cond.bgrdecl FORWARD_INIT:True INIT_TRANSFORM:LOG OUTPUT_TRANSFORM:EXP

GEN_KW INIT_TEMP_SCALE init_temp_prior.txt
GEN_KW CORR_LENGTH corr_length_prior.txt
Expand Down
Loading