Skip to content

Commit efcc1d1

Browse files
committed
add unit tests for optimize_access_pattern module
1 parent f9f06a7 commit efcc1d1

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Unit tests for optimize_access_pattern module."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
import pytest
9+
from segy import SegyFactory
10+
from segy.standards import get_segy_standard
11+
from zarr.codecs import ZFPY
12+
13+
from mdio import open_mdio
14+
from mdio import segy_to_mdio
15+
from mdio.builder.template_registry import get_template
16+
from mdio.optimize.access_pattern import OptimizedAccessPatternConfig
17+
from mdio.optimize.access_pattern import optimize_access_patterns
18+
from mdio.optimize.common import ZfpQuality
19+
20+
if TYPE_CHECKING:
21+
from pathlib import Path
22+
23+
24+
INLINES = np.arange(1, 9)
25+
CROSSLINES = np.arange(1, 17)
26+
NUM_SAMPLES = 64
27+
28+
SPEC = get_segy_standard(1)
29+
30+
31+
@pytest.fixture
32+
def test_segy_path(fake_segy_tmp: Path) -> Path:
33+
"""Create a small synthetic 3D SEG-Y file."""
34+
segy_path = fake_segy_tmp / "optimize_ap_test_3d.sgy"
35+
36+
num_traces = len(INLINES) * len(CROSSLINES)
37+
38+
factory = SegyFactory(spec=SPEC, sample_interval=4000, samples_per_trace=NUM_SAMPLES)
39+
headers = factory.create_trace_header_template(num_traces)
40+
samples = factory.create_trace_sample_template(num_traces)
41+
42+
headers["inline"] = INLINES.repeat(len(CROSSLINES))
43+
headers["crossline"] = np.tile(CROSSLINES, len(INLINES))
44+
headers["coordinate_scalar"] = 1
45+
46+
samples[:] = np.arange(num_traces)[..., None]
47+
48+
with segy_path.open(mode="wb") as fp:
49+
fp.write(factory.create_textual_header())
50+
fp.write(factory.create_binary_header())
51+
fp.write(factory.create_traces(headers, samples))
52+
53+
return segy_path
54+
55+
56+
@pytest.fixture
57+
def mdio_dataset_path(test_segy_path: Path, zarr_tmp: Path) -> Path:
58+
"""Convert synthetic SEG-Y to MDIO."""
59+
test_mdio_path = zarr_tmp / "optimize_ap_test_3d.mdio"
60+
segy_to_mdio(
61+
segy_spec=SPEC,
62+
mdio_template=get_template("PostStack3DTime"),
63+
input_path=test_segy_path,
64+
output_path=test_mdio_path,
65+
overwrite=True,
66+
)
67+
return test_mdio_path
68+
69+
70+
class TestOptimizeAccessPattern:
71+
"""Tests for optimize_access_pattern module."""
72+
73+
def test_optimize_access_patterns(self, mdio_dataset_path: str) -> None:
74+
"""Test optimization of access patterns."""
75+
conf = OptimizedAccessPatternConfig(
76+
quality=ZfpQuality.HIGH,
77+
optimize_dimensions={"time": (512, 512, 4), "inline": (4, 512, 512)},
78+
processing_chunks={"inline": 512, "crossline": 512, "time": 512},
79+
)
80+
ds = open_mdio(mdio_dataset_path)
81+
optimize_access_patterns(ds, conf)
82+
83+
ds = open_mdio(mdio_dataset_path)
84+
85+
assert "fast_time" in ds.variables
86+
assert ds["fast_time"].encoding["chunks"] == (512, 512, 4)
87+
assert isinstance(ds["fast_time"].encoding["serializer"], ZFPY)
88+
89+
assert "inline" in ds.variables
90+
assert ds["fast_inline"].encoding["chunks"] == (4, 512, 512)
91+
assert isinstance(ds["fast_inline"].encoding["serializer"], ZFPY)
92+
93+
def test_missing_default_variable_name(self, mdio_dataset_path: str) -> None:
94+
"""Test case where default variable name is missing from dataset attributes."""
95+
conf = OptimizedAccessPatternConfig(
96+
quality=ZfpQuality.HIGH,
97+
optimize_dimensions={"time": (512, 512, 4)},
98+
processing_chunks={"inline": 512, "crossline": 512, "time": 512},
99+
)
100+
ds = open_mdio(mdio_dataset_path)
101+
del ds.attrs["attributes"]
102+
103+
with pytest.raises(ValueError, match="Default variable name is missing from dataset attributes"):
104+
optimize_access_patterns(ds, conf)
105+
106+
def test_missing_stats(self, mdio_dataset_path: str) -> None:
107+
"""Test case where statistics are missing from default variable."""
108+
conf = OptimizedAccessPatternConfig(
109+
quality=ZfpQuality.HIGH,
110+
optimize_dimensions={"time": (512, 512, 4)},
111+
processing_chunks={"inline": 512, "crossline": 512, "time": 512},
112+
)
113+
ds = open_mdio(mdio_dataset_path)
114+
del ds["amplitude"].attrs["statsV1"]
115+
116+
with pytest.raises(ValueError, match="Statistics are missing from data"):
117+
optimize_access_patterns(ds, conf)
118+
119+
def test_invalid_optimize_access_patterns(self, mdio_dataset_path: str) -> None:
120+
"""Test when optimize_dimensions contains invalid dimensions."""
121+
conf = OptimizedAccessPatternConfig(
122+
quality=ZfpQuality.HIGH,
123+
optimize_dimensions={"time": (512, 512, 4), "invalid": (4, 512, 512)},
124+
processing_chunks={"inline": 512, "crossline": 512, "time": 512},
125+
)
126+
ds = open_mdio(mdio_dataset_path)
127+
128+
with pytest.raises(ValueError, match="Dimension to optimize 'invalid' not found"):
129+
optimize_access_patterns(ds, conf)

0 commit comments

Comments
 (0)