|
| 1 | +"""Unit tests for optimize_access_pattern module.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pytest |
| 9 | +from segy import SegyFactory |
| 10 | +from segy.standards import get_segy_standard |
| 11 | +from zarr.codecs import ZFPY |
| 12 | + |
| 13 | +from mdio import open_mdio |
| 14 | +from mdio import segy_to_mdio |
| 15 | +from mdio.builder.template_registry import get_template |
| 16 | +from mdio.optimize.access_pattern import OptimizedAccessPatternConfig |
| 17 | +from mdio.optimize.access_pattern import optimize_access_patterns |
| 18 | +from mdio.optimize.common import ZfpQuality |
| 19 | + |
| 20 | +if TYPE_CHECKING: |
| 21 | + from pathlib import Path |
| 22 | + |
| 23 | + |
| 24 | +INLINES = np.arange(1, 9) |
| 25 | +CROSSLINES = np.arange(1, 17) |
| 26 | +NUM_SAMPLES = 64 |
| 27 | + |
| 28 | +SPEC = get_segy_standard(1) |
| 29 | + |
| 30 | + |
| 31 | +@pytest.fixture |
| 32 | +def test_segy_path(fake_segy_tmp: Path) -> Path: |
| 33 | + """Create a small synthetic 3D SEG-Y file.""" |
| 34 | + segy_path = fake_segy_tmp / "optimize_ap_test_3d.sgy" |
| 35 | + |
| 36 | + num_traces = len(INLINES) * len(CROSSLINES) |
| 37 | + |
| 38 | + factory = SegyFactory(spec=SPEC, sample_interval=4000, samples_per_trace=NUM_SAMPLES) |
| 39 | + headers = factory.create_trace_header_template(num_traces) |
| 40 | + samples = factory.create_trace_sample_template(num_traces) |
| 41 | + |
| 42 | + headers["inline"] = INLINES.repeat(len(CROSSLINES)) |
| 43 | + headers["crossline"] = np.tile(CROSSLINES, len(INLINES)) |
| 44 | + headers["coordinate_scalar"] = 1 |
| 45 | + |
| 46 | + samples[:] = np.arange(num_traces)[..., None] |
| 47 | + |
| 48 | + with segy_path.open(mode="wb") as fp: |
| 49 | + fp.write(factory.create_textual_header()) |
| 50 | + fp.write(factory.create_binary_header()) |
| 51 | + fp.write(factory.create_traces(headers, samples)) |
| 52 | + |
| 53 | + return segy_path |
| 54 | + |
| 55 | + |
| 56 | +@pytest.fixture |
| 57 | +def mdio_dataset_path(test_segy_path: Path, zarr_tmp: Path) -> Path: |
| 58 | + """Convert synthetic SEG-Y to MDIO.""" |
| 59 | + test_mdio_path = zarr_tmp / "optimize_ap_test_3d.mdio" |
| 60 | + segy_to_mdio( |
| 61 | + segy_spec=SPEC, |
| 62 | + mdio_template=get_template("PostStack3DTime"), |
| 63 | + input_path=test_segy_path, |
| 64 | + output_path=test_mdio_path, |
| 65 | + overwrite=True, |
| 66 | + ) |
| 67 | + return test_mdio_path |
| 68 | + |
| 69 | + |
| 70 | +class TestOptimizeAccessPattern: |
| 71 | + """Tests for optimize_access_pattern module.""" |
| 72 | + |
| 73 | + def test_optimize_access_patterns(self, mdio_dataset_path: str) -> None: |
| 74 | + """Test optimization of access patterns.""" |
| 75 | + conf = OptimizedAccessPatternConfig( |
| 76 | + quality=ZfpQuality.HIGH, |
| 77 | + optimize_dimensions={"time": (512, 512, 4), "inline": (4, 512, 512)}, |
| 78 | + processing_chunks={"inline": 512, "crossline": 512, "time": 512}, |
| 79 | + ) |
| 80 | + ds = open_mdio(mdio_dataset_path) |
| 81 | + optimize_access_patterns(ds, conf) |
| 82 | + |
| 83 | + ds = open_mdio(mdio_dataset_path) |
| 84 | + |
| 85 | + assert "fast_time" in ds.variables |
| 86 | + assert ds["fast_time"].encoding["chunks"] == (512, 512, 4) |
| 87 | + assert isinstance(ds["fast_time"].encoding["serializer"], ZFPY) |
| 88 | + |
| 89 | + assert "inline" in ds.variables |
| 90 | + assert ds["fast_inline"].encoding["chunks"] == (4, 512, 512) |
| 91 | + assert isinstance(ds["fast_inline"].encoding["serializer"], ZFPY) |
| 92 | + |
| 93 | + def test_missing_default_variable_name(self, mdio_dataset_path: str) -> None: |
| 94 | + """Test case where default variable name is missing from dataset attributes.""" |
| 95 | + conf = OptimizedAccessPatternConfig( |
| 96 | + quality=ZfpQuality.HIGH, |
| 97 | + optimize_dimensions={"time": (512, 512, 4)}, |
| 98 | + processing_chunks={"inline": 512, "crossline": 512, "time": 512}, |
| 99 | + ) |
| 100 | + ds = open_mdio(mdio_dataset_path) |
| 101 | + del ds.attrs["attributes"] |
| 102 | + |
| 103 | + with pytest.raises(ValueError, match="Default variable name is missing from dataset attributes"): |
| 104 | + optimize_access_patterns(ds, conf) |
| 105 | + |
| 106 | + def test_missing_stats(self, mdio_dataset_path: str) -> None: |
| 107 | + """Test case where statistics are missing from default variable.""" |
| 108 | + conf = OptimizedAccessPatternConfig( |
| 109 | + quality=ZfpQuality.HIGH, |
| 110 | + optimize_dimensions={"time": (512, 512, 4)}, |
| 111 | + processing_chunks={"inline": 512, "crossline": 512, "time": 512}, |
| 112 | + ) |
| 113 | + ds = open_mdio(mdio_dataset_path) |
| 114 | + del ds["amplitude"].attrs["statsV1"] |
| 115 | + |
| 116 | + with pytest.raises(ValueError, match="Statistics are missing from data"): |
| 117 | + optimize_access_patterns(ds, conf) |
| 118 | + |
| 119 | + def test_invalid_optimize_access_patterns(self, mdio_dataset_path: str) -> None: |
| 120 | + """Test when optimize_dimensions contains invalid dimensions.""" |
| 121 | + conf = OptimizedAccessPatternConfig( |
| 122 | + quality=ZfpQuality.HIGH, |
| 123 | + optimize_dimensions={"time": (512, 512, 4), "invalid": (4, 512, 512)}, |
| 124 | + processing_chunks={"inline": 512, "crossline": 512, "time": 512}, |
| 125 | + ) |
| 126 | + ds = open_mdio(mdio_dataset_path) |
| 127 | + |
| 128 | + with pytest.raises(ValueError, match="Dimension to optimize 'invalid' not found"): |
| 129 | + optimize_access_patterns(ds, conf) |
0 commit comments