Skip to content

Commit 9c2fb87

Browse files
committed
create_empty_like
1 parent d6b8fd2 commit 9c2fb87

File tree

8 files changed

+346
-27
lines changed

8 files changed

+346
-27
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ dev = [
5555
"pre-commit-hooks>=6.0.0",
5656
"pytest>=8.4.2",
5757
"pytest-dependency>=0.6.0",
58+
"pytest-order>=1.3.0",
5859
"typeguard>=4.4.4",
5960
"xdoctest[colors]>=1.3.0",
6061
"Pygments>=2.19.2"

src/mdio/creators/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""MDIO Data creation API."""
2+
3+
from mdio.creators.mdio import create_empty_like
4+
5+
__all__ = ["create_empty_like"]

src/mdio/creators/mdio.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Creating MDIO v1 datasets."""
2+
3+
from __future__ import annotations
4+
5+
from datetime import UTC
6+
from datetime import datetime
7+
from typing import TYPE_CHECKING
8+
9+
from mdio.api.io import _normalize_path
10+
from mdio.api.io import open_mdio
11+
from mdio.api.io import to_mdio
12+
from mdio.builder.template_registry import TemplateRegistry
13+
from mdio.builder.xarray_builder import to_xarray_dataset
14+
from mdio.converters.segy import populate_dim_coordinates
15+
from mdio.converters.type_converter import to_structured_type
16+
from mdio.core.grid import Grid
17+
18+
if TYPE_CHECKING:
19+
from pathlib import Path
20+
21+
from segy.schema import HeaderSpec
22+
from upath import UPath
23+
from xarray import Dataset as xr_Dataset
24+
25+
from mdio.builder.schemas import Dataset
26+
from mdio.core.dimension import Dimension
27+
28+
29+
def create_empty_like( # noqa PLR0913
30+
input_path: UPath | Path | str,
31+
output_path: UPath | Path | str,
32+
keep_coordinates: bool = False,
33+
overwrite: bool = False,
34+
) -> xr_Dataset:
35+
"""A function that creates an empty MDIO v1 file with the same structure as an existing one.
36+
37+
Args:
38+
input_path: The path of the input MDIO file.
39+
output_path: The path of the output MDIO file.
40+
If None, the output will not be written to disk.
41+
keep_coordinates: Whether to keep the coordinates in the output file.
42+
overwrite: Whether to overwrite the output file if it exists.
43+
44+
Returns:
45+
The output MDIO dataset.
46+
47+
Raises:
48+
FileExistsError: If the output location already exists and overwrite is False.
49+
"""
50+
input_path = _normalize_path(input_path)
51+
output_path = _normalize_path(output_path) if output_path is not None else None
52+
53+
if not overwrite and output_path is not None and output_path.exists():
54+
err = f"Output location '{output_path.as_posix()}' exists. Set `overwrite=True` if intended."
55+
raise FileExistsError(err)
56+
57+
ds = open_mdio(input_path)
58+
59+
# Create a copy with the same structure but no data or,
60+
# optionally, coordinates
61+
ds_output = ds.copy(data=None).reset_coords(drop=not keep_coordinates)
62+
63+
# Dataset
64+
# Keep the name (which is the same as the used template name) and the original API version
65+
# ds_output.attrs["name"]
66+
# ds_output.attrs["apiVersion"]
67+
ds_output.attrs["createdOn"] = datetime.now(UTC)
68+
69+
# Coordinates
70+
if not keep_coordinates:
71+
for coord_name in ds_output.coords:
72+
ds_output[coord_name].attrs["unitsV1"] = None
73+
74+
# MDIO attributes
75+
attr = ds_output.attrs["attributes"]
76+
if attr is not None:
77+
attr.pop("gridOverrides", None) # Empty dataset should not have gridOverrides
78+
# Keep the original values for the following attributes
79+
# attr["defaultVariableName"]
80+
# attr["surveyType"]
81+
# attr["gatherType"]
82+
83+
# "All traces should be marked as dead in empty dataset"
84+
if "trace_mask" in ds_output.variables:
85+
ds_output["trace_mask"][:] = False
86+
87+
# Data variable
88+
var_name = attr["defaultVariableName"]
89+
var = ds_output[var_name]
90+
var.attrs["statsV1"] = None
91+
if not keep_coordinates:
92+
var.attrs["unitsV1"] = None
93+
94+
# SEG-Y file header
95+
if "segy_file_header" in ds_output.variables:
96+
segy_file_header = ds_output["segy_file_header"]
97+
if segy_file_header is not None:
98+
segy_file_header.attrs["textHeader"] = None
99+
segy_file_header.attrs["binaryHeader"] = None
100+
segy_file_header.attrs["rawBinaryHeader"] = None
101+
102+
if output_path is not None:
103+
to_mdio(ds_output, output_path=output_path, mode="w", compute=True)
104+
105+
return ds_output

tests/conftest.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ def zarr_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
4646
"""Make a temp file for the output MDIO."""
4747
return tmp_path_factory.mktemp(r"mdio")
4848

49+
@pytest.fixture(scope="session")
50+
def teapot_mdio_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
51+
"""Make a temp file for the output MDIO."""
52+
return tmp_path_factory.mktemp(r"teapot.mdio")
53+
54+
55+
@pytest.fixture(scope="module")
56+
def mdio_4d_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
57+
"""Make a temp file for the output MDIO."""
58+
return tmp_path_factory.mktemp(r"tmp_4d.mdio")
59+
4960

5061
@pytest.fixture(scope="module")
5162
def zarr_tmp2(tmp_path_factory: pytest.TempPathFactory) -> Path: # pragma: no cover - used by disabled test
@@ -58,3 +69,31 @@ def segy_export_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
5869
"""Make a temp file for the round-trip IBM SEG-Y."""
5970
tmp_dir = tmp_path_factory.mktemp("segy")
6071
return tmp_dir / "teapot_roundtrip.segy"
72+
73+
74+
@pytest.fixture(scope="class")
75+
def empty_mdio_with_headers(tmp_path_factory: pytest.TempPathFactory) -> Path:
76+
"""Make a temp file for empty MDIO testing."""
77+
path = tmp_path_factory.mktemp(r"empty_with_headers.mdio")
78+
return path
79+
80+
81+
# @pytest.fixture(scope="session")
82+
# def tmp_path_factory() -> pytest.TempPathFactory:
83+
# """Custom tmp_path_factory implementation for local debugging."""
84+
# from pathlib import Path # noqa: PLC0415
85+
86+
# class DebugTempPathFactory:
87+
# def __init__(self) -> None:
88+
# pass
89+
90+
# def mktemp(self, basename: str, numbered: bool = True) -> Path:
91+
# _ = numbered
92+
# path = self.getbasetemp() / basename
93+
# path.mkdir(parents=True, exist_ok=True)
94+
# return path
95+
96+
# def getbasetemp(self) -> Path:
97+
# return Path("tmp")
98+
99+
# return DebugTempPathFactory()
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""Test for create_empty_mdio function."""
2+
3+
from __future__ import annotations
4+
5+
import math
6+
from typing import TYPE_CHECKING
7+
8+
import numpy as np
9+
import pytest
10+
from segy.schema import HeaderField
11+
from segy.schema import HeaderSpec
12+
from segy.schema import ScalarType
13+
from segy.standards import get_segy_standard
14+
15+
from mdio.builder.schemas.v1.units import LengthUnitEnum
16+
from mdio.builder.schemas.v1.units import LengthUnitModel
17+
from mdio.builder.schemas.v1.units import SpeedUnitEnum
18+
from mdio.builder.schemas.v1.units import SpeedUnitModel
19+
from mdio.builder.schemas.v1.units import TimeUnitEnum
20+
from mdio.builder.schemas.v1.units import TimeUnitModel
21+
22+
if TYPE_CHECKING:
23+
from pathlib import Path
24+
25+
from xarray import Dataset as xr_Dataset
26+
27+
28+
from tests.integration.test_segy_roundtrip_teapot import get_teapot_segy_spec
29+
from tests.integration.testing_helpers import get_values
30+
from tests.integration.testing_helpers import validate_variable
31+
32+
from mdio import __version__
33+
from mdio.api.io import open_mdio
34+
from mdio.api.io import to_mdio
35+
from mdio.builder.schemas.v1.stats import CenteredBinHistogram
36+
from mdio.builder.schemas.v1.stats import SummaryStatistics
37+
from mdio.converters.mdio import mdio_to_segy
38+
from mdio.core import Dimension
39+
from mdio.creators.mdio import create_empty_like
40+
41+
42+
@pytest.mark.order(1000)
43+
class TestCreateEmptyPostStack3DTimeMdio:
44+
"""Tests for create_empty_mdio function."""
45+
46+
@classmethod
47+
def _get_customized_v10_trace_header_spec(cls) -> HeaderSpec:
48+
"""Get the header spec for the MDIO dataset."""
49+
trace_header_fields = [
50+
HeaderField(name="inline", byte=17, format=ScalarType.INT32),
51+
HeaderField(name="crossline", byte=13, format=ScalarType.INT32),
52+
HeaderField(name="cdp_x", byte=181, format=ScalarType.INT32),
53+
HeaderField(name="cdp_y", byte=185, format=ScalarType.INT32),
54+
HeaderField(name="coordinate_scalar", byte=71, format=ScalarType.INT16),
55+
]
56+
hs: HeaderSpec = get_segy_standard(1.0).trace.header
57+
hs.customize(fields=trace_header_fields)
58+
return hs
59+
60+
@classmethod
61+
def _validate_dataset_metadata(cls, ds: xr_Dataset) -> None:
62+
"""Validate the dataset metadata."""
63+
# Check basic metadata attributes
64+
expected_attrs = {
65+
"apiVersion": __version__,
66+
"name": "PostStack3DTime",
67+
}
68+
actual_attrs_json = ds.attrs
69+
70+
# Compare one by one due to ever changing createdOn
71+
for key, value in expected_attrs.items():
72+
assert key in actual_attrs_json
73+
if key == "createdOn":
74+
assert actual_attrs_json[key] is not None
75+
else:
76+
assert actual_attrs_json[key] == value
77+
78+
# Check that createdOn exists
79+
assert "createdOn" in actual_attrs_json
80+
assert actual_attrs_json["createdOn"] is not None
81+
82+
# Validate template attributes
83+
attributes = ds.attrs["attributes"]
84+
assert attributes is not None
85+
assert len(attributes) == 3
86+
# Validate all attributes provided by the abstract template
87+
assert attributes["defaultVariableName"] == "amplitude"
88+
assert attributes["surveyType"] == "3D"
89+
assert attributes["gatherType"] == "stacked"
90+
assert "gridOverrides" not in attributes, "Empty dataset should not have gridOverrides"
91+
92+
@classmethod
93+
def _validate_empty_mdio_dataset(cls, ds: xr_Dataset, has_headers: bool) -> None:
94+
"""Validate an empty MDIO dataset structure and content."""
95+
# Check that the dataset has the expected shape
96+
assert ds.sizes == {"inline": 345, "crossline": 188, "time": 1501}
97+
98+
# Validate the dimension coordinate variables
99+
validate_variable(ds, "inline", (345,), ("inline",), np.int32, range(1, 346), get_values)
100+
validate_variable(ds, "crossline", (188,), ("crossline",), np.int32, range(1, 189), get_values)
101+
validate_variable(ds, "time", (1501,), ("time",), np.int32, range(0, 3002, 2), get_values)
102+
103+
# Validate the non-dimensional coordinate variables (should be empty for empty dataset)
104+
validate_variable(ds, "cdp_x", (345, 188), ("inline", "crossline"), np.float64, None, None)
105+
validate_variable(ds, "cdp_y", (345, 188), ("inline", "crossline"), np.float64, None, None)
106+
107+
if has_headers:
108+
segy_spec = get_teapot_segy_spec()
109+
# Validate the headers (should be empty for empty dataset)
110+
# Infer the dtype from segy_spec and ignore endianness
111+
header_dtype = segy_spec.trace.header.dtype.newbyteorder("native")
112+
validate_variable(ds, "headers", (345, 188), ("inline", "crossline"), header_dtype, None, None)
113+
validate_variable(ds, "segy_file_header", (), (), np.dtype("U1"), None, None)
114+
115+
assert "segy_file_header" in ds.variables
116+
assert ds["segy_file_header"].attrs.get("textHeader", None) is None, (
117+
"TextHeader should be empty for empty dataset"
118+
)
119+
assert ds["segy_file_header"].attrs.get("binaryHeader", None) is None, (
120+
"BinaryHeader should be empty for empty dataset"
121+
)
122+
assert ds["segy_file_header"].attrs.get("rawBinaryHeader", None) is None, (
123+
"RawBinaryHeader should be empty for empty dataset"
124+
)
125+
else:
126+
assert "headers" not in ds.variables
127+
assert "segy_file_header" not in ds.variables
128+
129+
# Validate the trace mask
130+
validate_variable(ds, "trace_mask", (345, 188), ("inline", "crossline"), np.bool_, None, None)
131+
trace_mask = ds["trace_mask"].values
132+
assert not np.any(trace_mask), "All traces should be marked as dead in empty dataset"
133+
134+
# Validate the amplitude data (should be empty)
135+
validate_variable(ds, "amplitude", (345, 188, 1501), ("inline", "crossline", "time"), np.float32, None, None)
136+
assert ds["amplitude"].attrs.get("statsV1", None) is None, "StatsV1 should be empty for empty dataset"
137+
assert ds["amplitude"].attrs.get("unitsV1", None) is None, "UnitsV1 should be empty for empty dataset"
138+
139+
@pytest.mark.order(1001)
140+
@pytest.mark.dependency
141+
def test_create_empty_like(self, teapot_mdio_tmp: Path, empty_mdio_with_headers: Path) -> None:
142+
"""Create an empty MDIO file like the input file."""
143+
_ = empty_mdio_with_headers
144+
ds = create_empty_like(
145+
input_path=teapot_mdio_tmp,
146+
output_path=None, # We don't want to write to disk for now
147+
keep_coordinates=True,
148+
overwrite=True,
149+
)
150+
self._validate_dataset_metadata(ds)
151+
self._validate_empty_mdio_dataset(ds, has_headers=True)

0 commit comments

Comments
 (0)