Skip to content

Commit 62e9be7

Browse files
committed
Address some of PR review comments
1 parent de35aeb commit 62e9be7

File tree

4 files changed

+40
-54
lines changed

4 files changed

+40
-54
lines changed

src/mdio/api/create.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,6 @@ def create_empty( # noqa PLR0913
3838
3939
Args:
4040
mdio_template: The MDIO template or template name to use to define the dataset structure.
41-
NOTE: If you want to have a unit-aware MDIO model, you need to add the units
42-
to the template before calling this function. For example:
43-
'unit_aware_template = TemplateRegistry().get("PostStack3DTime")'
44-
'unit_aware_template.add_units({"time": UNITS_SECOND})'
45-
'unit_aware_template.add_units({"cdp_x": UNITS_METER})'
46-
'unit_aware_template.add_units({"cdp_y": UNITS_METER})'
47-
'create_empty(unit_aware_template, dimensions, output_path, headers, overwrite)'
4841
dimensions: The dimensions of the MDIO file.
4942
output_path: The universal path for the output MDIO v1 file.
5043
headers: SEG-Y v1.0 trace headers. Defaults to None.
@@ -139,7 +132,7 @@ def create_empty_like( # noqa PLR0913
139132
# Coordinates
140133
if not keep_coordinates:
141134
for coord_name in ds_output.coords:
142-
ds_output[coord_name].attrs["unitsV1"] = None
135+
ds_output[coord_name].attrs.pop("unitsV1", None)
143136

144137
# MDIO attributes
145138
attr = ds_output.attrs["attributes"]
@@ -157,17 +150,17 @@ def create_empty_like( # noqa PLR0913
157150
# Data variable
158151
var_name = attr["defaultVariableName"]
159152
var = ds_output[var_name]
160-
var.attrs["statsV1"] = None
153+
var.attrs.pop("statsV1", None)
161154
if not keep_coordinates:
162-
var.attrs["unitsV1"] = None
155+
var.attrs.pop("unitsV1", None)
163156

164157
# SEG-Y file header
165158
if "segy_file_header" in ds_output.variables:
166159
segy_file_header = ds_output["segy_file_header"]
167160
if segy_file_header is not None:
168-
segy_file_header.attrs["textHeader"] = None
169-
segy_file_header.attrs["binaryHeader"] = None
170-
segy_file_header.attrs["rawBinaryHeader"] = None
161+
segy_file_header.attrs.pop("textHeader", None)
162+
segy_file_header.attrs.pop("binaryHeader", None)
163+
segy_file_header.attrs.pop("rawBinaryHeader", None)
171164

172165
if output_path is not None:
173166
to_mdio(ds_output, output_path=output_path, mode="w", compute=True)

tests/integration/test_segy_roundtrip_teapot.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import numpy.testing as npt
1111
import pytest
1212
from tests.integration.testing_helpers import UNITS_METER
13+
from tests.integration.testing_helpers import UNITS_MILLISECOND
1314
from tests.integration.testing_helpers import UNITS_NONE
14-
from tests.integration.testing_helpers import UNITS_SECOND
1515
from tests.integration.testing_helpers import get_inline_header_values
1616
from tests.integration.testing_helpers import get_teapot_segy_spec
1717
from tests.integration.testing_helpers import get_values
@@ -159,7 +159,7 @@ def test_teapot_import(
159159
NOTE: This test must be executed before the 'TestReader' and 'TestExport' tests.
160160
"""
161161
unit_aware_template = TemplateRegistry().get("PostStack3DTime")
162-
unit_aware_template.add_units({"time": UNITS_SECOND})
162+
unit_aware_template.add_units({"time": UNITS_MILLISECOND})
163163
unit_aware_template.add_units({"cdp_x": UNITS_METER})
164164
unit_aware_template.add_units({"cdp_y": UNITS_METER})
165165
segy_to_mdio(
@@ -227,7 +227,9 @@ def test_grid(self, teapot_mdio_tmp: Path, teapot_segy_spec: SegySpec) -> None:
227227
validate_xr_variable(
228228
ds, "crossline", {"crossline": 188}, UNITS_NONE, np.int32, False, range(1, 189), get_values
229229
)
230-
validate_xr_variable(ds, "time", {"time": 1501}, UNITS_SECOND, np.int32, False, range(0, 3002, 2), get_values)
230+
validate_xr_variable(
231+
ds, "time", {"time": 1501}, UNITS_MILLISECOND, np.int32, False, range(0, 3002, 2), get_values
232+
)
231233

232234
# Validate the non-dimensional coordinate variables
233235
validate_xr_variable(ds, "cdp_x", {"inline": 345, "crossline": 188}, UNITS_METER, np.float64)

tests/integration/test_z_create_empty.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@
2929
from xarray import Dataset as xr_Dataset
3030

3131

32-
from tests.integration.testing_helpers import UNITS_FEET_PER_SECOND
33-
from tests.integration.testing_helpers import UNITS_FOOT
3432
from tests.integration.testing_helpers import UNITS_METER
3533
from tests.integration.testing_helpers import UNITS_METERS_PER_SECOND
34+
from tests.integration.testing_helpers import UNITS_MILLISECOND
3635
from tests.integration.testing_helpers import UNITS_NONE
3736
from tests.integration.testing_helpers import UNITS_SECOND
3837
from tests.integration.testing_helpers import get_teapot_segy_spec
@@ -51,34 +50,24 @@
5150
from mdio.core import Dimension
5251

5352

54-
class PostStack3DVelocityTemplate(Seismic3DPostStackTemplate):
53+
class PostStack3DVelocityMetricTemplate(Seismic3DPostStackTemplate):
5554
"""Custom template that uses 'velocity' as the default variable name instead of 'amplitude'."""
5655

5756
@property
5857
def _default_variable_name(self) -> str:
5958
"""Override the default variable name."""
6059
return "velocity"
6160

62-
def __init__(self, data_domain: str, is_metric: bool) -> None:
61+
def __init__(self, data_domain: str) -> None:
6362
super().__init__(data_domain)
64-
if is_metric:
65-
self._units.update(
66-
{
67-
"time": UNITS_SECOND,
68-
"cdp_x": UNITS_METER,
69-
"cdp_y": UNITS_METER,
70-
"velocity": UNITS_METERS_PER_SECOND,
71-
}
72-
)
73-
else:
74-
self._units.update(
75-
{
76-
"time": UNITS_SECOND,
77-
"cdp_x": UNITS_FOOT,
78-
"cdp_y": UNITS_FOOT,
79-
"velocity": UNITS_FEET_PER_SECOND,
80-
}
81-
)
63+
self._units.update(
64+
{
65+
"time": UNITS_MILLISECOND,
66+
"cdp_x": UNITS_METER,
67+
"cdp_y": UNITS_METER,
68+
"velocity": UNITS_METERS_PER_SECOND,
69+
}
70+
)
8271

8372
@property
8473
def _name(self) -> str:
@@ -95,25 +84,25 @@ def _create_empty_mdio(cls, create_headers: bool, output_path: Path, overwrite:
9584
"""Create a temporary empty MDIO file for testing."""
9685
# Create the grid with the specified dimensions
9786
dims = [
98-
Dimension(name="inline", coords=range(1, 346, 1)), # 100-300 with step 1
99-
Dimension(name="crossline", coords=range(1, 189, 1)), # 1000-1600 with step 2
100-
Dimension(name="time", coords=range(0, 3002, 2)), # 0-3 seconds 4ms sample rate
87+
Dimension(name="inline", coords=range(1, 346, 1)),
88+
Dimension(name="crossline", coords=range(1, 189, 1)),
89+
Dimension(name="time", coords=range(0, 3002, 2)),
10190
]
10291

10392
# If later on, we want to export to SEG-Y, we need to provide the trace header spec.
10493
# The HeaderSpec can be either standard or customized.
10594
headers = get_teapot_segy_spec().trace.header if create_headers else None
10695
# Create an empty MDIO v1 metric post-stack 3D time velocity dataset
10796
return create_empty(
108-
mdio_template=PostStack3DVelocityTemplate(data_domain="time", is_metric=True),
97+
mdio_template=PostStack3DVelocityMetricTemplate(data_domain="time"),
10998
dimensions=dims,
11099
output_path=output_path,
111100
headers=headers,
112101
overwrite=overwrite,
113102
)
114103

115104
@classmethod
116-
def validate_teapod_dataset_metadata(cls, ds: xr_Dataset, is_velocity: bool) -> None:
105+
def validate_teapot_dataset_metadata(cls, ds: xr_Dataset, is_velocity: bool) -> None:
117106
"""Validate the dataset metadata."""
118107
if is_velocity:
119108
assert ds.name == "PostStack3DVelocityTime"
@@ -137,7 +126,6 @@ def validate_teapod_dataset_metadata(cls, ds: xr_Dataset, is_velocity: bool) ->
137126

138127
# Check that createdOn exists
139128
assert "createdOn" in actual_attrs_json
140-
assert actual_attrs_json["createdOn"] is not None
141129

142130
# Validate template attributes
143131
attributes = ds.attrs["attributes"]
@@ -152,7 +140,7 @@ def validate_teapod_dataset_metadata(cls, ds: xr_Dataset, is_velocity: bool) ->
152140
assert attributes["gatherType"] == "stacked"
153141

154142
@classmethod
155-
def validate_teapod_dataset_variables(
143+
def validate_teapot_dataset_variables(
156144
cls, ds: xr_Dataset, header_dtype: np.dtype | None, is_velocity: bool
157145
) -> None:
158146
"""Validate an empty MDIO dataset structure and content."""
@@ -164,7 +152,9 @@ def validate_teapod_dataset_variables(
164152
validate_xr_variable(
165153
ds, "crossline", {"crossline": 188}, UNITS_NONE, np.int32, False, range(1, 189), get_values
166154
)
167-
validate_xr_variable(ds, "time", {"time": 1501}, UNITS_SECOND, np.int32, False, range(0, 3002, 2), get_values)
155+
validate_xr_variable(
156+
ds, "time", {"time": 1501}, UNITS_MILLISECOND, np.int32, False, range(0, 3002, 2), get_values
157+
)
168158

169159
# Validate the non-dimensional coordinate variables (should be empty for empty dataset)
170160
validate_xr_variable(ds, "cdp_x", {"inline": 345, "crossline": 188}, UNITS_METER, np.float64)
@@ -183,7 +173,7 @@ def validate_teapod_dataset_variables(
183173
# Validate the trace mask (should be all True for empty dataset)
184174
validate_xr_variable(ds, "trace_mask", {"inline": 345, "crossline": 188}, UNITS_NONE, np.bool_)
185175
trace_mask = ds["trace_mask"].values
186-
assert not np.any(trace_mask), "All traces should be marked as dead in empty dataset"
176+
assert not np.any(trace_mask), "Expected all `False` values in `trace_mask` but found `True`."
187177

188178
# Validate the velocity or amplitude data (should be empty)
189179
if is_velocity:
@@ -222,16 +212,16 @@ def mdio_no_headers(self, empty_mdio_dir: Path) -> Path:
222212
def test_dataset_metadata(self, mdio_with_headers: Path) -> None:
223213
"""Test dataset metadata for empty MDIO file."""
224214
ds = open_mdio(mdio_with_headers)
225-
self.validate_teapod_dataset_metadata(ds, is_velocity=True)
215+
self.validate_teapot_dataset_metadata(ds, is_velocity=True)
226216

227217
def test_variables(self, mdio_with_headers: Path, mdio_no_headers: Path) -> None:
228218
"""Test grid validation for empty MDIO file."""
229219
ds = open_mdio(mdio_with_headers)
230220
header_dtype = get_teapot_segy_spec().trace.header.dtype
231-
self.validate_teapod_dataset_variables(ds, header_dtype=header_dtype, is_velocity=True)
221+
self.validate_teapot_dataset_variables(ds, header_dtype=header_dtype, is_velocity=True)
232222

233223
ds = open_mdio(mdio_no_headers)
234-
self.validate_teapod_dataset_variables(ds, header_dtype=None, is_velocity=True)
224+
self.validate_teapot_dataset_variables(ds, header_dtype=None, is_velocity=True)
235225

236226
def test_overwrite_behavior(self, empty_mdio_dir: Path) -> None:
237227
"""Test overwrite parameter behavior in create_empty_mdio."""
@@ -258,9 +248,9 @@ def test_overwrite_behavior(self, empty_mdio_dir: Path) -> None:
258248

259249
# Validate that the MDIO file can be loaded correctly using the helper function
260250
ds = open_mdio(empty_mdio)
261-
self.validate_teapod_dataset_metadata(ds, is_velocity=True)
251+
self.validate_teapot_dataset_metadata(ds, is_velocity=True)
262252
header_dtype = get_teapot_segy_spec().trace.header.dtype
263-
self.validate_teapod_dataset_variables(ds, header_dtype=header_dtype, is_velocity=True)
253+
self.validate_teapot_dataset_variables(ds, header_dtype=header_dtype, is_velocity=True)
264254

265255
# Verify the garbage data was overwritten (should not exist)
266256
assert not garbage_file.exists(), "Garbage file should have been overwritten"
@@ -403,6 +393,6 @@ def test_create_empty_like(self, teapot_mdio_tmp: Path, empty_mdio_dir: Path) ->
403393
)
404394
assert ds is not None
405395

406-
self.validate_teapod_dataset_metadata(ds, is_velocity=False)
396+
self.validate_teapot_dataset_metadata(ds, is_velocity=False)
407397
header_dtype = get_teapot_segy_spec().trace.header.dtype
408-
self.validate_teapod_dataset_variables(ds, header_dtype=header_dtype, is_velocity=False)
398+
self.validate_teapot_dataset_variables(ds, header_dtype=header_dtype, is_velocity=False)

tests/integration/testing_helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
UNITS_NONE = None
2121
UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER)
2222
UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND)
23+
UNITS_MILLISECOND = TimeUnitModel(time=TimeUnitEnum.MILLISECOND)
2324
UNITS_METERS_PER_SECOND = SpeedUnitModel(speed=SpeedUnitEnum.METERS_PER_SECOND)
2425
UNITS_FOOT = LengthUnitModel(length=LengthUnitEnum.FOOT)
2526
UNITS_FEET_PER_SECOND = SpeedUnitModel(speed=SpeedUnitEnum.FEET_PER_SECOND)

0 commit comments

Comments
 (0)