Skip to content

Commit f6b20ac

Browse files
committed
Filter NaN parameter rows in EnIF update
Surface and field parameters commonly have NaN for inactive grid cells. When EnIF stacks all parameters into a joint matrix, these NaN values corrupt boost_linear_regression (np.argmax treats NaN as maximum) and propagate through StandardScaler, contaminating whichever parameter sits at index 0 in the joint matrix.
1 parent 58cc8d7 commit f6b20ac

2 files changed

Lines changed: 290 additions & 7 deletions

File tree

src/ert/analysis/_enif_update.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
from collections.abc import Callable, Iterable
55

6+
import networkx as nx
67
import numpy as np
78
import polars as pl
89
import scipy as sp
@@ -147,12 +148,41 @@ def analysis_EnIF(
147148
iens_active_index=iens_active_index,
148149
)
149150

150-
X_full_scaler = StandardScaler()
151-
X_full_scaled = X_full_scaler.fit_transform(X_full.T)
151+
# np.argmax in boost regression treats NaN as maximum, corrupting H
152+
nan_row_mask = np.any(np.isnan(X_full), axis=1)
153+
if nan_row_mask.any():
154+
n_nan = int(nan_row_mask.sum())
155+
n_all_nan = int(np.all(np.isnan(X_full), axis=1).sum())
156+
n_partial_nan = n_nan - n_all_nan
157+
log_msg = (
158+
f"EnIF: Excluding {n_nan}/{len(nan_row_mask)} parameter rows "
159+
f"containing NaN ({n_all_nan} fully inactive"
160+
)
161+
if n_partial_nan > 0:
162+
log_msg += (
163+
f", {n_partial_nan} partially active — "
164+
f"these will not be updated for any realization"
165+
)
166+
log_msg += ")"
167+
logger.info(log_msg)
168+
progress_callback(AnalysisStatusEvent(msg=log_msg))
169+
170+
X_clean = X_full[~nan_row_mask]
171+
if X_clean.shape[0] == 0:
172+
msg = "All parameter rows contain NaN — cannot run EnIF update"
173+
data = DataSection(
174+
header=smoother_snapshot.header,
175+
data=smoother_snapshot.csv,
176+
extra=smoother_snapshot.extra,
177+
)
178+
raise ErtAnalysisError(msg, data=data)
179+
180+
X_clean_scaler = StandardScaler()
181+
X_clean_scaled = X_clean_scaler.fit_transform(X_clean.T)
152182

