diff --git a/src/ert/dark_storage/endpoints/parameters.py b/src/ert/dark_storage/endpoints/parameters.py index 61314f56e15..56324277ff2 100644 --- a/src/ert/dark_storage/endpoints/parameters.py +++ b/src/ert/dark_storage/endpoints/parameters.py @@ -77,6 +77,27 @@ def get_parameter_std_dev( return Response(content=buffer.getvalue(), media_type="application/octet-stream") +@router.get("/ensembles/{ensemble_id}/parameters/{key}/mean") +def get_parameter_mean( + *, storage: Storage = DEFAULT_STORAGE, ensemble_id: UUID, key: str, z: int +) -> Response: + key = unquote(key) + with reraise_as_http_errors(logger): + ensemble = storage.get_ensemble(ensemble_id) + da = ensemble.calculate_mean_for_parameter_group(key) + + if z >= int(da.shape[2]): + logger.error("invalid z index") + raise HTTPException(status_code=500, detail="Internal server error") + + data_2d = da[:, :, z] + + buffer = io.BytesIO() + np.save(buffer, data_2d) + + return Response(content=buffer.getvalue(), media_type="application/octet-stream") + + def data_for_parameter(ensemble: Ensemble, key: str) -> pd.DataFrame: param_info = ensemble.experiment.parameter_info.get(key) diff --git a/src/ert/gui/tools/plot/plot_api.py b/src/ert/gui/tools/plot/plot_api.py index b6e9b6da35c..8749791819f 100644 --- a/src/ert/gui/tools/plot/plot_api.py +++ b/src/ert/gui/tools/plot/plot_api.py @@ -488,3 +488,23 @@ def std_dev_for_parameter( return np.load(io.BytesIO(http_response.content)) else: return np.array([]) + + def mean_for_parameter( + self, key: str, ensemble_id: str, z: int + ) -> npt.NDArray[np.float32]: + ensemble = self._get_ensemble_by_id(ensemble_id) + if not ensemble: + return np.array([]) + + with create_ertserver_client(self.ens_path) as client: + http_response = client.get( + f"/ensembles/{ensemble.id}/parameters/{PlotApi.escape(key)}/mean", + params={"z": z}, + timeout=self._timeout, + ) + + if http_response.status_code == 200: + # Deserialize the numpy array + return np.load(io.BytesIO(http_response.content)) + else: + return np.array([]) diff --git a/src/ert/gui/tools/plot/plot_widget.py b/src/ert/gui/tools/plot/plot_widget.py index 44b96050e11..efdb76bf4bb 100644 --- a/src/ert/gui/tools/plot/plot_widget.py +++ b/src/ert/gui/tools/plot/plot_widget.py @@ -11,7 +11,7 @@ NavigationToolbar2QT, ) from matplotlib.figure import Figure -from PyQt6.QtCore import QStringListModel, Qt, pyqtBoundSignal +from PyQt6.QtCore import QStringListModel, Qt from PyQt6.QtCore import pyqtSignal as Signal from PyQt6.QtCore import pyqtSlot as Slot from PyQt6.QtGui import QAction @@ -37,6 +37,7 @@ from .plottery.plots.ensemble import EnsemblePlot from .plottery.plots.gaussian_kde import GaussianKDEPlot from .plottery.plots.histogram import HistogramPlot + from .plottery.plots.mean import MeanPlot from .plottery.plots.misfits import MisfitsPlot from .plottery.plots.statistics import StatisticsPlot from .plottery.plots.std_dev import StdDevPlot @@ -114,6 +115,7 @@ def updateLayerWidget(self, layers: int) -> None: class PlotWidget(QWidget): customizationTriggered = Signal() layerIndexChanged = Signal(int) + plotUpdateRequested = Signal() updateLayerWidget = Signal(int) resetLayerWidget = Signal() showLayerWidget = Signal(bool) @@ -129,6 +131,7 @@ def __init__( "DistributionPlot", "CrossEnsembleStatisticsPlot", "StdDevPlot", + "MeanPlot", "ValuesOverIterationsPlot", "MisfitsPlot", "EverestGradientsPlot", @@ -162,6 +165,7 @@ def __init__( self._log_checkbox.setVisible(False) self._log_checkbox.setToolTip("Toggle data domain to log scale and back") self._log_checkbox.clicked.connect(self.logLogScaleButtonUsage) + self._log_checkbox.toggled.connect(lambda _: self.plotUpdateRequested.emit()) log_checkbox_row = QHBoxLayout() log_checkbox_row.addWidget(self._log_checkbox) @@ -177,10 +181,6 @@ def __init__( self._active = False self.resetPlot() - @property - def plotUpdateRequested(self) -> pyqtBoundSignal: - return self._log_checkbox.toggled - def resetPlot(self) -> None: self._figure.clear() diff --git a/src/ert/gui/tools/plot/plot_window.py b/src/ert/gui/tools/plot/plot_window.py index f8d4c5a7065..519833465d7 100644 --- a/src/ert/gui/tools/plot/plot_window.py +++ b/src/ert/gui/tools/plot/plot_window.py @@ -43,6 +43,7 @@ EverestGradientsPlot, GaussianKDEPlot, HistogramPlot, + MeanPlot, MisfitsPlot, StatisticsPlot, StdDevPlot, @@ -57,6 +58,7 @@ HISTOGRAM = "Histogram" STATISTICS = "Statistics" STD_DEV = "Std Dev" +MEAN = "Mean" MISFITS = "Misfits" EVEREST_RESPONSES_PLOT = "Batch responses" EVEREST_CONTROLS_PLOT = "Batch controls" @@ -65,6 +67,7 @@ RESPONSE_DEFAULT = 0 GEN_KW_DEFAULT = 3 STD_DEV_DEFAULT = 7 +MEAN_DEFAULT = 8 logger = logging.getLogger(__name__) @@ -214,6 +217,7 @@ def __init__( CROSS_ENSEMBLE_STATISTICS, CrossEnsembleStatisticsPlot() ) self.addPlotWidget(STD_DEV, StdDevPlot()) + self.addPlotWidget(MEAN, MeanPlot()) else: self.addPlotWidget(ENSEMBLE, EnsemblePlot()) self.addPlotWidget(EVEREST_CONTROLS_PLOT, ValuesOverIterationsPlot()) @@ -411,9 +415,14 @@ def fetch_data( for ensemble in selected_ensembles: try: - std_dev_images[ensemble.name] = self._api.std_dev_for_parameter( - key, ensemble.id, layer - ) + if isinstance(plot_widget._plotter, MeanPlot): + std_dev_images[ensemble.name] = ( + self._api.mean_for_parameter(key, ensemble.id, layer) + ) + else: + std_dev_images[ensemble.name] = ( + self._api.std_dev_for_parameter(key, ensemble.id, layer) + ) except BaseException as e: handle_exception(e) else: @@ -507,6 +516,7 @@ def addPlotWidget( | DistributionPlot | CrossEnsembleStatisticsPlot | StdDevPlot + | MeanPlot | MisfitsPlot | ValuesOverIterationsPlot | EverestGradientsPlot, diff --git a/src/ert/gui/tools/plot/plottery/plots/__init__.py b/src/ert/gui/tools/plot/plottery/plots/__init__.py index e04dfdec238..f94dbbfec77 100644 --- a/src/ert/gui/tools/plot/plottery/plots/__init__.py +++ b/src/ert/gui/tools/plot/plottery/plots/__init__.py @@ -4,6 +4,7 @@ from .everest_gradients_plot import EverestGradientsPlot from .gaussian_kde import GaussianKDEPlot from .histogram import HistogramPlot +from .mean import MeanPlot from .misfits import MisfitsPlot from .statistics import StatisticsPlot from .std_dev import StdDevPlot @@ -16,6 +17,7 @@ "EverestGradientsPlot", "GaussianKDEPlot", "HistogramPlot", + "MeanPlot", "MisfitsPlot", "StatisticsPlot", "StdDevPlot", diff --git a/src/ert/gui/tools/plot/plottery/plots/mean.py b/src/ert/gui/tools/plot/plottery/plots/mean.py new file mode 100644 index 00000000000..5a0290030d4 --- /dev/null +++ b/src/ert/gui/tools/plot/plottery/plots/mean.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import pandas as pd +from matplotlib.figure import Figure +from mpl_toolkits.axes_grid1 import make_axes_locatable + +if TYPE_CHECKING: + from ert.gui.tools.plot.plot_api import EnsembleObject, PlotApiKeyDefinition + from ert.gui.tools.plot.plottery import PlotContext + + +class MeanPlot: + def __init__(self) -> None: + self.dimensionality = 3 + self.requires_observations = False + + def plot( + self, + figure: Figure, + plot_context: PlotContext, + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], + observation_data: pd.DataFrame, + mean_data: dict[str, npt.NDArray[np.float32]], + key_def: PlotApiKeyDefinition | None = None, + ) -> None: + ensemble_count = len(plot_context.ensembles()) + layer = plot_context.layer + if layer is not None: + vmin: float = np.inf + vmax: float = -np.inf + heatmaps = [] + boxplot_axes = [] + + figure.set_layout_engine("constrained") + gridspec = figure.add_gridspec(2, ensemble_count, hspace=0.2) + + for i, ensemble in enumerate(reversed(plot_context.ensembles()), start=1): + ax_heat = figure.add_subplot(gridspec[0, i - 1]) + ax_box = figure.add_subplot(gridspec[1, i - 1]) + data = mean_data[ensemble.name] + if data.size == 0: + ax_heat.set_axis_off() + ax_box.set_axis_off() + ax_heat.text( + 0.5, + 0.5, + f"No data for {ensemble.experiment_name} : {ensemble.name}", + ha="center", + va="center", + ) + else: + vmin = min(vmin, float(np.min(data))) + vmax = max(vmax, float(np.max(data))) + + im = ax_heat.imshow(data, cmap="plasma", aspect="equal") + heatmaps.append(im) + + ax_box.boxplot(data.flatten(), orientation="vertical", widths=0.5) + boxplot_axes.append(ax_box) + + min_value = np.min(data) + mean_value = np.mean(data) + max_value = np.max(data) + + ax_box.annotate( + f"Min: {min_value:.2f}\nMean: {mean_value:.2f}\n" + f"Max: {max_value:.2f}", + xy=(1, 1), + xycoords="axes fraction", + ha="right", + va="top", + fontsize=8, + fontweight="bold", + bbox={ + "facecolor": "white", + "edgecolor": "black", + "boxstyle": "round,pad=0.2", + }, + ) + + ax_box.spines["top"].set_visible(False) + ax_box.spines["right"].set_visible(False) + ax_box.spines["bottom"].set_visible(False) + ax_box.spines["left"].set_visible(True) + + ax_box.set_xticks([]) + ax_box.set_xticklabels([]) + + ax_heat.set_ylabel("") + ax_box.set_ylabel("Mean", fontsize=8) + + self._colorbar(im) + + ax_heat.set_title( + f"{ensemble.experiment_name} : {ensemble.name} layer={layer}", + wrap=True, + fontsize=10, + ) + + norm = plt.Normalize(vmin, vmax) + for im in heatmaps: + im.set_norm(norm) + + padding = 0.05 * (vmax - vmin) + if padding > 0.0: + for ax_box in boxplot_axes: + ax_box.set_ylim(vmin - padding, vmax + padding) + + @staticmethod + def _colorbar(mappable: Any) -> Any: + last_axes = plt.gca() + ax = mappable.axes + assert ax is not None + fig = ax.figure + assert fig is not None + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + cbar = fig.colorbar(mappable, cax=cax) + plt.sca(last_axes) + return cbar diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index a37bf57b3dd..9118f04f4c1 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -27,6 +27,7 @@ ParameterConfig, SummaryConfig, ) +from ert.config.field import Field, field_transform from ert.exceptions import StorageError from ert.substitutions import substitute_runpath_name @@ -892,7 +893,23 @@ def calculate_std_dev_for_parameter_group( data = self.load_parameters(parameter_group) if isinstance(data, pl.DataFrame): return data.drop("realization").std().to_numpy().reshape(-1) - return data.std("realizations")["values"].to_numpy() + param_config = self.experiment.parameter_configuration.get(parameter_group) + values = data["values"] + if isinstance(param_config, Field) and param_config.output_transformation: + values = field_transform(values, param_config.output_transformation) + return values.std("realizations").to_numpy() + + def calculate_mean_for_parameter_group( + self, parameter_group: str + ) -> npt.NDArray[np.float64]: + data = self.load_parameters(parameter_group) + if isinstance(data, pl.DataFrame): + return data.drop("realization", strict=False).mean().to_numpy().reshape(-1) + param_config = self.experiment.parameter_configuration.get(parameter_group) + values = data["values"] + if isinstance(param_config, Field) and param_config.output_transformation: + values = field_transform(values, param_config.output_transformation) + return values.mean("realizations").to_numpy() def get_parameter_state( self, realization: int diff --git a/test-data/ert/heat_equation/config.ert b/test-data/ert/heat_equation/config.ert index fad78b07b6f..455aa5d7b25 100644 --- a/test-data/ert/heat_equation/config.ert +++ b/test-data/ert/heat_equation/config.ert @@ -22,7 +22,7 @@ GRID CASE.EGRID OBS_CONFIG observations_loc.txt ECLBASE HEAT -FIELD COND PARAMETER cond.bgrdecl INIT_FILES:cond.bgrdecl FORWARD_INIT:True +FIELD COND PARAMETER cond.bgrdecl INIT_FILES:cond.bgrdecl FORWARD_INIT:True INIT_TRANSFORM:LOG OUTPUT_TRANSFORM:EXP GEN_KW INIT_TEMP_SCALE init_temp_prior.txt GEN_KW CORR_LENGTH corr_length_prior.txt