153183
# Call fit: Learn sparse linear map only
154184
H = linear_boost_ic_regression(
155-
U=X_full_scaled,
185+
U=X_clean_scaled,
156186
Y=S.T,
157187
verbose_level=5,
158188
)
@@ -166,10 +196,23 @@ def analysis_EnIF(
166196
for param_group in updated_parameters:
167197
config_node = source_ensemble.experiment.parameter_configuration[param_group]
168198
X_local = source_ensemble.load_parameters_numpy(param_group, iens_active_index)
199+
200+
local_nan_mask = np.any(np.isnan(X_local), axis=1)
201+
X_local_clean = X_local[~local_nan_mask]
202+
203+
if X_local_clean.shape[0] == 0:
204+
continue
205+
169206
X_local_scaler = StandardScaler()
170-
X_scaled = X_local_scaler.fit_transform(X_local.T)
207+
X_scaled = X_local_scaler.fit_transform(X_local_clean.T)
171208

172209
graph_u_sub = config_node.load_parameter_graph()
210+
if local_nan_mask.any():
211+
nan_nodes = set(np.where(local_nan_mask)[0].tolist())
212+
graph_u_sub.remove_nodes_from(nan_nodes)
213+
graph_u_sub = nx.convert_node_labels_to_integers(
214+
graph_u_sub, ordering="sorted"
215+
)
173216

174217
# This works for up to ~10^5 parameters
175218
Prec_u_sub = fit_precision_cholesky_approximate(
@@ -202,16 +245,18 @@ def analysis_EnIF(
202245
neighbor_propagation_order=15, verbose_level=1
203246
)
204247

205-
X_full = gtmap.transport(
206-
X_full_scaled,
248+
X_updated = gtmap.transport(
249+
X_clean_scaled,
207250
S.T,
208251
observation_values,
209252
update_indices=update_indices,
210253
iterative=True,
211254
verbose_level=5,
212255
seed=random_seed,
213256
)
214-
X_full = X_full_scaler.inverse_transform(X_full).T
257+
X_updated = X_clean_scaler.inverse_transform(X_updated).T
258+
259+
X_full[~nan_row_mask] = X_updated
215260

216261
# Iterate over parameters to store the updated ensemble
217262
log_msg = f"Storing {len(updated_parameters)} updated parameter groups"

tests/ert/unit_tests/analysis/test_es_update.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import polars as pl
77
import pytest
8+
import xarray as xr
89
from tabulate import tabulate
910

1011
from ert.analysis import (
@@ -13,6 +14,7 @@
1314
build_strategy_map,
1415
smoother_update,
1516
)
17+
from ert.analysis._enif_update import enif_update
1618
from ert.analysis._es_update import _create_combined_ensemble_mask
1719
from ert.analysis._update_commons import (
1820
_compute_observation_statuses,
@@ -26,6 +28,7 @@
2628
GenKwConfig,
2729
ObservationSettings,
2830
OutlierSettings,
31+
SurfaceConfig,
2932
)
3033
from ert.storage import Ensemble, open_storage
3134

@@ -1293,3 +1296,238 @@ def test_that_create_combined_ensemble_mask_handles_different_length_masks(
12931296
) -> None:
12941297
result = _create_combined_ensemble_mask(ens_mask, active_realizations)
12951298
np.testing.assert_array_equal(result, expected)
1299+
1300+
1301+
def _make_enif_genkw(name: str) -> dict[str, Any]:
1302+
return GenKwConfig(
1303+
name=name,
1304+
group=name,
1305+
distribution={"name": "uniform", "min": 0.8, "max": 1.2},
1306+
).model_dump(mode="json")
1307+
1308+
1309+
def _populate_enif_prior(
1310+
prior, rng, *, ncol, nrow, ensemble_size, num_obs, genkw_names
1311+
):
1312+
inactive_mask = rng.random((ncol, nrow)) < 0.55
1313+
1314+
for real in range(ensemble_size):
1315+
vals = rng.normal(1.0, 0.1, (ncol, nrow)).astype(np.float32)
1316+
vals[inactive_mask] = np.nan
1317+
prior.save_parameters(
1318+
xr.Dataset({"values": (["x", "y"], vals)}), "HEIMDAL_TOP", real
1319+
)
1320+
1321+
for name in genkw_names:
1322+
prior.save_parameters(
1323+
dataset=pl.concat(
1324+
[
1325+
pl.DataFrame(
1326+
{name: [float(rng.uniform(0.8, 1.2))], "realization": i}
1327+
)
1328+
for i in range(ensemble_size)
1329+
],
1330+
how="vertical",
1331+
)
1332+
)
1333+
1334+
for iens in range(ensemble_size):
1335+
prior.save_response(
1336+
"gen_data",
1337+
pl.DataFrame(
1338+
{
1339+
"response_key": "RESPONSE",
1340+
"report_step": pl.Series(
1341+
np.zeros(num_obs, dtype=int), dtype=pl.UInt16
1342+
),
1343+
"index": pl.Series(np.arange(num_obs), dtype=pl.UInt16),
1344+
"values": pl.Series(rng.normal(1, 0.1, num_obs), dtype=pl.Float32),
1345+
}
1346+
),
1347+
iens,
1348+
)
1349+
1350+
1351+
@pytest.fixture
1352+
def enif_surface_config():
1353+
return SurfaceConfig(
1354+
name="HEIMDAL_TOP",
1355+
forward_init=True,
1356+
update=True,
1357+
ncol=20,
1358+
nrow=25,
1359+
xori=0.0,
1360+
yori=0.0,
1361+
xinc=25.0,
1362+
yinc=25.0,
1363+
rotation=0.0,
1364+
yflip=1,
1365+
forward_init_file="init_surf",
1366+
output_file="out_surf",
1367+
base_surface_path="base_surf",
1368+
)
1369+
1370+
1371+
@pytest.fixture
1372+
def enif_observations():
1373+
return [
1374+
{
1375+
"type": "general_observation",
1376+
"name": f"OBS_{i}",
1377+
"data": "RESPONSE",
1378+
"restart": 0,
1379+
"index": i,
1380+
"value": 1.0,
1381+
"error": 0.1,
1382+
}
1383+
for i in range(5)
1384+
]
1385+
1386+
1387+
@pytest.mark.filterwarnings("ignore::scipy.linalg.LinAlgWarning")
1388+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
1389+
def test_that_enif_with_nan_surface_cells_does_not_contaminate_genkw(
1390+
tmp_path, enif_surface_config, enif_observations
1391+
):
1392+
rng = np.random.default_rng(42)
1393+
ensemble_size = 20
1394+
1395+
with open_storage(tmp_path, mode="w") as storage:
1396+
experiment = storage.create_experiment(
1397+
name="enif_surface_nan",
1398+
experiment_config={
1399+
"parameter_configuration": [
1400+
_make_enif_genkw("VOLUME_MULT"),
1401+
_make_enif_genkw("IN_MULT"),
1402+
enif_surface_config.model_dump(mode="json"),
1403+
],
1404+
"response_configuration": [
1405+
GenDataConfig(keys=["RESPONSE"]).model_dump(mode="json")
1406+
],
1407+
"observations": enif_observations,
1408+
},
1409+
)
1410+
prior = storage.create_ensemble(
1411+
experiment, ensemble_size=ensemble_size, iteration=0, name="prior"
1412+
)
1413+
_populate_enif_prior(
1414+
prior,
1415+
rng,
1416+
ncol=enif_surface_config.ncol,
1417+
nrow=enif_surface_config.nrow,
1418+
ensemble_size=ensemble_size,
1419+
num_obs=5,
1420+
genkw_names=["VOLUME_MULT", "IN_MULT"],
1421+
)
1422+
1423+
posterior = storage.create_ensemble(
1424+
experiment,
1425+
ensemble_size=ensemble_size,
1426+
iteration=1,
1427+
name="posterior_enif",
1428+
prior_ensemble=prior,
1429+
)
1430+
enif_update(
1431+
prior,
1432+
posterior,
1433+
observations=experiment.observation_keys,
1434+
parameters=list(experiment.parameter_configuration.keys()),
1435+
random_seed=42,
1436+
)
1437+
1438+
iens = np.arange(ensemble_size)
1439+
post_vol = posterior.load_parameters_numpy("VOLUME_MULT", iens)
1440+
post_in = posterior.load_parameters_numpy("IN_MULT", iens)
1441+
prior_surf = prior.load_parameters_numpy("HEIMDAL_TOP", iens)
1442+
post_surf = posterior.load_parameters_numpy("HEIMDAL_TOP", iens)
1443+
1444+
assert np.all(np.isfinite(post_vol)), (
1445+
f"VOLUME_MULT should be finite, "
1446+
f"got {np.isnan(post_vol).sum()} NaN out of {post_vol.size}"
1447+
)
1448+
assert np.all(np.isfinite(post_in)), (
1449+
f"IN_MULT should be finite, "
1450+
f"got {np.isnan(post_in).sum()} NaN out of {post_in.size}"
1451+
)
1452+
1453+
nan_mask = np.isnan(prior_surf)
1454+
assert np.array_equal(np.isnan(post_surf), nan_mask), (
1455+
"Inactive surface cells (NaN) should be preserved after update"
1456+
)
1457+
active = ~nan_mask
1458+
assert np.all(np.isfinite(post_surf[active])), (
1459+
"Active surface cells should remain finite after update"
1460+
)
1461+
assert not np.array_equal(post_surf[active], prior_surf[active]), (
1462+
"Active surface cells should be updated, not identical to prior"
1463+
)
1464+
1465+
1466+
@pytest.mark.filterwarnings("ignore::scipy.linalg.LinAlgWarning")
1467+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
1468+
def test_that_enif_nan_filtering_preserves_genkw_regardless_of_parameter_order(
1469+
tmp_path, enif_surface_config, enif_observations
1470+
):
1471+
rng = np.random.default_rng(42)
1472+
ensemble_size = 20
1473+
1474+
with open_storage(tmp_path, mode="w") as storage:
1475+
experiment = storage.create_experiment(
1476+
name="enif_reversed_order",
1477+
experiment_config={
1478+
"parameter_configuration": [
1479+
enif_surface_config.model_dump(mode="json"),
1480+
_make_enif_genkw("IN_MULT"),
1481+
_make_enif_genkw("VOLUME_MULT"),
1482+
],
1483+
"response_configuration": [
1484+
GenDataConfig(keys=["RESPONSE"]).model_dump(mode="json")
1485+
],
1486+
"observations": enif_observations,
1487+
},
1488+
)
1489+
prior = storage.create_ensemble(
1490+
experiment, ensemble_size=ensemble_size, iteration=0, name="prior"
1491+
)
1492+
_populate_enif_prior(
1493+
prior,
1494+
rng,
1495+
ncol=enif_surface_config.ncol,
1496+
nrow=enif_surface_config.nrow,
1497+
ensemble_size=ensemble_size,
1498+
num_obs=5,
1499+
genkw_names=["IN_MULT", "VOLUME_MULT"],
1500+
)
1501+
1502+
posterior = storage.create_ensemble(
1503+
experiment,
1504+
ensemble_size=ensemble_size,
1505+
iteration=1,
1506+
name="posterior_enif",
1507+
prior_ensemble=prior,
1508+
)
1509+
enif_update(
1510+
prior,
1511+
posterior,
1512+
observations=experiment.observation_keys,
1513+
parameters=list(experiment.parameter_configuration.keys()),
1514+
random_seed=42,
1515+
)
1516+
1517+
iens = np.arange(ensemble_size)
1518+
1519+
for name in ["IN_MULT", "VOLUME_MULT"]:
1520+
post_vals = posterior.load_parameters_numpy(name, iens)
1521+
assert np.all(np.isfinite(post_vals)), (
1522+
f"{name} should be finite regardless of parameter order"
1523+
)
1524+
1525+
prior_surf = prior.load_parameters_numpy("HEIMDAL_TOP", iens)
1526+
post_surf = posterior.load_parameters_numpy("HEIMDAL_TOP", iens)
1527+
active = ~np.isnan(prior_surf)
1528+
assert np.all(np.isfinite(post_surf[active])), (
1529+
"Active surface cells should remain finite after update"
1530+
)
1531+
assert not np.array_equal(post_surf[active], prior_surf[active]), (
1532+
"Active surface cells should be updated, not identical to prior"
1533+
)

0 commit comments

Comments
 (0)