diff --git a/pyproject.toml b/pyproject.toml index 19882bf9..5ca6787b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,4 +82,4 @@ dev = [ "pre-commit>=4.2.0", "ruff>=0.11.5", "tomli>=2.2.1", -] \ No newline at end of file +] diff --git a/src/quantem/__init__.py b/src/quantem/__init__.py index ba70f629..8db92d1f 100644 --- a/src/quantem/__init__.py +++ b/src/quantem/__init__.py @@ -9,6 +9,7 @@ from quantem.core import visualization as visualization from quantem import imaging as imaging +from quantem import diffraction as diffraction from quantem import diffractive_imaging as diffractive_imaging __version__ = version("quantem") diff --git a/src/quantem/core/datastructures/__init__.py b/src/quantem/core/datastructures/__init__.py index dfb5b47a..ac8f3d64 100644 --- a/src/quantem/core/datastructures/__init__.py +++ b/src/quantem/core/datastructures/__init__.py @@ -2,6 +2,7 @@ from quantem.core.datastructures.vector import Vector as Vector from quantem.core.datastructures.dataset4dstem import Dataset4dstem as Dataset4dstem +from quantem.core.datastructures.polar4dstem import Polar4dstem as Polar4dstem from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d from quantem.core.datastructures.dataset2d import Dataset2d as Dataset2d diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index 94744978..1ef22b74 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -191,6 +191,11 @@ def sampling(self) -> NDArray: def sampling(self, value: NDArray | tuple | list | float | int) -> None: self._sampling = validate_ndinfo(value, self.ndim, "sampling") + @property + def origin_units(self) -> NDArray: + # Origin expressed in physical units: origin * sampling + return np.asarray(self.origin) * np.asarray(self.sampling) + @property def units(self) -> list[str]: return self._units @@ -368,6 +373,35 @@ def _copy_custom_attributes(self, new_dataset: Self) -> None: # Skip attributes that can't be copied pass + def coords(self, axis: int) -> Any: + """ + Coordinate array for a given axis in pixel units. + + coords(d) = arange(shape[d]) - origin[d] + """ + axis = int(axis) + if axis < 0 or axis >= self.ndim: + raise ValueError(f"axis {axis} out of bounds for ndim={self.ndim}") + + xp = self._xp + n = int(self.shape[axis]) + origin_d = float(np.asarray(self.origin)[axis]) + + return xp.arange(n, dtype=float) - origin_d + + def coords_units(self, axis: int) -> Any: + """ + Coordinate array for a given axis in physical units. + + coords_units(d) = (arange(shape[d]) - origin[d]) * sampling[d] + """ + axis = int(axis) + if axis < 0 or axis >= self.ndim: + raise ValueError(f"axis {axis} out of bounds for ndim={self.ndim}") + + sampling_d = float(np.asarray(self.sampling)[axis]) + return self.coords(axis) * sampling_d + def mean(self, axes: int | tuple[int, ...] | None = None) -> Any: """ Computes and returns mean of the data array. diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index 67bdc279..4a628eb7 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -9,6 +9,7 @@ from quantem.core.datastructures.dataset2d import Dataset2d from quantem.core.datastructures.dataset4d import Dataset4d +from quantem.core.datastructures.polar4dstem import dataset4dstem_polar_transform from quantem.core.utils.validators import ensure_valid_array from quantem.core.visualization import show_2d from quantem.core.visualization.visualization_utils import ScalebarConfig @@ -78,7 +79,7 @@ def __init__( _token : object | None, optional Token to prevent direct instantiation, by default None """ - mdata_keys_4dstem = ["r_to_q_rotation_cw_deg", "ellipticity"] + mdata_keys_4dstem = ["q_to_r_rotation_ccw_deg", "q_transpose", "ellipticity"] for k in mdata_keys_4dstem: if k not in metadata.keys(): metadata[k] = None @@ -799,3 +800,5 @@ def median_filter_masked_pixels(self, mask: np.ndarray, kernel_width: int = 3): self.array[:, :, index_x, index_y] = np.median( self.array[:, :, x_min:x_max, y_min:y_max], axis=(2, 3) ) + + polar_transform = dataset4dstem_polar_transform diff --git a/src/quantem/core/datastructures/polar4dstem.py b/src/quantem/core/datastructures/polar4dstem.py new file mode 100644 index 00000000..6619af5c --- /dev/null +++ b/src/quantem/core/datastructures/polar4dstem.py @@ -0,0 +1,237 @@ +import numpy as np +from numpy.typing import NDArray +from typing import Any, TYPE_CHECKING +from scipy.ndimage import map_coordinates + +if TYPE_CHECKING: + from .dataset4dstem import Dataset4dstem + +from quantem.core.datastructures.dataset4d import Dataset4d + + +class Polar4dstem(Dataset4d): + """4D-STEM dataset in polar coordinates (scan_y, scan_x, phi, r).""" + + def __init__( + self, + array: NDArray | Any, + name: str, + origin: NDArray | tuple | list | float | int, + sampling: NDArray | tuple | list | float | int, + units: list[str] | tuple | list, + signal_units: str = "arb. units", + metadata: dict | None = None, + _token: object | None = None, + ): + if metadata is None: + metadata = {} + mdata_keys_polar = [ + "polar_radial_min", + "polar_radial_max", + "polar_radial_step", + "polar_num_annular_bins", + "polar_two_fold_rotation_symmetry", + "polar_origin_row", + "polar_origin_col", + "polar_ellipse_params", + ] + for k in mdata_keys_polar: + if k not in metadata: + metadata[k] = None + super().__init__( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + metadata=metadata, + _token=_token, + ) + + @classmethod + def from_array( + cls, + array: NDArray | Any, + name: str | None = None, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + metadata: dict | None = None, + ) -> "Polar4dstem": + array = np.asarray(array) + if array.ndim != 4: + raise ValueError("Polar4dstem.from_array expects a 4D array.") + if origin is None: + origin = np.zeros(4, dtype=float) + if sampling is None: + sampling = np.ones(4, dtype=float) + if units is None: + units = ["pixels", "pixels", "deg", "pixels"] + if metadata is None: + metadata = {} + return cls( + array=array, + name=name if name is not None else "Polar 4D-STEM dataset", + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + metadata=metadata, + _token=cls._token, + ) + + @property + def n_phi(self) -> int: + return int(self.array.shape[2]) + + @property + def n_r(self) -> int: + return int(self.array.shape[3]) + + +def _precompute_polar_coords( + ny: int, + nx: int, + origin_row: float, + origin_col: float, + ellipse_params: tuple[float, float, float] | None, + num_annular_bins: int, + radial_min: float, + radial_max: float | None, + radial_step: float, + two_fold_rotation_symmetry: bool, +) -> tuple[NDArray, NDArray, NDArray, float]: + origin_row = float(origin_row) + origin_col = float(origin_col) + if radial_step <= 0: + raise ValueError("radial_step must be > 0.") + if num_annular_bins < 1: + raise ValueError("num_annular_bins must be >= 1.") + if radial_max is None: + r_row_pos = origin_row + r_row_neg = (ny - 1) - origin_row + r_col_pos = origin_col + r_col_neg = (nx - 1) - origin_col + radial_max_eff = float(min(r_row_pos, r_row_neg, r_col_pos, r_col_neg)) + else: + radial_max_eff = float(radial_max) + if radial_max_eff <= radial_min: + radial_max_eff = radial_min + radial_step + radial_bins = np.arange(radial_min, radial_max_eff, radial_step, dtype=np.float64) + if radial_bins.size == 0: + radial_bins = np.array([radial_min], dtype=np.float64) + if two_fold_rotation_symmetry: + phi_range = np.pi + else: + phi_range = 2.0 * np.pi + phi_bins = np.linspace(0.0, phi_range, num_annular_bins, endpoint=False, dtype=np.float64) + phi_grid, r_grid = np.meshgrid(phi_bins, radial_bins, indexing="ij") + if ellipse_params is None: + x = r_grid * np.cos(phi_grid) + y = r_grid * np.sin(phi_grid) + else: + if len(ellipse_params) != 3: + raise ValueError("ellipse_params must be (a, b, theta_deg).") + a, b, theta_deg = ellipse_params + theta = np.deg2rad(theta_deg) + alpha = phi_grid - theta + u = (a / b) * r_grid * np.cos(alpha) + v_prime = r_grid * np.sin(alpha) + cos_t = np.cos(theta) + sin_t = np.sin(theta) + x = u * cos_t - v_prime * sin_t + y = u * sin_t + v_prime * cos_t + coords_y = y + origin_row + coords_x = x + origin_col + coords = np.stack((coords_y, coords_x), axis=0) + return coords, phi_bins, radial_bins, radial_max_eff + + +def dataset4dstem_polar_transform( + self: "Dataset4dstem", + origin_row: float | int | NDArray, + origin_col: float | int | NDArray, + ellipse_params: tuple[float, float, float] | None = None, + num_annular_bins: int = 180, + radial_min: float = 0.0, + radial_max: float | None = None, + radial_step: float = 1.0, + two_fold_rotation_symmetry: bool = False, + name: str | None = None, + signal_units: str | None = None, +) -> Polar4dstem: + if self.array.ndim != 4: + raise ValueError("polar_transform requires a 4D-STEM dataset (ndim=4).") + scan_y, scan_x, ny, nx = self.array.shape + origin_row_f = float(origin_row) + origin_col_f = float(origin_col) + coords, phi_bins, radial_bins, radial_max_eff = _precompute_polar_coords( + ny=ny, + nx=nx, + origin_row=origin_row_f, + origin_col=origin_col_f, + ellipse_params=ellipse_params, + num_annular_bins=num_annular_bins, + radial_min=radial_min, + radial_max=radial_max, + radial_step=radial_step, + two_fold_rotation_symmetry=two_fold_rotation_symmetry, + ) + n_phi = phi_bins.size + n_r = radial_bins.size + result_dtype = np.result_type(self.array.dtype, np.float32) + out = np.empty((scan_y, scan_x, n_phi, n_r), dtype=result_dtype) + for iy in range(scan_y): + for ix in range(scan_x): + dp = self.array[iy, ix] + out[iy, ix] = map_coordinates( + dp, + coords, + order=1, + mode="constant", + cval=0.0, + ) + if two_fold_rotation_symmetry: + phi_range = np.pi + else: + phi_range = 2.0 * np.pi + phi_step_deg = (phi_range / float(n_phi)) * (180.0 / np.pi) + sampling = np.zeros(4, dtype=float) + origin = np.zeros(4, dtype=float) + sampling[0:2] = np.asarray(self.sampling)[0:2] + sampling[2] = phi_step_deg + sampling[3] = float(np.asarray(self.sampling)[-1]) * radial_step + origin[0:2] = np.asarray(self.origin)[0:2] + origin[2] = 0.0 + origin[3] = radial_min * float(np.asarray(self.sampling)[-1]) + units = [ + self.units[0], + self.units[1], + "deg", + self.units[-1], + ] + metadata = dict(self.metadata) + metadata.update( + { + "polar_radial_min": float(radial_min), + "polar_radial_max": float(radial_max_eff), + "polar_radial_step": float(radial_step), + "polar_num_annular_bins": int(n_phi), + "polar_two_fold_rotation_symmetry": bool(two_fold_rotation_symmetry), + "polar_origin_row": origin_row_f, + "polar_origin_col": origin_col_f, + "polar_ellipse_params": tuple(ellipse_params) if ellipse_params is not None else None, + } + ) + return Polar4dstem( + array=out, + name=name if name is not None else f"{self.name}_polar", + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units if signal_units is not None else self.signal_units, + metadata=metadata, + _token=Polar4dstem._token, + ) diff --git a/src/quantem/core/fitting/background.py b/src/quantem/core/fitting/background.py index 000d4a01..6d678427 100644 --- a/src/quantem/core/fitting/background.py +++ b/src/quantem/core/fitting/background.py @@ -48,6 +48,17 @@ def forward(self, ctx: RenderContext) -> torch.Tensor: inten = self.intensity_raw.to(device=ctx.device, dtype=ctx.dtype) return torch.ones(ctx.shape, device=ctx.device, dtype=ctx.dtype) * inten + def forward_batched( + self, + ctx: RenderContext, + *, + intensity_raw_b: torch.Tensor, + ) -> torch.Tensor: + B = intensity_raw_b.shape[0] + return intensity_raw_b.view(B, 1, 1).expand(B, ctx.shape[0], ctx.shape[1]).to( + device=ctx.device, dtype=ctx.dtype + ) + class GaussianBackground(RenderComponent): # TODO this should be N dimensional by default def __init__( @@ -110,3 +121,21 @@ def forward(self, ctx: RenderContext) -> torch.Tensor: inten = self.intensity_raw.to(device=ctx.device, dtype=ctx.dtype) r2 = (rr - r0) ** 2 + (cc - c0) ** 2 return inten * torch.exp(-0.5 * r2 / (sigma * sigma)) + + def forward_batched( + self, + ctx: RenderContext, + *, + sigma_raw_b: torch.Tensor, + intensity_raw_b: torch.Tensor, + origin_coords_b: torch.Tensor, + ) -> torch.Tensor: + B = sigma_raw_b.shape[0] + rr = torch.arange(ctx.shape[0], device=ctx.device, dtype=ctx.dtype).view(1, ctx.shape[0], 1) + cc = torch.arange(ctx.shape[1], device=ctx.device, dtype=ctx.dtype).view(1, 1, ctx.shape[1]) + r0 = origin_coords_b[:, 0].view(B, 1, 1) + c0 = origin_coords_b[:, 1].view(B, 1, 1) + sigma = sigma_raw_b.view(B, 1, 1) + inten = intensity_raw_b.view(B, 1, 1) + r2 = (rr - r0) ** 2 + (cc - c0) ** 2 + return inten * torch.exp(-0.5 * r2 / (sigma * sigma)) diff --git a/src/quantem/core/fitting/base.py b/src/quantem/core/fitting/base.py index 74ce1f0f..21f7a7f6 100644 --- a/src/quantem/core/fitting/base.py +++ b/src/quantem/core/fitting/base.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from types import NoneType from typing import Any, Literal, Self, Sequence, cast import numpy as np @@ -10,9 +11,6 @@ from quantem.core.ml.optimizer_mixin import ( OptimizerMixin, - OptimizerParams, - OptimizerType, - SchedulerType, ) @@ -95,16 +93,38 @@ def __init__(self, *, ndim: int, init: Sequence[float]): self.coords = nn.Parameter(torch.as_tensor(init, dtype=torch.float32).reshape(self.ndim)) -class RenderComponent(nn.Module): +class RenderComponent(OptimizerMixin,nn.Module): DEFAULT_HARD_CONSTRAINTS: dict[str, Any] = {} DEFAULT_SOFT_CONSTRAINTS: dict[str, Any] = {} + DEFAULT_CONSTRAINT_CONFIG: dict[str, Any] = {} + + DEFAULT_OPTIMIZER: str = "adam" + DEFAULT_LR: float = 1e-2 + DEFAULT_SCHEDULER_TYPE: str = "none" def __init__(self) -> None: - super().__init__() + nn.Module.__init__(self) + + self._optimizer = None + self._scheduler = None + self._optimizer_params = {} + self._scheduler_params = {} self.hard_constraints: dict[str, Any] = dict(self.DEFAULT_HARD_CONSTRAINTS) self.soft_constraints: dict[str, Any] = dict(self.DEFAULT_SOFT_CONSTRAINTS) + self.constraint_config: dict[str, Any] = dict(self.DEFAULT_CONSTRAINT_CONFIG) self.parameter_bounds: dict[str, tuple[float | None, float | None]] = {} + optimizer_params = { + "type": self.DEFAULT_OPTIMIZER, + "lr": self.DEFAULT_LR + } + self.optimizer_params = optimizer_params + + scheduler_params = { + "type": self.DEFAULT_SCHEDULER_TYPE + } + self.scheduler_params = scheduler_params + @staticmethod def parse_bounded_init( value: float | int | Sequence[float | int | None], *, name: str @@ -275,8 +295,107 @@ def constraint_loss( ) -> torch.Tensor: return torch.zeros((), device=ctx.device, dtype=ctx.dtype) + def get_optimization_parameters(self) -> dict[str, list]: + params = [p for p in self.parameters() if p.requires_grad] + if not params: + return {} + return {'default': params} + + def initialize_optimizer(self, + optimizer_params: dict[str, Any] | None = None, + scheduler_params: dict[str, Any] | None = None, + num_iter: int | None = None, + ) -> None: + trainable_params = (self.get_optimization_parameters()) + if not trainable_params: + self._optimizer = None + self._scheduler = None + return + if optimizer_params is not None: + self.optimizer_params = optimizer_params + else: + self.optimizer_params = { + "type": self.DEFAULT_OPTIMIZER, + "lr": self.DEFAULT_LR + } + self.set_optimizer(self.optimizer_params) + + if scheduler_params is not None: + self.scheduler_params = scheduler_params + else: + self.scheduler_params = { + "type": self.DEFAULT_SCHEDULER_TYPE + } + self.set_scheduler(self.scheduler_params, num_iter = num_iter) + return + + + def _infer_optimizer_rebuild_params(self) -> Any: + if self.optimizer_params: + return self.optimizer_params + if self.optimizer is not None: + opt_type: str | type[torch.optim.Optimizer] + if isinstance(self.optimizer, torch.optim.AdamW): + opt_type = "adamw" + elif isinstance(self.optimizer, torch.optim.Adam): + opt_type = "adam" + elif isinstance(self.optimizer, torch.optim.SGD): + opt_type = "sgd" + else: + opt_type = type(self.optimizer) + lr = float( + self.optimizer.param_groups[0].get( + "lr", getattr(self, "DEFAULT_LR", self.DEFAULT_LR) + ) + ) + return {"type": opt_type, "lr": lr} + return { + "type": getattr(self, "DEFAULT_OPTIMIZER_TYPE", self.DEFAULT_OPTIMIZER), + "lr": float(getattr(self, "DEFAULT_LR", self.DEFAULT_LR)), + } + + def _infer_scheduler_rebuild_params(self) -> Any: + if self.scheduler_params: + return self.scheduler_params + return { + "type": self.DEFAULT_SCHEDULER_TYPE, + } + + def _rebuild_optimizer_after_trainability_change(self) -> None: + trainable_params = (self.get_optimization_parameters()) + if not trainable_params: + self._optimizer = None + self._scheduler = None + return + # rebuild_params = self._infer_optimizer_rebuild_params() + # rebuild_params_scheduler = self._infer_scheduler_rebuild_params() + self.set_optimizer(self.optimizer_params) + self.set_scheduler(self.scheduler_params) + + def initialize_constraint_config(self, config: dict[str, Any], strict: bool = True) -> None: + if not hasattr(self, 'constraint_config'): + if strict: + raise AttributeError( + f"{self.__class__.__name__} does not have constraint_config attribute" + ) + return + if not isinstance(config, dict): + raise TypeError("constraint config must be a dict.") + + unknown: dict[str, Any] = {} + for k, v in config.items(): + if k in self.DEFAULT_CONSTRAINT_CONFIG: + self.constraint_config[k] = v + else: + unknown[k] = v + + if unknown and strict: + keys = ", ".join(str(k) for k in unknown.keys()) + raise KeyError(f"Unknown constraint keys for {self.__class__.__name__}: {keys}") + return -class AdditiveRenderModel(nn.Module): + +class AdditiveRenderModel(nn.Module): # step all otpimzers def __init__(self, *, origin: nn.Module, components: list[RenderComponent]): super().__init__() self.origin = origin @@ -340,70 +459,125 @@ def total_constraint_loss(self, ctx: RenderContext) -> torch.Tensor: loss = loss + component.constraint_loss(ctx) return loss + def initilize_independant_optimizers(self, + individual_optimizers: dict[str, dict[str, Any]] | None = None, + individual_schedulers: dict[str, dict[str, Any]] | None = None, + num_iter: int | None = None, + ) -> None: + for idx, module in enumerate(self.components): + component = cast(RenderComponent, module) + component_name = self._component_constraint_name(component, idx) + + component_optimizer_params = None + component_scheduler_params = None + if individual_optimizers is not None: + if component_name in individual_optimizers: + component_optimizer_params = individual_optimizers[component_name] + elif component.__class__.__name__ in individual_optimizers : + component_optimizer_params = individual_optimizers[component.__class__.__name__] + + if individual_schedulers is not None: + if component_name in individual_schedulers: + component_scheduler_params = individual_schedulers[component_name] + elif component.__class__.__name__ in individual_schedulers: + component_scheduler_params = individual_schedulers[component.__class__.__name__] + + component.initialize_optimizer(component_optimizer_params, component_scheduler_params, num_iter=num_iter) + + def set_independant_optimizer_params(self, + individual_optimizer_params: dict[str, dict[str, Any]], + individual_scheduler_params: dict[str, dict[str, Any]] | None = None, + num_iter: int | None = None, + ) -> None: + for component_name, param in individual_optimizer_params.items(): + scheduler_params = None + if individual_scheduler_params is not None and component_name in individual_scheduler_params: + scheduler_params = individual_scheduler_params[component_name] + + component = self._resolve_component_by_name(str(component_name)) + component.initialize_optimizer(param,scheduler_params,num_iter) + + def rebuild_independant_optimizers(self) -> None: + for module in self.components: + component = cast(RenderComponent, module) + component.initialize_optimizer() -@dataclass -class FitResult: - losses: list[float] - lrs: list[float] - final_loss: float - num_steps: int - metrics: dict[str, list[float]] = field(default_factory=dict) - - -class FitBase(OptimizerMixin): - DEFAULT_LR = 1e-2 - DEFAULT_OPTIMIZER_TYPE = "adam" - - def __init__(self): - super().__init__() - # Core wiring - self.loss_fn = torch.nn.MSELoss(reduction="mean") - self.model: AdditiveRenderModel | None = None - self.ctx: RenderContext | None = None - - # State/checkpoints - self.state_initialized: dict[str, torch.Tensor] | None = None - - # Histories/results - self.fit_history: dict[str, FitResult] = {} - - def get_optimization_parameters(self) -> Any: - if self.model is None: - return [] - return [p for p in self.model.parameters() if p.requires_grad] - - @property - def state_current(self) -> dict[str, torch.Tensor] | None: - if self.model is None: - return None - return self._get_model_state_dict_copy() + def step_optimizers(self) -> None: + for module in self.components: + component = cast(RenderComponent, module) + component.step_optimizer() + + def step_schedulers(self, loss: float | None = None) -> None: + for module in self.components: + component = cast(RenderComponent, module) + if hasattr(component, 'step_scheduler'): + try: + component.step_scheduler(loss) + except (AttributeError, TypeError): + pass + + def zero_grad_optimizers(self) -> None: + for module in self.components: + component = cast(RenderComponent, module) + component.zero_optimizer_grad() + + def _iter_named_components(self) -> list[tuple[str, RenderComponent]]: + """ + Return canonical component names paired with components. - @property - def render_initialized(self) -> np.ndarray: - if self.state_initialized is None: - raise RuntimeError("initialized state is unavailable. Call .define_model(...) first.") - return self._render_state_array(self.state_initialized) + Returns + ------- + list[tuple[str, RenderComponent]] + ``(name, component)`` entries using the model's canonical naming + rule. Names fall back to class-name/index behavior when ``.name`` is + missing. - @property - def render_current(self) -> np.ndarray: - if self.model is None or self.ctx is None: - raise RuntimeError("Call .define_model(...) first.") - return self.model(self.ctx).detach().cpu().numpy() + Raises + ------ + RuntimeError + If the model is not defined. + """ + entries: list[tuple[str, RenderComponent]] = [] + for idx, module in enumerate(self.components): + component = cast(RenderComponent, module) + name = self._component_constraint_name(component, idx) + entries.append((name, component)) + return entries + + def _resolve_component_by_name(self, component_name: str) -> RenderComponent: + target = str(component_name) + for resolved_name, component in self._iter_named_components(): + if resolved_name == target: + return component + + for resolved_name, component in self._iter_named_components(): + if component.__class__.__name__ == target: + return component + + known = ", ".join(self.get_component_names()) + raise KeyError(f"Component not found: {target}. Known components: {known}") + + def get_component_names(self) -> list[str]: + """ + Return canonical component names. - def reset( - self, - reset_to: Literal["initialized"] = "initialized", - ) -> Self: - if reset_to != "initialized": - raise ValueError("FitBase.reset only supports reset_to='initialized'.") - if self.state_initialized is None: - raise RuntimeError("initialized state is unavailable. Call .define_model(...) first.") - self._load_model_state_dict_copy(self.state_initialized) - self._clear_fit_history_all() - return self + Returns + ------- + list[str] + Canonical component names. + """ + return [name for name, _ in self._iter_named_components()] + + def get_component_by_name(self, component_name: str) -> RenderComponent: + return self._resolve_component_by_name(component_name) + def set_component_trainable( - self, component_name: str, enabled: bool, rebuild_optimizer: bool = True + self, + component_name: str, + enabled: bool, + rebuild_optimizer: bool = True, + num_iter: int | None = None, ) -> None: """ Enable or disable optimization for all parameters in one component. @@ -439,7 +613,7 @@ def set_component_trainable( for _, param in component.named_parameters(recurse=True): param.requires_grad_(bool(enabled)) if rebuild_optimizer: - self._rebuild_optimizer_after_trainability_change() + component._rebuild_optimizer_after_trainability_change() def set_parameter_trainable( self, @@ -447,6 +621,7 @@ def set_parameter_trainable( parameter_name: str, enabled: bool, rebuild_optimizer: bool = True, + num_iter: int | None = None, ) -> None: """ Enable or disable optimization for one component parameter. @@ -488,7 +663,7 @@ def set_parameter_trainable( ) params[parameter_name].requires_grad_(bool(enabled)) if rebuild_optimizer: - self._rebuild_optimizer_after_trainability_change() + component._rebuild_optimizer_after_trainability_change() def set_parameters_trainable( self, @@ -496,6 +671,7 @@ def set_parameters_trainable( parameter_names: list[str], enabled: bool, rebuild_optimizer: bool = True, + num_iter: int | None = None, ) -> None: """ Enable or disable optimization for multiple component parameters. @@ -534,7 +710,7 @@ def set_parameters_trainable( for name in parameter_names: params[name].requires_grad_(bool(enabled)) if rebuild_optimizer: - self._rebuild_optimizer_after_trainability_change() + component._rebuild_optimizer_after_trainability_change() def get_component_trainable(self, component_name: str) -> dict[str, bool]: """ @@ -559,6 +735,162 @@ def get_component_trainable(self, component_name: str) -> dict[str, bool]: """ component = self._resolve_component_by_name(component_name) return {name: bool(param.requires_grad) for name, param in component.named_parameters()} + + def apply_constraint_configs( + self, constraint_configs: dict[str, Any], strict: bool = True + ) -> None: + for component_name, param in constraint_configs.items(): + component = self._resolve_component_by_name(str(component_name)) + component.initialize_constraint_config(param, strict=strict) + + + + + + + + +@dataclass +class FitResult: + losses: list[float] + lrs: list[float] + final_loss: float + num_steps: int + metrics: dict[str, list[float]] = field(default_factory=dict) + +class BCEMSELoss(nn.Module): + + def __init__( + self, + gamma: float = 1, + percentile: float = 0.95, + ): + super().__init__() + self.gamma = gamma + self.percentile = percentile + + def forward(self, pred, target): + mse_loss = torch.mean((pred - target) ** 2) + + threshold_pred = torch.quantile(pred, self.percentile) + threshold_target = torch.quantile(target, self.percentile) + + # Create binary masks + target_binary = (target > threshold_target).float() + pred_binary = (pred > threshold_pred).float() + + epsilon = 1e-7 + pred_binary_clamped = torch.clamp(pred_binary, epsilon, 1.0 - epsilon) + bce_loss = -torch.mean( + target_binary * torch.log(pred_binary_clamped) + + (1 - target_binary) * torch.log(1 - pred_binary_clamped) + ) + + # Combined loss + total_loss = mse_loss + self.gamma ** bce_loss + + return total_loss + +class SqrtMSELoss(nn.Module): + def __init__( + self, + gamma: float = 0.25, + ): + super().__init__() + self.gamma = gamma + self.mse_fn = torch.nn.MSELoss(reduction="mean") + + def forward(self, pred, target): + eps = 1 + pred_modified = (pred-torch.min(pred)+eps)**self.gamma + # pred_modified = pred_modified / torch.linalg.norm(pred_modified) + + target_modified = (target-torch.min(target)+eps)**self.gamma + # target_modified = target_modified / torch.linalg.norm(target_modified) + + loss = self.mse_fn(pred_modified, target_modified) + return loss + +class LogMSELoss(nn.Module): + def __init__( + self, + ): + super().__init__() + self.mse_fn = torch.nn.MSELoss(reduction="mean") + + def forward(self, pred, target): + return self.mse_fn(torch.log(1+pred), torch.log(1+target)) + +class FitBase(OptimizerMixin): + + def __init__(self): + super().__init__() + # Core wiring + # self.loss_fn = torch.nn.L1Loss(reduction="mean") + # self.loss_fn = torch.nn.MSELoss(reduction="mean") + self.loss_fn = SqrtMSELoss() + self.model: AdditiveRenderModel | None = None + self.ctx: RenderContext | None = None + + # State/checkpoints + self.state_initialized: dict[str, torch.Tensor] | None = None + + # Histories/results + self.fit_history: dict[str, FitResult] = {} + + # self.multi_optimizers: dict[str, torch.optim.Optimizer] | None = None + # self.use_mutiple_optimizers: bool = False + + def get_optimization_parameters(self) -> Any: + if self.model is None: + return [] + + return [p for p in self.model.parameters() if p.requires_grad] + + @property + def state_current(self) -> dict[str, torch.Tensor] | None: + if self.model is None: + return None + return self._get_model_state_dict_copy() + + @property + def render_initialized(self) -> np.ndarray: + if self.state_initialized is None: + raise RuntimeError("initialized state is unavailable. Call .define_model(...) first.") + return self._render_state_array(self.state_initialized) + + @property + def render_current(self) -> np.ndarray: + if self.model is None or self.ctx is None: + raise RuntimeError("Call .define_model(...) first.") + return self.model(self.ctx).detach().cpu().numpy() + + def set_loss(self, loss_fn: str = "sqrtmse", gamma: float = 0.25): + mode_in = loss_fn.strip().lower() + if mode_in in {"mse"}: + self.loss_fn = torch.nn.MSELoss(reduction="mean") + elif mode_in in {"sqrtmse"}: + self.loss_fn = SqrtMSELoss(gamma = gamma) + elif mode_in in {"logmse"}: + self.loss_fn = LogMSELoss() + elif mode_in in {"l1"}: + self.loss_fn = torch.nn.L1Loss(reduction="mean") + else: + raise ValueError( + "loss function must be mse, sqrtmse, logmse or l1" + ) + + def reset( + self, + reset_to: Literal["initialized"] = "initialized", + ) -> Self: + if reset_to != "initialized": + raise ValueError("FitBase.reset only supports reset_to='initialized'.") + if self.state_initialized is None: + raise RuntimeError("initialized state is unavailable. Call .define_model(...) first.") + self._load_model_state_dict_copy(self.state_initialized) + self._clear_fit_history_all() + return self def fit_render( self, @@ -567,8 +899,9 @@ def fit_render( n_steps: int, constraint_weight: float = 1.0, constraint_params: dict[str, Any] | None = None, - optimizer_params: OptimizerType | dict | None = None, - scheduler_params: SchedulerType | dict | None = None, + constraint_config_params: dict[str, Any] | None = None, + optimizer_params: dict[str, dict[str, Any]] | None = None, + scheduler_params: dict[str, dict[str, Any]] | None = None, progress: bool = False, run_key: str = "default", **kwargs: Any, @@ -616,50 +949,47 @@ def fit_render( raise RuntimeError("Model and context are not defined for fitting.") if constraint_params is not None: self.model.apply_constraint_params(constraint_params, strict=True) - - optimizer_rebuilt = False - if optimizer_params is not None: - self.set_optimizer(optimizer_params) - optimizer_rebuilt = True - elif self.optimizer is None: - if self.optimizer_params: - self.set_optimizer(self.optimizer_params) - else: - self.set_optimizer( - { - "type": getattr(self, "DEFAULT_OPTIMIZER_TYPE", "adamw"), - "lr": float(getattr(self, "DEFAULT_LR", self.DEFAULT_LR)), - } - ) - optimizer_rebuilt = True + if constraint_config_params is not None: + self.model.apply_constraint_configs(constraint_config_params, strict=True) n_steps = int(n_steps) - if scheduler_params is not None: - self.set_scheduler(scheduler_params, num_iter=n_steps) - elif self.scheduler is None and self.scheduler_params: - self.set_scheduler(self.scheduler_params, num_iter=n_steps) - elif optimizer_rebuilt and self.scheduler is not None and self.optimizer is not None: - self.scheduler.optimizer = self.optimizer + self.model.initilize_independant_optimizers( + optimizer_params, + scheduler_params, + num_iter=n_steps + ) pbar = tqdm(range(n_steps), desc="Fit render", disable=not progress) + loss_vals = torch.empty(n_steps, device=self.ctx.device) losses: list[float] = [] lrs: list[float] = [] - for _ in pbar: - self.zero_optimizer_grad() + for step in pbar: + self.model.zero_grad_optimizers() pred = self._forward_for_fit(target=target, **kwargs) data_loss = self._fidelity_loss(pred, target, **kwargs) constraint_loss = self._constraint_loss(pred, target, **kwargs) total_loss = data_loss + constraint_weight * constraint_loss total_loss.backward() - self.step_optimizer() + self.model.step_optimizers() if self.model is None or self.ctx is None: raise RuntimeError("Model and context are not defined for fitting.") self.model.apply_hard_constraints(self.ctx) - total_loss_value = float(total_loss.detach().cpu()) - self.step_scheduler(total_loss_value) - losses.append(total_loss_value) - lrs.append(float(self.get_current_lr())) + total_loss_value = (total_loss.detach()) + self.model.step_schedulers(total_loss_value) + # losses.append(total_loss_value) + loss_vals[step] = total_loss_value + + first_lr = 0.0 + if len(self.model.components) > 0: + first_comp = cast(RenderComponent, self.model.components[0]) + if hasattr(first_comp, 'get_current_lr'): + try: + first_lr = float(first_comp.get_current_lr()) + except (AttributeError, TypeError): + first_lr = 0.0 + lrs.append(first_lr) + losses = loss_vals.cpu().tolist() key = str(run_key) if key in self.fit_history: @@ -679,85 +1009,6 @@ def fit_render( self.fit_history[key] = result return result - def _iter_named_components(self) -> list[tuple[str, RenderComponent]]: - """ - Return canonical component names paired with components. - - Returns - ------- - list[tuple[str, RenderComponent]] - ``(name, component)`` entries using the model's canonical naming - rule. Names fall back to class-name/index behavior when ``.name`` is - missing. - - Raises - ------ - RuntimeError - If the model is not defined. - """ - if self.model is None: - raise RuntimeError("Call .define_model(...) first.") - entries: list[tuple[str, RenderComponent]] = [] - for idx, module in enumerate(self.model.components): - component = cast(RenderComponent, module) - name = self.model._component_constraint_name(component, idx) - entries.append((name, component)) - return entries - - def get_component_names(self) -> list[str]: - """ - Return canonical component names. - - Returns - ------- - list[str] - Canonical component names. - """ - return [name for name, _ in self._iter_named_components()] - - def _resolve_component_by_name(self, component_name: str) -> RenderComponent: - target = str(component_name) - for resolved_name, component in self._iter_named_components(): - if resolved_name == target: - return component - known = ", ".join(self.get_component_names()) - raise KeyError(f"Component not found: {target}. Known components: {known}") - - def _infer_optimizer_rebuild_params(self) -> dict[str, Any]: - if self.optimizer_params: - op = self.optimizer_params - if isinstance(op, OptimizerParams.NoneOptimizer): - return {"type": "none"} - out: dict[str, Any] = dict(op.params()) - out["type"] = op._name - return out - if self.optimizer is not None: - opt_type: str | type[torch.optim.Optimizer] - if isinstance(self.optimizer, torch.optim.AdamW): - opt_type = "adamw" - elif isinstance(self.optimizer, torch.optim.Adam): - opt_type = "adam" - elif isinstance(self.optimizer, torch.optim.SGD): - opt_type = "sgd" - else: - opt_type = type(self.optimizer) - lr = float( - self.optimizer.param_groups[0].get( - "lr", getattr(self, "DEFAULT_LR", self.DEFAULT_LR) - ) - ) - return {"type": opt_type, "lr": lr} - return { - "type": getattr(self, "DEFAULT_OPTIMIZER_TYPE", self.DEFAULT_OPTIMIZER_TYPE), - "lr": float(getattr(self, "DEFAULT_LR", self.DEFAULT_LR)), - } - - def _rebuild_optimizer_after_trainability_change(self) -> None: - if self.model is None: - raise RuntimeError("Call .define_model(...) first.") - rebuild_params = self._infer_optimizer_rebuild_params() - self.set_optimizer(rebuild_params) - self.set_scheduler({"type": "none"}) def _clone_state_dict(self, state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: return {k: v.detach().clone() for k, v in state.items()} @@ -815,6 +1066,76 @@ def _constraint_loss( raise RuntimeError("Model and context are not defined for fitting.") return self.model.total_constraint_loss(self.ctx) + def set_component_trainable( + self, + component_name: str, + enabled: bool, + rebuild_optimizer: bool = True, + num_iter: int | None = None + ) -> None: + if self.model is None: + raise RuntimeError("Call .define_model(...) first.") + self.model.set_component_trainable(component_name, enabled, rebuild_optimizer, num_iter) + + def set_parameter_trainable( + self, + component_name: str, + parameter_name: str, + enabled: bool, + rebuild_optimizer: bool = True, + num_iter: int | None = None + ) -> None: + if self.model is None: + raise RuntimeError("Call .define_model(...) first.") + self.model.set_parameter_trainable(component_name, parameter_name, enabled, rebuild_optimizer, num_iter) + + def set_parameters_trainable( + self, + component_name: str, + parameter_names: list[str], + enabled: bool, + rebuild_optimizer: bool = True, + num_iter: int | None = None + ) -> None: + if self.model is None: + raise RuntimeError("Call .define_model(...) first.") + self.model.set_parameters_trainable(component_name, parameter_names, enabled, rebuild_optimizer, num_iter) + + def get_component_trainable(self, component_name: str) -> dict[str, bool]: + if self.model is None: + raise RuntimeError("Call .define_model(...) first.") + return self.model.get_component_trainable(component_name) + + def get_component_names(self) -> list[str]: + if self.model is None: + raise RuntimeError("Call .define_model(...) first.") + return self.model.get_component_names() + + def _iter_named_components(self) -> list[tuple[str, RenderComponent]]: + if self.model is None: + raise RuntimeError("Call .define_model(...) first.") + return self.model._iter_named_components() + + def _resolve_component_by_name(self, component_name: str) -> RenderComponent: + if self.model is None: + raise RuntimeError("Call .define_model(...) first.") + return self.model._resolve_component_by_name(component_name) + + def get_component_by_name(self, component_name: str) -> RenderComponent: + return self._resolve_component_by_name(component_name) + + def _rebuild_optimizer_after_trainability_change(self) -> None: + if self.model is None: + raise RuntimeError("Call .define_model(...) first.") + self.model.rebuild_independant_optimizers() + + def apply_constraint_config_params( + self, constraint_configs: dict[str, Any], strict: bool = True + ) -> None: + if self.model is None: + raise RuntimeError("Call .define_model(...) first.") + self.model.apply_constraint_configs(constraint_configs, strict=strict) + Component = RenderComponent ModelContext = RenderContext diff --git a/src/quantem/core/fitting/diffraction.py b/src/quantem/core/fitting/diffraction.py index 30934bf5..a62c01f3 100644 --- a/src/quantem/core/fitting/diffraction.py +++ b/src/quantem/core/fitting/diffraction.py @@ -48,12 +48,72 @@ def put(rr: torch.Tensor, cc: torch.Tensor, ww: torch.Tensor) -> None: put(r0i + 1, c0i + 1, w11) +def _splat_patch_batched( + shape: tuple[int, int], + *, + r0: torch.Tensor, + c0: torch.Tensor, + vals: torch.Tensor, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + h, w = int(shape[0]), int(shape[1]) + B, N = r0.shape + + r_base = torch.floor(r0) + c_base = torch.floor(c0) + fr = r0 - r_base + fc = c0 - c_base + r0i = r_base.to(torch.long) + c0i = c_base.to(torch.long) + + w00 = (1.0 - fr) * (1.0 - fc) + w01 = (1.0 - fr) * fc + w10 = fr * (1.0 - fc) + w11 = fr * fc + + rr = torch.stack([r0i, r0i, r0i + 1, r0i + 1], dim=0) + cc = torch.stack([c0i, c0i + 1, c0i, c0i + 1], dim=0) + ww = torch.stack([w00, w01, w10, w11], dim=0) + + keep = (rr >= 0) & (rr < h) & (cc >= 0) & (cc < w) + weighted = ww * vals.unsqueeze(0) * keep.to(dtype) + + rr_c = rr.clamp(0, h - 1) + cc_c = cc.clamp(0, w - 1) + flat_idx = rr_c * w + cc_c + + flat_idx_b = flat_idx.permute(1, 0, 2).reshape(B, -1) + weighted_b = weighted.permute(1, 0, 2).reshape(B, -1) + + out_flat = torch.zeros(B, h * w, device=device, dtype=dtype) + out_flat.scatter_add_(1, flat_idx_b, weighted_b) + return out_flat.reshape(B, h, w) + + class DiskTemplate(RenderComponent): DEFAULT_HARD_CONSTRAINTS: dict[str, bool] = { "force_center": False, "force_positive": True, + "force_norm": True, # force range [0,1] + "force_shrinkage": False, + "force_cutoff": False, + "force_circular_mask": True, + } + DEFAULT_SOFT_CONSTRAINTS: dict[str, float] = { + "tv_weight": 0.0, + "cutoff_weight": 0.0, + "circular_weight": 0.0, + } + DEFAULT_CONSTRAINT_CONFIG: dict[str, float] = { + "soft_cutoff_threshold": 0.0, + "soft_cutoff_target_ratio": 0.1, + "hard_cutoff_threshold": 0.35, + "shrinkage_amount": 0.25, + "circular_mask_radius_fraction": 0.95, + "circular_mask_sharpness": 0, + "soft_circular_mask": False, } - DEFAULT_SOFT_CONSTRAINTS: dict[str, float] = {"tv_weight": 0.0} def __init__( self, @@ -66,6 +126,7 @@ def __init__( origin_key: str = "origin", intensity: float | Sequence[float] = 1.0, constraint_params: dict[str, Any] | None = None, + constraint_config: dict[str, Any] | None = None, ): """ Build a disk template renderer centered at the shared origin. @@ -125,10 +186,17 @@ def __init__( cc = cc.astype(np.float32) - (wt - 1) * 0.5 self.register_buffer("dr", torch.as_tensor(rr.ravel(), dtype=torch.float32)) self.register_buffer("dc", torch.as_tensor(cc.ravel(), dtype=torch.float32)) + + self.constraint_config = self.DEFAULT_CONSTRAINT_CONFIG.copy() + if constraint_config is not None: + self.constraint_config.update(constraint_config) + if constraint_params is not None: self.apply_constraint_params(constraint_params, strict=True) if bool(self.hard_constraints.get("force_positive", False)): self._enforce_positivity() + if bool(self.hard_constraints.get("force_shrinkage", False)) and bool(self.hard_constraints.get("force_positive", True)): + raise RuntimeWarning("Setting shrinkage true and positivity false might cause negative values in disk template") @classmethod def from_array( @@ -142,6 +210,8 @@ def from_array( origin_key: str = "origin", intensity: float | Sequence[float] = 1.0, constraint_params: dict[str, Any] | None = None, + constraint_config: dict[str, Any] | None = None, + ) -> "DiskTemplate": return cls( name=name, @@ -152,6 +222,8 @@ def from_array( origin_key=origin_key, intensity=intensity, constraint_params=constraint_params, + constraint_config=constraint_config, + ) def set_origin(self, origin: OriginND) -> None: @@ -198,6 +270,25 @@ def forward(self, ctx: RenderContext) -> torch.Tensor: self.add_patch(out, r0=r0, c0=c0, scale=scale) return out + def forward_batched( + self, + ctx: RenderContext, + *, + template_raw_b: torch.Tensor, + intensity_raw_b: torch.Tensor, + origin_coords_b: torch.Tensor, + ) -> torch.Tensor: + B = template_raw_b.shape[0] + N = int(cast(torch.Tensor, self.dr).numel()) + dr = cast(torch.Tensor, self.dr).to(device=ctx.device, dtype=ctx.dtype) + dc = cast(torch.Tensor, self.dc).to(device=ctx.device, dtype=ctx.dtype) + r0 = origin_coords_b[:, 0:1] + dr.unsqueeze(0) + c0 = origin_coords_b[:, 1:2] + dc.unsqueeze(0) + vals = template_raw_b.reshape(B, N) * intensity_raw_b.view(B, 1) + return _splat_patch_batched( + ctx.shape, r0=r0, c0=c0, vals=vals, device=ctx.device, dtype=ctx.dtype + ) + def _center_disk(self) -> None: with torch.no_grad(): template = self.template_raw @@ -238,12 +329,52 @@ def _enforce_positivity(self) -> None: with torch.no_grad(): self.template_raw.clamp_(min=0.0) self.intensity_raw.clamp_(min=0.0) + + def _enforce_norm(self) -> None: # pick value and cut off 5 percent of mean, or every iteration shrinkage, every iteration just subtract a valye of 0.01 + with torch.no_grad(): + self.template_raw -= self.template_raw.min() + self.template_raw /= self.template_raw.max() + + def _enforce_shrinkage(self) -> None: + with torch.no_grad(): + self.template_raw -= self.constraint_config["shrinkage_amount"] + + def _enforce_cutoff(self) -> None: + with torch.no_grad(): + mean_val = torch.max(self.template_raw) * self.constraint_config["hard_cutoff_threshold"] + self.template_raw[self.template_raw <= mean_val] = 0 + + def _enforce_circular_mask(self) -> None: + with torch.no_grad(): + h, w = self.template_raw.shape + radius = (min(h, w) / 2.0) * self.constraint_config["circular_mask_radius_fraction"] + + r = torch.arange(-h/2, h/2, device=self.template_raw.device, dtype=self.template_raw.dtype) + c = torch.arange(-w/2, w/2, device=self.template_raw.device, dtype=self.template_raw.dtype) + rr, cc = torch.meshgrid(r, c, indexing='ij') + circle_matrix = torch.sqrt(rr**2 + cc**2) + + if self.constraint_config["soft_circular_mask"]: + mask = torch.sigmoid(self.constraint_config["circular_mask_sharpness"]*(radius-circle_matrix)) + else: + mask = circle_matrix <= radius + self.template_raw *= mask + def enforce_hard_constraints(self, ctx: RenderContext) -> None: if bool(self.hard_constraints.get("force_center", False)): self._center_disk() + if bool(self.hard_constraints.get("force_cutoff", False)): + self._enforce_cutoff() + if bool(self.hard_constraints.get("force_circular_mask", False)): + self._enforce_circular_mask() + if bool(self.hard_constraints.get("force_shrinkage", False)): + self._enforce_shrinkage() if bool(self.hard_constraints.get("force_positive", False)): self._enforce_positivity() + if bool(self.hard_constraints.get("force_norm", False)): # could be put in positivity + self._enforce_norm() + super().enforce_hard_constraints(ctx) def constraint_loss( @@ -251,8 +382,17 @@ def constraint_loss( ) -> torch.Tensor: cfg = self.effective_soft_constraints(cast(dict[str, object] | None, params)) tv_weight = float(cfg.get("tv_weight", 0.0)) + cutoff_weight = float(cfg.get("cutoff_weight", 0.0)) + circular_weight = float(cfg.get("circular_weight", 0.0)) + if tv_weight <= 0.0: - return torch.zeros((), device=ctx.device, dtype=ctx.dtype) + tv_weight = 0.0 + if cutoff_weight <= 0.0: + cutoff_weight = 0.0 + if circular_weight <= 0.0: + circular_weight = 0.0 + + # tv loss calculation template = self.template_raw.to(device=ctx.device, dtype=ctx.dtype) tv_r = ( torch.mean(torch.abs(template[1:, :] - template[:-1, :])) @@ -264,7 +404,101 @@ def constraint_loss( if template.shape[1] > 1 else torch.zeros((), device=ctx.device, dtype=ctx.dtype) ) - return torch.as_tensor(tv_weight, device=ctx.device, dtype=ctx.dtype) * (tv_r + tv_c) + tv_loss = torch.as_tensor(tv_weight, device=ctx.device, dtype=ctx.dtype) * (tv_r + tv_c) + + # Cutoff loss calculation + num_px = torch.prod(torch.tensor(template.shape)) + px_under_threshold = torch.sum(template <= torch.mean(template)*self.constraint_config["soft_cutoff_threshold"])/num_px + cutoff_loss = torch.as_tensor(cutoff_weight, device=ctx.device, dtype=ctx.dtype) * torch.maximum( + px_under_threshold-self.constraint_config["soft_cutoff_target_ratio"], torch.as_tensor(0.0, device=ctx.device, dtype=ctx.dtype)) + + # circular loss calculation + h, w = template.shape + radius = (min(h, w) / 2.0) * self.constraint_config["circular_mask_radius_fraction"] + + r = torch.arange(h, device=ctx.device, dtype=ctx.dtype) - h / 2.0 + c = torch.arange(w, device=ctx.device, dtype=ctx.dtype) - w / 2.0 + rr, cc = torch.meshgrid(r, c, indexing='ij') + + circle_mask = torch.sqrt(rr**2 + cc**2) + + dist_from_radius = torch.abs(circle_mask - radius) + dist_from_radius = torch.relu(dist_from_radius) + circular_err = torch.mean(dist_from_radius * template) + + circular_loss = torch.as_tensor(circular_weight, device=ctx.device, dtype=ctx.dtype) * circular_err + + return cutoff_loss + tv_loss + circular_loss + + def constraint_loss_batched( + self, + ctx: RenderContext, + *, + template_raw_b: torch.Tensor, + params: dict[str, object] | None = None, + ) -> torch.Tensor: + """ + Per-sample analogue of ``constraint_loss`` for stacked templates. + + Parameters + ---------- + template_raw_b : torch.Tensor + Stacked templates with shape ``(B, H_t, W_t)``. + + Returns + ------- + torch.Tensor + Per-sample soft-constraint losses with shape ``(B,)``. Identical + semantics to ``constraint_loss`` on each slice, but reductions are + taken over ``dim=(1, 2)`` only. + """ + cfg = self.effective_soft_constraints(cast(dict[str, object] | None, params)) + tv_weight = max(float(cfg.get("tv_weight", 0.0)), 0.0) + cutoff_weight = max(float(cfg.get("cutoff_weight", 0.0)), 0.0) + circular_weight = max(float(cfg.get("circular_weight", 0.0)), 0.0) + + template = template_raw_b.to(device=ctx.device, dtype=ctx.dtype) + B, h, w = template.shape + + if h > 1: + tv_r = torch.mean(torch.abs(template[:, 1:, :] - template[:, :-1, :]), dim=(1, 2)) + else: + tv_r = torch.zeros(B, device=ctx.device, dtype=ctx.dtype) + if w > 1: + tv_c = torch.mean(torch.abs(template[:, :, 1:] - template[:, :, :-1]), dim=(1, 2)) + else: + tv_c = torch.zeros(B, device=ctx.device, dtype=ctx.dtype) + tv_loss = tv_weight * (tv_r + tv_c) + + # Per-sample cutoff (note: hard `<=` is non-differentiable; matches serial). + per_sample_mean = template.mean(dim=(1, 2), keepdim=True) + thresh = per_sample_mean * float(self.constraint_config["soft_cutoff_threshold"]) + frac_under = (template <= thresh).to(dtype=ctx.dtype).mean(dim=(1, 2)) + target_ratio = float(self.constraint_config["soft_cutoff_target_ratio"]) + cutoff_loss = cutoff_weight * torch.relu(frac_under - target_ratio) + + # Per-sample circular: mask depends only on (H, W), so build once. + radius = (min(h, w) / 2.0) * float(self.constraint_config["circular_mask_radius_fraction"]) + r = torch.arange(h, device=ctx.device, dtype=ctx.dtype) - h / 2.0 + c = torch.arange(w, device=ctx.device, dtype=ctx.dtype) - w / 2.0 + rr, cc = torch.meshgrid(r, c, indexing="ij") + circle_mask = torch.sqrt(rr * rr + cc * cc) + dist_from_radius = torch.relu(torch.abs(circle_mask - radius)) # (H, W) + circular_err = (dist_from_radius.unsqueeze(0) * template).mean(dim=(1, 2)) + circular_loss = circular_weight * circular_err + + return tv_loss + cutoff_loss + circular_loss + + def get_optimization_parameters(self) -> dict[str, list[torch.nn.Parameter]]: + params = [] + for name, param in self.named_parameters(recurse=True): + if not name.startswith('origin.') and param.requires_grad: + params.append(param) + if not params: + return {} + return {'default': params} + + class SyntheticDiskLattice(RenderComponent): @@ -513,6 +747,7 @@ def enforce_hard_constraints(self, ctx: RenderContext) -> None: self.i0_raw[idx].clamp_(max=float(hi)) super().enforce_hard_constraints(ctx) + def forward(self, ctx: RenderContext) -> torch.Tensor: if self.origin is None: raise RuntimeError("SyntheticDiskLattice requires an OriginND instance.") @@ -532,9 +767,13 @@ def forward(self, ctx: RenderContext) -> torch.Tensor: b = torch.as_tensor(self.boundary_px, device=ctx.device, dtype=ctx.dtype) keep = (centers_r >= b) & (centers_r <= (ctx.shape[0] - 1) - b) keep = keep & (centers_c >= b) & (centers_c <= (ctx.shape[1] - 1) - b) - keep_idx = torch.nonzero(keep, as_tuple=False).reshape(-1) - if keep_idx.numel() == 0: + if not torch.any(keep): return out + + centers_r = centers_r[keep] + centers_c = centers_c[keep] + keep_idx = torch.nonzero(keep, as_tuple=False).reshape(-1) + num_disks = centers_r.shape[0] active_order = int( ctx.fields.get( @@ -543,39 +782,147 @@ def forward(self, ctx: RenderContext) -> torch.Tensor: ) active_order = max(0, min(active_order, self.max_intensity_order)) - dr, dc = self.disk.patch_offsets() - dr = dr.to(device=ctx.device, dtype=ctx.dtype) - dc = dc.to(device=ctx.device, dtype=ctx.dtype) - dr2 = dr * dr - dc2 = dc * dc - drdc = dr * dc - - for j in keep_idx: - rr0 = centers_r[j] - cc0 = centers_c[j] - - if self.per_disk_intensity: - inten = self.i0_raw[j] - if active_order >= 1 and self.ir is not None and self.ic is not None: - inten = inten + self.ir[j] * dr + self.ic[j] * dc - if ( - active_order >= 2 - and self.irr is not None - and self.icc is not None - and self.irc is not None - ): - inten = inten + self.irr[j] * dr2 + self.icc[j] * dc2 + self.irc[j] * drdc - else: - inten = self.i0_raw - if active_order >= 1: - assert self.ir is not None and self.ic is not None - inten = inten + self.ir * rr0 + self.ic * cc0 - if active_order >= 2: - assert self.irr is not None and self.icc is not None and self.irc is not None - inten = ( - inten + self.irr * rr0 * rr0 + self.icc * cc0 * cc0 + self.irc * rr0 * cc0 - ) - - self.disk.add_patch(out, r0=rr0, c0=cc0, scale=inten) + dr = cast(torch.Tensor, self.disk.dr).to(device=ctx.device, dtype=ctx.dtype) + dc = cast(torch.Tensor, self.disk.dc).to(device=ctx.device, dtype=ctx.dtype) + patch_vals = self.disk.patch_values().to(device=ctx.device, dtype=ctx.dtype) + num_pixels = patch_vals.shape[0] + if self.per_disk_intensity: + i0 = self.i0_raw[keep_idx][:, None] + inten = i0.expand(-1, num_pixels) + + if active_order >= 1 and self.ir is not None: + ir = self.ir[keep_idx][:, None] + ic = self.ic[keep_idx][:, None] + inten = inten + ir * dr[None, :] + ic * dc[None, :] + + if active_order >= 2 and self.irr is not None: + irr = self.irr[keep_idx][:, None] + icc = self.icc[keep_idx][:, None] + irc = self.irc[keep_idx][:, None] + inten = inten + irr * (dr*dr)[None, :] + icc * (dc*dc)[None, :] + irc * (dr*dc)[None, :] + else: + inten = self.i0_raw if isinstance(self.i0_raw, torch.Tensor) else self.i0_raw + if active_order >= 1: + inten = inten + self.ir * centers_r + self.ic * centers_c + if active_order >= 2: + inten = inten + self.irr * centers_r**2 + self.icc * centers_c**2 + self.irc * centers_r * centers_c + + inten = inten[:, None].expand(-1, num_pixels) if inten.ndim == 1 else inten.expand(num_disks, num_pixels) + total_pixels = num_disks * num_pixels + r0_all = centers_r[:, None].expand(-1, num_pixels).reshape(total_pixels) + c0_all = centers_c[:, None].expand(-1, num_pixels).reshape(total_pixels) + dr_all = dr[None, :].expand(num_disks, -1).reshape(total_pixels) + dc_all = dc[None, :].expand(num_disks, -1).reshape(total_pixels) + vals_all = (patch_vals[None, :] * inten).reshape(total_pixels) + _splat_patch( + out, + r0=r0_all, + c0=c0_all, + patch_vals=vals_all, + dr=dr_all, + dc=dc_all, + scale=torch.ones_like(vals_all) + ) return out + + def forward_batched( + self, + ctx: RenderContext, + *, + u_row_b: torch.Tensor, + u_col_b: torch.Tensor, + v_row_b: torch.Tensor, + v_col_b: torch.Tensor, + i0_raw_b: torch.Tensor, + ir_b: torch.Tensor | None, + ic_b: torch.Tensor | None, + irr_b: torch.Tensor | None, + icc_b: torch.Tensor | None, + irc_b: torch.Tensor | None, + template_raw_b: torch.Tensor, + origin_coords_b: torch.Tensor, + ) -> torch.Tensor: + if self.origin is None: + raise RuntimeError("SyntheticDiskLattice requires an OriginND instance.") + + B = u_row_b.shape[0] + uv = cast(torch.Tensor, self.uv_indices).to(device=ctx.device) + K = int(uv.shape[0]) + if K == 0: + return torch.zeros(B, ctx.shape[0], ctx.shape[1], device=ctx.device, dtype=ctx.dtype) + + u = uv[:, 0].to(dtype=ctx.dtype) + v = uv[:, 1].to(dtype=ctx.dtype) + + r0_kb = ( + origin_coords_b[:, 0:1] + + u.unsqueeze(0) * u_row_b.unsqueeze(1) + + v.unsqueeze(0) * v_row_b.unsqueeze(1) + ) + c0_kb = ( + origin_coords_b[:, 1:2] + + u.unsqueeze(0) * u_col_b.unsqueeze(1) + + v.unsqueeze(0) * v_col_b.unsqueeze(1) + ) + + bb = torch.as_tensor(self.boundary_px, device=ctx.device, dtype=ctx.dtype) + keep = (r0_kb >= bb) & (r0_kb <= (ctx.shape[0] - 1) - bb) + keep = keep & (c0_kb >= bb) & (c0_kb <= (ctx.shape[1] - 1) - bb) + keep_f = keep.to(dtype=ctx.dtype) + + active_order = int( + ctx.fields.get( + "lattice_intensity_order_override", self.default_pattern_intensity_order + ) + ) + active_order = max(0, min(active_order, self.max_intensity_order)) + + dr = cast(torch.Tensor, self.disk.dr).to(device=ctx.device, dtype=ctx.dtype) + dc = cast(torch.Tensor, self.disk.dc).to(device=ctx.device, dtype=ctx.dtype) + N_pix = int(dr.numel()) + patch_vals = template_raw_b.reshape(B, N_pix) + + if self.per_disk_intensity: + inten = i0_raw_b.unsqueeze(2).expand(B, K, N_pix) + if active_order >= 1 and ir_b is not None and ic_b is not None: + inten = inten + ir_b.unsqueeze(2) * dr.view(1, 1, N_pix) + ic_b.unsqueeze(2) * dc.view(1, 1, N_pix) + if active_order >= 2 and irr_b is not None and icc_b is not None and irc_b is not None: + inten = ( + inten + + irr_b.unsqueeze(2) * (dr * dr).view(1, 1, N_pix) + + icc_b.unsqueeze(2) * (dc * dc).view(1, 1, N_pix) + + irc_b.unsqueeze(2) * (dr * dc).view(1, 1, N_pix) + ) + else: + inten = i0_raw_b.view(B, 1, 1).expand(B, K, N_pix).clone() + if active_order >= 1 and ir_b is not None and ic_b is not None: + inten = inten + ir_b.view(B, 1, 1) * r0_kb.unsqueeze(2) + ic_b.view(B, 1, 1) * c0_kb.unsqueeze(2) + if active_order >= 2 and irr_b is not None and icc_b is not None and irc_b is not None: + inten = ( + inten + + irr_b.view(B, 1, 1) * (r0_kb * r0_kb).unsqueeze(2) + + icc_b.view(B, 1, 1) * (c0_kb * c0_kb).unsqueeze(2) + + irc_b.view(B, 1, 1) * (r0_kb * c0_kb).unsqueeze(2) + ) + + inten = inten * keep_f.unsqueeze(2) + vals = patch_vals.unsqueeze(1) * inten + + r0_full = (r0_kb.unsqueeze(2) + dr.view(1, 1, N_pix)).reshape(B, K * N_pix) + c0_full = (c0_kb.unsqueeze(2) + dc.view(1, 1, N_pix)).reshape(B, K * N_pix) + vals_full = vals.reshape(B, K * N_pix) + + return _splat_patch_batched( + ctx.shape, r0=r0_full, c0=c0_full, vals=vals_full, + device=ctx.device, dtype=ctx.dtype, + ) + + def get_optimization_parameters(self) -> dict[str, list[torch.nn.Parameter]]: + params = [] + for name, param in self.named_parameters(recurse=True): + if not name.startswith('disk.') and param.requires_grad: + params.append(param) + if not params: + return {} + return {'default': params} diff --git a/src/quantem/core/io/file_readers.py b/src/quantem/core/io/file_readers.py index 4fe72645..6e970172 100644 --- a/src/quantem/core/io/file_readers.py +++ b/src/quantem/core/io/file_readers.py @@ -4,6 +4,7 @@ from typing import Any import h5py +import numpy as np from quantem.core.datastructures import Dataset as Dataset from quantem.core.datastructures import Dataset2d as Dataset2d @@ -19,23 +20,47 @@ def read_4dstem( **kwargs, ) -> Dataset4dstem: """ - File reader for 4D-STEM data + File reader for 4D-STEM data. Parameters ---------- - file_path: str | PathLike - Path to data - file_type: str - The type of file reader needed. See rosettasciio for supported formats + file_path : str | PathLike + Path to data. + file_type : str, optional + The type of file reader needed. See RosettaSciIO for supported formats: https://hyperspy.org/rosettasciio/supported_formats/index.html - dataset_index: int, optional + dataset_index : int, optional Index of the dataset to load if file contains multiple datasets. If None, automatically selects the first 4D dataset found. +<<<<<<< HEAD +<<<<<<< HEAD + If no 4D dataset is found but a 3D stack exists, a 3D dataset can be + interpreted as 4D if `scan_length` is provided. + scan_length : int, optional + For 3D datasets shaped (n_frames, ny, nx) (after possibly moving the + scan axis to the front), interpret the data as a raster scan with shape + (scan_y, scan_x, ny, nx), where scan_y = n_frames // scan_length and + scan_x = scan_length. Required if you want to treat a 3D stack as 4D. + scan_axis : int, default 0 + Which axis of a 3D dataset is the scan/time axis before reshaping. + Must be 0 or 1. The specified axis is moved to axis 0 before the + (scan_y, scan_x) reshape. + transpose_scan_axes : bool, default False + Only used when interpreting a 3D dataset as 4D via `scan_length`. + If True, transpose the scan axes after reshaping so that + (scan_y, scan_x) -> (scan_x, scan_y). This effectively swaps the + interpretation of scan rows and columns in the final 4D array. + + **kwargs : dict + Additional keyword arguments to pass to the Dataset4dstem constructor. +======= +======= hot_pixel_filter: bool, optional If True, detect and replace hot detector pixels immediately after loading using `quantem.core.utils.filter.filter_hot_pixels` with its default parameters. For custom thresholds, call `filter_hot_pixels` directly on the array. +>>>>>>> dev **kwargs: dict Additional keyword arguments to pass to the file reader. @@ -51,9 +76,10 @@ def read_4dstem( Units for each dimension. If None, defaults to ["pixels"] * 4 signal_units : str, optional Units for the array values, by default "arb. units" +>>>>>>> upstream/fitting_models_clean Returns - -------- + ------- Dataset4dstem Examples @@ -76,6 +102,94 @@ def read_4dstem( ... hot_pixel_filter=True, ... ) """ + + def _reshape_3d_to_4d( + imported_data: dict, + *, + dataset_index_local: int | None, + scan_length_local: int, + scan_axis_local: int, + transpose_scan_axes_local: bool, + ) -> dict: + data = imported_data["data"] + if data.ndim != 3: + raise ValueError( + f"Expected 3D data to reshape, got ndim={data.ndim} " + f"with shape {data.shape}" + ) + + if scan_axis_local not in (0, 1): + raise ValueError(f"scan_axis must be 0 or 1, got {scan_axis_local}") + + # Move scan axis to front so it becomes the frame axis + if scan_axis_local != 0: + data = np.moveaxis(data, scan_axis_local, 0) + + n_frames, ny, nx = data.shape + + if scan_length_local <= 0: + raise ValueError(f"scan_length must be positive, got {scan_length_local}") + if n_frames % scan_length_local != 0: + raise ValueError( + f"scan_length={scan_length_local} is not compatible with n_frames={n_frames}; " + f"n_frames % scan_length = {n_frames % scan_length_local}" + ) + + scan_y = n_frames // scan_length_local + scan_x = scan_length_local + + data_4d = data.reshape(scan_y, scan_x, ny, nx) + + if transpose_scan_axes_local: + data_4d = np.transpose(data_4d, (1, 0, 2, 3)) + scan_y, scan_x = scan_x, scan_y + + old_axes = imported_data.get("axes", None) + if old_axes is None or len(old_axes) != 3: + raise ValueError( + "Expected 3 axes for 3D data when reshaping to 4D; " + f"got axes={old_axes}" + ) + + ax_scan_y = { + "scale": 1.0, + "offset": 0.0, + "units": "pixels", + "name": "scan_y", + } + ax_scan_x = { + "scale": 1.0, + "offset": 0.0, + "units": "pixels", + "name": "scan_x", + } + + ax_qy = dict(old_axes[1]) + ax_qx = dict(old_axes[2]) + + imported_data_4d = imported_data.copy() + imported_data_4d["data"] = data_4d + imported_data_4d["axes"] = [ax_scan_y, ax_scan_x, ax_qy, ax_qx] + + original_shape = imported_data["data"].shape + new_shape = data_4d.shape + if dataset_index_local is not None: + print( + f"Using 3D dataset {dataset_index_local} with shape {original_shape} " + f"interpreted as 4D with shape={new_shape} " + f"(scan_axis={scan_axis_local}, scan_length={scan_length_local}, " + f"transpose_scan_axes={transpose_scan_axes_local})." + ) + else: + print( + f"Using 3D dataset with shape {original_shape} " + f"interpreted as 4D with shape={new_shape} " + f"(scan_axis={scan_axis_local}, scan_length={scan_length_local}, " + f"transpose_scan_axes={transpose_scan_axes_local})." + ) + + return imported_data_4d + if file_type is None: file_type = Path(file_path).suffix.lower().lstrip(".") @@ -87,29 +201,98 @@ def read_4dstem( file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader data_list = file_reader(file_path, **kwargs) - # If specific index provided, use it + if not data_list: + raise ValueError(f"No datasets returned by rsciio.{file_type} for '{file_path}'") + + # Case 1: dataset_index specified explicitly if dataset_index is not None: imported_data = data_list[dataset_index] - if imported_data["data"].ndim != 4: + ndim = imported_data["data"].ndim + + if ndim == 4: + # Use 4D as-is + pass + elif ndim == 3: + if scan_length is None: + raise ValueError( + f"Dataset at index {dataset_index} is 3D (shape={imported_data['data'].shape}). " + "To interpret it as 4D-STEM, please provide scan_length." + ) + imported_data = _reshape_3d_to_4d( + imported_data, + dataset_index_local=dataset_index, + scan_length_local=scan_length, + scan_axis_local=scan_axis, + transpose_scan_axes_local=transpose_scan_axes, + ) + else: raise ValueError( - f"Dataset at index {dataset_index} has {imported_data['data'].ndim} dimensions, " - f"expected 4D. Shape: {imported_data['data'].shape}" + f"Dataset at index {dataset_index} has ndim={ndim}, " + f"expected 4D or 3D. Shape: {imported_data['data'].shape}" ) + else: - # Automatically find first 4D dataset + # Case 2: auto-select dataset four_d_datasets = [(i, d) for i, d in enumerate(data_list) if d["data"].ndim == 4] - if len(four_d_datasets) == 0: - print(f"No 4D datasets found in {file_path}. Available datasets:") - for i, d in enumerate(data_list): - print(f" Dataset {i}: shape {d['data'].shape}, ndim={d['data'].ndim}") - raise ValueError("No 4D dataset found in file") - - dataset_index, imported_data = four_d_datasets[0] - - if len(data_list) > 1: - print( - f"File contains {len(data_list)} dataset(s). Using dataset {dataset_index} with shape {imported_data['data'].shape}" + if four_d_datasets: + dataset_index, imported_data = four_d_datasets[0] + if len(data_list) > 1: + print( + f"File contains {len(data_list)} dataset(s). Using 4D dataset " + f"{dataset_index} with shape {imported_data['data'].shape}" + ) + else: + three_d_datasets = [(i, d) for i, d in enumerate(data_list) if d["data"].ndim == 3] + + if not three_d_datasets: + print(f"No 4D datasets found in {file_path}. Available datasets:") + for i, d in enumerate(data_list): + print(f" Dataset {i}: shape {d['data'].shape}, ndim={d['data'].ndim}") + raise ValueError("No 4D or 3D dataset found in file") + + if scan_length is None: + print(f"No 4D datasets found in {file_path}. Available datasets:") + for i, d in enumerate(data_list): + print(f" Dataset {i}: shape {d['data'].shape}, ndim={d['data'].ndim}") + raise ValueError( + "File contains only 3D datasets. To interpret one as 4D-STEM, " + "please specify scan_length so that n_frames % scan_length == 0." + ) + + # Choose first 3D dataset compatible with scan_length along scan_axis + candidates: list[tuple[int, dict]] = [] + for i, d in three_d_datasets: + shape = d["data"].shape + if scan_axis < 0 or scan_axis > 2: + raise ValueError(f"scan_axis must be in [0, 2] for 3D data, got {scan_axis}") + n_frames_axis = shape[scan_axis] + if n_frames_axis % scan_length == 0: + candidates.append((i, d)) + + if not candidates: + print(f"3D datasets in {file_path}:") + for i, d in three_d_datasets: + print(f" Dataset {i}: shape {d['data'].shape}") + raise ValueError( + f"No 3D dataset has length along scan_axis={scan_axis} " + f"divisible by scan_length={scan_length}." + ) + + dataset_index, imported_data = candidates[0] + if len(candidates) > 1: + print( + f"Multiple 3D datasets compatible with scan_length={scan_length} " + f"along scan_axis={scan_axis}. Using dataset {dataset_index} " + f"with shape {imported_data['data'].shape}" + ) + + imported_data = _reshape_3d_to_4d( + imported_data, + dataset_index_local=dataset_index, + scan_length_local=scan_length, + scan_axis_local=scan_axis, + transpose_scan_axes_local=transpose_scan_axes, ) imported_axes = imported_data["axes"] diff --git a/src/quantem/core/utils/imaging_utils.py b/src/quantem/core/utils/imaging_utils.py index d352a051..a4585ad2 100644 --- a/src/quantem/core/utils/imaging_utils.py +++ b/src/quantem/core/utils/imaging_utils.py @@ -6,16 +6,22 @@ import numpy as np import torch from numpy.typing import NDArray -from scipy.ndimage import gaussian_filter +from scipy.ndimage import gaussian_filter, map_coordinates from quantem.core.utils.utils import generate_batches +def _parabolic_peak(v) -> float: + denom = 4.0 * v[1] - 2.0 * v[2] - 2.0 * v[0] + if denom == 0: + return 0.0 + return float((v[2] - v[0]) / denom) + + def dft_upsample( F: NDArray, up: int, shift: Tuple[float, float], - device: str = "cpu", ): """ Matrix multiplication DFT, from: @@ -24,27 +30,55 @@ def dft_upsample( image registration algorithms," Opt. Lett. 33, 156-158 (2008). http://www.sciencedirect.com/science/article/pii/S0045790612000778 """ - if device == "gpu": - import cupy as cp # type: ignore + M, N = F.shape + pixel_radius = 1.5 + num_row = int(math.ceil(pixel_radius * up)) + num_col = num_row - xp = cp - else: - xp = np + col_freq = np.fft.ifftshift(np.arange(N)) - math.floor(N / 2) + row_freq = np.fft.ifftshift(np.arange(M)) - math.floor(M / 2) - M, N = F.shape - du = np.ceil(1.5 * up).astype(int) - row = np.arange(-du, du + 1) - col = np.arange(-du, du + 1) - r_shift = shift[0] - M // 2 - c_shift = shift[1] - N // 2 - - kern_row = np.exp( - -2j * np.pi / (M * up) * np.outer(row, xp.fft.ifftshift(xp.arange(M)) - M // 2 + r_shift) - ) - kern_col = np.exp( - -2j * np.pi / (N * up) * np.outer(xp.fft.ifftshift(xp.arange(N)) - N // 2 + c_shift, col) + row_coords = np.arange(num_row, dtype=float) - float(shift[0]) + col_coords = np.arange(num_col, dtype=float) - float(shift[1]) + + factor_row = -2j * math.pi / (M * float(up)) + factor_col = -2j * math.pi / (N * float(up)) + + row_kern = np.exp(factor_row * (row_coords[:, None] * row_freq[None, :])).astype(F.dtype) + col_kern = np.exp(factor_col * (col_freq[:, None] * col_coords[None, :])).astype(F.dtype) + + return (row_kern @ F @ col_kern).real + + +def _upsampled_correlation_numpy( + imageCorr: NDArray, + upsampleFactor: int, + xyShift: NDArray, +) -> NDArray: + xyShift = np.round(xyShift * float(upsampleFactor)) / float(upsampleFactor) + globalShift = math.floor(math.ceil(upsampleFactor * 1.5) / 2.0) + upsampleCenter = float(globalShift) - (float(upsampleFactor) * xyShift) + + im_up = dft_upsample( + np.conj(imageCorr), upsampleFactor, (float(upsampleCenter[0]), float(upsampleCenter[1])) ) - return xp.real(kern_row @ F @ kern_col) + imageCorrUpsample = np.conj(im_up) + + flat_idx = int(np.argmax(imageCorrUpsample.real)) + r = flat_idx // imageCorrUpsample.shape[1] + c = flat_idx % imageCorrUpsample.shape[1] + + dx = 0.0 + dy = 0.0 + patch = imageCorrUpsample.real[r - 1 : r + 2, c - 1 : c + 2] + if patch.shape == (3, 3): + dx = _parabolic_peak(patch[:, 1]) + dy = _parabolic_peak(patch[1, :]) + + xySubShift = np.array([float(r), float(c)], dtype=float) - float(globalShift) + xyShift = xyShift + (xySubShift + np.array([dx, dy], dtype=float)) / float(upsampleFactor) + + return xyShift def cross_correlation_shift( @@ -55,7 +89,6 @@ def cross_correlation_shift( return_shifted_image: bool = False, fft_input: bool = False, fft_output: bool = False, - device: str = "cpu", ): """ Estimate subpixel shift between two 2D images using Fourier cross-correlation. @@ -67,98 +100,78 @@ def cross_correlation_shift( im : ndarray Image to align or its FFT if fft_input=True upsample_factor : int - Subpixel upsampling factor (must be > 1 for subpixel accuracy) - fft_input : bool - If True, assumes im_ref and im are already in Fourier space + Subpixel upsampling factor (torch-equivalent behavior): + - <= 2 : half-pixel refinement (parabolic, then rounded to nearest 0.5 px) + - > 2 : additional DFT upsample refinement + max_shift : float or None + Optional radial cutoff in pixel-shift units (keeps only shifts with |shift| <= max_shift) return_shifted_image : bool If True, return the shifted version of `im` aligned to `im_ref` - device : str - 'cpu' or 'gpu' (requires CuPy) + fft_input : bool + If True, assumes im_ref and im are already in Fourier space + fft_output : bool + If True and return_shifted_image=True, return the shifted image in Fourier space Returns ------- shifts : tuple of float (row_shift, col_shift) to align `im` to `im_ref` image_shifted : ndarray (optional) - Shifted image in real space, only returned if return_shifted_image=True + Shifted image in real space (or Fourier space if fft_output=True) """ - if device == "gpu": - import cupy as cp # type: ignore - - xp = cp - else: - xp = np + F_ref = np.asarray(im_ref) if fft_input else np.fft.fft2(np.asarray(im_ref)) + F_im = np.asarray(im) if fft_input else np.fft.fft2(np.asarray(im)) - # Fourier transforms - F_ref = im_ref if fft_input else xp.fft.fft2(im_ref) - F_im = im if fft_input else xp.fft.fft2(im) + cc = F_ref * np.conj(F_im) + cc_real = np.fft.ifft2(cc).real - # Correlation - cc = F_ref * xp.conj(F_im) - cc_real = xp.real(xp.fft.ifft2(cc)) + M, N = cc_real.shape if max_shift is not None: - x = np.fft.fftfreq(cc.shape[0], 1 / cc.shape[0]) - y = np.fft.fftfreq(cc.shape[1], 1 / cc.shape[1]) - mask = x[:, None] ** 2 + y[None, :] ** 2 >= max_shift**2 - cc_real[mask] = 0.0 + x = np.fft.fftfreq(M) * M + y = np.fft.fftfreq(N) * N + mask = x[:, None] ** 2 + y[None, :] ** 2 > float(max_shift) ** 2 + cc_real = cc_real.copy() + cc_real[mask] = -np.inf - # Coarse peak - peak = xp.unravel_index(xp.argmax(cc_real), cc_real.shape) - x0, y0 = peak + flat_idx = int(np.argmax(cc_real)) + x0 = flat_idx // N + y0 = flat_idx % N - # Parabolic refinement - x_inds = xp.mod(x0 + xp.arange(-1, 2), cc.shape[0]).astype(int) - y_inds = xp.mod(y0 + xp.arange(-1, 2), cc.shape[1]).astype(int) + x_inds = [((x0 + dx) % M) for dx in (-1, 0, 1)] + y_inds = [((y0 + dy) % N) for dy in (-1, 0, 1)] vx = cc_real[x_inds, y0] vy = cc_real[x0, y_inds] - def parabolic_peak(v): - return (v[2] - v[0]) / (4 * v[1] - 2 * v[2] - 2 * v[0]) - - dx = parabolic_peak(vx) - dy = parabolic_peak(vy) + dx = _parabolic_peak(vx) + dy = _parabolic_peak(vy) - x0 = (x0 + dx) % cc.shape[0] - y0 = (y0 + dy) % cc.shape[1] + x0 = np.round((float(x0) + float(dx)) * 2.0) / 2.0 + y0 = np.round((float(y0) + float(dy)) * 2.0) / 2.0 - if upsample_factor <= 1: - shifts = (x0, y0) - else: - # Local DFT upsampling - - local = dft_upsample(cc, upsample_factor, (x0, y0), device=device) - peak = np.unravel_index(xp.argmax(local), local.shape) + xy_shift = np.array([x0, y0], dtype=float) - try: - lx, ly = peak - icc = local[lx - 1 : lx + 2, ly - 1 : ly + 2] - if icc.shape == (3, 3): - dxf = parabolic_peak(icc[:, 1]) - dyf = parabolic_peak(icc[1, :]) - else: - raise ValueError("Subarray too close to edge") - except (IndexError, ValueError): - dxf = dyf = 0.0 - - shifts = np.array([x0, y0]) + (np.array(peak) - upsample_factor) / upsample_factor - shifts += np.array([dxf, dyf]) / upsample_factor + if upsample_factor > 2: + xy_shift = _upsampled_correlation_numpy(cc, int(upsample_factor), xy_shift) - shifts = (shifts + 0.5 * np.array(cc.shape)) % cc.shape - 0.5 * np.array(cc.shape) + shifts = np.empty(2, dtype=float) + shifts[0] = ((xy_shift[0] + M / 2) % M) - M / 2 + shifts[1] = ((xy_shift[1] + N / 2) % N) - N / 2 + shifts = (float(shifts[0]), float(shifts[1])) if not return_shifted_image: return shifts - # Fourier shift image (F_im assumed to be FFT) - kx = xp.fft.fftfreq(F_im.shape[0])[:, None] - ky = xp.fft.fftfreq(F_im.shape[1])[None, :] - phase_ramp = xp.exp(-2j * np.pi * (kx * shifts[0] + ky * shifts[1])) + kx = np.fft.fftfreq(F_im.shape[0])[:, None] + ky = np.fft.fftfreq(F_im.shape[1])[None, :] + phase_ramp = np.exp(-2j * np.pi * (kx * shifts[0] + ky * shifts[1])) F_im_shifted = F_im * phase_ramp + if fft_output: image_shifted = F_im_shifted else: - image_shifted = xp.real(xp.fft.ifft2(F_im_shifted)) + image_shifted = np.fft.ifft2(F_im_shifted).real return shifts, image_shifted @@ -175,7 +188,6 @@ def cross_correlation_shift_torch( xy_shift = align_images_fourier_torch(G1, G2, upsample_factor) - # convert to centered signed shifts as original code M, N = im_ref.shape dx = ((xy_shift[0] + M / 2) % M) - M / 2 dy = ((xy_shift[1] + N / 2) % N) - N / 2 @@ -197,12 +209,10 @@ def align_images_fourier_torch( cc = G1 * G2.conj() cc_real = torch.fft.ifft2(cc).real - # local max (integer) flat_idx = torch.argmax(cc_real) x0 = (flat_idx // cc_real.shape[1]).to(torch.long).item() y0 = (flat_idx % cc_real.shape[1]).to(torch.long).item() - # half pixel shifts: pick ±1 indices with wrap (mod) M, N = cc_real.shape x_inds = [((x0 + dx) % M) for dx in (-1, 0, 1)] y_inds = [((y0 + dy) % N) for dy in (-1, 0, 1)] @@ -210,14 +220,11 @@ def align_images_fourier_torch( vx = cc_real[x_inds, y0] vy = cc_real[x0, y_inds] - # parabolic half-pixel refine - # dx = (vx[2] - vx[0]) / (4*vx[1] - 2*vx[2] - 2*vx[0]) denom_x = 4.0 * vx[1] - 2.0 * vx[2] - 2.0 * vx[0] denom_y = 4.0 * vy[1] - 2.0 * vy[2] - 2.0 * vy[0] dx = (vx[2] - vx[0]) / denom_x if denom_x != 0 else torch.tensor(0.0, device=device) dy = (vy[2] - vy[0]) / denom_y if denom_y != 0 else torch.tensor(0.0, device=device) - # round to nearest half-pixel x0 = torch.round((x0 + dx) * 2.0) / 2.0 y0 = torch.round((y0 + dy) * 2.0) / 2.0 @@ -242,7 +249,6 @@ def upsampled_correlation_torch( xyShift: 2-element tensor (x,y) in image coords; must be half-pixel precision as described. Returns refined xyShift (tensor length 2). """ - assert upsampleFactor > 2 xyShift = torch.round(xyShift * float(upsampleFactor)) / float(upsampleFactor) @@ -253,26 +259,19 @@ def upsampled_correlation_torch( im_up = dftUpsample_torch(conj_input, upsampleFactor, upsampleCenter) imageCorrUpsample = im_up.conj() - # find maximum - # flatten argmax -> unravel to 2D flat_idx = torch.argmax(imageCorrUpsample.real) - # unravel_index xySubShift0 = (flat_idx // imageCorrUpsample.shape[1]).to(torch.long) xySubShift1 = (flat_idx % imageCorrUpsample.shape[1]).to(torch.long) xySubShift = torch.tensor([xySubShift0.item(), xySubShift1.item()]) - # parabolic subpixel refinement dx = 0.0 dy = 0.0 try: - # extract 3x3 patch around found peak r = xySubShift[0].item() c = xySubShift[1].item() patch = imageCorrUpsample.real[r - 1 : r + 2, c - 1 : c + 2] - # if patch is incomplete (near edge) this will raise / have wrong shape -> except if patch.shape == (3, 3): icc = patch - # dx corresponds to row direction (vertical axis) as in original code: dx = (icc[2, 1] - icc[0, 1]) / (4.0 * icc[1, 1] - 2.0 * icc[2, 1] - 2.0 * icc[0, 1]) dy = (icc[1, 2] - icc[1, 0]) / (4.0 * icc[1, 1] - 2.0 * icc[1, 2] - 2.0 * icc[1, 0]) dx = dx.item() @@ -282,7 +281,6 @@ def upsampled_correlation_torch( except Exception: dx, dy = 0.0, 0.0 - # convert xySubShift to zero-centered by subtracting globalShift xySubShift = xySubShift.to(dtype=torch.get_default_dtype()) xySubShift = xySubShift - globalShift.to(xySubShift.dtype) @@ -311,13 +309,9 @@ def dftUpsample_torch( numRow = int(math.ceil(pixelRadius * upsampleFactor)) numCol = numRow - # prepare the vectors exactly like the numpy version - # col: frequency indices (centered) for N col_freq = torch.fft.ifftshift(torch.arange(N, device=device)) - math.floor(N / 2) - # row: frequency indices (centered) for M row_freq = torch.fft.ifftshift(torch.arange(M, device=device)) - math.floor(M / 2) - # small upsample grid coordinates (integer positions in the UPSAMPLED GRID) col_coords = torch.arange(numCol, device=device, dtype=torch.get_default_dtype()) - float( xyShift[1] ) @@ -325,28 +319,125 @@ def dftUpsample_torch( xyShift[0] ) - # build kernels: note factor signs and denominators match original numpy code - # colKern: shape (N, numCol) factor_col = -2j * math.pi / (N * float(upsampleFactor)) - # outer(col_freq, col_coords) -> shape (N, numCol) colKern = torch.exp(factor_col * (col_freq.unsqueeze(1) * col_coords.unsqueeze(0))).to( imageCorr.dtype ) - # rowKern: shape (numRow, M) factor_row = -2j * math.pi / (M * float(upsampleFactor)) - # outer(row_coords, row_freq) -> shape (numRow, M) rowKern = torch.exp(factor_row * (row_coords.unsqueeze(1) * row_freq.unsqueeze(0))).to( imageCorr.dtype ) - # perform the small-matrix DFT: (numRow, M) @ (M, N) @ (N, numCol) -> (numRow, numCol) imageUpsample = rowKern @ imageCorr @ colKern - # original code took xp.real(...) before returning return imageUpsample.real +def weighted_cross_correlation_shift( + im_ref=None, + im=None, + *, + cc=None, + weight_real=None, + upsample_factor: int = 1, + max_shift=None, + fft_input: bool = False, + fft_output: bool = False, + return_shifted_image: bool = False, +): + """ + Weighted peak selection + DFT subpixel refinement for Fourier cross-correlation. + + Provide either: + - im_ref and im (real-space images, or Fourier-domain if fft_input=True), OR + - cc (the Fourier-domain cross-spectrum), where cc = F_ref * conj(F_im) + + The weight is applied ONLY in real-space correlation to choose the peak location, + but the subpixel refinement uses the true (unweighted) cross-spectrum `cc`. + + Returns + ------- + shift_rc : tuple[float, float] + (d_row, d_col) shift to apply to `im` to align it to `im_ref`. + shifted : ndarray (optional) + If return_shifted=True: shifted image. If fft_output=True returns FFT (corner-centered), + else returns real-space image. + """ + if cc is None: + if im_ref is None or im is None: + raise ValueError("Provide either `cc` or both `im_ref` and `im`.") + F_ref = np.asarray(im_ref) if fft_input else np.fft.fft2(np.asarray(im_ref)) + F_im = np.asarray(im) if fft_input else np.fft.fft2(np.asarray(im)) + cc = F_ref * np.conj(F_im) + else: + cc = np.asarray(cc) + F_im = None + + cc_real = np.fft.ifft2(cc).real + M, N = cc_real.shape + + if weight_real is not None: + w = np.asarray(weight_real) + if w.shape != cc_real.shape: + raise ValueError( + f"weight_real.shape={w.shape} must match correlation shape {cc_real.shape}." + ) + cc_pick = cc_real * w + else: + cc_pick = cc_real + + if max_shift is not None: + fr = np.fft.fftfreq(M) * M + fc = np.fft.fftfreq(N) * N + mask = fr[:, None] ** 2 + fc[None, :] ** 2 > float(max_shift) ** 2 + cc_pick = cc_pick.copy() + cc_pick[mask] = -np.inf + + flat_idx = int(np.argmax(cc_pick)) + x0 = flat_idx // N + y0 = flat_idx % N + + x_inds = [((x0 + dx) % M) for dx in (-1, 0, 1)] + y_inds = [((y0 + dy) % N) for dy in (-1, 0, 1)] + vx = cc_pick[x_inds, y0] + vy = cc_pick[x0, y_inds] + + dx = _parabolic_peak(vx) + dy = _parabolic_peak(vy) + + x0 = np.round((float(x0) + float(dx)) * 2.0) / 2.0 + y0 = np.round((float(y0) + float(dy)) * 2.0) / 2.0 + xy_shift = np.array([x0, y0], dtype=float) + + if upsample_factor > 2: + xy_shift = _upsampled_correlation_numpy(cc, int(upsample_factor), xy_shift) + + dr = ((xy_shift[0] + M / 2) % M) - M / 2 + dc = ((xy_shift[1] + N / 2) % N) - N / 2 + shift_rc = (float(dr), float(dc)) + + if not return_shifted_image: + return shift_rc + + if im is None: + raise ValueError( + "return_shifted_image=True requires `im` (or its FFT via fft_input=True)." + ) + + if F_im is None: + F_im = np.asarray(im) if fft_input else np.fft.fft2(np.asarray(im)) + + kr = np.fft.fftfreq(M)[:, None] + kc = np.fft.fftfreq(N)[None, :] + phase_ramp = np.exp(-2j * np.pi * (kr * shift_rc[0] + kc * shift_rc[1])) + F_im_shifted = F_im * phase_ramp + + if fft_output: + return shift_rc, F_im_shifted + return shift_rc, np.fft.ifft2(F_im_shifted).real + + def bilinear_kde( xa: NDArray, ya: NDArray, @@ -361,32 +452,6 @@ def bilinear_kde( ) -> NDArray | tuple[NDArray, NDArray]: """ Compute a bilinear kernel density estimate (KDE) with smooth threshold masking. - - Parameters - ---------- - xa : NDArray - Vertical (row) coordinates of input points. - ya : NDArray - Horizontal (col) coordinates of input points. - values : NDArray - Weights for each (xa, ya) point. - output_shape : tuple of int - Output image shape (rows, cols). - kde_sigma : float - Standard deviation of Gaussian KDE smoothing. - pad_value : float, default = 1.0 - Value to return when KDE support is too low. - threshold : float, default = 1e-3 - Minimum counts_KDE value for trusting the output signal. - lowpass_filter : bool, optional - If True, apply sinc-based inverse filtering to deconvolve the kernel. - max_batch_size : int or None, optional - Max number of points to process in one batch. - - Returns - ------- - NDArray - The estimated KDE image with threshold-masked output. """ rows, cols = output_shape xF = np.floor(xa.ravel()).astype(int) @@ -416,14 +481,12 @@ def bilinear_kde( inds_1D, weights=weights * w[start:end], minlength=rows * cols ) - # Reshape to 2D and apply Gaussian KDE pix_count = pix_count.reshape(output_shape) pix_output = pix_output.reshape(output_shape) pix_count = gaussian_filter(pix_count, kde_sigma) pix_output = gaussian_filter(pix_output, kde_sigma) - # Final image weight = np.minimum(pix_count / threshold, 1.0) image = pad_value * (1.0 - weight) + weight * (pix_output / np.maximum(pix_count, 1e-8)) @@ -455,23 +518,7 @@ def bilinear_array_interpolation( ) -> NDArray: """ Bilinear sampling of values from an array and pixel positions. - - Parameters - ---------- - image: np.ndarray - Image array to sample from - xa: np.ndarray - Vertical interpolation sampling positions of image array in pixels - ya: np.ndarray - Horizontal interpolation sampling positions of image array in pixels - - Returns - ------- - values: np.ndarray - Bilinear interpolation values of array at (xa,ya) positions - """ - xF = np.floor(xa.ravel()).astype("int") yF = np.floor(ya.ravel()).astype("int") dx = xa.ravel() - xF @@ -497,10 +544,7 @@ def bilinear_array_interpolation( values[start:end] += raveled_image[inds_1D] * weights - values = np.reshape( - values, - xa.shape, - ) + values = np.reshape(values, xa.shape) return values @@ -512,20 +556,7 @@ def fourier_cropping( """ Crops a corner-centered FFT array to retain only the lowest frequencies, equivalent to a center crop on the fftshifted version. - - Parameters: - ----------- - corner_centered_array : ndarray - 2D array (typically result of np.fft.fft2) with corner-centered DC - crop_shape : tuple of int - (height, width) of the desired cropped array (could be odd or even depending on arr.shape) - - Returns: - -------- - cropped : ndarray - Cropped array containing only the lowest frequencies, still corner-centered. """ - H, W = corner_centered_array.shape crop_h, crop_w = crop_shape @@ -536,13 +567,9 @@ def fourier_cropping( result = np.zeros(crop_shape, dtype=corner_centered_array.dtype) - # Top-left result[:h1, :w1] = corner_centered_array[:h1, :w1] - # Top-right result[:h1, -w2:] = corner_centered_array[:h1, -w2:] - # Bottom-left result[-h2:, :w1] = corner_centered_array[-h2:, :w1] - # Bottom-right result[-h2:, -w2:] = corner_centered_array[-h2:, -w2:] return result @@ -556,22 +583,6 @@ def compute_fsc_from_halfsets( """ Compute radially averaged Fourier Shell Correlation (FSC) from two half-set reconstructions. - - Parameters - ---------- - halfset_recons : list[torch.Tensor] - Two statistically-independent reconstructions, using half the dataset. - sampling: tuple[float,float] - Reconstruction sampling in Angstroms. - epsilon: float, optional - Small number to avoid dividing by zero - - Returns - ------- - q_bins: NDarray - Spatial frequency bins - fsc : NDarray - Fourier shell correlation as function of spatial frequency """ r1, r2 = halfset_recons @@ -601,12 +612,10 @@ def compute_fsc_from_halfsets( w0 = 1.0 - d_ind w1 = d_ind - # Flatten arrays cross = cross.reshape(-1) p1 = p1.reshape(-1) p2 = p2.reshape(-1) - # Accumulate cross_b = torch.bincount(inds_f, weights=cross * w0, minlength=num_bins) + torch.bincount( inds_f + 1, weights=cross * w1, minlength=num_bins ) @@ -636,45 +645,14 @@ def compute_spectral_snr_from_halfsets( ): """ Compute spectral SNR from two half-set reconstructions using symmetric/antisymmetric decomposition. - - The method decomposes the Fourier transforms into: - - Symmetric: (F₁ + F₂)/2 → signal + correlated noise - - Antisymmetric: (F₁ - F₂)/2 → uncorrelated noise only - - SSNR(q) = sqrt(signal_power / noise_power) - - where: - - signal_power = (|symmetric|² - |antisymmetric|²)₊ - - noise_power = |antisymmetric|² - - Parameters - ---------- - halfset_recons : list[torch.Tensor] - Two statistically-independent reconstructions, using half the dataset. - sampling: tuple[float,float] - Reconstruction sampling in Angstroms. - total_dose: float - Total _normalized_ electron dose, e.g. in DirectPtychography this is ~self.num_bf - epsilon: float, optional - Small number to avoid dividing by zero - - Returns - ------- - q_bins: NDarray - Spatial frequency bins - ssnr : NDarray - Radially averaged spectral SNR as function of spatial frequency """ - # Compute Fourier transforms halfset_1, halfset_2 = halfset_recons F1 = torch.fft.fft2(halfset_1) F2 = torch.fft.fft2(halfset_2) - # Symmetric and antisymmetric decomposition symmetric = (F1 + F2) / 2 antisymmetric = (F1 - F2) / 2 - # Power spectra noise_power = antisymmetric.abs() total_power = symmetric.abs() signal_power = (total_power - noise_power).clamp_min(0) @@ -698,11 +676,9 @@ def compute_spectral_snr_from_halfsets( w0 = 1.0 - d_ind w1 = d_ind - # Flatten arrays signal = signal_power.reshape(-1) noise = noise_power.reshape(-1) - # Accumulate signal_b = torch.bincount(inds_f, weights=signal * w0, minlength=num_bins) + torch.bincount( inds_f + 1, weights=signal * w1, minlength=num_bins ) @@ -725,20 +701,6 @@ def radially_average_fourier_array( ): """ Radially average a corner-centered Fourier array. - - Parameters - ---------- - corner_centered_array : list[torch.Tensor] - Fourier array to average radially. - sampling: tuple[float,float] - Reconstruction sampling in Angstroms. - - Returns - ------- - q_bins: NDarray - Spatial frequency bins - array_1d : NDarray - Radially averaged Fourier array as function of spatial frequency """ device = corner_centered_array.device nx, ny = corner_centered_array.shape @@ -759,10 +721,8 @@ def radially_average_fourier_array( w0 = 1.0 - d_ind w1 = d_ind - # Flatten arrays array = corner_centered_array.reshape(-1) - # Accumulate array_b = torch.bincount(inds_f, weights=array * w0, minlength=num_bins) + torch.bincount( inds_f + 1, weights=array * w1, minlength=num_bins ) @@ -841,9 +801,7 @@ def add_edges(i1, i2): inc = _find_wrap(phi_f[i1], phi_f[i2]) rel = rel_f[i1] + rel_f[i2] - edges.append( # ty:ignore[possibly-missing-attribute] - torch.stack([i1, i2, rel, inc], dim=1) - ) + edges.append(torch.stack([i1, i2, rel, inc], dim=1)) if wrap_around: add_edges(idx.flatten(), torch.roll(idx, -1, 1).flatten()) @@ -855,7 +813,6 @@ def add_edges(i1, i2): edges = torch.cat(edges, dim=0) edges = edges[edges[:, 2].argsort()] - # return integer tensors only (CPU) return ( edges[:, 0].long(), edges[:, 1].long(), @@ -884,7 +841,6 @@ def union(self, x, y, inc_xy): if rx == ry: return - # phase(y) + oy + inc = phase(x) + ox delta = ox - oy - inc_xy if self.rank[rx] < self.rank[ry]: @@ -962,18 +918,6 @@ def _unwrap_phase_2d_torch_poisson( ): """ Least-squares / Poisson phase unwrapping with optional mask. - - Parameters - ---------- - phi_wrapped : (H, W) tensor - Wrapped phase in (-pi, pi], any device - mask : (H, W) bool tensor, optional - True = valid pixel - - Returns - ------- - phi_unwrapped : (H, W) tensor - Unwrapped phase (same device as input) """ device = phi_wrapped.device dtype = phi_wrapped.dtype @@ -1013,10 +957,10 @@ def _unwrap_phase_2d_torch_poisson( denom = kx**2 + ky**2 + regularization_lambda else: denom = kx**2 + ky**2 - denom[0, 0] = 1.0 # avoid divide by zero + denom[0, 0] = 1.0 phi_hat = -div_hat / denom - phi_hat[0, 0] = 0.0 # fix piston + phi_hat[0, 0] = 0.0 phi = torch.fft.ifftn(phi_hat).real @@ -1165,3 +1109,67 @@ def radially_project_fourier_tensor( array_1d = array_1d[0] return q_bins_out, array_1d + + +def rotate_image( + im, + rotation_deg: float, + origin: tuple[float, float] | None = None, + clockwise: bool = True, + interpolation: str = "bilinear", + mode: str = "constant", + cval: float = 0.0, +): + """Rotate an array about a pixel origin using bilinear/bicubic interpolation.""" + im = np.asarray(im) + if im.ndim < 2: + raise ValueError("im must have at least 2 dimensions") + + H, W = im.shape[-2], im.shape[-1] + if origin is None: + r0 = float(H // 2) + c0 = float(W // 2) + else: + r0 = float(origin[0]) + c0 = float(origin[1]) + + interp = str(interpolation).lower() + if interp in {"bilinear", "linear"}: + order = 1 + elif interp in {"bicubic", "cubic"}: + order = 3 + else: + raise ValueError("interpolation must be 'bilinear' or 'bicubic'") + + theta = float(np.deg2rad(rotation_deg)) + if not clockwise: + theta = -theta + + ct = float(np.cos(theta)) + st = float(np.sin(theta)) + + r_out, c_out = np.meshgrid( + np.arange(H, dtype=np.float64), + np.arange(W, dtype=np.float64), + indexing="ij", + ) + + c_rel = c_out - c0 + r_rel = r_out - r0 + + c_in = ct * c_rel + st * r_rel + c0 + r_in = -st * c_rel + ct * r_rel + r0 + + coords = np.vstack((r_in.ravel(), c_in.ravel())) + + if im.ndim == 2: + out = map_coordinates(im, coords, order=order, mode=mode, cval=cval) + return out.reshape(H, W) + + prefix = im.shape[:-2] + n = int(np.prod(prefix)) if prefix else 1 + im_flat = im.reshape(n, H, W) + out_flat = np.empty((n, H * W), dtype=np.result_type(im_flat.dtype, np.float64)) + for i in range(n): + out_flat[i] = map_coordinates(im_flat[i], coords, order=order, mode=mode, cval=cval) + return out_flat.reshape(*prefix, H, W) diff --git a/src/quantem/diffraction/__init__.py b/src/quantem/diffraction/__init__.py index 4367682b..fed057f1 100644 --- a/src/quantem/diffraction/__init__.py +++ b/src/quantem/diffraction/__init__.py @@ -1 +1,6 @@ +from quantem.diffraction.polar import RDF as RDF +from quantem.diffraction.bragg_vectors import BraggVectors as BraggVectors +from quantem.diffraction.strain import StrainMap as StrainMap +from quantem.diffraction.strain_autocorrelation import StrainMapAutocorrelation as StrainMapAutocorrelation +from quantem.diffraction.maped import MAPED as MAPED from quantem.diffraction.model_fitting import ModelDiffraction as ModelDiffraction diff --git a/src/quantem/diffraction/bragg_vectors.py b/src/quantem/diffraction/bragg_vectors.py new file mode 100644 index 00000000..a2da2eba --- /dev/null +++ b/src/quantem/diffraction/bragg_vectors.py @@ -0,0 +1,1462 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal, Sequence, Union + +import numpy as np +import torch +from numpy.typing import NDArray + +from quantem.core.datastructures.dataset2d import Dataset2d +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.datastructures.vector import Vector +from quantem.core.io.serialize import AutoSerialize +from quantem.diffraction.bragg_vectors_visualization import ( + plot_basis_vectors, + plot_bvm, + plot_detection, + plot_diffraction_grid, + plot_lattice_fit, + plot_reference_lattice, + plot_template, +) +from quantem.diffraction.disk_detection import ( + cross_correlation, + detect_disks_batch, + estimate_central_beam, + make_template, + probe_centroid, + synthetic_probe, + template_fourier, +) +from quantem.diffraction.strain import StrainMap + +PEAK_FIELDS = ("q_row", "q_col", "intensity") + + +class BraggVectors(AutoSerialize): + """Correlation-based Bragg disk detection and lattice fitting for 4D-STEM. + + Workflow (each step writes state consumed by the next): + + 1. ``make_template_*`` – build a cross-correlation template, either from a + synthetic soft disk (:meth:`make_template_synthetic`), by averaging data + over an ROI (:meth:`make_template_from_data`), or from an explicit probe + image (:meth:`make_template_from_probe`). + 2. :meth:`detect_disks` – template-match every scan position; detected peaks + are stored in :attr:`peaks` (a :class:`Vector` of ``[q_row, q_col, + intensity]``, numpy-backed) and accumulated into the Bragg vector map + :attr:`bvm`. + 3. :meth:`choose_basis_vectors` – pick the lattice basis ``(origin, g1, g2)`` + from the BVM, automatically or by hand; also stores the numbered candidate + peaks. + 4. :meth:`index_peaks` – quick, lightweight: index just the picked candidate + peaks into a reference lattice (``reference_ab``/``reference_qpos``). + 5. :meth:`fit_lattice` – the heavy step: at every scan position, match the + detections to the reference within ``max_peak_shift``, intensity-weighted + least-squares fit the lattice vectors into ``u_array``/``v_array`` of shape + ``(scan_row, scan_col, 2)``, and compute the per-position ``mask_weight``. + 6. :meth:`calculate_strain_map` – hand the lattice vectors (and + ``mask_weight``) to a :class:`~quantem.diffraction.strain.StrainMap`. + + Detection runs in torch (CPU now, CUDA later); the ragged peak table is held + in a numpy-backed :class:`Vector`. The detector→scan rotation is read from + the parent dataset metadata (``q_to_r_rotation_ccw_deg`` + ``q_transpose``), + the single source of truth shared with the DPC/CoM workflow. + + Use :meth:`from_dataset` to construct an instance. + + Parameters + ---------- + dataset : Dataset4dstem + The 4D-STEM dataset to analyze. + device : str, default="cpu" + Torch device used for detection (e.g. ``"cpu"`` or ``"cuda"``). + """ + + _token = object() + + # nanobeam lattice vectors are measured in reciprocal space -> sign flip + real_space: bool = False + + def __init__( + self, + dataset: Dataset4dstem, + device: str = "cpu", + _token: object | None = None, + ): + if _token is not self._token: + raise RuntimeError("Use BraggVectors.from_dataset() to instantiate this class.") + super(BraggVectors, self).__init__() + self.dataset = dataset + self.device = device + + self.peaks: Vector | None = None + self.bvm: Dataset2d | None = None + + self.origin: np.ndarray | None = None + self.g1: np.ndarray | None = None + self.g2: np.ndarray | None = None + + # candidate peaks picked from the BVM in choose_basis_vectors(), and the + # reference lattice (those candidates indexed) set by index_peaks(). + self.candidates_rc: np.ndarray | None = None + self.candidates_intensity: np.ndarray | None = None + self.reference_ab: np.ndarray | None = None + self.reference_qpos: np.ndarray | None = None + self.reference_intensity: np.ndarray | None = None + + self.u_array: np.ndarray | None = None + self.v_array: np.ndarray | None = None + # per-position diagnostics from fit_lattice() + self.mask_weight: np.ndarray | None = None + self.fit_error: np.ndarray | None = None + + self._template: torch.Tensor | None = None + self._template_ft: torch.Tensor | None = None + self.metadata: dict[str, Any] = {} + + @classmethod + def from_dataset( + cls, dataset: Dataset4dstem, *, device: str = "cpu", name: str | None = None + ) -> "BraggVectors": + """Create a BraggVectors workflow bound to a 4D-STEM dataset. + + Parameters + ---------- + dataset : Dataset4dstem + The 4D-STEM dataset to analyze. + device : str, default="cpu" + Torch device used for detection (e.g. ``"cpu"`` or ``"cuda"``). + name : str, optional + If given, sets ``dataset.name``. + + Returns + ------- + BraggVectors + A new workflow instance bound to ``dataset``. + """ + if not isinstance(dataset, Dataset4dstem): + raise TypeError("BraggVectors.from_dataset expects a Dataset4dstem instance.") + if name is not None: + dataset.name = name + return cls(dataset=dataset, device=device, _token=cls._token) + + def save( + self, + path: str | Path, + mode: Literal["w", "o"] = "w", + store: Literal["auto", "zip", "dir"] = "auto", + skip: Union[str, type, Sequence[Union[str, type]]] = (), + compression_level: int | None = 4, + *, + include_dataset: bool = False, + ) -> None: + """Save the workflow to disk, excluding the raw 4D-STEM dataset by default. + + Overrides :meth:`~quantem.core.io.serialize.AutoSerialize.save` to drop + :attr:`dataset` — the raw 4D-STEM cube, which dominates the file size — from + serialization by default. The detected :attr:`peaks`, lattice fit + (:attr:`u_array`/:attr:`v_array`), Bragg vector map and all diagnostics are + kept, so the file holds the *results* of the workflow (orders of magnitude + smaller than the data) rather than the data itself. + + ``"dataset"`` is recorded in the file's skip metadata, so a reloaded workflow + simply has no ``dataset`` attribute. Re-attach one (``bv.dataset = ds``) before + calling methods that read the raw cube — :meth:`detect_disks`, + :meth:`correlation_map`, :meth:`make_template_from_data`, + :meth:`calculate_strain_map`, etc. Pass ``include_dataset=True`` to keep the + dataset in the file instead. + + Parameters + ---------- + path : str or Path + Target file path. Use a ``.zip`` extension for zip format, otherwise a + directory is written. + mode : {'w', 'o'}, default='w' + ``'w'`` writes only if the path does not exist; ``'o'`` overwrites. + store : {'auto', 'zip', 'dir'}, default='auto' + Storage format; ``'auto'`` infers from the file extension. + skip : str, type, or sequence of (str or type), default=() + Additional attribute names/types to skip during serialization, merged with + the default ``dataset`` exclusion. + compression_level : int or None, default=4 + Zstandard/Blosc compression level (0–9); ``0`` disables compression. + include_dataset : bool, default=False + If ``True``, keep the raw 4D-STEM :attr:`dataset` in the file (large). The + default ``False`` excludes it. + """ + if isinstance(skip, (str, type)): + skip = [skip] + else: + skip = list(skip) + if not include_dataset and "dataset" not in skip: + skip.append("dataset") + # Explicit (two-arg) super() rather than the bare super(): the zero-arg form + # needs a compiler-created __class__ closure cell that is absent when this + # method's source is re-exec'd from a string (Jupyter autoreload), which + # raises "super(): __class__ cell not found". The explicit form is immune. + super(BraggVectors, self).save( + path, + mode=mode, + store=store, + skip=skip, + compression_level=compression_level, + ) + + # ---- main methods ---- + + def make_template_synthetic( + self, + radius: float | None = None, + edge: float = 1.0, + center: tuple[float, float] | None = None, + subtract_mean: bool = True, + ) -> "BraggVectors": + """Build the template from a synthetic soft-edged disk. + + Parameters + ---------- + radius : float, optional + Disk radius in pixels. Defaults to a rough estimate from the mean + diffraction pattern + (:func:`~quantem.diffraction.disk_detection.estimate_central_beam`); + pass it explicitly when several disks share comparable intensity. + edge : float, default=1.0 + Width in pixels of the ``tanh`` edge falloff. + center : tuple of float, optional + ``(row, col)`` disk center; defaults to the detector center + ``(H // 2, W // 2)``. + subtract_mean : bool, default=True + If ``True``, make the template zero-sum — a band-pass kernel that + suppresses uniform background in the correlation. + + Returns + ------- + BraggVectors + ``self``, for method chaining. + """ + H, W = int(self.dataset.shape[-2]), int(self.dataset.shape[-1]) + if radius is None: + dp_mean = torch.as_tensor( + np.asarray(self.dataset.dp_mean.array), dtype=torch.float, device=self.device + ) + _, radius = estimate_central_beam(dp_mean) + if center is None: + center = (H // 2, W // 2) + probe = synthetic_probe((H, W), float(radius), edge=edge, center=center) + self._set_template(probe, center=center, subtract_mean=subtract_mean) + self.metadata["template"] = { + "kind": "synthetic", + "radius": float(radius), + "edge": float(edge), + "center": (float(center[0]), float(center[1])), + } + return self + + def make_template_from_data( + self, + roi: NDArray | None = None, + subtract_mean: bool = True, + center: tuple[float, float] | None = None, + ) -> "BraggVectors": + """Build the template by averaging diffraction patterns from the data. + + Parameters + ---------- + roi : np.ndarray, optional + ``(scan_row, scan_col)`` mask selecting scan positions to average — + ideally a vacuum / single-disk region so the unscattered probe is + isolated. ``None`` (default) averages the whole scan (the mean + diffraction pattern). + subtract_mean : bool, default=True + If ``True``, make the template zero-sum — a band-pass kernel that + suppresses uniform background in the correlation. + center : tuple of float, optional + ``(row, col)`` probe center rolled to the origin; defaults to the + probe's intensity centroid. + + Returns + ------- + BraggVectors + ``self``, for method chaining. + """ + data = torch.as_tensor( + np.asarray(self.dataset.array), dtype=torch.float, device=self.device + ) + if roi is None: + probe = data.mean(dim=(0, 1)) + else: + m = torch.as_tensor(np.asarray(roi) > 0, device=self.device) + if not bool(m.any()): + raise ValueError("roi selects no scan positions.") + probe = data[m].mean(dim=0) + if center is None: + center = probe_centroid(probe) + self._set_template(probe, center=center, subtract_mean=subtract_mean) + self.metadata["template"] = { + "kind": "from_data", + "roi": roi is not None, + "center": (float(center[0]), float(center[1])), + } + return self + + def make_template_from_probe( + self, + probe: NDArray | torch.Tensor, + center: tuple[float, float] | None = None, + subtract_mean: bool = True, + ) -> "BraggVectors": + """Build the template from an explicit probe image (e.g. a measured vacuum probe). + + Parameters + ---------- + probe : np.ndarray or torch.Tensor + Probe image; must match the diffraction-pattern shape. + center : tuple of float, optional + ``(row, col)`` probe center rolled to the origin; defaults to the + probe's intensity centroid. + subtract_mean : bool, default=True + If ``True``, make the template zero-sum — a band-pass kernel that + suppresses uniform background in the correlation. + + Returns + ------- + BraggVectors + ``self``, for method chaining. + """ + probe_t = torch.as_tensor(probe, dtype=torch.float, device=self.device) + if center is None and tuple(probe_t.shape) == tuple(self.dataset.shape[-2:]): + center = probe_centroid(probe_t) + self._set_template(probe_t, center=center, subtract_mean=subtract_mean) + self.metadata["template"] = { + "kind": "from_probe", + "center": (float(center[0]), float(center[1])), + } + return self + + @property + def template(self) -> np.ndarray | None: + """The correlation template, fftshifted to the image center for display (numpy). + + The ``make_template_*`` methods store the template corner-shifted (center + at the ``[0, 0]`` FFT origin) so correlation peaks land at absolute disk + positions; this property shifts it back to the center for plotting. + + Returns + ------- + np.ndarray or None + ``(H, W)`` center-shifted template, or ``None`` if no template has been + built yet. + """ + if self._template is None: + return None + return torch.fft.fftshift(self._template).detach().cpu().numpy() + + def correlation_map(self, row: int, col: int) -> np.ndarray: + """Cross-correlation map of one diffraction pattern with the template (numpy). + + Peaks in the returned map sit at absolute disk positions (no fftshift + needed), matching what :meth:`detect_disks` searches. + + Parameters + ---------- + row : int + Scan row of the diffraction pattern to correlate. + col : int + Scan column of the diffraction pattern to correlate. + + Returns + ------- + np.ndarray + ``(H, W)`` real-space correlation map. + """ + if self._template_ft is None: + raise ValueError("Run a make_template_* method before correlation_map().") + dp = torch.as_tensor( + np.asarray(self.dataset.array[row, col]), dtype=torch.float, device=self.device + ) + corr, _ = cross_correlation(dp, self._template_ft) + return corr.detach().cpu().numpy() + + def detect_disks( + self, + *, + positions: list[tuple[int, int]] | None = None, + min_abs_intensity: float = 0.0, + min_spacing: float = 0.0, + edge_boundary: int = 1, + subpixel: str = "upsample", + upsample_factor: int = 16, + max_num_peaks: int = 1000, + batch_size: int | None = None, + progressbar: bool = True, + ) -> Vector: + """Detect Bragg disks at every scan position (or a subset for testing). + + Pass ``positions`` to test detection hyperparameters on a handful of + patterns without scanning the full grid; the returned :class:`Vector` then + has shape ``(len(positions),)`` and the workflow state is left untouched. + With ``positions=None`` the full scan is processed, :attr:`peaks` and + :attr:`bvm` are populated, and the same Vector is returned. Patterns are + processed in batches (the FFTs and subpixel refinement run together across + the batch, which is far faster on a GPU); results are identical to detecting + each pattern on its own. + + Parameters + ---------- + positions : list of tuple of int, optional + ``(row, col)`` scan positions to test on. ``None`` (default) processes + the full scan and updates the workflow state. + min_abs_intensity : float, default=0.0 + Drop correlation peaks below this absolute intensity. + min_spacing : float, default=0.0 + Minimum spacing in pixels between kept peaks; closer / dimmer peaks are + suppressed. + edge_boundary : int, default=1 + Width in pixels of the border in which peaks are ignored. + subpixel : {"none", "parabolic", "upsample"}, default="upsample" + Subpixel refinement mode; see + :func:`~quantem.diffraction.disk_detection.detect_disks`. + upsample_factor : int, default=16 + Upsampling factor for the ``"upsample"`` subpixel refinement. + max_num_peaks : int, default=1000 + Maximum number of peaks to keep per pattern. + batch_size : int, optional + Number of patterns per batch. ``None`` (default) picks a size from the + detector dimensions. + progressbar : bool, default=True + If ``True``, show a tqdm progress bar over the full-scan detection. + + Returns + ------- + Vector + Detected peaks (``[q_row, q_col, intensity]``): shape ``(scan_row, + scan_col)`` for the full scan, or ``(len(positions),)`` for a test run. + """ + if self._template_ft is None: + raise ValueError("Run a make_template_* method before detect_disks().") + + detect_kwargs = dict( + min_abs_intensity=min_abs_intensity, + min_spacing=min_spacing, + edge_boundary=edge_boundary, + subpixel=subpixel, + upsample_factor=upsample_factor, + max_num_peaks=max_num_peaks, + ) + + if positions is not None: + if len(positions) == 0: + raise ValueError("positions must contain at least one (row, col) to test on.") + coords = [(int(r), int(c)) for r, c in positions] + results = self._detect_positions(coords, detect_kwargs, batch_size, progressbar=False) + return Vector.from_data(results, fields=PEAK_FIELDS, name="bragg_peaks_test") + + scan_r, scan_c = int(self.dataset.shape[0]), int(self.dataset.shape[1]) + coords = list(np.ndindex(scan_r, scan_c)) + results = self._detect_positions( + coords, detect_kwargs, batch_size, progressbar=progressbar + ) + # Store every cell in a single bulk pass. Per-cell assignment + # (peaks[r, c] = arr) re-concatenates the entire backing buffer on every + # write -- O(N^2) over the scan, which is the stall at the end of + # detection. from_data stacks all cells with one _replace_cells call. + nested = [results[r * scan_c : (r + 1) * scan_c] for r in range(scan_r)] + peaks = Vector.from_data(nested, fields=PEAK_FIELDS, name="bragg_peaks") + + self.peaks = peaks + self.metadata["detect"] = detect_kwargs + self.compute_bvm() + return peaks + + def compute_bvm(self, sampling: float = 1.0) -> Dataset2d: + """Accumulate all detected peaks into a Bragg vector map (intensity histogram). + + Parameters + ---------- + sampling : float, default=1.0 + Reciprocal-space sampling (per pixel) stored on the returned dataset. + + Returns + ------- + Dataset2d + ``(H, W)`` Bragg vector map, also stored on :attr:`bvm`. + """ + if self.peaks is None: + raise ValueError("Run detect_disks() before compute_bvm().") + H, W = (int(self.dataset.shape[-2]), int(self.dataset.shape[-1])) + flat = self.peaks.select_fields("q_row", "q_col", "intensity").flatten() + + bvm = np.zeros((H, W), dtype=float) + if flat.shape[0] > 0: + rows = np.clip(np.round(flat[:, 0]).astype(int), 0, H - 1) + cols = np.clip(np.round(flat[:, 1]).astype(int), 0, W - 1) + np.add.at(bvm, (rows, cols), flat[:, 2]) + + self.bvm = Dataset2d.from_array( + bvm, name="bragg_vector_map", sampling=(sampling, sampling), signal_units="intensity" + ) + return self.bvm + + def choose_basis_vectors( + self, + origin: int | tuple[float, float] | NDArray | None = None, + g1: int | tuple[float, float] | NDArray | None = None, + g2: int | tuple[float, float] | NDArray | None = None, + *, + num_candidates: int = 100, + min_spacing: float = 2.0, + min_abs_intensity: float = 0.0, + plot: bool = True, + returnfig: bool = False, + **show_kwargs, + ): + """Select the lattice basis ``(origin, g1, g2)`` from the Bragg vector map. + + Any of ``origin``/``g1``/``g2`` may be given explicitly; the rest are picked + automatically. The origin is the **brightest** candidate peak (the + unscattered central beam). With ``quality = intensity / distance`` rewarding + short, bright vectors, ``g1`` is then the highest-``quality`` peak and ``g2`` + is the highest ``quality * sin^2(theta)`` peak, where ``theta`` is its angle + to ``g1`` (the ``sin^2`` factor vanishes for peaks collinear with ``g1``). + + Each override accepts **either** form, told apart by shape: + + - a scalar **candidate index** (an ``int``) picks one of the numbered + candidate peaks drawn on the plot; + - a **``(row, col)`` vector** is taken literally — an absolute position for + ``origin``, an offset *from the origin* for ``g1``/``g2``. + + With ``plot=True`` (default) the Bragg vector map is shown with the + candidate peaks numbered and the chosen basis overlaid, so the index to + pass back here can be read straight off the figure. + + Parameters + ---------- + origin : int or tuple of float or np.ndarray, optional + Candidate index, or absolute ``(row, col)`` lattice origin. Picked + automatically if omitted. + g1 : int or tuple of float or np.ndarray, optional + Candidate index (vector taken as ``peak - origin``), or a ``(row, col)`` + offset *from the origin*. Picked automatically if omitted. + g2 : int or tuple of float or np.ndarray, optional + Candidate index (vector taken as ``peak - origin``), or a ``(row, col)`` + offset *from the origin*. Picked automatically if omitted. + num_candidates : int, default=100 + Number of brightest candidate peaks to consider (and to number on the + plot). + min_spacing : float, default=2.0 + Minimum spacing in pixels between candidate peaks. + min_abs_intensity : float, default=0.0 + Drop candidate peaks below this absolute intensity. + plot : bool, default=True + If ``True``, show the Bragg vector map with the numbered candidates and + the chosen basis overlaid via + :func:`~quantem.diffraction.bragg_vectors_visualization.plot_basis_vectors`. + returnfig : bool, default=False + If ``True``, return ``(fig, ax)`` from the overlay plot instead of + ``self`` (implies ``plot=True``). + **show_kwargs + Display-scaling options forwarded to the overlay plot's + :func:`~quantem.core.visualization.show_2d` call (e.g. ``norm``, + ``vmin``, ``vmax``, ``cmap``, ``lower_quantile``, ``upper_quantile``). + + Returns + ------- + BraggVectors or tuple + ``self`` for method chaining; or ``(fig, ax)`` when ``returnfig=True``. + """ + if self.bvm is None: + raise ValueError("Run detect_disks()/compute_bvm() before choosing basis vectors.") + + cand_rc, cand_int = self._bvm_candidates(num_candidates, min_spacing, min_abs_intensity) + if cand_rc.shape[0] == 0: + raise RuntimeError("No candidate peaks found in the Bragg vector map.") + # remember the numbered candidates so index_peaks() reuses this exact set. + self.candidates_rc = cand_rc + self.candidates_intensity = cand_int + + def _candidate(idx: int) -> NDArray: + i = int(idx) + if not -cand_rc.shape[0] <= i < cand_rc.shape[0]: + raise IndexError( + f"candidate index {i} out of range for {cand_rc.shape[0]} candidates " + f"(increase num_candidates or loosen min_spacing/min_abs_intensity)." + ) + return cand_rc[i] + + if origin is None: + # candidates are returned brightest-first, so [0] is the central beam. + origin_rc = cand_rc[0] + elif np.ndim(origin) == 0: + origin_rc = _candidate(origin) + else: + origin_rc = np.asarray(origin, dtype=float).reshape(2) + + rel = cand_rc - origin_rc + dist = np.linalg.norm(rel, axis=1) + valid = dist > max(1e-6, min_spacing) + # "shortest/brightest": reward bright peaks close to the origin. + quality = cand_int / (dist + 1e-12) + + if g1 is None: + if not valid.any(): + raise RuntimeError("Could not find a g1 candidate distinct from the origin.") + g1_rc = rel[int(np.argmax(np.where(valid, quality, -np.inf)))] + elif np.ndim(g1) == 0: + g1_rc = _candidate(g1) - origin_rc + else: + g1_rc = np.asarray(g1, dtype=float).reshape(2) + + if g2 is None: + # g2 = highest quality * sin^2(theta) where theta is the angle to g1; + # sin^2 = 1 - cos^2 vanishes for peaks collinear with g1. + g1n = g1_rc / (np.linalg.norm(g1_rc) + 1e-12) + cos = rel @ g1n / (dist + 1e-12) + sin2 = np.clip(1.0 - cos**2, 0.0, 1.0) + g2_score = np.where(valid, quality * sin2, -np.inf) + if g2_score.max() <= 0.0: + raise RuntimeError("Could not find a g2 candidate non-collinear with g1.") + g2_rc = rel[int(np.argmax(g2_score))] + elif np.ndim(g2) == 0: + g2_rc = _candidate(g2) - origin_rc + else: + g2_rc = np.asarray(g2, dtype=float).reshape(2) + + self.origin = np.asarray(origin_rc, dtype=float).reshape(2) + self.g1 = np.asarray(g1_rc, dtype=float).reshape(2) + self.g2 = np.asarray(g2_rc, dtype=float).reshape(2) + + if plot or returnfig: + fig, ax = plot_basis_vectors( + np.asarray(self.bvm.array), + cand_rc, + cand_int, + self.origin, + self.g1, + self.g2, + **show_kwargs, + ) + if returnfig: + return fig, ax + return self + + def index_peaks( + self, + *, + plot: bool = True, + returnfig: bool = False, + **show_kwargs, + ): + """Index the chosen candidate peaks into a reference lattice. + + Assigns integer Miller indices ``(a, b)`` to the numbered candidate peaks + picked in :meth:`choose_basis_vectors` — not the full per-position + detections; that heavy lifting happens in :meth:`fit_lattice`. Each + candidate gets ``[a, b] = round(B^-1 (q - origin))`` with ``B`` columns + ``g1, g2``; when two candidates round to the same ``(a, b)`` only the + brightest is kept. The result is a compact reference lattice stored on + :attr:`reference_ab` / :attr:`reference_qpos` / :attr:`reference_intensity` + which :meth:`fit_lattice` matches against at every scan position. This step + is quick and lightweight. + + With ``plot=True`` (default) the reference lattice is drawn over the Bragg + vector map, each site ringed and labelled with its ``(a, b)`` index. The ring + color encodes how far the picked candidate sits from its ideal lattice site + ``origin + a*g1 + b*g2``, so a mis-picked or duplicate candidate (which rings + far from zero offset) is easy to spot. + + Parameters + ---------- + plot : bool, default=True + If ``True``, show the reference lattice via + :func:`~quantem.diffraction.bragg_vectors_visualization.plot_reference_lattice`. + returnfig : bool, default=False + If ``True``, return ``(fig, ax)`` instead of ``self`` (implies ``plot``). + **show_kwargs + Display-scaling options forwarded to the plot's + :func:`~quantem.core.visualization.show_2d` call (e.g. ``norm``, + ``vmin``, ``vmax``, ``cmap``). + + Returns + ------- + BraggVectors or tuple + ``self`` for method chaining; or ``(fig, ax)`` when ``returnfig=True``. + """ + if self.candidates_rc is None or self.candidates_intensity is None: + raise ValueError("Run choose_basis_vectors() before index_peaks().") + if self.origin is None or self.g1 is None or self.g2 is None: + raise ValueError("Run choose_basis_vectors() before index_peaks().") + + cand_rc = self.candidates_rc + cand_int = self.candidates_intensity + ab = _index_directions(cand_rc, self.origin, self.g1, self.g2) + + # Dedupe by (a, b): candidates are brightest-first, so the first occurrence + # of each index is the brightest -- keep it, drop the dimmer duplicates. + seen: dict[tuple[int, int], int] = {} + for i, (a, b) in enumerate(ab): + key = (int(a), int(b)) + if key not in seen: + seen[key] = i + keep = np.array(sorted(seen.values()), dtype=int) + + self.reference_ab = ab[keep] + self.reference_qpos = cand_rc[keep] + self.reference_intensity = cand_int[keep] + + if not (plot or returnfig): + return self + + fig, ax = plot_reference_lattice( + np.asarray(self.bvm.array), + self.reference_qpos, + self.reference_ab, + self.origin, + self.g1, + self.g2, + **show_kwargs, + ) + if returnfig: + return fig, ax + return self + + def fit_lattice( + self, + min_num_peaks: int = 5, + max_peak_shift: float | None = None, + *, + progressbar: bool = True, + plot: bool = True, + returnfig: bool = False, + ): + """Per-position weighted least-squares fit of the lattice vectors. + + This is the heavy step. At every scan position the detected peaks are + matched to the *ideal* lattice sites from :meth:`index_peaks` — + ``origin + a*g1 + b*g2`` for each reference ``(a, b)`` — keeping a peak only + when it lands within ``max_peak_shift`` of its nearest ideal site (not the + measured candidate position), then ``q = x0 + a*g1 + b*g2`` is fit by + intensity-weighted least squares over the matched peaks. The fitted ``g1``/``g2`` go into :attr:`u_array`/ + :attr:`v_array` (shape ``(scan_row, scan_col, 2)``, row/col components); + positions with fewer than ``min_num_peaks`` matched peaks are left ``nan``. + + Two diagnostics are stored per position. :attr:`fit_error` is the RMS fit + residual over the *matched* peaks, in pixels. :attr:`mask_weight` is a lattice + *order parameter* in ``0``–``1``: every detected peak is snapped to the + nearest site of the just-fitted lattice and the intensity-weighted RMS of + those displacements (the zero beam excluded) is normalized by + ``sqrt(|g1 x g2| / 2*pi)`` — the RMS displacement expected from intensity + scattered at random with no lattice — as ``1 - RMS_all / rms_rand`` clipped to + ``[0, 1]``. Because it weighs *all* detected intensity against the lattice + (not just the matched peaks, as :attr:`fit_error` does), a clean single + crystal approaches ``1`` while positions with strong off-lattice intensity (a + second grain, a mis-index, many spurious peaks) fall toward ``0``; weak false + positives carry little intensity and barely move it. Positions with no valid + fit (vacuum, fewer than ``min_num_peaks``) are ``0``. It is the default + reference weighting handed to :meth:`calculate_strain_map`. + + Parameters + ---------- + min_num_peaks : int, default=5 + Minimum number of matched peaks required to fit a position; positions + with fewer are left ``nan``. + max_peak_shift : float, optional + Inclusion radius in pixels: a detected peak is kept only if it lands + within this distance of its nearest *ideal* lattice site + (``origin + a*g1 + b*g2``), excluding peaks that stray too far from where + the best-fit lattice predicts. Defaults to ``0.5 * min(|g1|, |g2|)`` — + half the shorter lattice spacing. + progressbar : bool, default=True + If ``True``, show a tqdm progress bar over the scan positions. + plot : bool, default=True + If ``True``, show the fit diagnostics (mask weight + RMS error) via + :func:`~quantem.diffraction.bragg_vectors_visualization.plot_lattice_fit`. + returnfig : bool, default=False + If ``True``, return ``(fig, ax)`` instead of ``self`` (implies ``plot``). + + Returns + ------- + BraggVectors or tuple + ``self`` for method chaining; or ``(fig, ax)`` when ``returnfig=True``. + """ + if self.peaks is None: + raise ValueError("Run detect_disks() before fit_lattice().") + if self.reference_ab is None or self.reference_qpos is None: + raise ValueError("Run index_peaks() before fit_lattice().") + + ref_ab = self.reference_ab.astype(float) + + # Ideal lattice sites for the reference (a, b) set: origin + a*g1 + b*g2. + # Per-position detections are matched to these *ideal* points (not the measured + # candidate positions), so a peak is included only when it lands within + # max_peak_shift of where the best-fit lattice predicts it should be. + o = np.asarray(self.origin, dtype=float).reshape(2) + g1 = np.asarray(self.g1, dtype=float).reshape(2) + g2 = np.asarray(self.g2, dtype=float).reshape(2) + ideal_qpos = o[None, :] + ref_ab[:, 0:1] * g1[None, :] + ref_ab[:, 1:2] * g2[None, :] + + if max_peak_shift is None: + max_peak_shift = 0.5 * float(min(np.linalg.norm(g1), np.linalg.norm(g2))) + + # Normalization scale for the mask weight: the intensity-weighted RMS + # peak-to-nearest-site displacement expected from intensity scattered at + # random (uniformly over a unit cell of area |g1 x g2|) -- i.e. with no + # lattice order at all. sqrt(A / 2pi) is the equal-area-disk value; the mask + # weight is then 1 - RMS_all / rms_rand, a lattice order parameter in [0, 1]. + cell_area = abs(float(g1[0] * g2[1] - g1[1] * g2[0])) + rms_rand = float(np.sqrt(cell_area / (2.0 * np.pi))) if cell_area > 0 else 1.0 + + scan_r, scan_c = int(self.dataset.shape[0]), int(self.dataset.shape[1]) + u_array = np.full((scan_r, scan_c, 2), np.nan, dtype=float) + v_array = np.full((scan_r, scan_c, 2), np.nan, dtype=float) + mask_weight = np.zeros((scan_r, scan_c), dtype=float) + fit_error = np.full((scan_r, scan_c), np.nan, dtype=float) + + fields = self.peaks.fields + i_qr, i_qc = fields.index("q_row"), fields.index("q_col") + i_int = fields.index("intensity") + + coords: Any = list(np.ndindex(scan_r, scan_c)) + if progressbar: + try: + from tqdm.auto import tqdm + + coords = tqdm(coords, desc="fit_lattice", leave=True) + except Exception: + pass + + for r, c in coords: + cell = self.peaks[r, c].array + if cell.shape[0] == 0: + continue + qpos = cell[:, [i_qr, i_qc]] + inten = cell[:, i_int] + + # nearest *ideal* lattice site for each detected peak, kept if close enough + d = np.linalg.norm(qpos[:, None, :] - ideal_qpos[None, :, :], axis=2) + nearest = np.argmin(d, axis=1) + matched = d[np.arange(d.shape[0]), nearest] <= max_peak_shift + + if int(matched.sum()) < min_num_peaks: + continue + + beta, rms = _fit_lattice_vectors( + qpos[matched, 0], + qpos[matched, 1], + ref_ab[nearest[matched], 0], + ref_ab[nearest[matched], 1], + inten[matched], + ) + if beta is None: + continue + u_array[r, c] = beta[1] + v_array[r, c] = beta[2] + fit_error[r, c] = rms + + # mask weight = lattice "order parameter": snap EVERY detected peak to the + # nearest site of the just-fitted lattice (beta = [x0, g1, g2]) and take + # the intensity-weighted RMS of those displacements, excluding the zero + # beam. Unlike fit_error (matched peaks only), this sees OFF-lattice + # intensity -- a second grain, a mis-index, or many spurious peaks drive + # it up -- while weak false positives, carrying little intensity, barely + # move it. Normalized by rms_rand: 1 (all intensity on the lattice) down + # to 0 (scattered as if there were no lattice). + x0 = beta[0] + mat = np.stack([beta[1], beta[2]]) # (2, 2): rows g1, g2 + ab = np.rint((qpos - x0[None, :]) @ np.linalg.inv(mat)) + sites = x0[None, :] + ab @ mat + disp = np.linalg.norm(qpos - sites, axis=1) + nonzero = ~np.all(ab == 0, axis=1) # drop the central (zero) beam + w = np.clip(inten[nonzero], 0.0, None) + wsum = float(w.sum()) + if wsum > 0: + rms_all = float(np.sqrt(np.sum(w * disp[nonzero] ** 2) / wsum)) + mask_weight[r, c] = float(np.clip(1.0 - rms_all / rms_rand, 0.0, 1.0)) + + self.u_array = u_array + self.v_array = v_array + self.mask_weight = mask_weight + self.fit_error = fit_error + self.metadata["fit"] = { + "min_num_peaks": int(min_num_peaks), + "max_peak_shift": float(max_peak_shift), + } + + if not (plot or returnfig): + return self + + fig, ax = plot_lattice_fit(mask_weight, fit_error) + if returnfig: + return fig, ax + return self + + def calculate_strain_map( + self, + u_ref: np.ndarray | None = None, + v_ref: np.ndarray | None = None, + mask: np.ndarray | None = None, + ) -> StrainMap: + """Build a :class:`StrainMap` from the fitted per-position lattice vectors. + + Parameters + ---------- + u_ref : np.ndarray, optional + ``(2,)`` reference for the first lattice vector. Defaults to the median + over the scan inside :class:`StrainMap`. + v_ref : np.ndarray, optional + ``(2,)`` reference for the second lattice vector. Defaults to the median + over the scan inside :class:`StrainMap`. + mask : np.ndarray, optional + ``(scan_row, scan_col)`` per-position weighting used when computing the + reference lattice. Defaults to :attr:`mask_weight` from + :meth:`fit_lattice` (the lattice order parameter — how well all detected + intensity snaps to the fitted lattice), so clean single-crystal positions + dominate the reference and positions with off-lattice intensity are + down-weighted. + + Returns + ------- + StrainMap + A strain map initialized from the fitted lattice vectors. + """ + if self.u_array is None or self.v_array is None: + raise ValueError("Run fit_lattice() before calculate_strain_map().") + + if mask is None: + mask = self.mask_weight + + ds_sampling = float(self.dataset.sampling[0]) + ds_units = str(self.dataset.units[0]) + + return StrainMap( + u_array=self.u_array, + v_array=self.v_array, + ds_shape=tuple(self.dataset.shape), + real_space=self.real_space, + u_ref=u_ref, + v_ref=v_ref, + mask=mask, + ds_sampling=ds_sampling, + ds_units=ds_units, + ) + + # ---- visualization ---- + + def show_template( + self, + position: tuple[int, int] = (0, 0), + *, + crop_factor: float | None = None, + returnfig: bool = False, + **kwargs, + ): + """Plot the mean diffraction pattern, the template, and one correlation map. + + Parameters + ---------- + position : tuple of int, default=(0, 0) + ``(row, col)`` scan position whose correlation map is shown. + crop_factor : float, optional + If given, zoom to a square window of half-width ``crop_factor * radius`` + about the central-beam center, where ``radius`` is the central-beam + radius (the synthetic template radius if known, else estimated from the + mean diffraction pattern). The mean-diffraction and correlation panels are + centered on the beam; the template panel on its own (fftshifted) center. + For example, ``crop_factor=2.0`` shows two beam radii either side of the + center. The window is clamped to the detector, so a large factor shows the + full image. ``None`` (default) shows the full panels. + returnfig : bool, default=False + If ``True``, return the ``(fig, ax)`` for further customization. + **kwargs + Extra keyword arguments forwarded to + :func:`~quantem.diffraction.bragg_vectors_visualization.plot_template`. + + Returns + ------- + tuple + ``(fig, ax)`` when ``returnfig=True``; otherwise nothing. + """ + if self._template is None: + raise ValueError("Run a make_template_* method before show_template().") + r, c = int(position[0]), int(position[1]) + dp_mean = np.asarray(self.dataset.dp_mean.array) + + crop = None + if crop_factor is not None: + # Center the mean-diffraction and correlation panels on the actual + # central-beam position rather than the geometric center (H//2, W//2): + # the unscattered beam -- and the correlation peak that matches it -- are + # generally offset by a few pixels from the detector center. + center, radius_est = estimate_central_beam(dp_mean) + radius = self.metadata.get("template", {}).get("radius") + if radius is None: + radius = radius_est + crop = (float(center[0]), float(center[1]), float(crop_factor) * float(radius)) + + fig, ax = plot_template( + dp_mean, + self.template, + self.correlation_map(r, c), + (r, c), + crop=crop, + **kwargs, + ) + if returnfig: + return fig, ax + + def show_diffraction( + self, + inds: list[tuple[int, int]], + image: np.ndarray | None = None, + *, + ncols: int = 4, + image_kwargs: dict | None = None, + marker_radius: float | None = None, + linewidth: float = 0.5, + sigma_plot: float | None = None, + returnfig: bool = False, + **show_kwargs, + ): + """Preview the diffraction patterns at ``inds``, optionally beside a navigation image. + + No detection is run; this is for choosing scan positions to tune on. The + patterns are tiled ``ncols`` wide and rendered with + :func:`~quantem.core.visualization.show_2d`; the navigation image is styled + independently (real space and reciprocal space rarely want the same + scaling). + + Parameters + ---------- + inds : list of tuple of int + ``(row, col)`` scan positions to preview. + image : np.ndarray or Dataset2d, optional + Real-space navigation image (e.g. a virtual dark-field image) shown on + the left with the positions marked. ``None`` (default) shows only the + diffraction tiles. + ncols : int, default=4 + Number of diffraction tiles per row. + image_kwargs : dict, optional + Keyword arguments styling the navigation image (passed to + :func:`~quantem.core.visualization.show_2d`). + marker_radius : float, optional + Radius in image pixels of the scan-position marker rings. + linewidth : float, default=0.5 + Stroke width of the scan-position markers. + sigma_plot : float, optional + Gaussian blur (sigma) applied to the *displayed* patterns only. + returnfig : bool, default=False + If ``True``, return the ``(fig, ax)`` for further customization. + **show_kwargs + Extra keyword arguments (e.g. ``norm``, ``cmap``, ``cbar``, ``axsize``) + styling the diffraction tiles via + :func:`~quantem.core.visualization.show_2d`. + + Returns + ------- + tuple + ``(fig, ax)`` when ``returnfig=True``; otherwise nothing. + """ + if image is None: + image_arr = None + image_title = "navigation image" + else: + image_title = getattr(image, "name", None) or "navigation image" + image_arr = np.asarray(image.array if hasattr(image, "array") else image) + dps = [np.asarray(self.dataset.array[r, c], dtype=float) for r, c in inds] + fig, ax = plot_diffraction_grid( + image_arr, + dps, + inds, + ncols=ncols, + image_title=image_title, + image_kwargs=image_kwargs, + marker_radius=marker_radius, + linewidth=linewidth, + sigma_plot=sigma_plot, + **show_kwargs, + ) + if returnfig: + return fig, ax + + def show_detection( + self, + positions: list[tuple[int, int]] | None = None, + *, + min_abs_intensity: float = 0.0, + min_spacing: float = 0.0, + edge_boundary: int = 1, + subpixel: str = "upsample", + upsample_factor: int = 16, + max_num_peaks: int = 1000, + image: np.ndarray | None = None, + peak_radius: float = 6.0, + marker_radius: float | None = None, + linewidth: float = 1.0, + sigma_plot: float | None = None, + image_kwargs: dict | None = None, + returnfig: bool = False, + **plot_kwargs, + ): + """Detect on a few patterns and overlay the peaks, for tuning hyperparameters. + + The detection keywords match :meth:`detect_disks`. The workflow state + (:attr:`peaks`/:attr:`bvm`) is left untouched, so this is safe to re-run + while tuning; for the raw peaks, call :meth:`detect_disks` with the same + ``positions``. + + Parameters + ---------- + positions : list of tuple of int, optional + ``(row, col)`` scan positions to detect on. ``None`` (default) + auto-samples four positions spread across the scan. + min_abs_intensity : float, default=0.0 + Drop correlation peaks below this absolute intensity. + min_spacing : float, default=0.0 + Minimum spacing in pixels between kept peaks. + edge_boundary : int, default=1 + Width in pixels of the border in which peaks are ignored. + subpixel : {"none", "parabolic", "upsample"}, default="upsample" + Subpixel refinement mode; see :meth:`detect_disks`. + upsample_factor : int, default=16 + Upsampling factor for the ``"upsample"`` subpixel refinement. + max_num_peaks : int, default=1000 + Maximum number of peaks to keep per pattern. + image : np.ndarray or Dataset2d, optional + Real-space navigation image (e.g. a virtual dark-field image) shown on + the left with the chosen positions marked. ``None`` (default) shows only + the diffraction tiles. + peak_radius : float, default=6.0 + Radius in diffraction pixels of the cyan rings drawn at detected peaks. + marker_radius : float, optional + Radius in image pixels of the scan-position marker rings. + linewidth : float, default=1.0 + Stroke width of the peak and scan-position rings. + sigma_plot : float, optional + Gaussian blur (sigma) applied to the *displayed* patterns only; + detection still uses the raw data. + image_kwargs : dict, optional + Keyword arguments styling the navigation image (passed to + :func:`~quantem.core.visualization.show_2d`). + returnfig : bool, default=False + If ``True``, return the ``(fig, ax)`` for further customization. + **plot_kwargs + Extra keyword arguments (e.g. ``ncols``, ``norm``, ``cmap``, ``cbar``, + ``axsize``) styling the diffraction tiles via + :func:`~quantem.core.visualization.show_2d`. + + Returns + ------- + tuple + ``(fig, ax)`` when ``returnfig=True``; otherwise nothing. + """ + if positions is None: + positions = self._sample_positions() + sub = self.detect_disks( + positions=positions, + min_abs_intensity=min_abs_intensity, + min_spacing=min_spacing, + edge_boundary=edge_boundary, + subpixel=subpixel, + upsample_factor=upsample_factor, + max_num_peaks=max_num_peaks, + progressbar=False, + ) + if image is None: + image_arr = None + image_title = "virtual image" + else: + image_title = getattr(image, "name", None) or "virtual image" + image_arr = np.asarray(image.array if hasattr(image, "array") else image) + dps = [np.asarray(self.dataset.array[r, c], dtype=float) for r, c in positions] + peaks = [sub[i].array for i in range(len(positions))] + fig, ax = plot_detection( + image_arr, + dps, + peaks, + positions, + peak_radius=peak_radius, + marker_radius=marker_radius, + linewidth=linewidth, + sigma_plot=sigma_plot, + image_title=image_title, + image_kwargs=image_kwargs, + **plot_kwargs, + ) + if returnfig: + return fig, ax + + def peak_histogram(self, *, returnfig: bool = False, **kwargs): + """Plot the Bragg vector map beside the per-position peak count. + + The Bragg vector map is the 2-D histogram of all detected peak positions + (intensity-weighted) accumulated over the scan; the second panel shows the + number of peaks detected at each scan position. + + Parameters + ---------- + returnfig : bool, default=False + If ``True``, return the ``(fig, ax)`` for further customization. + **kwargs + Extra keyword arguments forwarded to + :func:`~quantem.diffraction.bragg_vectors_visualization.plot_bvm`. + + Returns + ------- + tuple + ``(fig, ax)`` when ``returnfig=True``; otherwise nothing. + """ + if self.peaks is None or self.bvm is None: + raise ValueError("Run detect_disks() before peak_histogram().") + scan_r, scan_c = int(self.dataset.shape[0]), int(self.dataset.shape[1]) + # row_counts() reads cell_lengths directly; far cheaper than building a + # per-cell view (self.peaks[r, c]) just to read its row count. + counts = np.asarray(self.peaks.row_counts(), dtype=int).reshape(scan_r, scan_c) + fig, ax = plot_bvm(np.asarray(self.bvm.array), counts, **kwargs) + if returnfig: + return fig, ax + + # ---- helpers ---- + + def _set_template( + self, + probe: torch.Tensor, + center: tuple[float, float] | None, + subtract_mean: bool, + ) -> None: + """Validate the probe shape, then store the template and its conjugate FT. + + Parameters + ---------- + probe : torch.Tensor + Probe image; must match the diffraction-pattern shape. + center : tuple of float or None + ``(row, col)`` probe center rolled to the origin, or ``None`` for the + geometric center. + subtract_mean : bool + If ``True``, make the template zero-sum. + """ + dp_shape = tuple(self.dataset.shape[-2:]) + probe_t = torch.as_tensor(probe, dtype=torch.float, device=self.device) + if tuple(probe_t.shape) != dp_shape: + raise ValueError( + f"probe shape {tuple(probe_t.shape)} does not match diffraction pattern " + f"shape {dp_shape}." + ) + self._template = make_template(probe_t, center=center, subtract_mean=subtract_mean) + self._template_ft = template_fourier(self._template) + + def _sample_positions(self) -> list[tuple[int, int]]: + """Four scan positions spread across the field (quadrant centers). + + Returns + ------- + list of tuple of int + Up to four ``(row, col)`` scan positions at the quadrant centers. + """ + R, C = int(self.dataset.shape[0]), int(self.dataset.shape[1]) + rs = sorted({min(max(R // 4, 0), R - 1), min(max(3 * R // 4, 0), R - 1)}) + cs = sorted({min(max(C // 4, 0), C - 1), min(max(3 * C // 4, 0), C - 1)}) + return [(r, c) for r in rs for c in cs] + + def _detect_positions( + self, + coords: list[tuple[int, int]], + detect_kwargs: dict[str, Any], + batch_size: int | None, + *, + progressbar: bool, + ) -> list[NDArray]: + """Batched detection over a list of ``(row, col)`` scan positions. + + Patterns are stacked into chunks of ``batch_size`` and passed to + :func:`~quantem.diffraction.disk_detection.detect_disks_batch`. + + Parameters + ---------- + coords : list of tuple of int + ``(row, col)`` scan positions to detect on. + detect_kwargs : dict + Keyword arguments forwarded to + :func:`~quantem.diffraction.disk_detection.detect_disks_batch`. + batch_size : int or None + Number of patterns per batch. ``None`` picks a size from the detector + dimensions. + progressbar : bool + If ``True``, show a tqdm progress bar over the patterns. + + Returns + ------- + list of np.ndarray + One ``(M, 3)`` array of ``[q_row, q_col, intensity]`` per position, in + ``coords`` order. + """ + H, W = int(self.dataset.shape[-2]), int(self.dataset.shape[-1]) + if batch_size is None: + batch_size = int(min(1024, max(1, 16_000_000 // (H * W)))) + + it = range(0, len(coords), batch_size) + if progressbar: + try: + from tqdm.auto import tqdm + + bar = tqdm(total=len(coords), desc="detect_disks", leave=True) + except Exception: + bar = None + else: + bar = None + + results: list[NDArray] = [] + for start in it: + chunk = coords[start : start + batch_size] + dps = torch.stack( + [ + torch.as_tensor( + np.asarray(self.dataset.array[r, c]), + dtype=torch.float, + device=self.device, + ) + for r, c in chunk + ], + dim=0, + ) + out = detect_disks_batch(dps, self._template_ft, **detect_kwargs) + results.extend( + arr if arr.shape[0] else np.empty((0, len(PEAK_FIELDS)), dtype=float) + for arr in out + ) + if bar is not None: + bar.update(len(chunk)) + + if bar is not None: + bar.close() + return results + + def _bvm_candidates( + self, num_candidates: int, min_spacing: float, min_abs_intensity: float + ) -> tuple[NDArray, NDArray]: + """Find the brightest, well-separated local maxima in the Bragg vector map. + + Parameters + ---------- + num_candidates : int + Maximum number of candidate peaks to return. + min_spacing : float + Minimum spacing in pixels between candidates. + min_abs_intensity : float + Drop candidates below this absolute intensity. + + Returns + ------- + cand_rc : np.ndarray + ``(N, 2)`` ``[row, col]`` candidate positions, brightest first. + cand_int : np.ndarray + ``(N,)`` candidate intensities. + """ + from quantem.diffraction.disk_detection import _filter_maxima, _local_maxima + + bvm = torch.as_tensor(self.bvm.array, dtype=torch.float) + peaks = _local_maxima(bvm, edge_boundary=1) + peaks = _filter_maxima(peaks, min_abs_intensity, min_spacing, num_candidates) + arr = peaks.detach().cpu().numpy() + return arr[:, :2].astype(float), arr[:, 2].astype(float) + + +def _fractional_indices(q: NDArray, origin: NDArray, g1: NDArray, g2: NDArray) -> NDArray: + """Continuous (unrounded) ``(a, b)`` lattice coordinates of peaks ``q``. + + Solves ``B [a, b]^T = q - origin`` (least squares), with + ``B = [[g1_row, g2_row], [g1_col, g2_col]]``. + + Parameters + ---------- + q : np.ndarray + ``(N, 2)`` ``[row, col]`` peak positions. + origin : np.ndarray + ``(2,)`` lattice origin ``[row, col]``. + g1 : np.ndarray + ``(2,)`` first lattice vector ``[row, col]``. + g2 : np.ndarray + ``(2,)`` second lattice vector ``[row, col]``. + + Returns + ------- + np.ndarray + ``(N, 2)`` float array of continuous ``[a, b]`` coordinates. + """ + if q.shape[0] == 0: + return np.empty((0, 2), dtype=float) + beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]], dtype=float) + alpha = (q - origin[None, :]).T # (2, N) + return np.linalg.lstsq(beta, alpha, rcond=None)[0].T # (N, 2) + + +def _index_directions(q: NDArray, origin: NDArray, g1: NDArray, g2: NDArray) -> NDArray: + """Integer Miller indices for peaks ``q`` against basis ``(g1, g2)`` about ``origin``. + + Rounds the continuous coordinates from :func:`_fractional_indices`. + + Parameters + ---------- + q : np.ndarray + ``(N, 2)`` ``[row, col]`` peak positions. + origin : np.ndarray + ``(2,)`` lattice origin ``[row, col]``. + g1 : np.ndarray + ``(2,)`` first lattice vector ``[row, col]``. + g2 : np.ndarray + ``(2,)`` second lattice vector ``[row, col]``. + + Returns + ------- + np.ndarray + ``(N, 2)`` int array of ``[a, b]`` Miller indices. + """ + if q.shape[0] == 0: + return np.empty((0, 2), dtype=int) + return np.round(_fractional_indices(q, origin, g1, g2)).astype(int) + + +def _fit_lattice_vectors( + q_row: NDArray, + q_col: NDArray, + a: NDArray, + b: NDArray, + intensity: NDArray, +) -> tuple[NDArray | None, NDArray | None]: + """Intensity-weighted lattice fit ``q = x0 + a*g1 + b*g2`` for one pattern. + + Parameters + ---------- + q_row : np.ndarray + ``(N,)`` peak row positions. + q_col : np.ndarray + ``(N,)`` peak column positions. + a : np.ndarray + ``(N,)`` Miller index along ``g1``. + b : np.ndarray + ``(N,)`` Miller index along ``g2``. + intensity : np.ndarray + ``(N,)`` peak intensities, used as fit weights (``sqrt`` of the clamped + intensity). + + Returns + ------- + beta : np.ndarray or None + ``(3, 2)`` fit ``[x0; g1; g2]`` (row/col components per row), or ``None`` if + the fit is rank-deficient (e.g. all peaks share one lattice row). + rms : float + RMS fit residual in pixels (``nan`` when ``beta`` is ``None``). + """ + design = np.stack([np.ones_like(a), a, b], axis=1) # (N, 3) + target = np.stack([q_row, q_col], axis=1) # (N, 2) + w = np.sqrt(np.clip(intensity, 0.0, None))[:, None] + + if np.linalg.matrix_rank(design * w) < 3: + return None, float("nan") + + beta = np.linalg.lstsq(design * w, target * w, rcond=None)[0] # (3, 2): x0, g1, g2 + resid = target - design @ beta + rms = float(np.sqrt(np.mean(np.sum(resid**2, axis=1)))) + return beta, rms diff --git a/src/quantem/diffraction/bragg_vectors_visualization.py b/src/quantem/diffraction/bragg_vectors_visualization.py new file mode 100644 index 00000000..047927bc --- /dev/null +++ b/src/quantem/diffraction/bragg_vectors_visualization.py @@ -0,0 +1,776 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np + + +def plot_template( + dp_mean: np.ndarray, + template: np.ndarray, + corr_map: np.ndarray, + position: tuple[int, int], + *, + crop: tuple[float, float, float] | None = None, + figsize: tuple[float, float] = (13, 4), +): + """Mean diffraction pattern, the (centered) template, and one correlation map. + + Parameters + ---------- + dp_mean : np.ndarray + Mean diffraction pattern. + template : np.ndarray + The correlation template, centered for display. + corr_map : np.ndarray + Correlation map computed at ``position``. + position : tuple of int + ``(row, col)`` scan position the correlation map was computed at. + crop : tuple of float, optional + ``(center_row, center_col, half_width)`` zoom window (in pixels). The mean + diffraction and correlation panels are centered on ``(center_row, + center_col)`` -- the central-beam position -- while the template panel is + centered on its own array center (it is displayed fftshifted to there). The + view spans ``half_width`` either side of the center, clamped to each image's + bounds, so an over-large ``half_width`` just shows the full image. ``None`` + (default) shows the full panels. + figsize : tuple of float, default=(13, 4) + Figure size in inches. + + Returns + ------- + tuple + ``(fig, ax)`` with ``ax`` a length-3 array of axes. + """ + fig, ax = plt.subplots(1, 3, figsize=figsize) + ax[0].imshow(dp_mean, cmap="gray") + ax[0].set_title("mean diffraction") + ax[1].imshow(template, cmap="gray") + ax[1].set_title("template (centered)") + ax[2].imshow(corr_map, cmap="viridis") + ax[2].set_title(f"correlation @ {tuple(position)}") + for a in ax: + a.set_xticks([]) + a.set_yticks([]) + if crop is not None: + cr, cc, hw = float(crop[0]), float(crop[1]), float(crop[2]) + th, tw = template.shape[:2] + # The mean-diffraction beam and the correlation peak sit at the beam center + # (cr, cc); the displayed template is fftshifted to its own array center, so + # zoom that panel about its center instead. + panel_centers = ((cr, cc), (th / 2.0, tw / 2.0), (cr, cc)) + for a, img, (ecr, ecc) in zip(ax, (dp_mean, template, corr_map), panel_centers): + h, w = img.shape[:2] + a.set_xlim(max(ecc - hw, -0.5), min(ecc + hw, w - 0.5)) + a.set_ylim(min(ecr + hw, h - 0.5), max(ecr - hw, -0.5)) + fig.tight_layout() + return fig, ax + + +def _mark_positions(ax, positions, *, radius=None, linewidth=0.5): + """Overlay numbered red markers at each ``(row, col)`` scan position. + + Parameters + ---------- + ax : matplotlib.axes.Axes + Axis to draw the markers on. + positions : sequence of tuple of int + ``(row, col)`` scan positions to mark, numbered in order. + radius : float, optional + Marker radius in image pixels. If given, each marker is a ring of that + radius drawn as a circle patch in data coordinates; if ``None``, a fixed + screen-size scatter marker is used instead. + linewidth : float, default=0.5 + Ring stroke width. + """ + from matplotlib.patches import Circle + + for i, (r, c) in enumerate(positions): + if radius is None: + ax.scatter(c, r, s=70, facecolors="none", edgecolors="red", linewidths=linewidth) + else: + ax.add_patch( + Circle((c, r), radius=radius, fill=False, edgecolor="red", linewidth=linewidth) + ) + ax.annotate( + str(i), + (c, r), + color="red", + fontsize=11, + fontweight="bold", + xytext=(4, 4), + textcoords="offset points", + ) + + +def _blur(dp: np.ndarray, sigma: float | None) -> np.ndarray: + """Gaussian-blur a diffraction pattern for display only (passthrough if ``sigma`` falsy). + + Smoothing the *displayed* pattern (the detection still runs on the raw data) + makes it easier to judge by eye which correlation peaks sit on real disks and + which are noise. + + Parameters + ---------- + dp : np.ndarray + Diffraction pattern to blur. + sigma : float or None + Gaussian blur width in pixels. If falsy (``None``, ``0``, or negative), + ``dp`` is returned unchanged. + + Returns + ------- + np.ndarray + The blurred pattern, or ``dp`` unchanged when ``sigma`` is falsy. + """ + if not sigma or sigma <= 0: + return dp + from scipy.ndimage import gaussian_filter + + return gaussian_filter(np.asarray(dp, dtype=float), float(sigma)) + + +def _grid_axes(n, ncols, *, show_image, axsize=None, figsize=None): + """Figure with an optional left nav-image axis and a right ``nrows x ncols`` tile grid. + + With ``show_image`` a navigation-image axis is placed to the left of the tile + grid; otherwise the figure is just the grid. + + Parameters + ---------- + n : int + Number of tiles to lay out. + ncols : int + Number of tile columns, clamped to ``n``. + show_image : bool + If ``True``, add a navigation-image axis to the left of the tile grid. + axsize : tuple of float, optional + Per-tile size in inches, ``(w, h)``. Sizes the figure when ``figsize`` is + not given, so a larger ``axsize`` zooms every tile. + figsize : tuple of float, optional + Explicit figure size in inches; overrides the ``axsize``-derived size. + + Returns + ------- + tuple + ``(fig, ax_image, dp_axes, ncols)`` where ``ax_image`` is ``None`` when + ``show_image`` is ``False``, ``dp_axes`` is an ``(nrows, ncols)`` object + array of axes (trailing unused tiles already turned off), and ``ncols`` is + clamped to ``n``. + """ + ncols = max(1, min(ncols, n)) + nrows = int(np.ceil(n / ncols)) + + tile_w, tile_h = (float(axsize[0]), float(axsize[1])) if axsize is not None else (3.0, 3.2) + nav_w = (tile_w if axsize is not None else 4.0) if show_image else 0.0 + if figsize is None: + figsize = (nav_w + tile_w * ncols, max(tile_h, tile_h * nrows)) + + fig = plt.figure(figsize=figsize) + if show_image: + outer = fig.add_gridspec(1, 2, width_ratios=[nav_w, tile_w * ncols], wspace=0.12) + ax_image = fig.add_subplot(outer[0, 0]) + grid = outer[0, 1].subgridspec(nrows, ncols, wspace=0.08, hspace=0.2) + else: + ax_image = None + grid = fig.add_gridspec(nrows, ncols, wspace=0.08, hspace=0.2) + + dp_axes = np.empty((nrows, ncols), dtype=object) + for i in range(nrows * ncols): + a = fig.add_subplot(grid[i // ncols, i % ncols]) + dp_axes[i // ncols, i % ncols] = a + if i >= n: + a.axis("off") + return fig, ax_image, dp_axes, ncols + + +def plot_diffraction_grid( + image: np.ndarray | None, + dps: list[np.ndarray], + positions: list[tuple[int, int]], + *, + ncols: int = 4, + image_title: str = "navigation image", + image_kwargs: dict | None = None, + marker_radius: float | None = None, + linewidth: float = 0.5, + sigma_plot: float | None = None, + axsize: tuple[float, float] | None = None, + figsize: tuple[float, float] | None = None, + **show_kwargs, +): + """A tiled grid of the diffraction patterns at ``positions``, optionally beside a nav image. + + When ``image`` is given (e.g. a virtual dark-field image) it is drawn on the + left with each scan position marked as a numbered red marker at ``(x=col, + y=row)``; pass ``image=None`` to omit it and show only the grid. Right: the + diffraction patterns tiled ``ncols`` wide. Both are rendered with + :func:`~quantem.core.visualization.show_2d`, and the navigation image and the + diffraction tiles take separate, independent styling (``image_kwargs`` vs + ``show_kwargs``). + + Parameters + ---------- + image : np.ndarray or None + Navigation image (e.g. a virtual dark-field image) drawn at left, with the + scan positions marked. Pass ``None`` to show only the diffraction grid. + dps : list of np.ndarray + Diffraction patterns to tile, one per entry in ``positions``. + positions : list of tuple of int + ``(row, col)`` scan positions, used for the tile titles and the nav-image + markers. + ncols : int, default=4 + Number of columns in the diffraction-pattern tile grid. + image_title : str, default="navigation image" + Title for the navigation image. + image_kwargs : dict, optional + Extra keyword arguments (e.g. ``norm``, ``cmap``, ``scalebar``) for the + navigation image's :func:`show_2d` call, styling it independently of the + tiles. + marker_radius : float, optional + Scan-position marker radius in image pixels; ``None`` uses a fixed + screen-size marker. + linewidth : float, default=0.5 + Stroke width of the scan-position markers. + sigma_plot : float, optional + Gaussian blur width (pixels) applied to the *displayed* patterns only (the + data is untouched) to ease judging real features. + axsize : tuple of float, optional + Per-tile size in inches; a larger value zooms the tiles. + figsize : tuple of float, optional + Explicit figure size in inches. + **show_kwargs + Forwarded to :func:`show_2d` for the diffraction-pattern tiles (e.g. + ``norm``, ``cmap``, ``cbar``). + + Returns + ------- + tuple + ``(fig, (ax_image, dp_axes))`` with ``ax_image`` ``None`` when no image is + shown. + """ + from quantem.core.visualization import show_2d + + n = len(positions) + show_image = image is not None + fig, ax_image, dp_axes, ncols = _grid_axes( + n, ncols, show_image=show_image, axsize=axsize, figsize=figsize + ) + + if show_image: + image_show_kwargs = {"cmap": "gray", "title": image_title, **(image_kwargs or {})} + show_2d(np.asarray(image), figax=(fig, ax_image), **image_show_kwargs) + _mark_positions(ax_image, positions, radius=marker_radius, linewidth=linewidth) + + user_title = show_kwargs.pop("title", None) + for i in range(n): + r, c = positions[i] + title = user_title if user_title is not None else f"{i}: ({r},{c})" + show_2d( + _blur(dps[i], sigma_plot), + figax=(fig, dp_axes[i // ncols, i % ncols]), + title=title, + **show_kwargs, + ) + return fig, (ax_image, dp_axes) + + +def plot_detection( + image: np.ndarray | None, + dps: list[np.ndarray], + peaks: list[np.ndarray], + positions: list[tuple[int, int]], + *, + ncols: int = 4, + peak_radius: float = 6.0, + marker_radius: float | None = None, + linewidth: float = 0.5, + sigma_plot: float | None = None, + image_title: str = "virtual image", + image_kwargs: dict | None = None, + axsize: tuple[float, float] | None = None, + figsize: tuple[float, float] | None = None, + **show_kwargs, +): + """The diffraction patterns with detected peaks overlaid, optionally beside a nav image. + + When ``image`` is given (e.g. a virtual dark-field image) it is drawn on the + left with each chosen scan position drawn as a numbered red marker at ``(x=col, + y=row)``; pass ``image=None`` to omit it and show only the tiles. Right: the + diffraction patterns tiled ``ncols`` wide, each rendered with + :func:`~quantem.core.visualization.show_2d`, with the detected peaks overlaid as + cyan rings that trace the disks rather than obscuring them. + + Parameters + ---------- + image : np.ndarray or None + Navigation image (e.g. a virtual dark-field image) drawn at left, with the + scan positions marked. Pass ``None`` to show only the diffraction tiles. + dps : list of np.ndarray + Diffraction patterns to tile, one per entry in ``positions``. + peaks : list of np.ndarray + Detected peaks per pattern; ``peaks[i]`` is an ``(M, 3)`` array of + ``[q_row, q_col, intensity]``. Each peak is drawn as a cyan ring at + ``(x=q_col, y=q_row)``. + positions : list of tuple of int + ``(row, col)`` scan positions, used for the tile titles and the nav-image + markers. + ncols : int, default=4 + Number of columns in the diffraction-pattern tile grid. + peak_radius : float, default=6.0 + Radius of the cyan peak rings, in diffraction pixels. + marker_radius : float, optional + Scan-position marker radius in image pixels; ``None`` uses a fixed + screen-size marker. + linewidth : float, default=0.5 + Stroke width of both the scan-position markers and the peak rings. + sigma_plot : float, optional + Gaussian blur width (pixels) applied to the *displayed* patterns only + (detection still uses the raw data) to ease telling real disks from false + positives. + image_title : str, default="virtual image" + Title for the navigation image. + image_kwargs : dict, optional + Extra keyword arguments for the navigation image's :func:`show_2d` call, + styling it independently of the tiles. + axsize : tuple of float, optional + Per-tile size in inches; a larger value zooms the tiles. + figsize : tuple of float, optional + Explicit figure size in inches. + **show_kwargs + Forwarded to :func:`show_2d` for the diffraction-pattern tiles (e.g. + ``norm``, ``cmap``, ``cbar``). + + Returns + ------- + tuple + ``(fig, (ax_image, dp_axes))`` with ``ax_image`` ``None`` when no image is + shown. + """ + from matplotlib.patches import Circle + + from quantem.core.visualization import show_2d + + n = len(positions) + show_image = image is not None + fig, ax_image, dp_axes, ncols = _grid_axes( + n, ncols, show_image=show_image, axsize=axsize, figsize=figsize + ) + + if show_image: + image_show_kwargs = {"cmap": "gray", "title": image_title, **(image_kwargs or {})} + show_2d(np.asarray(image), figax=(fig, ax_image), **image_show_kwargs) + _mark_positions(ax_image, positions, radius=marker_radius, linewidth=linewidth) + + user_title = show_kwargs.pop("title", None) + for i in range(n): + a = dp_axes[i // ncols, i % ncols] + r, c = positions[i] + pk = peaks[i] + title = user_title if user_title is not None else f"{i}: ({r},{c}) n={pk.shape[0]}" + show_2d(_blur(dps[i], sigma_plot), figax=(fig, a), title=title, **show_kwargs) + for q in pk: + a.add_patch( + Circle( + (q[1], q[0]), + radius=peak_radius, + fill=False, + edgecolor="cyan", + linewidth=linewidth, + ) + ) + + return fig, (ax_image, dp_axes) + + +def plot_basis_vectors( + bvm: np.ndarray, + cand_rc: np.ndarray, + cand_int: np.ndarray, + origin: np.ndarray, + g1: np.ndarray, + g2: np.ndarray, + *, + cmap: str = "gray", + norm: str | dict = "log_auto", + zoom: bool = True, + figsize: tuple[float, float] = (6, 6), + **show_kwargs, +): + """The Bragg vector map with the candidate peaks, origin, and basis vectors overlaid. + + Every candidate peak is drawn as a numbered cyan ring; those numbers are the + indices accepted by + :meth:`~quantem.diffraction.bragg_vectors.BraggVectors.choose_basis_vectors` + for overriding ``origin``/``g1``/``g2`` by peak. The chosen origin is a green + marker and ``g1`` (red) / ``g2`` (blue) are arrows drawn from it, labelled at + their midpoints so the labels never sit on top of the candidate numbers. + + Parameters + ---------- + bvm : np.ndarray + Bragg vector map; rendered through + :func:`~quantem.core.visualization.show_2d` so its display scaling is + controlled by ``norm`` and ``show_kwargs``. + cand_rc : np.ndarray + ``(N, 2)`` ``[row, col]`` candidate peak positions, brightest first; the + ring labels are their row indices. + cand_int : np.ndarray + ``(N,)`` candidate intensities (unused for drawing; kept for parity with + the candidate API). + origin : np.ndarray + ``(row, col)`` chosen lattice origin. + g1 : np.ndarray + First lattice vector as a ``(row, col)`` offset from ``origin``. + g2 : np.ndarray + Second lattice vector as a ``(row, col)`` offset from ``origin``. + cmap : str, default="gray" + Colormap for the Bragg vector map; gray keeps the colored overlays legible. + norm : str or dict, default="log_auto" + Intensity scaling forwarded to :func:`show_2d` (e.g. ``"linear_auto"``, + ``"log_auto"``, ``"power_sqrt"``, or ``{"power": 0.5}``). + zoom : bool, default=True + If ``True``, frame the view to the candidate bounding box (plus a margin) + so the numbered peaks are large enough to read. + figsize : tuple of float, default=(6, 6) + Figure size in inches. + **show_kwargs + Extra keyword arguments forwarded to :func:`show_2d` for fine display + control (e.g. ``vmin``, ``vmax``, ``lower_quantile``, ``upper_quantile``). + + Returns + ------- + tuple + ``(fig, ax)``. + """ + import matplotlib.patheffects as path_effects + + from quantem.core.visualization import show_2d + + stroke = [path_effects.withStroke(linewidth=2.5, foreground="black")] + origin_color = (0.0, 0.7, 0.0) + g1_color = (1.0, 0.0, 0.0) + g2_color = (0.0, 0.7, 1.0) + + fig, ax = plt.subplots(figsize=figsize) + show_2d(np.asarray(bvm), figax=(fig, ax), cmap=cmap, norm=norm, **show_kwargs) + + cand_rc = np.asarray(cand_rc, dtype=float).reshape(-1, 2) + for i, (r, c) in enumerate(cand_rc): + ax.scatter(c, r, s=60, facecolors="none", edgecolors="cyan", linewidths=1.0, zorder=3) + ax.annotate( + str(i), + (c, r), + color="cyan", + fontsize=9, + fontweight="bold", + xytext=(4, 4), + textcoords="offset points", + path_effects=stroke, + zorder=4, + ) + + o = np.asarray(origin, dtype=float).reshape(2) + ax.scatter( + o[1], + o[0], + s=160, + marker="P", + facecolors=[origin_color], + edgecolors="white", + linewidths=1.5, + zorder=6, + ) + for g, label, color in ( + (np.asarray(g1, float), "g1", g1_color), + (np.asarray(g2, float), "g2", g2_color), + ): + tip = (o[1] + g[1], o[0] + g[0]) + ax.annotate( + "", + xy=tip, + xytext=(o[1], o[0]), + arrowprops=dict(arrowstyle="-|>", color=color, lw=2.4, shrinkA=0, shrinkB=0), + zorder=5, + ) + # Label at the arrow midpoint, nudged perpendicular to the shaft (in screen + # space) so it clears both the arrow and the candidate numbers on the peaks. + gnorm = float(np.hypot(g[0], g[1])) + 1e-12 + perp = (g[0] / gnorm * 15.0, g[1] / gnorm * 15.0) + mid = (o[1] + g[1] / 2.0, o[0] + g[0] / 2.0) + ax.annotate( + label, + mid, + color=color, + fontsize=14, + fontweight="bold", + xytext=perp, + textcoords="offset points", + ha="center", + va="center", + path_effects=stroke, + zorder=7, + ) + + if zoom and cand_rc.shape[0]: + rmin, cmin = cand_rc.min(axis=0) + rmax, cmax = cand_rc.max(axis=0) + margin = 0.12 * max(rmax - rmin, cmax - cmin, 1.0) + 6.0 + h, w = bvm.shape[:2] + ax.set_xlim(max(cmin - margin, -0.5), min(cmax + margin, w - 0.5)) + ax.set_ylim(min(rmax + margin, h - 0.5), max(rmin - margin, -0.5)) + + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title("lattice basis (origin +, g1, g2; cyan = candidate index)") + fig.tight_layout() + return fig, ax + + +def plot_bvm( + bvm: np.ndarray, + counts: np.ndarray, + *, + figsize: tuple[float, float] = (10, 4), +): + """The Bragg vector map (log-scaled) beside the per-position peak count. + + Parameters + ---------- + bvm : np.ndarray + Bragg vector map; displayed log-scaled (``log1p``). + counts : np.ndarray + Per-position peak count, shape ``(scan_row, scan_col)``. + figsize : tuple of float, default=(10, 4) + Figure size in inches. + + Returns + ------- + tuple + ``(fig, ax)`` with ``ax`` a length-2 array of axes. + """ + fig, ax = plt.subplots(1, 2, figsize=figsize) + ax[0].imshow(np.log1p(bvm), cmap="inferno") + ax[0].set_title("Bragg vector map (log)") + im = ax[1].imshow(counts, cmap="viridis") + ax[1].set_title("peaks per position") + for a in ax: + a.set_xticks([]) + a.set_yticks([]) + fig.colorbar(im, ax=ax[1], fraction=0.046, pad=0.04) + fig.tight_layout() + return fig, ax + + +def plot_reference_lattice( + bvm: np.ndarray, + ref_qpos: np.ndarray, + ref_ab: np.ndarray, + origin: np.ndarray, + g1: np.ndarray, + g2: np.ndarray, + *, + cmap: str = "gray", + norm: str | dict = "log_auto", + zoom: bool = True, + figsize: tuple[float, float] = (6, 6), + **show_kwargs, +): + """The reference lattice from :meth:`BraggVectors.index_peaks`, drawn over the BVM. + + Each indexed reference site is a ring labelled with its ``(a, b)`` Miller index; + the chosen origin is a green marker and ``g1`` (red) / ``g2`` (blue) are arrows + from it. The ring color encodes how far the picked candidate sits from its *ideal* + lattice site ``origin + a*g1 + b*g2`` (the colorbar reads pixels), scaled to half + the shorter lattice spacing — the default :meth:`fit_lattice` match radius. A + mis-picked or duplicate candidate stands out as a ring far from zero offset (bright + color), sitting off the regular grid, or carrying an index that breaks the pattern. + + Parameters + ---------- + bvm : np.ndarray + Bragg vector map; rendered through + :func:`~quantem.core.visualization.show_2d` so its display scaling is + controlled by ``norm`` and ``show_kwargs``. + ref_qpos : np.ndarray + ``(N, 2)`` ``[row, col]`` reference site positions. + ref_ab : np.ndarray + ``(N, 2)`` integer ``[a, b]`` Miller indices for each site. + origin : np.ndarray + ``(row, col)`` chosen lattice origin. + g1 : np.ndarray + First lattice vector as a ``(row, col)`` offset from ``origin``. + g2 : np.ndarray + Second lattice vector as a ``(row, col)`` offset from ``origin``. + cmap : str, default="gray" + Colormap for the Bragg vector map; gray keeps the colored overlays legible. + norm : str or dict, default="log_auto" + Intensity scaling forwarded to :func:`show_2d` (e.g. ``"linear_auto"``, + ``"log_auto"``, ``"power_sqrt"``, or ``{"power": 0.5}``). + zoom : bool, default=True + If ``True``, frame the view to the reference bounding box (plus a margin) + so the labelled sites are large enough to read. + figsize : tuple of float, default=(6, 6) + Figure size in inches. + **show_kwargs + Extra keyword arguments forwarded to :func:`show_2d` for fine display + control (e.g. ``vmin``, ``vmax``, ``lower_quantile``, ``upper_quantile``). + + Returns + ------- + tuple + ``(fig, ax)``. + """ + import matplotlib.patheffects as path_effects + from matplotlib.cm import ScalarMappable + from matplotlib.colors import Normalize + + from quantem.core.visualization import show_2d + + stroke = [path_effects.withStroke(linewidth=2.5, foreground="black")] + origin_color = (0.0, 0.7, 0.0) + g1_color = (1.0, 0.0, 0.0) + g2_color = (0.0, 0.7, 1.0) + + ref_qpos = np.asarray(ref_qpos, dtype=float).reshape(-1, 2) + ref_ab = np.asarray(ref_ab, dtype=int).reshape(-1, 2) + o = np.asarray(origin, dtype=float).reshape(2) + g1 = np.asarray(g1, dtype=float).reshape(2) + g2 = np.asarray(g2, dtype=float).reshape(2) + + fig, ax = plt.subplots(figsize=figsize) + show_2d(np.asarray(bvm), figax=(fig, ax), cmap=cmap, norm=norm, **show_kwargs) + + # offset of each picked candidate from its ideal lattice site origin + a*g1 + b*g2; + # this is the QC indicator -- a mis-picked or strongly strained candidate rings + # far from zero, while a clean pick rings near it. + ideal_qpos = o[None, :] + ref_ab[:, 0:1] * g1[None, :] + ref_ab[:, 1:2] * g2[None, :] + offset = np.linalg.norm(ref_qpos - ideal_qpos, axis=1) + + # color the rings by that offset, scaled to half the shorter lattice spacing (the + # default fit_lattice match radius) so the colorbar previews which candidates sit + # near the inclusion tolerance. + radius = 0.5 * float(min(np.hypot(*g1), np.hypot(*g2))) + cmap_offset = plt.get_cmap("plasma") + norm_offset = Normalize(vmin=0.0, vmax=radius if radius > 0 else 1.0) + + ax.scatter( + ref_qpos[:, 1], + ref_qpos[:, 0], + s=80, + facecolors="none", + edgecolors=cmap_offset(norm_offset(offset)), + linewidths=1.8, + zorder=3, + ) + for (r, c), (a, b) in zip(ref_qpos, ref_ab): + ax.annotate( + f"{int(a)},{int(b)}", + (c, r), + color="white", + fontsize=9, + fontweight="bold", + xytext=(4, 4), + textcoords="offset points", + path_effects=stroke, + zorder=4, + ) + + ax.scatter( + o[1], + o[0], + s=160, + marker="P", + facecolors=[origin_color], + edgecolors="white", + linewidths=1.5, + zorder=6, + ) + for g, label, color in ((g1, "g1", g1_color), (g2, "g2", g2_color)): + ax.annotate( + "", + xy=(o[1] + g[1], o[0] + g[0]), + xytext=(o[1], o[0]), + arrowprops=dict(arrowstyle="-|>", color=color, lw=2.4, shrinkA=0, shrinkB=0), + zorder=5, + ) + gnorm = float(np.hypot(g[0], g[1])) + 1e-12 + ax.annotate( + label, + (o[1] + g[1] / 2.0, o[0] + g[0] / 2.0), + color=color, + fontsize=13, + fontweight="bold", + xytext=(g[0] / gnorm * 15.0, g[1] / gnorm * 15.0), + textcoords="offset points", + ha="center", + va="center", + path_effects=stroke, + zorder=7, + ) + + if zoom and ref_qpos.shape[0]: + rmin, cmin = ref_qpos.min(axis=0) + rmax, cmax = ref_qpos.max(axis=0) + margin = 0.12 * max(rmax - rmin, cmax - cmin, 1.0) + 6.0 + h, w = bvm.shape[:2] + ax.set_xlim(max(cmin - margin, -0.5), min(cmax + margin, w - 0.5)) + ax.set_ylim(min(rmax + margin, h - 0.5), max(rmin - margin, -0.5)) + + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title("reference lattice (ring color = offset from ideal index; +origin, g1, g2)") + + sm = ScalarMappable(norm=norm_offset, cmap=cmap_offset) + sm.set_array([]) + cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04, extend="max") + cbar.set_label("peak offset from ideal index (px)") + + fig.tight_layout() + return fig, ax + + +def plot_lattice_fit( + mask_weight: np.ndarray, + fit_error: np.ndarray, + *, + figsize: tuple[float, float] = (11, 4.5), +): + """Per-position diagnostics from :meth:`BraggVectors.fit_lattice`. + + Left: the mask weight — a lattice *order parameter* per position (how well all + detected intensity snaps to the fitted lattice, intensity-weighted) — so ``0`` is + a position dominated by off-lattice intensity and ``1`` a clean single crystal. + This is the weighting handed to the strain reference. Right: the RMS lattice-fit + residual over the matched peaks, in pixels (low = a clean fit; high = a poorly + fit, overlapping, or strongly strained position). + + Parameters + ---------- + mask_weight : np.ndarray + ``(scan_row, scan_col)`` lattice-order-parameter weight in ``[0, 1]``. + fit_error : np.ndarray + ``(scan_row, scan_col)`` RMS fit residual in pixels (``nan`` where no fit + was made). + figsize : tuple of float, default=(11, 4.5) + Figure size in inches. + + Returns + ------- + tuple + ``(fig, ax)`` with ``ax`` a length-2 array of axes. + """ + fig, ax = plt.subplots(1, 2, figsize=figsize) + + im0 = ax[0].imshow(np.asarray(mask_weight), cmap="viridis", vmin=0.0, vmax=1.0) + ax[0].set_title("mask weight (lattice order)") + fig.colorbar(im0, ax=ax[0], fraction=0.046, pad=0.04) + + im1 = ax[1].imshow(np.asarray(fit_error), cmap="magma") + ax[1].set_title("fit RMS error (px)") + fig.colorbar(im1, ax=ax[1], fraction=0.046, pad=0.04) + + for a in ax: + a.set_xticks([]) + a.set_yticks([]) + fig.tight_layout() + return fig, ax diff --git a/src/quantem/diffraction/disk_detection.py b/src/quantem/diffraction/disk_detection.py new file mode 100644 index 00000000..05c5600c --- /dev/null +++ b/src/quantem/diffraction/disk_detection.py @@ -0,0 +1,1093 @@ +from __future__ import annotations + +import numpy as np +import torch + +SUBPIXEL_MODES = ("none", "parabolic", "upsample") + + +def make_template( + probe: torch.Tensor, + center: tuple[float, float] | None = None, + subtract_mean: bool = False, +) -> torch.Tensor: + """Build a cross-correlation template from a (vacuum) probe image. + + The probe is normalized to unit sum and rolled so its center sits at the array + origin ``[0, 0]`` (FFT corner), so correlation peaks land at absolute disk + positions. + + Parameters + ---------- + probe : torch.Tensor + ``(H, W)`` probe / vacuum disk image. + center : tuple of float, optional + ``(row, col)`` probe center rolled to the origin; defaults to the geometric + center ``(H // 2, W // 2)``. + subtract_mean : bool, default=False + If ``True``, make the template zero-sum — a band-pass kernel that suppresses + the uniform background in the correlation. + + Returns + ------- + torch.Tensor + ``(H, W)`` template, corner-centered (and zero-sum when ``subtract_mean``). + """ + probe = torch.as_tensor(probe) + total = probe.sum() + if total != 0: + probe = probe / total + + H, W = probe.shape + if center is None: + cr, cc = H // 2, W // 2 + else: + cr, cc = int(round(float(center[0]))), int(round(float(center[1]))) + + template = torch.roll(probe, shifts=(-cr, -cc), dims=(0, 1)) + if subtract_mean: + template = template - template.mean() + return template + + +def synthetic_probe( + shape: tuple[int, int], + radius: float, + edge: float = 1.0, + center: tuple[float, float] | None = None, +) -> torch.Tensor: + """Soft-edged disk for a synthetic correlation template. + + Returns ``0.5 - 0.5*tanh((r - radius)/edge)``: a disk of the given ``radius`` + (pixels) with a ``tanh`` falloff over ``edge`` pixels. + + Parameters + ---------- + shape : tuple of int + ``(H, W)`` output shape in pixels. + radius : float + Disk radius in pixels. + edge : float, default=1.0 + Width in pixels of the ``tanh`` edge falloff. + center : tuple of float, optional + ``(row, col)`` disk center; defaults to the geometric center + ``((H - 1) / 2, (W - 1) / 2)``. + + Returns + ------- + torch.Tensor + ``(H, W)`` soft-edged disk image. + """ + H, W = int(shape[0]), int(shape[1]) + if center is None: + cr, cc = (H - 1) / 2.0, (W - 1) / 2.0 + else: + cr, cc = float(center[0]), float(center[1]) + rows = torch.arange(H, dtype=torch.float).view(H, 1) + cols = torch.arange(W, dtype=torch.float).view(1, W) + rr = torch.sqrt((rows - cr) ** 2 + (cols - cc) ** 2) + edge = max(float(edge), 1e-6) + return 0.5 - 0.5 * torch.tanh((rr - float(radius)) / edge) + + +def _central_blob(image: torch.Tensor, threshold: float) -> tuple[np.ndarray | None, np.ndarray]: + """Mask of the connected bright region containing the brightest pixel. + + Thresholds ``image`` at ``threshold * max`` and keeps only the connected + component holding the brightest pixel — normally the central (unscattered) + disk of a mean diffraction pattern or vacuum probe — so other diffracted disks + are excluded. + + Parameters + ---------- + image : torch.Tensor + ``(H, W)`` image, e.g. a mean diffraction pattern or vacuum probe. + threshold : float + Fraction of the peak intensity (after min-subtraction) used to threshold + the image before connected-component labeling. + + Returns + ------- + blob_mask : np.ndarray or None + Boolean ``(H, W)`` mask of the central component, or ``None`` for an empty + / flat image. + img : np.ndarray + The min-subtracted image. + """ + from scipy import ndimage + + img = np.asarray(torch.as_tensor(image, dtype=torch.float).detach().cpu()) + img = img - img.min() + peak = float(img.max()) + if peak <= 0: + return None, img + labels, n = ndimage.label(img >= threshold * peak) + if n == 0: + return None, img + peak_label = int(labels[np.unravel_index(int(np.argmax(img)), img.shape)]) + return labels == peak_label, img + + +def estimate_central_beam( + image: torch.Tensor, + threshold: float = 0.5, + plot_result: bool = False, + **kwargs, +) -> tuple[tuple[float, float], float]: + """Center ``(row, col)`` and radius (pixels) of the central (direct) beam. + + Locates the connected bright region containing the brightest pixel (see + :func:`_central_blob`) — the unscattered / direct beam of a mean diffraction + pattern — and returns its intensity-weighted center together with an + area-equivalent radius (``A = pi r^2``). Other diffracted disks are excluded, so + the estimate holds whether the pattern shows one disk or many. Falls back to the + geometric center and unit radius for an empty / flat image. + + Parameters + ---------- + image : torch.Tensor or Dataset2d + ``(H, W)`` image — a raw array / tensor or a :class:`Dataset2d` (e.g. + ``dataset.dp_mean``). + threshold : float, default=0.5 + Fraction of the peak intensity used to threshold the image when isolating + the central beam (see :func:`_central_blob`). + plot_result : bool, default=False + If ``True``, show the image in greyscale with the fitted beam drawn as a red + circle. + **kwargs + Extra keyword arguments (e.g. ``norm``, ``cbar``, ``scalebar``) forwarded to + :func:`~quantem.core.visualization.show_2d` when ``plot_result=True``. + + Returns + ------- + center : tuple of float + ``(row, col)`` intensity-weighted center of the central beam. + radius : float + Area-equivalent radius in pixels (``A = pi r^2``). + """ + arr = image.array if hasattr(image, "array") else image + blob, img = _central_blob(arr, threshold) + if blob is None: + center = (img.shape[0] / 2.0, img.shape[1] / 2.0) + radius = 1.0 + else: + radius = float(np.sqrt(max(float(blob.sum()), 1.0) / np.pi)) + w = img * blob + total = float(w.sum()) + if total <= 0: + rr, cc = np.nonzero(blob) + center = (float(rr.mean()), float(cc.mean())) + else: + rows = np.arange(img.shape[0])[:, None] + cols = np.arange(img.shape[1])[None, :] + center = (float((w * rows).sum() / total), float((w * cols).sum() / total)) + + if plot_result: + from matplotlib.patches import Circle + + from quantem.core.visualization import show_2d + + show_kwargs = {"cmap": "gray", "title": "central beam", **kwargs} + _fig, ax = show_2d(arr, **show_kwargs) + ax.add_patch( + Circle((center[1], center[0]), radius, fill=False, edgecolor="red", linewidth=1.5) + ) + + return center, radius + + +def probe_centroid(probe: torch.Tensor) -> tuple[float, float]: + """Intensity-weighted ``(row, col)`` centroid of a probe image. + + Parameters + ---------- + probe : torch.Tensor + ``(H, W)`` probe image. Negative values are clamped to zero before + weighting. + + Returns + ------- + tuple of float + ``(row, col)`` intensity-weighted centroid; the geometric center for a + non-positive image. + """ + p = torch.clamp(torch.as_tensor(probe, dtype=torch.float), min=0.0) + total = p.sum() + if total <= 0: + return (p.shape[0] / 2.0, p.shape[1] / 2.0) + rows = torch.arange(p.shape[0], dtype=torch.float).view(-1, 1) + cols = torch.arange(p.shape[1], dtype=torch.float).view(1, -1) + return (float((p * rows).sum() / total), float((p * cols).sum() / total)) + + +def template_fourier(template: torch.Tensor) -> torch.Tensor: + """Pre-compute the conjugate FT of a template for repeated correlation. + + Parameters + ---------- + template : torch.Tensor + ``(H, W)`` corner-centered correlation template. + + Returns + ------- + torch.Tensor + ``(H, W)`` complex ``conj(fft2(template))``, ready to multiply against + ``fft2(dp)``. + """ + return torch.conj(torch.fft.fft2(template)) + + +def cross_correlation( + dp: torch.Tensor, + template_ft: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Cross-correlate a diffraction pattern with a template. + + Parameters + ---------- + dp : torch.Tensor + ``(H, W)`` diffraction pattern. + template_ft : torch.Tensor + ``(H, W)`` pre-computed template FT from :func:`template_fourier`. + + Returns + ------- + corr_map : torch.Tensor + ``(H, W)`` real-space correlation map ``relu(real(ifft2(m)))`` (used for + peak finding). + m : torch.Tensor + ``(H, W)`` Fourier-domain product ``fft2(dp) * template_ft`` (used for DFT + subpixel refinement). + """ + dp = torch.as_tensor(dp) + m = torch.fft.fft2(dp) * template_ft + corr_map = torch.clamp(torch.fft.ifft2(m).real, min=0.0) + return corr_map, m + + +def cross_correlation_batch( + dps: torch.Tensor, + template_ft: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Cross-correlate a stack of diffraction patterns with one template. + + Batched form of :func:`cross_correlation`. ``fft2`` acts on the trailing two + axes and the ``(H, W)`` ``template_ft`` broadcasts over the batch, so the result + for each pattern is bit-identical to :func:`cross_correlation`. + + Parameters + ---------- + dps : torch.Tensor + ``(B, H, W)`` stack of diffraction patterns. + template_ft : torch.Tensor + ``(H, W)`` pre-computed template FT from :func:`template_fourier`. + + Returns + ------- + corr_map : torch.Tensor + ``(B, H, W)`` real-space correlation maps. + m : torch.Tensor + ``(B, H, W)`` Fourier-domain products. + """ + dps = torch.as_tensor(dps) + m = torch.fft.fft2(dps) * template_ft + corr_map = torch.clamp(torch.fft.ifft2(m).real, min=0.0) + return corr_map, m + + +def _corr_map_rfft(dps: torch.Tensor, template_ft: torch.Tensor) -> torch.Tensor: + """Real-FFT correlation map(s), used when no Fourier product is needed downstream. + + For real ``dps`` and a real template the Fourier product is conjugate-symmetric, + so ``relu(real(ifft2(fft2(dps) * template_ft)))`` is reproduced exactly by an + ``rfft2`` / ``irfft2`` pair at roughly half the FFT cost. Only the correlation map + is returned — the complex product (needed solely for DFT upsampling) is skipped. + + Parameters + ---------- + dps : torch.Tensor + ``(H, W)`` or ``(B, H, W)`` diffraction pattern(s). + template_ft : torch.Tensor + ``(H, W)`` pre-computed template FT from :func:`template_fourier`. + + Returns + ------- + torch.Tensor + ``(H, W)`` or ``(B, H, W)`` real-space correlation map(s), ``relu``-clamped. + """ + dps = torch.as_tensor(dps) + H, W = dps.shape[-2], dps.shape[-1] + prod = torch.fft.rfft2(dps) * template_ft[..., : W // 2 + 1] + corr_map = torch.fft.irfft2(prod, s=(H, W)) + return torch.clamp(corr_map, min=0.0) + + +def detect_disks( + dp: torch.Tensor, + template_ft: torch.Tensor, + *, + min_abs_intensity: float = 0.0, + min_spacing: float = 0.0, + edge_boundary: int = 1, + subpixel: str = "upsample", + upsample_factor: int = 16, + max_num_peaks: int = 1000, +) -> np.ndarray: + """Detect Bragg disks in one diffraction pattern by template matching. + + Parameters + ---------- + dp : torch.Tensor + ``(H, W)`` diffraction pattern. + template_ft : torch.Tensor + ``(H, W)`` pre-computed template FT from :func:`template_fourier`. + min_abs_intensity : float, default=0.0 + Drop correlation peaks below this absolute intensity. + min_spacing : float, default=0.0 + Minimum spacing in pixels between kept peaks; closer / dimmer peaks are + suppressed. + edge_boundary : int, default=1 + Width in pixels of the border in which peaks are ignored. + subpixel : {"none", "parabolic", "upsample"}, default="upsample" + ``"none"`` returns pixel-resolution peaks; ``"parabolic"`` adds a 3-point + quadratic refinement; ``"upsample"`` further refines each peak by + Guizar-Sicairos DFT upsampling. + upsample_factor : int, default=16 + Upsampling factor for the ``"upsample"`` subpixel refinement. + max_num_peaks : int, default=1000 + Maximum number of peaks to keep (after intensity sorting). + + Returns + ------- + np.ndarray + ``(M, 3)`` array of ``[q_row, q_col, intensity]`` rows, sorted by descending + intensity. + """ + if subpixel not in SUBPIXEL_MODES: + raise ValueError(f"subpixel must be in {SUBPIXEL_MODES}, got {subpixel!r}") + + corr_map, m = cross_correlation(dp, template_ft) + + peaks = _local_maxima(corr_map, edge_boundary) + peaks = _filter_maxima(peaks, min_abs_intensity, min_spacing, max_num_peaks) + + if peaks.shape[0] == 0 or subpixel == "none": + return _to_numpy(peaks) + + peaks = _refine_parabolic(corr_map, peaks) + + if subpixel == "parabolic": + return _to_numpy(peaks) + + peaks = _refine_dft(m, peaks, upsample_factor) + return _to_numpy(peaks) + + +def detect_disks_batch( + dps: torch.Tensor, + template_ft: torch.Tensor, + *, + min_abs_intensity: float = 0.0, + min_spacing: float = 0.0, + edge_boundary: int = 1, + subpixel: str = "upsample", + upsample_factor: int = 16, + max_num_peaks: int = 1000, +) -> list[np.ndarray]: + """Detect Bragg disks across a stack of diffraction patterns (batched). + + Batched equivalent of :func:`detect_disks`: cross-correlation, peak extraction, + ``min_spacing`` suppression, and subpixel refinement are all batched across + patterns. The local-maxima search and greedy ``min_spacing`` suppression are + vectorized over the whole stack (no per-pattern Python loop) but reproduce the + per-pattern greedy result bit-for-bit, so each output matches :func:`detect_disks` + for that pattern. When ``subpixel`` is not ``"upsample"`` the correlation maps are + formed with a real FFT (``rfft2``), skipping the complex Fourier product that only + DFT upsampling needs. + + Parameters + ---------- + dps : torch.Tensor + ``(B, H, W)`` stack of diffraction patterns. + template_ft : torch.Tensor + ``(H, W)`` pre-computed template FT from :func:`template_fourier`. + min_abs_intensity : float, default=0.0 + Drop correlation peaks below this absolute intensity. + min_spacing : float, default=0.0 + Minimum spacing in pixels between kept peaks; closer / dimmer peaks are + suppressed. + edge_boundary : int, default=1 + Width in pixels of the border in which peaks are ignored. + subpixel : {"none", "parabolic", "upsample"}, default="upsample" + ``"none"`` returns pixel-resolution peaks; ``"parabolic"`` adds a 3-point + quadratic refinement; ``"upsample"`` further refines each peak by + Guizar-Sicairos DFT upsampling. + upsample_factor : int, default=16 + Upsampling factor for the ``"upsample"`` subpixel refinement. + max_num_peaks : int, default=1000 + Maximum number of peaks to keep per pattern (after intensity sorting). + + Returns + ------- + list of np.ndarray + Length-``B`` list of ``(M, 3)`` arrays of ``[q_row, q_col, intensity]`` + rows, each sorted by descending intensity. + """ + if subpixel not in SUBPIXEL_MODES: + raise ValueError(f"subpixel must be in {SUBPIXEL_MODES}, got {subpixel!r}") + + if subpixel == "upsample": + corr_map, m = cross_correlation_batch(dps, template_ft) + else: + corr_map = _corr_map_rfft(dps, template_ft) + m = None + + peaks_all, bidx, counts = _detect_peaks_batched( + corr_map, edge_boundary, min_abs_intensity, min_spacing, max_num_peaks + ) + + if subpixel == "none" or peaks_all.shape[0] == 0: + return _split_by_counts(peaks_all, counts) + + peaks_all = _refine_parabolic_batched(corr_map, peaks_all, bidx) + if subpixel == "upsample": + peaks_all = _refine_dft_batched(m, peaks_all, bidx, upsample_factor) + + return _split_by_counts(peaks_all, counts) + + +# ---- helpers ---- + + +def _local_maxima(corr_map: torch.Tensor, edge_boundary: int) -> torch.Tensor: + """Find 8-neighbor local maxima, sorted by descending intensity. + + Parameters + ---------- + corr_map : torch.Tensor + ``(H, W)`` correlation map. + edge_boundary : int + Width in pixels of the border in which maxima are ignored. + + Returns + ------- + torch.Tensor + ``(K, 3)`` tensor of ``[row, col, intensity]`` maxima, sorted by descending + intensity. + """ + is_max = _local_maxima_mask(corr_map, edge_boundary) + return _extract_maxima(corr_map, is_max) + + +def _local_maxima_mask(a: torch.Tensor, edge_boundary: int) -> torch.Tensor: + """Boolean 8-neighbor local-maxima mask of ``a``. + + Works on a single ``(H, W)`` map or a batch ``(B, H, W)`` — the neighbor + comparisons and edge masking use the trailing two axes — so the same code drives + the single-pattern and batched detection paths bit-identically. + + Parameters + ---------- + a : torch.Tensor + ``(H, W)`` or ``(B, H, W)`` correlation map(s). + edge_boundary : int + Width in pixels of the border (clamped to at least 1) set to ``False``. + + Returns + ------- + torch.Tensor + Boolean mask the same shape as ``a``, ``True`` at 8-neighbor local maxima. + """ + is_max = ( + (a >= torch.roll(a, (-1, 0), dims=(-2, -1))) + & (a > torch.roll(a, (1, 0), dims=(-2, -1))) + & (a >= torch.roll(a, (0, -1), dims=(-2, -1))) + & (a > torch.roll(a, (0, 1), dims=(-2, -1))) + & (a >= torch.roll(a, (-1, -1), dims=(-2, -1))) + & (a > torch.roll(a, (-1, 1), dims=(-2, -1))) + & (a >= torch.roll(a, (1, -1), dims=(-2, -1))) + & (a > torch.roll(a, (1, 1), dims=(-2, -1))) + ) + + eb = max(1, int(edge_boundary)) + is_max[..., :eb, :] = False + is_max[..., -eb:, :] = False + is_max[..., :, :eb] = False + is_max[..., :, -eb:] = False + return is_max + + +def _extract_maxima(a: torch.Tensor, is_max: torch.Tensor) -> torch.Tensor: + """Gather masked maxima of one ``(H, W)`` map into a descending-sorted ``(K, 3)``. + + Rows are taken in row-major ``nonzero`` order, then stably sorted by descending + intensity — matching the original single-pattern behaviour. + + Parameters + ---------- + a : torch.Tensor + ``(H, W)`` correlation map. + is_max : torch.Tensor + Boolean ``(H, W)`` local-maxima mask from :func:`_local_maxima_mask`. + + Returns + ------- + torch.Tensor + ``(K, 3)`` tensor of ``[row, col, intensity]`` maxima, sorted by descending + intensity. + """ + rows, cols = torch.nonzero(is_max, as_tuple=True) + intensity = a[rows, cols] + order = torch.argsort(intensity, descending=True) + return torch.stack((rows[order].to(a.dtype), cols[order].to(a.dtype), intensity[order]), dim=1) + + +def _filter_maxima( + peaks: torch.Tensor, + min_abs_intensity: float, + min_spacing: float, + max_num_peaks: int, +) -> torch.Tensor: + """Drop dim peaks, suppress peaks closer than ``min_spacing``, cap the count. + + Parameters + ---------- + peaks : torch.Tensor + ``(K, 3)`` ``[row, col, intensity]`` peaks, sorted by descending intensity. + min_abs_intensity : float + Drop peaks below this absolute intensity (ignored when ``<= 0``). + min_spacing : float + Minimum spacing in pixels; for each kept peak, dimmer peaks within this + distance are suppressed (ignored when ``<= 0``). + max_num_peaks : int + Maximum number of peaks to keep; the brightest are retained. + + Returns + ------- + torch.Tensor + ``(M, 3)`` filtered peaks. + """ + if peaks.shape[0] == 0: + return peaks + + if min_abs_intensity > 0: + peaks = peaks[peaks[:, 2] >= min_abs_intensity] + + if min_spacing > 0 and peaks.shape[0] > 1: + keep = torch.ones(peaks.shape[0], dtype=torch.bool, device=peaks.device) + rc = peaks[:, :2] + for i in range(peaks.shape[0]): + if not keep[i]: + continue + d2 = ((rc - rc[i]) ** 2).sum(dim=1) + too_close = d2 < min_spacing**2 + too_close[: i + 1] = False + keep[too_close] = False + peaks = peaks[keep] + + if max_num_peaks is not None and peaks.shape[0] > max_num_peaks: + peaks = peaks[:max_num_peaks] + + return peaks + + +def _detect_peaks_batched( + corr: torch.Tensor, + edge_boundary: int, + min_abs_intensity: float, + min_spacing: float, + max_num_peaks: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Extract, suppress, and cap peaks across a whole ``(B, H, W)`` stack at once. + + Vectorized replacement for the per-pattern ``_extract_maxima`` + ``_filter_maxima`` + loop. Local maxima are found and intensity-thresholded for the whole stack, then + the greedy ``min_spacing`` suppression runs as a single loop over the *rank* axis + (brightest-first) that is batched over every pattern. Suppressing peak ``k``'s + fainter neighbours only when ``k`` is still kept reproduces the per-pattern greedy + result of :func:`_filter_maxima` exactly, with no Python loop over patterns. + + Parameters + ---------- + corr : torch.Tensor + ``(B, H, W)`` correlation maps. + edge_boundary : int + Width in pixels of the border in which maxima are ignored. + min_abs_intensity : float + Drop peaks below this absolute intensity (ignored when ``<= 0``). + min_spacing : float + Minimum spacing in pixels; for each kept peak, fainter peaks within this + distance are suppressed (ignored when ``<= 0``). + max_num_peaks : int + Maximum number of peaks kept per pattern (the brightest are retained). + + Returns + ------- + peaks : torch.Tensor + ``(T, 3)`` ``[row, col, intensity]`` peaks pooled over all patterns, grouped + by ascending pattern index and sorted by descending intensity within a pattern. + bidx : torch.Tensor + ``(T,)`` pattern index of each peak. + counts : torch.Tensor + ``(B,)`` number of peaks kept per pattern. + """ + B = corr.shape[0] + device = corr.device + dtype = corr.dtype + + mask = _local_maxima_mask(corr, edge_boundary) + if min_abs_intensity > 0: + mask = mask & (corr >= min_abs_intensity) + + idx = torch.nonzero(mask) # (T0, 3): [b, row, col] + counts = torch.zeros(B, dtype=torch.long, device=device) + if idx.shape[0] == 0: + return corr.new_zeros((0, 3)), torch.zeros(0, dtype=torch.long, device=device), counts + + bcand, rcand, ccand = idx[:, 0], idx[:, 1], idx[:, 2] + inten = corr[bcand, rcand, ccand] + + # Group candidates by pattern, descending intensity within each pattern. A global + # descending-intensity sort followed by a stable sort on the pattern index keeps the + # within-pattern order identical to the per-pattern argsort in _extract_maxima. + o1 = torch.argsort(inten, descending=True) + order = o1[torch.argsort(bcand[o1], stable=True)] + bcand, rcand, ccand, inten = bcand[order], rcand[order], ccand[order], inten[order] + + counts = torch.bincount(bcand, minlength=B) + kmax = int(counts.max()) + + # Pack candidates into a dense (B, kmax) grid indexed by [pattern, brightness rank]. + starts = torch.zeros(B, dtype=torch.long, device=device) + starts[1:] = torch.cumsum(counts, 0)[:-1] + rank = torch.arange(bcand.shape[0], device=device) - starts[bcand] + flat = bcand * kmax + rank + + rows = torch.zeros(B * kmax, dtype=dtype, device=device) + cols = torch.zeros(B * kmax, dtype=dtype, device=device) + vals = torch.zeros(B * kmax, dtype=dtype, device=device) + valid = torch.zeros(B * kmax, dtype=torch.bool, device=device) + rows[flat] = rcand.to(dtype) + cols[flat] = ccand.to(dtype) + vals[flat] = inten + valid[flat] = True + rows, cols, valid = rows.view(B, kmax), cols.view(B, kmax), valid.view(B, kmax) + + keep = valid.clone() + if min_spacing > 0 and kmax > 1: + s2 = float(min_spacing) ** 2 + for k in range(kmax - 1): + active = keep[:, k] & valid[:, k] # (B,) brightest unsuppressed peak at rank k + dr = rows[:, k + 1 :] - rows[:, k : k + 1] + dc = cols[:, k + 1 :] - cols[:, k : k + 1] + suppress = ((dr * dr + dc * dc) < s2) & active[:, None] & valid[:, k + 1 :] + keep[:, k + 1 :] &= ~suppress + + if max_num_peaks is not None: + kept_rank = torch.cumsum(keep.to(torch.long), dim=1) - 1 + keep &= kept_rank < max_num_peaks + + final = (keep & valid).view(-1) + sel = torch.nonzero(final, as_tuple=True)[0] # row-major: pattern-major, rank-minor + peaks = torch.stack((rows.view(-1)[sel], cols.view(-1)[sel], vals[sel]), dim=1) + bidx = torch.div(sel, kmax, rounding_mode="floor") + counts = torch.bincount(bidx, minlength=B) + return peaks, bidx, counts + + +def _split_by_counts(peaks: torch.Tensor, counts: torch.Tensor) -> list[np.ndarray]: + """Split a pattern-grouped ``(T, 3)`` peak stack into one ``(M, 3)`` array per pattern. + + Parameters + ---------- + peaks : torch.Tensor + ``(T, 3)`` peaks ordered by ascending pattern index (contiguous per pattern). + counts : torch.Tensor + ``(B,)`` number of peaks belonging to each pattern, in order. + + Returns + ------- + list of np.ndarray + Length-``B`` list of ``(M, 3)`` arrays. + """ + out = [] + start = 0 + for cnt in counts.tolist(): + out.append(_to_numpy(peaks[start : start + cnt])) + start += cnt + return out + + +def _refine_parabolic(corr_map: torch.Tensor, peaks: torch.Tensor) -> torch.Tensor: + """3-point quadratic subpixel refinement of every peak (vectorized over peaks). + + Parameters + ---------- + corr_map : torch.Tensor + ``(H, W)`` correlation map the peaks were found in. + peaks : torch.Tensor + ``(M, 3)`` ``[row, col, intensity]`` peaks to refine. + + Returns + ------- + torch.Tensor + ``(M, 3)`` peaks with subpixel ``[row, col]`` and bilinearly interpolated + intensity. + """ + if peaks.shape[0] == 0: + return peaks + a = corr_map + H, W = a.shape + out = peaks.clone() + zero = torch.zeros(out.shape[0], device=a.device, dtype=a.dtype) + + r = out[:, 0].round().long().clamp(0, H - 1) + c = out[:, 1].round().long().clamp(0, W - 1) + + r_in = (r > 0) & (r < H - 1) + ix0, ix1, ix2 = a[(r - 1).clamp(0, H - 1), c], a[r, c], a[(r + 1).clamp(0, H - 1), c] + denom_r = 4.0 * ix1 - 2.0 * ix2 - 2.0 * ix0 + dr = torch.where(r_in & (denom_r != 0), (ix2 - ix0) / denom_r, zero) + + c_in = (c > 0) & (c < W - 1) + iy0, iy1, iy2 = a[r, (c - 1).clamp(0, W - 1)], a[r, c], a[r, (c + 1).clamp(0, W - 1)] + denom_c = 4.0 * iy1 - 2.0 * iy2 - 2.0 * iy0 + dc = torch.where(c_in & (denom_c != 0), (iy2 - iy0) / denom_c, zero) + + r_sub = r.to(a.dtype) + dr + c_sub = c.to(a.dtype) + dc + out[:, 0] = r_sub + out[:, 1] = c_sub + out[:, 2] = _bilinear(a, r_sub, c_sub) + return out + + +def _refine_parabolic_batched( + corr: torch.Tensor, peaks: torch.Tensor, bidx: torch.Tensor +) -> torch.Tensor: + """3-point quadratic refinement of peaks pooled across a batch of correlation maps. + + Batched form of :func:`_refine_parabolic`. Reading ``corr[bidx, r, c]`` gathers + exactly the values the single-pattern path would read, so each peak refines + identically. + + Parameters + ---------- + corr : torch.Tensor + ``(B, H, W)`` correlation maps. + peaks : torch.Tensor + ``(T, 3)`` ``[row, col, intensity]`` peaks pooled over all patterns. + bidx : torch.Tensor + ``(T,)`` batch index selecting each peak's correlation map. + + Returns + ------- + torch.Tensor + ``(T, 3)`` peaks with subpixel ``[row, col]`` and bilinearly interpolated + intensity. + """ + if peaks.shape[0] == 0: + return peaks + H, W = corr.shape[-2], corr.shape[-1] + dtype = corr.dtype + out = peaks.clone() + zero = torch.zeros(out.shape[0], device=corr.device, dtype=dtype) + + r = out[:, 0].round().long().clamp(0, H - 1) + c = out[:, 1].round().long().clamp(0, W - 1) + + r_in = (r > 0) & (r < H - 1) + ix0 = corr[bidx, (r - 1).clamp(0, H - 1), c] + ix1 = corr[bidx, r, c] + ix2 = corr[bidx, (r + 1).clamp(0, H - 1), c] + denom_r = 4.0 * ix1 - 2.0 * ix2 - 2.0 * ix0 + dr = torch.where(r_in & (denom_r != 0), (ix2 - ix0) / denom_r, zero) + + c_in = (c > 0) & (c < W - 1) + iy0 = corr[bidx, r, (c - 1).clamp(0, W - 1)] + iy1 = corr[bidx, r, c] + iy2 = corr[bidx, r, (c + 1).clamp(0, W - 1)] + denom_c = 4.0 * iy1 - 2.0 * iy2 - 2.0 * iy0 + dc = torch.where(c_in & (denom_c != 0), (iy2 - iy0) / denom_c, zero) + + r_sub = r.to(dtype) + dr + c_sub = c.to(dtype) + dc + out[:, 0] = r_sub + out[:, 1] = c_sub + out[:, 2] = _bilinear_batched(corr, bidx, r_sub, c_sub) + return out + + +def _refine_dft(m: torch.Tensor, peaks: torch.Tensor, upsample_factor: int) -> torch.Tensor: + """Guizar-Sicairos DFT upsampling refinement of every peak (vectorized over peaks). + + Each peak is rounded to half-pixel precision (matching py4DSTEM multicorr) before + upsampling, then all peaks are refined together with one batched DFT upsampling. + + Parameters + ---------- + m : torch.Tensor + ``(H, W)`` Fourier-domain correlation product from :func:`cross_correlation`. + peaks : torch.Tensor + ``(M, 3)`` ``[row, col, intensity]`` peaks to refine. + upsample_factor : int + DFT upsampling factor. + + Returns + ------- + torch.Tensor + ``(M, 3)`` peaks with DFT-refined ``[row, col]`` (intensity unchanged). + """ + if peaks.shape[0] == 0: + return peaks + out = peaks.clone() + xy = torch.round(out[:, :2] * 2.0) / 2.0 + refined = _upsampled_correlation_batch(m, int(upsample_factor), xy) + out[:, :2] = refined + return out + + +def _refine_dft_batched( + m: torch.Tensor, peaks: torch.Tensor, bidx: torch.Tensor, upsample_factor: int +) -> torch.Tensor: + """DFT-upsampling refinement of peaks pooled across a batch of correlation products. + + Batched form of :func:`_refine_dft`. Peaks are processed in sub-chunks so the + gathered per-peak product ``m[bidx]`` (``(chunk, H, W)`` complex) stays within a + fixed memory budget regardless of how many peaks were found. + + Parameters + ---------- + m : torch.Tensor + ``(B, H, W)`` Fourier-domain correlation products. + peaks : torch.Tensor + ``(T, 3)`` ``[row, col, intensity]`` peaks pooled over all patterns. + bidx : torch.Tensor + ``(T,)`` batch index selecting each peak's correlation product. + upsample_factor : int + DFT upsampling factor. + + Returns + ------- + torch.Tensor + ``(T, 3)`` peaks with DFT-refined ``[row, col]`` (intensity unchanged). + """ + if peaks.shape[0] == 0: + return peaks + out = peaks.clone() + M, N = m.shape[-2], m.shape[-1] + xy = torch.round(out[:, :2] * 2.0) / 2.0 + cap = max(1, 8_000_000 // (M * N)) + total = peaks.shape[0] + for start in range(0, total, cap): + stop = min(start + cap, total) + b = bidx[start:stop] + out[start:stop, :2] = _upsampled_correlation_batch( + m[b], int(upsample_factor), xy[start:stop] + ) + return out + + +def _upsampled_correlation_batch( + m: torch.Tensor, upsample_factor: int, xy: torch.Tensor +) -> torch.Tensor: + """Batched DFT upsampling of the correlation peak for many shifts at once. + + Vectorizes :func:`~quantem.core.utils.imaging_utils.upsampled_correlation_torch` + over the shifts ``xy`` with batched matmuls. Two exact speedups are applied: (1) ``m`` + is conjugate-symmetric (the diffraction pattern and template are real), so only the + non-negative half of the row frequencies is contracted — halving the dominant matmul — + with a rank-1 correction on the Nyquist column; (2) the per-peak DFT kernels are + factored into shared base kernels times per-peak phases, so ``torch.exp`` runs on far + fewer elements. Both are algebraically identical to the full-spectrum DFT upsample. + + Parameters + ---------- + m : torch.Tensor + ``(M, N)`` or ``(K, M, N)`` Fourier-domain correlation product(s). A single + ``(M, N)`` product broadcasts over all ``K`` shifts. + upsample_factor : int + DFT upsampling factor. + xy : torch.Tensor + ``(K, 2)`` ``[row, col]`` peak shifts at half-pixel precision. + + Returns + ------- + torch.Tensor + ``(K, 2)`` DFT-refined ``[row, col]`` positions. + """ + import math + + device = m.device + dtype = torch.get_default_dtype() + uf = float(upsample_factor) + M, N = m.shape[-2], m.shape[-1] + + xy = torch.round(xy * uf) / uf + global_shift = math.floor(math.ceil(uf * 1.5) / 2.0) + upsample_center = global_shift - uf * xy # (K, 2): [row, col] + + num = int(math.ceil(1.5 * uf)) + half = M // 2 + 1 + col_freq = (torch.fft.ifftshift(torch.arange(N, device=device)) - math.floor(N / 2)).to(dtype) + row_freq = (torch.fft.ifftshift(torch.arange(M, device=device)) - math.floor(M / 2)).to(dtype) + + # ``m = F(dp) * conj(F(template))`` is 2D conjugate-symmetric for real ``dp`` and + # template, so the upper-half row frequencies are redundant. Contract only the + # non-negative half (rows ``u = 0 .. M // 2``) and fold the conjugate upper rows + # back in with a weight of 2 — 1 for the self-paired DC row and, when ``M`` is even, + # the Nyquist row. This halves the dominant matmul over the row axis. + row_freq = row_freq[:half] + row_weight = torch.full((half,), 2.0, device=device, dtype=dtype) + row_weight[0] = 1.0 + if M % 2 == 0: + row_weight[-1] = 1.0 + + base = torch.arange(num, device=device, dtype=dtype) + factor_col = -2j * math.pi / (N * uf) + factor_row = -2j * math.pi / (M * uf) + + # Factor each kernel ``exp(f · freq · (base - center))`` into a peak-independent base + # kernel ``exp(f · freq · base)`` times a per-peak phase ``exp(-f · freq · center)``. + # The base kernels are tiny and built once; the phases fold into ``m`` and the + # intermediate product, so ``torch.exp`` runs on ~20x fewer elements. + base_col = torch.exp(factor_col * (col_freq[:, None] * base[None, :])) # (N, num) + base_row = torch.exp(factor_row * (base[:, None] * row_freq[None, :]))[None] # (1, num, half) + phase_col = torch.exp(-factor_col * (col_freq[None, :] * upsample_center[:, 1:2])) # (K, N) + phase_row = torch.exp(-factor_row * (upsample_center[:, 0:1] * row_freq[None, :])) # (K, half) + + mc_half = m.conj()[..., :half, :] # (K, half, N) — non-negative row frequencies only + # Fold the row phase and conjugate-symmetry weight into ``m`` so the shared base row + # kernel is reused across peaks; then contract rows, apply the column phase, and + # contract columns against the shared base column kernel. + prod = torch.matmul(base_row, mc_half * (row_weight * phase_row)[:, :, None]) # (K, num, N) + up = torch.matmul(prod * phase_col[:, None, :], base_col) # (K, num, num) + + if N % 2 == 0: + # Row-only folding is exact for every column except the Nyquist column ``v = N/2``, + # whose conjugate partner stays in the same column. Correct it with an exact rank-1 + # update ``pr_diff ⊗ col_kern[:, N/2]``, where ``pr_diff = -2i Im(S)`` and ``S`` sums + # only the kept interior rows. + nyq = N // 2 + interior = row_weight - 1.0 # 1 on interior rows, 0 on the DC / Nyquist rows + s = torch.matmul(base_row, (mc_half[..., nyq] * (interior * phase_row)).unsqueeze(-1)) + pr_diff = -2j * s[..., 0].imag # (K, num) + col_nyq = base_col[nyq, :][None, :] * phase_col[:, nyq : nyq + 1] # (K, num) + up = up + pr_diff[:, :, None] * col_nyq[:, None, :] + + image_up = up.real + + K = xy.shape[0] + kidx = torch.arange(K, device=device) + idx = torch.argmax(image_up.reshape(K, -1), dim=1) + sub_r = torch.div(idx, num, rounding_mode="floor") + sub_c = idx % num + + # 3-point parabolic refinement around the upsampled maximum (interior only) + interior = (sub_r > 0) & (sub_r < num - 1) & (sub_c > 0) & (sub_c < num - 1) + rr = sub_r.clamp(1, num - 2) + cc = sub_c.clamp(1, num - 2) + c11 = image_up[kidx, rr, cc] + c21, c01 = image_up[kidx, rr + 1, cc], image_up[kidx, rr - 1, cc] + c12, c10 = image_up[kidx, rr, cc + 1], image_up[kidx, rr, cc - 1] + zero = torch.zeros(K, device=device, dtype=dtype) + denom_x = 4.0 * c11 - 2.0 * c21 - 2.0 * c01 + denom_y = 4.0 * c11 - 2.0 * c12 - 2.0 * c10 + dx = torch.where(interior & (denom_x != 0), (c21 - c01) / denom_x, zero) + dy = torch.where(interior & (denom_y != 0), (c12 - c10) / denom_y, zero) + + sub = torch.stack([sub_r.to(dtype), sub_c.to(dtype)], dim=1) - global_shift + return xy + (sub + torch.stack([dx, dy], dim=1)) / uf + + +def _bilinear(a: torch.Tensor, r: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + """Bilinear interpolation of ``a`` at fractional ``(r, c)`` (vectorized over peaks). + + Parameters + ---------- + a : torch.Tensor + ``(H, W)`` map to sample. + r : torch.Tensor + ``(K,)`` fractional row coordinates. + c : torch.Tensor + ``(K,)`` fractional column coordinates. + + Returns + ------- + torch.Tensor + ``(K,)`` interpolated values. + """ + H, W = a.shape + r0 = torch.floor(r).long() + c0 = torch.floor(c).long() + r1 = (r0 + 1).clamp(max=H - 1) + c1 = (c0 + 1).clamp(max=W - 1) + r0 = r0.clamp(0, H - 1) + c0 = c0.clamp(0, W - 1) + dr = r - r0.to(a.dtype) + dc = c - c0.to(a.dtype) + return ( + (1 - dr) * (1 - dc) * a[r0, c0] + + (1 - dr) * dc * a[r0, c1] + + dr * (1 - dc) * a[r1, c0] + + dr * dc * a[r1, c1] + ) + + +def _bilinear_batched( + corr: torch.Tensor, bidx: torch.Tensor, r: torch.Tensor, c: torch.Tensor +) -> torch.Tensor: + """Bilinear interpolation of a batch ``corr`` ``(B, H, W)`` at per-peak ``(r, c)``. + + Batched form of :func:`_bilinear`. The four corner samples are gathered as + ``corr[bidx, r0, c0]`` etc., matching the single-pattern interpolation + value-for-value. + + Parameters + ---------- + corr : torch.Tensor + ``(B, H, W)`` maps to sample. + bidx : torch.Tensor + ``(T,)`` batch index selecting each peak's map. + r : torch.Tensor + ``(T,)`` fractional row coordinates. + c : torch.Tensor + ``(T,)`` fractional column coordinates. + + Returns + ------- + torch.Tensor + ``(T,)`` interpolated values. + """ + H, W = corr.shape[-2], corr.shape[-1] + dtype = corr.dtype + r0 = torch.floor(r).long() + c0 = torch.floor(c).long() + r1 = (r0 + 1).clamp(max=H - 1) + c1 = (c0 + 1).clamp(max=W - 1) + r0 = r0.clamp(0, H - 1) + c0 = c0.clamp(0, W - 1) + dr = r - r0.to(dtype) + dc = c - c0.to(dtype) + return ( + (1 - dr) * (1 - dc) * corr[bidx, r0, c0] + + (1 - dr) * dc * corr[bidx, r0, c1] + + dr * (1 - dc) * corr[bidx, r1, c0] + + dr * dc * corr[bidx, r1, c1] + ) + + +def _to_numpy(peaks: torch.Tensor) -> np.ndarray: + """Convert a peaks tensor to a contiguous ``(M, 3)`` float64 numpy array. + + Parameters + ---------- + peaks : torch.Tensor + ``(M, 3)`` (or flat) peaks tensor on any device. + + Returns + ------- + np.ndarray + ``(M, 3)`` ``float64`` array of ``[q_row, q_col, intensity]`` rows. + """ + return peaks.detach().cpu().numpy().astype(np.float64).reshape(-1, 3) diff --git a/src/quantem/diffraction/maped.py b/src/quantem/diffraction/maped.py new file mode 100644 index 00000000..3b5154c0 --- /dev/null +++ b/src/quantem/diffraction/maped.py @@ -0,0 +1,892 @@ +from __future__ import annotations + +import warnings +from typing import Any, Sequence + +import numpy as np +from scipy.ndimage import gaussian_filter, shift as ndi_shift +from scipy.signal import convolve2d +from scipy.signal.windows import tukey +from tqdm import tqdm + +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.io.serialize import AutoSerialize +from quantem.core.utils.imaging_utils import weighted_cross_correlation_shift +from quantem.core.visualization import show_2d + + +class MAPED(AutoSerialize): + """ + Merge-Averaged Precession Electron Diffraction (MAPED) helper. + + This class manages a set of 4D-STEM datasets and provides utilities to: + - compute mean BF and mean DP summaries, + - choose/find diffraction origins, + - align diffraction space and real space, + - merge datasets into a single composite Dataset4dstem. + """ + + _token = object() + + def __init__(self, datasets: list[Dataset4dstem], _token: object | None = None): + if _token is not self._token: + raise RuntimeError("Use MAPED.from_datasets() to instantiate this class.") + super().__init__() + self.datasets = datasets + self.metadata: dict[str, Any] = {} + + @classmethod + def from_datasets(cls, datasets: Sequence[Dataset4dstem]) -> MAPED: + """ + Construct a MAPED instance from a non-empty sequence of Dataset4dstem. + + Parameters + ---------- + datasets + Sequence of Dataset4dstem instances. + + Returns + ------- + MAPED + New MAPED instance. + """ + if not isinstance(datasets, Sequence) or isinstance(datasets, (str, bytes)): + raise TypeError("MAPED.from_datasets expects a sequence of Dataset4dstem instances.") + ds_list: list[Dataset4dstem] = [] + for d in datasets: + if not isinstance(d, Dataset4dstem): + raise TypeError("MAPED.from_datasets expects a sequence of Dataset4dstem instances.") + ds_list.append(d) + if not ds_list: + raise ValueError("MAPED.from_datasets expects a non-empty sequence of Dataset4dstem instances.") + return cls(datasets=ds_list, _token=cls._token) + + def preprocess( + self, + plot_summary: bool = True, + scale: float | Sequence[float] | None = None, + **plot_kwargs: Any, + ) -> MAPED: + """ + Compute dataset summary images. + + Stores + ------ + self.scales : np.ndarray + Per-dataset scaling factors (n,). + self.dp_mean : list[np.ndarray] + Mean diffraction patterns (H, W), one per dataset. + self.im_bf : list[np.ndarray] + Mean bright-field images (R, C), one per dataset. + """ + n = len(self.datasets) + if scale is None: + self.scales = np.ones(n, dtype=float) + elif isinstance(scale, (int, float, np.floating)): + self.scales = np.full(n, float(scale), dtype=float) + else: + self.scales = np.asarray(list(scale), dtype=float) + if self.scales.shape != (n,): + raise ValueError("scale must be a scalar or a sequence with the same length as datasets.") + if np.any(self.scales == 0): + raise ValueError("scale entries must be nonzero.") + + self.dp_mean: list[np.ndarray] = [] + self.im_bf: list[np.ndarray] = [] + + for d in self.datasets: + if hasattr(d, "get_dp_mean"): + try: + d.get_dp_mean() + except TypeError: + try: + d.get_dp_mean(returnval=False) + except Exception: + pass + + dp = getattr(d, "dp_mean", None) + if dp is None: + arr = np.asarray(d.array) + dp_arr = np.mean(arr, axis=(0, 1)) + else: + dp_arr = np.asarray(dp.array if hasattr(dp, "array") else dp) + + arr = np.asarray(d.array) + im_bf_arr = np.mean(arr, axis=(2, 3)) + + self.dp_mean.append(np.asarray(dp_arr)) + self.im_bf.append(np.asarray(im_bf_arr)) + + if plot_summary: + tiles = [[(self.im_bf[i] / self.scales[i]), self.dp_mean[i]] for i in range(n)] + titles = [[f"{i} - Mean Bright Field", f"{i} - Mean Diffraction Pattern"] for i in range(n)] + show_2d(tiles, title=titles, **plot_kwargs) + + return self + + def diffraction_origin( + self, + origins=None, + sigma=None, + plot_origins: bool = True, + plot_indices=None, + **plot_kwargs: Any, + ) -> MAPED: + """ + Choose or automatically find the origin in diffraction space. + + Parameters + ---------- + origins + Optional manual origins. Can be: + - a single (row, col) tuple, applied to all datasets + - a list of (row, col) tuples of length n (one per dataset) + sigma + Optional low-pass smoothing sigma (pixels) applied to each mean DP prior to peak finding. + plot_origins + If True, plot mean diffraction patterns with overlaid origin markers. + plot_indices + Optional indices to plot. If None, plots all datasets. + **plot_kwargs + Passed to show_2d. + + Stores + ------ + self.diffraction_origins : np.ndarray + Array of shape (n, 2) with integer (row, col) origins. + """ + n = len(self.datasets) + if not hasattr(self, "dp_mean"): + raise RuntimeError("Run preprocess() first so self.dp_mean exists.") + + if plot_indices is None: + plot_indices_list = list(range(n)) + else: + plot_indices_list = list(plot_indices) + for i in plot_indices_list: + if i < 0 or i >= n: + raise IndexError("plot_indices contains an out-of-range index.") + + if origins is None: + origins_arr = np.zeros((n, 2), dtype=int) + for i in range(n): + dp = np.asarray(self.dp_mean[i]) + if sigma is not None and float(sigma) > 0: + dp_use = gaussian_filter(dp.astype(float, copy=False), float(sigma), mode="nearest") + else: + dp_use = dp + r, c = np.unravel_index(int(np.argmax(dp_use)), dp_use.shape) + origins_arr[i, 0] = int(r) + origins_arr[i, 1] = int(c) + else: + if isinstance(origins, tuple) and len(origins) == 2: + origins_arr = np.tile(np.asarray(origins, dtype=int)[None, :], (n, 1)) + else: + origins_list = list(origins) + if len(origins_list) != n: + raise ValueError("origins must be a single (row,col) tuple or a list of length n.") + origins_arr = np.asarray(origins_list, dtype=int) + if origins_arr.shape != (n, 2): + raise ValueError("origins must have shape (n, 2) after conversion.") + + self.diffraction_origins = origins_arr + + if plot_origins: + arrays = [np.asarray(self.dp_mean[i]) for i in plot_indices_list] + titles = [f"{i} - Mean Diffraction Pattern" for i in plot_indices_list] + fig, ax = show_2d(arrays, title=titles, returnfig=True, **plot_kwargs) + axs = np.ravel(np.asarray(ax, dtype=object)) + for j, i in enumerate(plot_indices_list): + r, c = self.diffraction_origins[i] + axs[j].plot([c], [r], marker="+", color="red", markersize=16, markeredgewidth=2) + + return self + + def diffraction_align( + self, + edge_blend: float = 16.0, + padding=None, + pad_val: str | float = "min", + upsample_factor: int = 100, + weight_scale: float = 1 / 8, + plot_aligned: bool = True, + **plot_kwargs: Any, + ) -> MAPED: + """ + Align mean diffraction patterns using weighted cross-correlation in Fourier space. + + Parameters + ---------- + edge_blend + Tukey window edge taper (pixels). + padding + Passed to shift_images for plotting. + pad_val + Passed to shift_images for plotting. + upsample_factor + Subpixel upsampling factor for correlation peak estimation. + weight_scale + Radial weight falloff scale (fraction of mean DP size). + plot_aligned + If True, plot aligned mean diffraction patterns. + **plot_kwargs + Passed to show_2d when plotting. + + Stores + ------ + self.diffraction_shifts : np.ndarray + Array of shape (n, 2) with (row, col) shifts to align diffraction patterns. + """ + if not hasattr(self, "dp_mean"): + raise RuntimeError("Run preprocess() first so self.dp_mean exists.") + if not hasattr(self, "diffraction_origins"): + raise RuntimeError("Run diffraction_origin() first so self.diffraction_origins exists.") + + H, W = np.asarray(self.dp_mean[0]).shape + + w = tukey(H, alpha=2.0 * float(edge_blend) / float(H))[:, None] * tukey( + W, alpha=2.0 * float(edge_blend) / float(W) + )[None, :] + + r = np.fft.fftfreq(H, 1.0 / float(H))[:, None] + c = np.fft.fftfreq(W, 1.0 / float(W))[None, :] + + n = len(self.dp_mean) + self.diffraction_shifts = np.zeros((n, 2), dtype=float) + + G_ref = np.fft.fft2(w * np.asarray(self.dp_mean[0])) + xy0 = np.asarray(self.diffraction_origins[0], dtype=float) + + for ind in range(1, n): + G = np.fft.fft2(w * np.asarray(self.dp_mean[ind])) + xy = np.asarray(self.diffraction_origins[ind], dtype=float) + + dr2 = (r - xy0[0] + xy[0]) ** 2 + (c - xy0[1] + xy[1]) ** 2 + im_weight = np.clip( + 1.0 - np.sqrt(dr2) / float(np.mean((H, W))) / float(weight_scale), + 0.0, + 1.0, + ) + im_weight = np.sin(im_weight * np.pi / 2.0) ** 2 + + shift_rc, G_shift = weighted_cross_correlation_shift( + im_ref=G_ref, + im=G, + weight_real=im_weight * 0.0 + 1.0, + upsample_factor=int(upsample_factor), + fft_input=True, + fft_output=True, + return_shifted_image=True, + ) + self.diffraction_shifts[ind, :] = np.asarray(shift_rc, dtype=float) + + G_ref = G_ref * (ind / (ind + 1)) + G_shift / (ind + 1) + + self.diffraction_shifts -= np.mean(self.diffraction_shifts, axis=0)[None, :] + + if plot_aligned: + im_aligned = shift_images( + images=self.dp_mean, + shifts_rc=self.diffraction_shifts, + edge_blend=float(edge_blend), + padding=padding, + pad_val=pad_val, + ) + show_2d(im_aligned, **plot_kwargs) + + return self + + + def real_space_align( + self, + num_images=None, + num_iter: int = 3, + edge_blend: float = 1.0, + padding=None, + pad_val: str | float = "median", + upsample_factor: int = 100, + max_shift=None, + shift_method: str = "bilinear", + edge_filter: bool = True, + edge_sigma: float = 2.0, + hanning_filter: bool = False, + plot_aligned: bool = True, + **plot_kwargs: Any, + ) -> MAPED: + """ + Align real-space mean BF images using iterative average-reference correlation. + + Parameters + ---------- + num_images + If provided, align only the first num_images images. + num_iter + Number of refinement iterations. + edge_blend + Used to set default correlation padding when max_shift is None. + padding + Passed to shift_images for plotting. + pad_val + Passed to shift_images for plotting. + upsample_factor + Subpixel upsampling factor for correlation peak estimation. + max_shift + Optional maximum shift constraint passed to weighted_cross_correlation_shift. + shift_method + Passed to shift_images for plotting ('bilinear' or 'fourier'). + edge_filter + If True, correlate on gradient magnitude instead of raw intensity. + edge_sigma + Gaussian sigma applied to gradients when edge_filter is True. + hanning_filter + If True, apply a Hanning window prior to FFT. + plot_aligned + If True, plot aligned mean BF images. + **plot_kwargs + Passed to show_2d when plotting. + + Stores + ------ + self.real_space_shifts : np.ndarray + Array of shape (n_total, 2) with (row, col) shifts for aligned datasets. + """ + if not hasattr(self, "im_bf"): + raise RuntimeError("Run preprocess() first so self.im_bf exists.") + if len(self.im_bf) == 0: + raise RuntimeError("No images found in self.im_bf.") + + H, W = self.im_bf[0].shape + for im in self.im_bf: + if im.shape != (H, W): + raise ValueError("all self.im_bf images must have the same shape") + + n_total = len(self.im_bf) + if num_images is None: + n = n_total + else: + n = int(num_images) + if n <= 0: + raise ValueError("num_images must be positive") + n = min(n, n_total) + + if int(num_iter) < 1: + raise ValueError("num_iter must be >= 1") + + if max_shift is not None: + pad_cc = int(np.ceil(float(max_shift))) + 4 + else: + pad_cc = int(np.ceil(float(edge_blend))) + 4 + + Hp = H + 2 * pad_cc + Wp = W + 2 * pad_cc + r0 = pad_cc + c0 = pad_cc + + w_h = np.ones((H, W), dtype=float) + if hanning_filter: + w_h = np.hanning(H)[:, None] * np.hanning(W)[None, :] + w_h_pad = np.zeros((Hp, Wp), dtype=float) + w_h_pad[r0 : r0 + H, c0 : c0 + W] = w_h + w_h_sum = float(np.sum(w_h_pad)) + if w_h_sum <= 0: + raise RuntimeError("hanning window sum is zero") + + if edge_filter: + wx = np.array( + [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]], + dtype=float, + ) + else: + wx = None + + base_pad = np.zeros((n, Hp, Wp), dtype=float) + for i in range(n): + im0 = np.asarray(self.im_bf[i], dtype=float) + + if edge_filter: + gx = convolve2d(im0, wx, mode="same", boundary="symm") + gy = convolve2d(im0, wx.T, mode="same", boundary="symm") + gx = gaussian_filter(gx, float(edge_sigma), mode="nearest") + gy = gaussian_filter(gy, float(edge_sigma), mode="nearest") + im_use = np.sqrt(gx * gx + gy * gy) + else: + im_use = im0 + + base_pad[i, r0 : r0 + H, c0 : c0 + W] = im_use + + shifts = np.zeros((n, 2), dtype=float) + + for _ in range(int(num_iter)): + G_list = np.empty((n, Hp, Wp), dtype=np.complex128) + + for i in range(n): + im_a = ndi_shift( + base_pad[i], + shift=(shifts[i, 0], shifts[i, 1]), + order=1, + mode="constant", + cval=0.0, + prefilter=False, + ) + im_mean = float(np.sum(im_a * w_h_pad) / w_h_sum) + im_win = (im_a - im_mean) * w_h_pad + G_list[i] = np.fft.fft2(im_win) + + G_ref = np.mean(G_list, axis=0) + + for i in range(1, n): + drc = weighted_cross_correlation_shift( + im_ref=G_ref, + im=G_list[i], + weight_real=None, + upsample_factor=int(upsample_factor), + max_shift=max_shift, + fft_input=True, + fft_output=False, + return_shifted_image=False, + ) + shifts[i, 0] += float(drc[0]) + shifts[i, 1] += float(drc[1]) + + shifts -= shifts[0][None, :] + + shifts -= np.mean(shifts, axis=0)[None, :] + + self.real_space_shifts = np.zeros((n_total, 2), dtype=float) + self.real_space_shifts[:n, :] = shifts + + if plot_aligned: + im_aligned = shift_images( + images=self.im_bf[:n], + shifts_rc=self.real_space_shifts[:n, :], + edge_blend=float(edge_blend), + padding=padding, + pad_val=pad_val, + shift_method=shift_method, + ) + show_2d(im_aligned, **plot_kwargs) + + return self + + def merge_datasets( + self, + real_space_padding=0, + real_space_edge_blend=1.0, + diffraction_padding=0, + diffraction_edge_blend=0.0, + diffraction_pad_val="min", + shift_method: str = "bilinear", + dtype=None, + scale_output: bool = False, + plot_result: bool = True, + **plot_kwargs: Any, + ) -> Dataset4dstem: + """ + Merge aligned datasets into a single Dataset4dstem. + + Requires + -------- + self.real_space_shifts + From real_space_align(). + self.diffraction_shifts + From diffraction_align(). + + Parameters + ---------- + real_space_padding + Output scan padding in pixels (adds border to scan grid). + real_space_edge_blend + Tukey taper width for scan-space interpolation weights. + diffraction_padding + Output diffraction padding in pixels (adds border around DPs). + diffraction_edge_blend + Tukey taper width for diffraction-space weights. + diffraction_pad_val + Pad value for diffraction padding ('min','max','mean','median' or float). + shift_method + Diffraction shift method: 'bilinear' or 'fourier'. + dtype + Output dtype. If None, uses parent dtype. + scale_output + If True and dtype is integer, scale to full dynamic range using global max. + plot_result + If True, plot merged BF and merged mean DP. + **plot_kwargs + Passed to show_2d. + + Returns + ------- + Dataset4dstem + Merged dataset. + """ + if not hasattr(self, "real_space_shifts"): + raise RuntimeError("Run real_space_align() first so self.real_space_shifts exists.") + if not hasattr(self, "diffraction_shifts"): + raise RuntimeError("Run diffraction_align() first so self.diffraction_shifts exists.") + + arrays = [np.asarray(d.array) for d in self.datasets] + n = len(arrays) + if n == 0: + raise RuntimeError("No datasets found in self.datasets.") + + Rs, Cs, H, W = arrays[0].shape + for a in arrays: + if a.shape != (Rs, Cs, H, W): + raise ValueError("All dataset arrays must have the same shape (Rs, Cs, H, W).") + + rs_shifts = np.asarray(self.real_space_shifts, dtype=float) + dp_shifts = np.asarray(self.diffraction_shifts, dtype=float) + if rs_shifts.shape != (n, 2): + raise ValueError("self.real_space_shifts must have shape (n, 2).") + if dp_shifts.shape != (n, 2): + raise ValueError("self.diffraction_shifts must have shape (n, 2).") + + if dtype is None: + dtype_out = np.asarray(arrays[0]).dtype + warnings.warn(f"dtype=None; using parent dtype {dtype_out}.", RuntimeWarning) + else: + dtype_out = np.dtype(dtype) + + real_space_padding = int(real_space_padding) + diffraction_padding = int(diffraction_padding) + + Rout = Rs + 2 * real_space_padding + Cout = Cs + 2 * real_space_padding + + Hp = H + 2 * diffraction_padding + Wp = W + 2 * diffraction_padding + rp0 = diffraction_padding + cp0 = diffraction_padding + + method = str(shift_method).strip().lower() + if method not in {"bilinear", "fourier"}: + raise ValueError("shift_method must be 'bilinear' or 'fourier'.") + + if real_space_edge_blend and float(real_space_edge_blend) > 0: + alpha_r = min(1.0, 2.0 * float(real_space_edge_blend) / float(Rs)) + alpha_c = min(1.0, 2.0 * float(real_space_edge_blend) / float(Cs)) + w_rs = tukey(Rs, alpha=alpha_r)[:, None] * tukey(Cs, alpha=alpha_c)[None, :] + else: + w_rs = np.ones((Rs, Cs), dtype=float) + w_rs = w_rs.astype(float, copy=False) + + if diffraction_edge_blend and float(diffraction_edge_blend) > 0: + alpha_dr = min(1.0, 2.0 * float(diffraction_edge_blend) / float(H)) + alpha_dc = min(1.0, 2.0 * float(diffraction_edge_blend) / float(W)) + w_dp = tukey(H, alpha=alpha_dr)[:, None] * tukey(W, alpha=alpha_dc)[None, :] + else: + w_dp = np.ones((H, W), dtype=float) + w_dp = w_dp.astype(float, copy=False) + + dp_means = [np.mean(a, axis=(0, 1), dtype=np.float64) for a in arrays] + v = np.stack(dp_means, axis=0).reshape(-1) + + if isinstance(diffraction_pad_val, str): + s = diffraction_pad_val.strip().lower() + if s == "min": + pad_val_dp = float(np.min(v)) + elif s == "max": + pad_val_dp = float(np.max(v)) + elif s == "mean": + pad_val_dp = float(np.mean(v)) + elif s == "median": + pad_val_dp = float(np.median(v)) + else: + raise ValueError("diffraction_pad_val must be a float or one of {'min','max','mean','median'}.") + else: + pad_val_dp = float(diffraction_pad_val) + + wdp_pad = np.zeros((Hp, Wp), dtype=float) + wdp_pad[rp0 : rp0 + H, cp0 : cp0 + W] = w_dp + + wdp_shifted = np.zeros((n, Hp, Wp), dtype=float) + if method == "fourier": + kr = np.fft.fftfreq(Hp)[:, None] + kc = np.fft.fftfreq(Wp)[None, :] + Fw = np.fft.fft2(wdp_pad) + ramps: list[np.ndarray] = [] + for i in range(n): + dr, dc = dp_shifts[i, 0], dp_shifts[i, 1] + ramp = np.exp(-2j * np.pi * (kr * dr + kc * dc)) + ramps.append(ramp) + w_i = np.fft.ifft2(Fw * ramp).real + wdp_shifted[i] = np.clip(w_i, 0.0, 1.0) + else: + for i in range(n): + w_i = ndi_shift( + wdp_pad, + shift=(dp_shifts[i, 0], dp_shifts[i, 1]), + order=1, + mode="constant", + cval=0.0, + prefilter=False, + ) + wdp_shifted[i] = np.clip(w_i, 0.0, 1.0) + ramps = [] + + coverage = np.clip(np.sum(wdp_shifted, axis=0), 0.0, 1.0) + edge_w_dp = 1.0 - coverage + + merged = np.zeros((Rout, Cout, Hp, Wp), dtype=np.float64) + + dp_local = np.zeros((H, W), dtype=np.float64) + dp_pad = np.zeros((Hp, Wp), dtype=np.float64) + dp_shifted_tmp = np.zeros((Hp, Wp), dtype=np.float64) + num_tmp = np.zeros((Hp, Wp), dtype=np.float64) + den_tmp = np.zeros((Hp, Wp), dtype=np.float64) + + for ro in tqdm(range(Rout), desc="Merging (rows)"): + r_base = ro - real_space_padding + for co in range(Cout): + c_base = co - real_space_padding + + num_tmp.fill(0.0) + den_tmp.fill(0.0) + max_wi = 0.0 + + for i in range(n): + r_in = r_base - rs_shifts[i, 0] + c_in = c_base - rs_shifts[i, 1] + + r0 = int(np.floor(r_in)) + c0 = int(np.floor(c_in)) + if r0 < 0 or r0 >= Rs - 1 or c0 < 0 or c0 >= Cs - 1: + continue + + dr = r_in - r0 + dc = c_in - c0 + + w00 = (1.0 - dr) * (1.0 - dc) + w10 = dr * (1.0 - dc) + w01 = (1.0 - dr) * dc + w11 = dr * dc + + wi = ( + w00 * w_rs[r0, c0] + + w10 * w_rs[r0 + 1, c0] + + w01 * w_rs[r0, c0 + 1] + + w11 * w_rs[r0 + 1, c0 + 1] + ) + if wi <= 0.0: + continue + if wi > max_wi: + max_wi = wi + + a = arrays[i] + dp_local[:] = ( + w00 * a[r0, c0] + + w10 * a[r0 + 1, c0] + + w01 * a[r0, c0 + 1] + + w11 * a[r0 + 1, c0 + 1] + ) + + dp_pad.fill(0.0) + dp_pad[rp0 : rp0 + H, cp0 : cp0 + W] = dp_local * w_dp + + if method == "fourier": + ramp = ramps[i] + dp_shifted_tmp[:] = np.fft.ifft2(np.fft.fft2(dp_pad) * ramp).real + else: + dp_shifted_tmp[:] = ndi_shift( + dp_pad, + shift=(dp_shifts[i, 0], dp_shifts[i, 1]), + order=1, + mode="constant", + cval=0.0, + prefilter=False, + ) + + num_tmp += wi * dp_shifted_tmp + den_tmp += wi * wdp_shifted[i] + + if max_wi <= 0.0: + merged[ro, co] = 0.0 + continue + + num = num_tmp + edge_w_dp * pad_val_dp + den = den_tmp + edge_w_dp + + out = np.empty_like(num) + np.divide(num, den, out=out, where=den != 0.0) + out[den == 0.0] = 0.0 + merged[ro, co] = out + + self.im_bf_merged = np.mean(merged, axis=(2, 3), dtype=np.float64) + self.dp_mean_merged = np.mean(merged, axis=(0, 1), dtype=np.float64) + + if np.issubdtype(dtype_out, np.integer): + info = np.iinfo(dtype_out) + dmin = float(info.min) + dmax = float(info.max) + + merged_f = merged + + if scale_output: + peak = float(np.max(merged_f)) + if peak <= 0.0: + merged_scaled = merged_f + else: + merged_scaled = merged_f * (dmax / peak) + + if np.issubdtype(dtype_out, np.unsignedinteger): + lo, hi = 0.0, dmax + else: + lo, hi = dmin, dmax + + merged_out = np.rint(np.clip(merged_scaled, lo, hi)).astype(dtype_out) + else: + below = float(np.min(merged_f)) + above = float(np.max(merged_f)) + if below < dmin or above > dmax: + warnings.warn( + f"Output overflow for dtype {dtype_out}: data range [{below}, {above}] exceeds " + f"[{dmin}, {dmax}]. Values will be clipped.", + RuntimeWarning, + ) + merged_out = np.rint(np.clip(merged_f, dmin, dmax)).astype(dtype_out) + else: + merged_out = merged.astype(dtype_out, copy=False) + + dataset_merged = Dataset4dstem.from_array(array=merged_out) + dataset_merged.im_bf_merged = self.im_bf_merged + dataset_merged.dp_mean_merged = self.dp_mean_merged + + if plot_result: + show_2d( + [[self.im_bf_merged, self.dp_mean_merged]], + title=[["Merged Bright Field", "Merged Mean Diffraction Pattern"]], + **plot_kwargs, + ) + + return dataset_merged + + +def shift_images( + images, + shifts_rc, + edge_blend: float = 8.0, + padding=None, + pad_val: str | float = 0.0, + shift_method: str = "bilinear", +): + """ + Shift and blend a stack of 2D images into a common padded canvas. + + Parameters + ---------- + images + Sequence of (H, W) arrays. + shifts_rc + Array-like of shape (n, 2) with (row, col) shifts for each image. + edge_blend + Tukey taper width in pixels for image blending. + padding + Output padding. If None, set from max shift and edge_blend. + pad_val + Fill value outside support ('min','max','mean','median' or float). + shift_method + 'bilinear' or 'fourier'. + + Returns + ------- + np.ndarray + Blended image of shape (H + 2*padding, W + 2*padding). + """ + images = [np.asarray(im, dtype=float) for im in images] + if len(images) == 0: + raise ValueError("images must be non-empty") + + H, W = images[0].shape + for im in images: + if im.shape != (H, W): + raise ValueError("all images must have the same shape") + + shifts_rc = np.asarray(shifts_rc, dtype=float) + if shifts_rc.shape != (len(images), 2): + raise ValueError("shifts_rc must have shape (len(images), 2)") + + if isinstance(pad_val, str): + s = pad_val.strip().lower() + v = np.stack(images, axis=0).reshape(-1) + if s == "min": + pad_val_f = float(np.min(v)) + elif s == "max": + pad_val_f = float(np.max(v)) + elif s == "mean": + pad_val_f = float(np.mean(v)) + elif s == "median": + pad_val_f = float(np.median(v)) + else: + raise ValueError("pad_val must be a float or one of {'min','max','mean','median'}") + else: + pad_val_f = float(pad_val) + + if padding is None: + max_shift = float(np.max(np.abs(shifts_rc))) if shifts_rc.size else 0.0 + padding = int(np.ceil(max_shift + float(edge_blend))) + 2 + padding = int(padding) + + alpha_r = min(1.0, 2.0 * float(edge_blend) / float(H)) if edge_blend > 0 else 0.0 + alpha_c = min(1.0, 2.0 * float(edge_blend) / float(W)) if edge_blend > 0 else 0.0 + w = tukey(H, alpha=alpha_r)[:, None] * tukey(W, alpha=alpha_c)[None, :] + w = w.astype(float, copy=False) + + Hp = H + 2 * padding + Wp = W + 2 * padding + + stack_w = np.zeros((len(images), Hp, Wp), dtype=float) + stack = np.zeros_like(stack_w) + + r0 = padding + c0 = padding + stack_w[:, r0 : r0 + H, c0 : c0 + W] = w[None, :, :] + for ind, im in enumerate(images): + stack[ind, r0 : r0 + H, c0 : c0 + W] = im * w + + method = str(shift_method).strip().lower() + if method not in {"bilinear", "fourier"}: + raise ValueError("shift_method must be 'bilinear' or 'fourier'") + + if method == "fourier": + kr = np.fft.fftfreq(Hp)[:, None] + kc = np.fft.fftfreq(Wp)[None, :] + for ind in range(len(images)): + dr, dc = shifts_rc[ind, 0], shifts_rc[ind, 1] + ramp = np.exp(-2j * np.pi * (kr * dr + kc * dc)) + + F = np.fft.fft2(stack[ind]) + stack[ind] = np.fft.ifft2(F * ramp).real + + Fw = np.fft.fft2(stack_w[ind]) + stack_w[ind] = np.fft.ifft2(Fw * ramp).real + stack_w[ind] = np.clip(stack_w[ind], 0.0, 1.0) + else: + for ind in range(len(images)): + stack[ind] = ndi_shift( + stack[ind], + shift=(shifts_rc[ind, 0], shifts_rc[ind, 1]), + order=1, + mode="constant", + cval=0.0, + prefilter=False, + ) + stack_w[ind] = ndi_shift( + stack_w[ind], + shift=(shifts_rc[ind, 0], shifts_rc[ind, 1]), + order=1, + mode="constant", + cval=0.0, + prefilter=False, + ) + stack_w[ind] = np.clip(stack_w[ind], 0.0, 1.0) + + edge_w = np.clip(1.0 - np.sum(stack_w, axis=0), 0.0, 1.0) + + num = np.sum(stack, axis=0) + edge_w * pad_val_f + den = np.sum(stack_w, axis=0) + edge_w + + out = np.empty_like(num) + np.divide(num, den, out=out, where=den != 0.0) + out[den == 0.0] = 0.0 + + return out diff --git a/src/quantem/diffraction/model_fitting.py b/src/quantem/diffraction/model_fitting.py index 41e6b8c6..b148fa91 100644 --- a/src/quantem/diffraction/model_fitting.py +++ b/src/quantem/diffraction/model_fitting.py @@ -4,10 +4,13 @@ import numpy as np import torch +import torch.nn.functional as F from scipy.ndimage import shift as ndi_shift from scipy.signal.windows import tukey +from tqdm import tqdm from quantem.core.datastructures import Dataset2d, Dataset3d, Dataset4d, Dataset4dstem +from quantem.core.fitting.background import DCBackground, GaussianBackground from quantem.core.fitting.base import ( AdditiveRenderModel, FitBase, @@ -17,9 +20,9 @@ ) from quantem.core.fitting.diffraction import DiskTemplate, SyntheticDiskLattice from quantem.core.io.serialize import AutoSerialize -from quantem.core.ml.optimizer_mixin import OptimizerType, SchedulerType from quantem.core.utils.imaging_utils import cross_correlation_shift from quantem.diffraction.model_fitting_visualizations import ModelDiffractionVisualizations +from quantem.diffraction.strain import StrainMap def _parse_init(value: float | int | Sequence[float | int | None], *, name: str) -> float: @@ -54,9 +57,19 @@ def __init__(self, dataset: Any, _token: object | None = None): self.state_mean_refined: dict[str, torch.Tensor] | None = None self.mean_refined: bool = False + self.state_individual_refined: np.ndarray | None = None + self.individual_refined: bool = False + # Misc metadata self.metadata: dict[str, Any] = {} + self.u_ref: np.ndarray | None = None + self.v_ref: np.ndarray | None = None + self.u_array: np.ndarray | None = None + self.v_array: np.ndarray | None = None + + self.real_space = False + @classmethod def from_dataset( cls, dataset: Dataset2d | Dataset3d | Dataset4d | Dataset4dstem | Any @@ -312,10 +325,65 @@ def preprocess( upsample_factor: int = 32, max_shift: float | None = None, shift_order: int = 1, + gamma: float = 0.5, + mode: str = "linear", + rows=None, + cols = None, ) -> "ModelDiffraction": arr = np.asarray(self.dataset.array) + self.mask = np.ones(arr.shape[:2]) + if arr.ndim < 2: raise ValueError("dataset.array must have at least 2 dimensions.") + mode_in = mode.strip().lower() + if mode_in in {"linear", "patterson", "paterson", "acf", "autocorrelation"}: + mode_norm = "linear" + elif mode_in in {"log", "cepstrum", "cepstral"}: + mode_norm = "log" + elif mode_in in {"gamma", "power", "sqrt"}: + mode_norm = "gamma" + else: + raise ValueError( + "mode must be 'linear', 'log', or 'gamma' (aliases: 'patterson'->'linear', 'cepstrum'/'cepstral'->'log')." + ) + + self.metadata["mode"] = mode_norm + if mode_norm == "gamma": + self.metadata["gamma"] = gamma + + if mode_norm == "linear": + arr = arr + elif mode_norm == "log": + arr = np.log1p(arr) + elif mode_norm == "gamma": + arr = np.power(np.clip(arr, 0.0, None), self.metadata["gamma"]) + else: + raise RuntimeError("Unreachable: normalized mode mapping failed.") + self.dataset.array = np.asarray(arr) + + if rows is None and cols is None: + rows = range(self.dataset.shape[0]) + cols = range(self.dataset.shape[1]) + elif rows is not None and cols is None: + cols = range(self.dataset.shape[1]) + elif rows is None and cols is not None: + rows = range(self.dataset.shape[0]) + else: + rows = rows + cols = cols + + if isinstance(rows, int): + rows = np.array([rows]).astype(int) + else: + rows = np.asarray(rows).astype(int) + + if isinstance(cols, int): + cols = np.array([cols]).astype(int) + else: + cols = np.asarray(cols).astype(int) + + arr = arr[np.ix_(rows, cols)] + h, w = arr.shape[-2], arr.shape[-1] self.index_shape = tuple(arr.shape[:-2]) @@ -429,10 +497,11 @@ def fit_mean_diffraction_pattern( *, n_steps: int = 200, reset: bool | Literal["initialized", "mean_refined"] = False, - optimizer_params: OptimizerType | dict | None = None, - scheduler_params: SchedulerType | dict | None = None, + optimizer_params: dict | None = None, + scheduler_params: dict | None = None, constraint_weight: float = 1.0, constraint_params: dict[str, Any] | None = None, + constraint_config_params: dict[str, Any] | None = None, progress: bool = True, ) -> "ModelDiffraction": """ @@ -484,10 +553,12 @@ def fit_mean_diffraction_pattern( raise ValueError("reset must be False, True, 'initialized', or 'mean_refined'.") self.fit_render( + # target=torch.tensor(self.dataset[30,30].array.astype("float32")), target=self.target_mean, n_steps=int(n_steps), constraint_weight=float(constraint_weight), constraint_params=constraint_params, + constraint_config_params=constraint_config_params, optimizer_params=optimizer_params, scheduler_params=scheduler_params, progress=bool(progress), @@ -501,7 +572,10 @@ def fit_mean_diffraction_pattern( def reset( self, - reset_to: Literal["initialized", "mean_refined"] = "mean_refined", + reset_to: Literal["initialized", "mean_refined", "individual"] = "mean_refined", + reset_history: bool = True, + individual_row: int = 0, + individual_col: int = 0, ) -> "ModelDiffraction": if reset_to == "initialized": state = self.state_initialized @@ -509,22 +583,578 @@ def reset( raise RuntimeError( "initialized state is unavailable. Call .define_model(...) first." ) - self._clear_fit_history_all() + if reset_history: + self._clear_fit_history_all() elif reset_to == "mean_refined": state = self.state_mean_refined if state is None: raise RuntimeError( "mean_refined state is unavailable. Run .fit_mean_diffraction_pattern(...) first." ) - mean_hist = self.fit_history.get("mean") - self._clear_fit_history_all() - if mean_hist is not None: - self.fit_history["mean"] = mean_hist + if reset_history: + mean_hist = self.fit_history.get("mean") + self._clear_fit_history_all() + if mean_hist is not None: + self.fit_history["mean"] = mean_hist + elif reset_to == "individual": + if self.state_individual_refined is None: + raise ValueError("individual states is unavalible. Run fit_individual_diffraction_pattern(....) first") + if (individual_row >= self.state_individual_refined.shape[0]) or (individual_col >= self.state_individual_refined.shape[1]): + raise ValueError("row and column values not in range") + state = self.state_individual_refined[individual_row, individual_col] + if reset_history: + self._clear_fit_history_all() else: - raise ValueError("reset_to must be 'initialized' or 'mean_refined'.") + raise ValueError("reset_to must be 'initialized' or 'mean_refined' or 'individual'.") self._load_model_state_dict_copy(state) return self + + def fit_individual_diffraction_pattern( + self, + *, + rows=None, + cols = None, + n_steps: int = 200, + reset: bool | Literal["initialized", "mean_refined"], + optimizer_params: dict | None = None, + scheduler_params: dict | None = None, + constraint_weight: float = 1.0, + constraint_params: dict[str, Any] | None = None, + constraint_config_params: dict[str, Any] | None = None, + progress: bool = True, + batch_size: int | None = None, + frozen_components: list[str] | str | None = None, + sample_trainability: dict[str, Any] | None = None, + **_compat_kwargs: Any, + ) -> "ModelDiffraction": + if batch_size is not None: + return self.fit_individual_diffraction_pattern_batched( + rows=rows, + cols=cols, + batch_size=int(batch_size), + n_steps=int(n_steps), + reset=cast(Literal["initialized", "mean_refined"], reset), + optimizer_params=optimizer_params, + scheduler_params=scheduler_params, + constraint_weight=float(constraint_weight), + constraint_params=constraint_params, + constraint_config_params=constraint_config_params, + frozen_components=frozen_components, + sample_trainability=sample_trainability, + progress=progress, + ) + + if self.model is None or self.ctx is None or self.target_mean is None: + raise RuntimeError("Call .define_model(...) first.") + if reset not in ("initialized", "mean_refined"): + raise ValueError("reset must be initialized', or 'mean_refined'.") + self.reset(reset_to=cast(Literal["initialized", "mean_refined"], reset)) + if not isinstance(self.dataset, Dataset4d): + raise ValueError("Dataset must be Dataset4d or Dataset4dstem for fit_individual_diffraction_pattern") + + scan_r = self.dataset.shape[0] + scan_c = self.dataset.shape[1] + + if rows is None and cols is None: + rows = range(self.dataset.shape[0]) + cols = range(self.dataset.shape[1]) + elif rows is not None and cols is None: + cols = range(self.dataset.shape[1]) + elif rows is None and cols is not None: + rows = range(self.dataset.shape[0]) + else: + rows = rows + cols = cols + + if isinstance(rows, int): + rows = np.array([rows]).astype(int) + else: + rows = np.asarray(rows).astype(int) + + if isinstance(cols, int): + cols = np.array([cols]).astype(int) + else: + cols = np.asarray(cols).astype(int) + + self.state_individual_refined = np.full(shape=(scan_r, scan_c), fill_value=None, dtype=object) + + if progress: + pbar = tqdm(total=rows.shape[0] * cols.shape[0], desc="Fit individual") + + for r in rows: + for c in cols: + # print(self.dataset.array[r,c].shape) + self.reset(reset_to=cast(Literal["initialized", "mean_refined"], reset), reset_history=False) + self.fit_render( + target=torch.as_tensor(self.dataset.array[r,c],device=self.ctx.device,dtype=self.ctx.dtype), + n_steps=int(n_steps), + constraint_weight=float(constraint_weight), + constraint_params=constraint_params, + optimizer_params=optimizer_params, + scheduler_params=scheduler_params, + progress=False, + run_key=f"individual_{r}_{c}", + ) + + s_fit = self._get_model_state_dict_copy() + self.state_individual_refined[r,c] = self._clone_state_dict(s_fit) + # self.reset(reset_to=cast(Literal["initialized", "mean_refined"], reset), reset_history=False) + if progress: + pbar.update(1) + if progress: + pbar.close() + + self.individual_refined=True + return self + + def fit_individual_diffraction_pattern_batched( + self, + *, + rows: Any = None, + cols: Any = None, + batch_size: int = 16, + n_steps: int = 200, + reset: Literal["initialized", "mean_refined"] = "mean_refined", + optimizer_params: dict | None = None, + scheduler_params: dict | None = None, + constraint_weight: float = 1.0, + constraint_params: dict[str, Any] | None = None, + constraint_config_params: dict[str, Any] | None = None, + frozen_components: list[str] | str | None = None, + sample_trainability: dict[str, Any] | None = None, + progress: bool = True, + ) -> "ModelDiffraction": + """ + Per-pattern fit, vectorized across a batch dimension on a single GPU. + + See ``fit_individual_diffraction_pattern`` for argument semantics. The + batched version runs ``batch_size`` patterns in parallel per optimizer + step, with per-sample stacked parameters and per-sample Adam moments. + + Soft constraint losses are added per-sample via + ``model.total_constraint_loss``-style accumulation (currently only + ``DiskTemplate`` contributes). Hard constraints and parameter bounds + are still enforced between steps. + + Parameters + ---------- + scheduler_params : dict | None, optional + Either a single spec dict (``{"type": "cosine", "t_max": N}``) + applied to every parameter group, or a dict keyed by component + name/class. Supported types: ``none``, ``cosine`` (a.k.a. + ``cosine_annealing``), ``linear``, ``exponential``. ``plateau`` + and ``cyclic`` fall back to constant LR with a warning. + constraint_weight : float + Global multiplier on the per-sample soft-constraint loss. + constraint_config_params : dict | None + Applied to the model once before the fit (``apply_constraint_configs``). + frozen_components : list[str] | str | None + Component name(s) whose parameters are frozen across all samples. + Resolved via ``AdditiveRenderModel._component_constraint_name``; + both instance names (``"disk0"``) and class names (``"DiskTemplate"``) + work. + sample_trainability : dict[str, ArrayLike] | None + Per-canonical-key (e.g. ``"disk.template_raw"``) boolean array of + length ``len(positions)``; ``False`` entries freeze that parameter + for the matching scan position. + """ + if self.model is None or self.ctx is None or self.target_mean is None: + raise RuntimeError("Call .define_model(...) first.") + if not isinstance(self.dataset, Dataset4d): + raise ValueError("Dataset must be Dataset4d or Dataset4dstem.") + if reset not in ("initialized", "mean_refined"): + raise ValueError("reset must be 'initialized' or 'mean_refined'.") + + # Choose initialization state and load into the live model so we can + # read current parameter values + constraint settings off the modules. + if reset == "mean_refined": + if self.state_mean_refined is None: + raise RuntimeError("mean_refined state is unavailable. Run fit_mean_diffraction_pattern first.") + init_state = self._clone_state_dict(self.state_mean_refined) + else: + if self.state_initialized is None: + raise RuntimeError("initialized state is unavailable. Call define_model first.") + init_state = self._clone_state_dict(self.state_initialized) + self._load_model_state_dict_copy(init_state) + + if constraint_params is not None: + self.model.apply_constraint_params(constraint_params, strict=True) + if constraint_config_params is not None: + self.model.apply_constraint_configs(constraint_config_params, strict=True) + + scan_r = int(self.dataset.shape[0]) + scan_c = int(self.dataset.shape[1]) + rows_arr, cols_arr = _resolve_rows_cols_for_batched(rows, cols, scan_r, scan_c) + positions: list[tuple[int, int]] = [(int(r), int(c)) for r in rows_arr for c in cols_arr] + if len(positions) == 0: + return self + + if self.state_individual_refined is None or self.state_individual_refined.shape != (scan_r, scan_c): + self.state_individual_refined = np.full(shape=(scan_r, scan_c), fill_value=None, dtype=object) + + ctx = self.ctx + components_list = list(self.model.components) + plan = _BatchedPlan.from_model(self.model, components_list, optimizer_params or {}) + plan.set_scheduler_params(scheduler_params) + + # Build per-position trainability arrays (default True everywhere). + n_positions = len(positions) + is_trainable_pos: dict[str, np.ndarray] = { + key: np.ones(n_positions, dtype=bool) for key in plan.lrs + } + frozen_keys = plan.resolve_component_keys(frozen_components) + for key in frozen_keys: + is_trainable_pos[key][:] = False + hard_skip_keys: set[str] = set(frozen_keys) + if sample_trainability is not None: + for key, arr in sample_trainability.items(): + if key not in is_trainable_pos: + raise KeyError( + f"sample_trainability key '{key}' is not a known stacked-param. " + f"Known: {sorted(is_trainable_pos)}" + ) + mask = np.asarray(arr, dtype=bool).reshape(-1) + if mask.shape != (n_positions,): + raise ValueError( + f"sample_trainability['{key}'] has shape {mask.shape}, " + f"expected ({n_positions},)." + ) + is_trainable_pos[key] = mask + + loss_fn = self.loss_fn + + total_steps = len(positions) * n_steps + pbar = tqdm(total=total_steps, desc="Fit individual (batched)", disable=not progress) + + for start in range(0, len(positions), int(batch_size)): + chunk = positions[start:start + int(batch_size)] + B = len(chunk) + + targets = torch.stack( + [ + torch.as_tensor(self.dataset.array[r, c], device=ctx.device, dtype=ctx.dtype) + for (r, c) in chunk + ], + dim=0, + ) + + stacked = plan.build_stacked_params(B) + + adam_state: dict[str, dict[str, torch.Tensor]] = { + name: { + "m": torch.zeros_like(p.detach()), + "v": torch.zeros_like(p.detach()), + } + for name, p in stacked.items() + } + + # Per-chunk trainability slice as (B,) bool tensors on device. + chunk_slice = slice(start, start + B) + chunk_trainable: dict[str, torch.Tensor] = {} + mixed_trainable_keys: list[str] = [] + init_snapshots: dict[str, torch.Tensor] = {} + for key in stacked: + if key in is_trainable_pos: + mask_np = is_trainable_pos[key][chunk_slice] + mask = torch.as_tensor(mask_np, device=ctx.device, dtype=torch.bool) + chunk_trainable[key] = mask + # A "mixed" key has some frozen samples and some trainable. + # We snapshot the init value so frozen samples can be + # restored after hard constraints / Adam. + if not bool(mask.all()) and not bool((~mask).all()): + mixed_trainable_keys.append(key) + init_snapshots[key] = stacked[key].detach().clone() + + for step in range(int(n_steps)): + pred = plan.batched_forward(ctx, stacked) + # Per-sample fidelity loss summed → scalar with per-sample grads + diff2 = (pred.float() - targets.float()) + # Match SqrtMSELoss behavior approximately when loss_fn is SqrtMSELoss: + # gamma-power transform of (x - min(x) + 1), per-sample independently. + from quantem.core.fitting.base import SqrtMSELoss, LogMSELoss + if isinstance(loss_fn, SqrtMSELoss): + gamma = float(loss_fn.gamma) + eps = 1.0 + pred_min = pred.amin(dim=(1, 2), keepdim=True) + tgt_min = targets.amin(dim=(1, 2), keepdim=True) + pred_mod = (pred - pred_min + eps) ** gamma + tgt_mod = (targets - tgt_min + eps) ** gamma + per_sample_loss = ((pred_mod - tgt_mod) ** 2).mean(dim=(1, 2)) + elif isinstance(loss_fn, LogMSELoss): + per_sample_loss = ((torch.log1p(pred) - torch.log1p(targets)) ** 2).mean(dim=(1, 2)) + else: + per_sample_loss = (diff2 * diff2).mean(dim=(1, 2)) + + # Add per-sample soft constraint loss. + if float(constraint_weight) != 0.0: + constraint_per_sample = plan.batched_constraint_loss(ctx, stacked) + per_sample_loss = per_sample_loss + float(constraint_weight) * constraint_per_sample + + total_loss = per_sample_loss.sum() + + grads_list = list(torch.autograd.grad(total_loss, list(stacked.values()))) + + # Apply trainability masks: zero out per-sample gradient entries + # for frozen samples/params before the Adam step. + for i, (name, g) in enumerate(zip(stacked.keys(), grads_list)): + if g is None: + continue + mask_b = chunk_trainable.get(name) + if mask_b is None: + continue + if bool(mask_b.all()): + continue + # Broadcast (B,) over remaining dims. + view_shape = (g.shape[0],) + (1,) * (g.ndim - 1) + grads_list[i] = g * mask_b.view(view_shape).to(dtype=g.dtype) + + t = step + 1 + current_lrs = plan.lr_at_step(t, int(n_steps)) + _adam_step_inplace(stacked, tuple(grads_list), adam_state, current_lrs, chunk_trainable, t) + + with torch.no_grad(): + plan.apply_hard_constraints(stacked, skip_keys=hard_skip_keys) + for key in stacked: + if key not in chunk_trainable: + continue + mask = chunk_trainable[key] + if bool(mask.all()): + continue + if key not in init_snapshots: + init_snapshots[key] = stacked[key].detach().clone() + view_shape = (mask.shape[0],) + (1,) * (stacked[key].ndim - 1) + keep = mask.view(view_shape) + stacked[key].data.copy_( + torch.where(keep, stacked[key].data, init_snapshots[key]) + ) + + pbar.update(B) + + # Unstack each sample back into a state_dict and store + for b, (r, c) in enumerate(chunk): + sample_state = plan.build_sample_state_dict(init_state, stacked, b) + self.state_individual_refined[r, c] = sample_state + + pbar.close() + self.individual_refined = True + return self + + def get_individual_uv_vectors(self) -> "ModelDiffraction": + scan_r = self.dataset.shape[0] + scan_c = self.dataset.shape[1] + + self.u_array = np.empty(shape=(scan_r, scan_c, 2)) + self.v_array = np.empty(shape=(scan_r, scan_c, 2)) + if self.state_individual_refined is None: + raise RuntimeError("Call .fit_individual_diffraction_pattern(...) on all patterns first.") + for r in range(scan_r): + for c in range(scan_c): + pos_state = self.state_individual_refined[r,c] + if pos_state is None: + self.u_array[r,c,:] = None + self.v_array[r,c,:] = None + for key in pos_state.keys(): + key_parts = key.split('.') + if(key_parts[-1] == 'u_row'): + self.u_array[r,c,0] = pos_state[key] + if(key_parts[-1] == 'u_col'): + self.u_array[r,c,1] = pos_state[key] + if(key_parts[-1] == 'v_row'): + self.v_array[r,c,0] = pos_state[key] + if(key_parts[-1] == 'v_col'): + self.v_array[r,c,1] = pos_state[key] + + return self + + def render_individual_pattern(self, row, col): + if self.state_individual_refined is None: + raise RuntimeError( + "individual_refined_state is unavalible. Run fit_individual_diffraction_pattern(...) first." + ) + if self.dataset.shape[0] <= row or self.dataset.shape[1] <= col: + raise ValueError("individual row or column outside bounds of dataset") + if row < 0 or col < 0: + raise ValueError("individual row or column outside bounds of dataset") + if self.state_individual_refined[row, col] is None: + raise RuntimeError( + "individual_refined_state is not avalible for given row and column. Run fit_individual_diffraction_pattern(...) for that row and column." + ) + return self._render_state_array(self.state_individual_refined[row, col]) + + + def create_mask( + self, + use_radial_method: bool = False, + exclusion_radius_fraction: float = 0.1, + plot: bool = True, + figsize: tuple[float, float] = (5, 4), + ): + """Compute the per-position weight :attr:`mask` from the fitted lattice signal. + + Builds a ``(scan_row, scan_col)`` weight in ``[0, 1]`` measuring the lattice + signal at each position -- the model-fitting analogue of + :attr:`BraggVectors.mask_weight` and the cepstral ``mask_weight``. It is passed to + :meth:`calculate_strain_map`, where higher-weight positions contribute more to the + reference lattice and weak/vacuum positions are down-weighted. + + The raw signal (summed intensity of the non-central fitted disks, or the + diffracted intensity outside a central disk for the radial method) is min-max + normalized to ``[0, 1]``; no contrast windowing or smoothing is applied. Display + contrast is applied via the ``mask_range`` argument of + :meth:`StrainMap.plot_strain`, :meth:`StrainMap.update_reference`, and + :meth:`StrainMap.estimate_strain_precision` (e.g. ``mask_range=(0.37, 0.5)``): + weights at/below ``low`` render black, at/above ``high`` render full color. The + weight-map plot shows the distribution so an appropriate ``mask_range`` can be + chosen. + + Parameters + ---------- + use_radial_method : bool, default=False + If ``True``, weight by the total diffracted intensity *outside* a central + disk (radius ``exclusion_radius_fraction`` of the detector width). If + ``False`` (default, recommended), weight by the summed intensity of the + fitted non-central lattice disks -- requires + :meth:`fit_individual_diffraction_pattern` to have been run. + exclusion_radius_fraction : float, default=0.1 + Central-disk radius (fraction of detector width) excluded by the radial + method. + plot : bool, default=True + If ``True``, show the resulting per-position weight map (so a separate + ``plt.imshow(mask)`` cell is unnecessary). + figsize : tuple of float, default=(5, 4) + Figure size in inches for the weight-map plot. + + Returns + ------- + ModelDiffraction + ``self``, with :attr:`mask` set. + """ + if not isinstance(self.dataset, (Dataset4d, Dataset4dstem)): + raise ValueError("Dataset must be Dataset4d or Dataset4dstem.") + + scan_r = self.dataset.shape[0] + scan_c = self.dataset.shape[1] + self.i0_sum_array = np.zeros(shape=(scan_r, scan_c)) + + if use_radial_method: + center_y, center_x = np.array(self.dataset.shape[:-2]) / 2 + y, x = np.ogrid[: self.dataset.shape[-2], : self.dataset.shape[-1]] + radius_map = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2) + exclusion_radius = exclusion_radius_fraction * self.dataset.shape[-1] + outside_mask = radius_map > exclusion_radius + for r in range(scan_r): + for c in range(scan_c): + self.i0_sum_array[r, c] = np.sum(self.dataset.array[r, c][outside_mask]) + else: + if self.state_individual_refined is None: + raise RuntimeError("Call .fit_individual_diffraction_pattern(...) first.") + for r in range(scan_r): + for c in range(scan_c): + pos_state = self.state_individual_refined[r, c] + if pos_state is None: + continue + i0_raw = None + uv_indices = None + for key in pos_state.keys(): + if key.endswith("i0_raw"): + i0_raw = pos_state[key].cpu().numpy() + if key.endswith("uv_indices"): + uv_indices = pos_state[key].cpu().numpy() + if i0_raw is None or uv_indices is None: + continue + is_not_center = ~((uv_indices[:, 0] == 0) & (uv_indices[:, 1] == 0)) + self.i0_sum_array[r, c] = np.sum(i0_raw[is_not_center]) + + # Min-max normalization to [0, 1], no contrast windowing (set mask_range at + # display). Subtracting the floor -- rather than only dividing by the max -- keeps + # the weight from saturating near 1.0 when every position carries a baseline + # lattice intensity, and matches the cepstral _amplitude_mask_weight. Degenerate + # (constant / non-finite) input falls back to uniform full weight. + lo = np.nanmin(self.i0_sum_array) + hi = np.nanmax(self.i0_sum_array) + if np.isfinite(lo) and np.isfinite(hi) and hi > lo: + self.mask = (self.i0_sum_array - lo) / (hi - lo) + else: + self.mask = np.ones_like(self.i0_sum_array) + + if plot: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 1, figsize=figsize) + handle = ax.imshow(self.mask, cmap="gray", vmin=0.0, vmax=1.0) + ax.set_title("Per-position lattice weight (mask)") + ax.set_xlabel("scan column") + ax.set_ylabel("scan row") + fig.colorbar(handle, ax=ax, fraction=0.046, pad=0.04) + fig.tight_layout() + + return self + + def calculate_strain_map( + self, + u_ref: np.ndarray | None = None, + v_ref: np.ndarray | None = None, + mask: np.ndarray | None = None, + ) -> StrainMap: + """Build a :class:`StrainMap` from the fitted per-position lattice vectors. + + Mirrors :meth:`BraggVectors.calculate_strain_map` and + :meth:`StrainMapAutocorrelation.calculate_strain_map` so the downstream strain + cells (``plot_strain``, ``update_reference``, ``estimate_strain_precision``) are + identical across the correlation, cepstral, and model-fitting workflows. + + Parameters + ---------- + u_ref : np.ndarray, optional + ``(2,)`` reference for the first lattice vector. Defaults to the median over + the scan inside :class:`StrainMap`. + v_ref : np.ndarray, optional + ``(2,)`` reference for the second lattice vector. Defaults to the median over + the scan inside :class:`StrainMap`. + mask : np.ndarray, optional + ``(scan_row, scan_col)`` per-position weighting used when computing the + reference lattice. Defaults to :attr:`mask` from :meth:`create_mask` (the + lattice signal strength), so strong, well-fit positions dominate the + reference. + + Returns + ------- + StrainMap + A strain map initialized from the fitted lattice vectors. + """ + if self.u_array is None or self.v_array is None: + self.get_individual_uv_vectors() + if not isinstance(self.dataset, (Dataset4d, Dataset4dstem)): + raise ValueError("Dataset must be Dataset4d or Dataset4dstem.") + + if mask is None: + mask = self.mask + + default_units = None + default_sampling = None + if hasattr(self.dataset, 'units'): + if isinstance(self.dataset.units, (tuple, list)): + default_units = str(self.dataset.units[0]) + else: + default_units = str(self.dataset.units) + if hasattr(self.dataset, 'sampling'): + if isinstance(self.dataset.sampling, (tuple, list, np.ndarray)): + default_sampling = float(self.dataset.sampling[0]) + else: + default_sampling = float(self.dataset.sampling) + + return StrainMap( + u_array = self.u_array, + v_array = self.v_array, + ds_shape = self.dataset.shape, + real_space = self.real_space, + u_ref = u_ref, + v_ref = v_ref, + mask = mask, + ds_sampling=default_sampling, + ds_units = default_units, + ) @property def render_mean_refined(self) -> np.ndarray: @@ -533,3 +1163,706 @@ def render_mean_refined(self) -> np.ndarray: "mean_refined state is unavailable. Run .fit_mean_diffraction_pattern(...) first." ) return self._render_state_array(self.state_mean_refined) + + +def _resolve_rows_cols_for_batched( + rows: Any, cols: Any, scan_r: int, scan_c: int +) -> tuple[np.ndarray, np.ndarray]: + if rows is None: + rows = range(scan_r) + if cols is None: + cols = range(scan_c) + if isinstance(rows, int): + rows_arr = np.array([rows], dtype=int) + else: + rows_arr = np.asarray(list(rows), dtype=int) + if isinstance(cols, int): + cols_arr = np.array([cols], dtype=int) + else: + cols_arr = np.asarray(list(cols), dtype=int) + return rows_arr, cols_arr + + +def _lr_for_component( + component: RenderComponent, + component_idx: int, + optimizer_params: dict[str, Any], + model: AdditiveRenderModel, +) -> float: + name = model._component_constraint_name(component, component_idx) + if name in optimizer_params: + return float(optimizer_params[name].get("lr", component.DEFAULT_LR)) + class_name = component.__class__.__name__ + if class_name in optimizer_params: + return float(optimizer_params[class_name].get("lr", component.DEFAULT_LR)) + return float(component.DEFAULT_LR) + +def _adam_step_inplace( + stacked: dict[str, torch.Tensor], + grads: tuple[torch.Tensor, ...], + adam_state: dict[str, dict[str, torch.Tensor]], + lrs: dict[str, float], + chunk_trainable: dict[str, torch.Tensor], # NEW parameter + t: int, + beta1: float = 0.9, + beta2: float = 0.999, + eps: float = 1e-8, +) -> None: + """ + In-place Adam step with per-sample trainability support. + + Frozen samples (mask=False) will: + 1. Not accumulate gradient statistics in moments + 2. Not receive parameter updates + """ + bias1 = 1.0 - beta1 ** t + bias2 = 1.0 - beta2 ** t + + with torch.no_grad(): + for (name, p), g in zip(stacked.items(), grads): + if g is None: + continue + + st = adam_state[name] + mask_b = chunk_trainable.get(name) + + if mask_b is not None and not bool(mask_b.all()): + view_shape = (g.shape[0],) + (1,) * (g.ndim - 1) + mask_expanded = mask_b.view(view_shape).to(dtype=g.dtype) + g_masked = g * mask_expanded + + st["m"].mul_(beta1).add_(g_masked, alpha=1.0 - beta1) + st["v"].mul_(beta2).addcmul_(g_masked, g_masked, value=1.0 - beta2) + else: + st["m"].mul_(beta1).add_(g, alpha=1.0 - beta1) + st["v"].mul_(beta2).addcmul_(g, g, value=1.0 - beta2) + + m_hat = st["m"] / bias1 + v_hat = st["v"] / bias2 + + lr = float(lrs.get(name, 1e-2)) + p.data.addcdiv_(m_hat, v_hat.sqrt().add_(eps), value=-lr) + +class _BatchedPlan: + """Resolved layout for the batched per-pattern fit: component refs, lrs, and helpers.""" + + def __init__(self) -> None: + self.origin: OriginND | None = None + self.disk: DiskTemplate | None = None + self.dcbg: DCBackground | None = None + self.gaussbg: GaussianBackground | None = None + self.lat: SyntheticDiskLattice | None = None + self.disk_idx: int | None = None + self.dcbg_idx: int | None = None + self.gaussbg_idx: int | None = None + self.lat_idx: int | None = None + self.lrs: dict[str, float] = {} + self.scheduler_specs: dict[str, dict[str, Any]] = {} + self.component_keys: dict[str, list[str]] = {} + + self._trainable_flags: dict[str, bool] = {} + + @classmethod + def from_model( + cls, + model: AdditiveRenderModel, + components_list: list[Any], + optimizer_params: dict[str, Any], + ) -> "_BatchedPlan": + self = cls() + self.origin = cast(OriginND, model.origin) + + for idx, comp in enumerate(components_list): + if isinstance(comp, DiskTemplate) and self.disk is None: + self.disk = comp + self.disk_idx = idx + elif isinstance(comp, DCBackground) and self.dcbg is None: + self.dcbg = comp + self.dcbg_idx = idx + elif isinstance(comp, GaussianBackground) and self.gaussbg is None: + self.gaussbg = comp + self.gaussbg_idx = idx + elif isinstance(comp, SyntheticDiskLattice) and self.lat is None: + self.lat = comp + self.lat_idx = idx + else: + raise TypeError( + f"Batched fit does not yet support component type {type(comp).__name__} " + f"at index {idx} (or duplicate of an already-handled type)." + ) + + # Track trainability for each parameter + if self.disk is not None and self.disk_idx is not None: + lr = _lr_for_component(self.disk, self.disk_idx, optimizer_params, model) + name = model._component_constraint_name(self.disk, self.disk_idx) + + # Check if parameters are trainable + self._trainable_flags["disk.template_raw"] = self.disk.template_raw.requires_grad + self._trainable_flags["disk.intensity_raw"] = self.disk.intensity_raw.requires_grad + + # Only add to lrs if trainable + if self._trainable_flags["disk.template_raw"]: + self.lrs["disk.template_raw"] = lr + if self._trainable_flags["disk.intensity_raw"]: + self.lrs["disk.intensity_raw"] = lr + + self.component_keys[name] = ["disk.template_raw", "disk.intensity_raw"] + self.component_keys[self.disk.__class__.__name__] = list(self.component_keys[name]) + + if self.dcbg is not None and self.dcbg_idx is not None: + lr = _lr_for_component(self.dcbg, self.dcbg_idx, optimizer_params, model) + name = model._component_constraint_name(self.dcbg, self.dcbg_idx) + + self._trainable_flags["dcbg.intensity_raw"] = self.dcbg.intensity_raw.requires_grad + if self._trainable_flags["dcbg.intensity_raw"]: + self.lrs["dcbg.intensity_raw"] = lr + + self.component_keys[name] = ["dcbg.intensity_raw"] + self.component_keys[self.dcbg.__class__.__name__] = list(self.component_keys[name]) + + if self.gaussbg is not None and self.gaussbg_idx is not None: + lr = _lr_for_component(self.gaussbg, self.gaussbg_idx, optimizer_params, model) + name = model._component_constraint_name(self.gaussbg, self.gaussbg_idx) + + self._trainable_flags["gaussbg.sigma_raw"] = self.gaussbg.sigma_raw.requires_grad + self._trainable_flags["gaussbg.intensity_raw"] = self.gaussbg.intensity_raw.requires_grad + + if self._trainable_flags["gaussbg.sigma_raw"]: + self.lrs["gaussbg.sigma_raw"] = lr + if self._trainable_flags["gaussbg.intensity_raw"]: + self.lrs["gaussbg.intensity_raw"] = lr + + self.component_keys[name] = ["gaussbg.sigma_raw", "gaussbg.intensity_raw"] + self.component_keys[self.gaussbg.__class__.__name__] = list(self.component_keys[name]) + + if self.lat is not None and self.lat_idx is not None: + lr = _lr_for_component(self.lat, self.lat_idx, optimizer_params, model) + name = model._component_constraint_name(self.lat, self.lat_idx) + lat_keys = [] + + for k in ("u_row", "u_col", "v_row", "v_col", "i0_raw", "ir", "ic", "irr", "icc", "irc"): + full_key = f"lat.{k}" + param = getattr(self.lat, k, None) + if param is not None: + is_trainable = param.requires_grad if isinstance(param, torch.nn.Parameter) else False + self._trainable_flags[full_key] = is_trainable + if is_trainable: + self.lrs[full_key] = lr + lat_keys.append(full_key) + + # Origin uses the lattice's LR + self._trainable_flags["origin.coords"] = self.origin.coords.requires_grad + if self._trainable_flags["origin.coords"]: + self.lrs["origin.coords"] = lr + + self.component_keys[name] = lat_keys + self.component_keys[self.lat.__class__.__name__] = list(lat_keys) + elif self.gaussbg is not None: + if self._trainable_flags.get("gaussbg.intensity_raw", False): + self.lrs["origin.coords"] = self.lrs["gaussbg.intensity_raw"] + else: + if self.origin.coords.requires_grad: + self.lrs["origin.coords"] = 1e-2 + + # Default schedulers: constant LR for every key + self.scheduler_specs = {k: {"type": "none"} for k in self.lrs} + return self + + # ... (keep set_scheduler_params, lr_at_step, batched_constraint_loss, resolve_component_keys as-is) + + def build_stacked_params(self, B: int) -> dict[str, torch.Tensor]: + """Build stacked params only for trainable parameters, expand frozen ones as needed.""" + out: dict[str, torch.Tensor] = {} + assert self.origin is not None + + # Only stack trainable params + if self._trainable_flags.get("origin.coords", False): + out["origin.coords"] = self._stack(self.origin.coords, B) + + if self.disk is not None: + if self._trainable_flags.get("disk.template_raw", False): + out["disk.template_raw"] = self._stack(self.disk.template_raw, B) + if self._trainable_flags.get("disk.intensity_raw", False): + out["disk.intensity_raw"] = self._stack(self.disk.intensity_raw, B) + + if self.dcbg is not None: + if self._trainable_flags.get("dcbg.intensity_raw", False): + out["dcbg.intensity_raw"] = self._stack(self.dcbg.intensity_raw, B) + + if self.gaussbg is not None: + if self._trainable_flags.get("gaussbg.sigma_raw", False): + out["gaussbg.sigma_raw"] = self._stack(self.gaussbg.sigma_raw, B) + if self._trainable_flags.get("gaussbg.intensity_raw", False): + out["gaussbg.intensity_raw"] = self._stack(self.gaussbg.intensity_raw, B) + + if self.lat is not None: + for attr in ("u_row", "u_col", "v_row", "v_col", "i0_raw"): + t = getattr(self.lat, attr) + key = f"lat.{attr}" + if t is not None and self._trainable_flags.get(key, False): + out[key] = self._stack(t, B) + for attr in ("ir", "ic", "irr", "icc", "irc"): + t = getattr(self.lat, attr, None) + key = f"lat.{attr}" + if t is not None and self._trainable_flags.get(key, False): + out[key] = self._stack(t, B) + + return out + + @staticmethod + def _stack(p: torch.Tensor, B: int) -> torch.Tensor: + x = p.detach().clone().unsqueeze(0).expand(B, *p.shape).contiguous() + x.requires_grad_(True) + return x + + def batched_forward( + self, ctx: RenderContext, stacked: dict[str, torch.Tensor] + ) -> torch.Tensor: + """Batched forward with frozen parameter support.""" + B = next(iter(stacked.values())).shape[0] if stacked else 1 + + # Get origin (might be frozen) + origin_b = stacked.get("origin.coords") + if origin_b is None: + # Origin is frozen - expand from component + origin_b = self.origin.coords.unsqueeze(0).expand(B, *self.origin.coords.shape).clone() + + pred = torch.zeros(B, ctx.shape[0], ctx.shape[1], device=ctx.device, dtype=ctx.dtype) + + if self.disk is not None: + # Handle potentially frozen disk params + template_b = stacked.get("disk.template_raw") + if template_b is None: + template_b = self.disk.template_raw.unsqueeze(0).expand(B, *self.disk.template_raw.shape).clone() + + intensity_b = stacked.get("disk.intensity_raw") + if intensity_b is None: + intensity_b = self.disk.intensity_raw.unsqueeze(0).expand(B).clone() + + pred = pred + self.disk.forward_batched( + ctx, + template_raw_b=template_b, + intensity_raw_b=intensity_b, + origin_coords_b=origin_b, + ) + + if self.dcbg is not None: + intensity_b = stacked.get("dcbg.intensity_raw") + if intensity_b is None: + intensity_b = self.dcbg.intensity_raw.unsqueeze(0).expand(B).clone() + pred = pred + self.dcbg.forward_batched(ctx, intensity_raw_b=intensity_b) + + if self.gaussbg is not None: + sigma_b = stacked.get("gaussbg.sigma_raw") + if sigma_b is None: + sigma_b = self.gaussbg.sigma_raw.unsqueeze(0).expand(B).clone() + + intensity_b = stacked.get("gaussbg.intensity_raw") + if intensity_b is None: + intensity_b = self.gaussbg.intensity_raw.unsqueeze(0).expand(B).clone() + + pred = pred + self.gaussbg.forward_batched( + ctx, + sigma_raw_b=sigma_b, + intensity_raw_b=intensity_b, + origin_coords_b=origin_b, + ) + + if self.lat is not None: + # Lattice requires all params - expand frozen ones + u_row_b = stacked.get("lat.u_row") + if u_row_b is None: + u_row_b = self.lat.u_row.unsqueeze(0).expand(B).clone() + + u_col_b = stacked.get("lat.u_col") + if u_col_b is None: + u_col_b = self.lat.u_col.unsqueeze(0).expand(B).clone() + + v_row_b = stacked.get("lat.v_row") + if v_row_b is None: + v_row_b = self.lat.v_row.unsqueeze(0).expand(B).clone() + + v_col_b = stacked.get("lat.v_col") + if v_col_b is None: + v_col_b = self.lat.v_col.unsqueeze(0).expand(B).clone() + + i0_raw_b = stacked.get("lat.i0_raw") + if i0_raw_b is None: + i0_raw_b = self.lat.i0_raw.unsqueeze(0).expand(B, *self.lat.i0_raw.shape).clone() + + # Optional params + ir_b = stacked.get("lat.ir") + if ir_b is None and self.lat.ir is not None: + ir_b = self.lat.ir.unsqueeze(0).expand(B, *self.lat.ir.shape).clone() + + ic_b = stacked.get("lat.ic") + if ic_b is None and self.lat.ic is not None: + ic_b = self.lat.ic.unsqueeze(0).expand(B, *self.lat.ic.shape).clone() + + irr_b = stacked.get("lat.irr") + if irr_b is None and self.lat.irr is not None: + irr_b = self.lat.irr.unsqueeze(0).expand(B, *self.lat.irr.shape).clone() + + icc_b = stacked.get("lat.icc") + if icc_b is None and self.lat.icc is not None: + icc_b = self.lat.icc.unsqueeze(0).expand(B, *self.lat.icc.shape).clone() + + irc_b = stacked.get("lat.irc") + if irc_b is None and self.lat.irc is not None: + irc_b = self.lat.irc.unsqueeze(0).expand(B, *self.lat.irc.shape).clone() + + # Template (from disk) + template_b = stacked.get("disk.template_raw") + if template_b is None and self.disk is not None: + template_b = self.disk.template_raw.unsqueeze(0).expand(B, *self.disk.template_raw.shape).clone() + + pred = pred + self.lat.forward_batched( + ctx, + u_row_b=u_row_b, + u_col_b=u_col_b, + v_row_b=v_row_b, + v_col_b=v_col_b, + i0_raw_b=i0_raw_b, + ir_b=ir_b, + ic_b=ic_b, + irr_b=irr_b, + icc_b=icc_b, + irc_b=irc_b, + template_raw_b=template_b, + origin_coords_b=origin_b, + ) + return pred + + def apply_hard_constraints( + self, + stacked: dict[str, torch.Tensor], + skip_keys: set[str] | None = None, + ) -> None: + """Apply batched hard constraints, respecting frozen parameters.""" + skip_keys = skip_keys or set() + + # Add frozen params to skip_keys + frozen_keys = {k for k, trainable in self._trainable_flags.items() if not trainable} + skip_keys = skip_keys | frozen_keys + + # Parameter bounds (always elementwise; safe on any shape) + if self.disk is not None: + for pname, (lo, hi) in self.disk.parameter_bounds.items(): + key = f"disk.{pname}" + if key in stacked and key not in skip_keys: + self._clamp_bounds_inplace(stacked[key], lo, hi) + if self.dcbg is not None: + for pname, (lo, hi) in self.dcbg.parameter_bounds.items(): + key = f"dcbg.{pname}" + if key in stacked and key not in skip_keys: + self._clamp_bounds_inplace(stacked[key], lo, hi) + if self.gaussbg is not None: + for pname, (lo, hi) in self.gaussbg.parameter_bounds.items(): + key = f"gaussbg.{pname}" + if key in stacked and key not in skip_keys: + self._clamp_bounds_inplace(stacked[key], lo, hi) + if self.lat is not None: + for pname, (lo, hi) in self.lat.parameter_bounds.items(): + key = f"lat.{pname}" + if key in stacked and key not in skip_keys: + self._clamp_bounds_inplace(stacked[key], lo, hi) + + disk_template_frozen = "disk.template_raw" in skip_keys + disk_intensity_frozen = "disk.intensity_raw" in skip_keys + + # DiskTemplate composite hard constraints + if self.disk is not None and not disk_template_frozen: + template = stacked.get("disk.template_raw") + intensity = stacked.get("disk.intensity_raw") + cfg = self.disk.constraint_config + if template is not None and intensity is not None: + if bool(self.disk.hard_constraints.get("force_center", False)): + self._batched_center_disk(template) + if bool(self.disk.hard_constraints.get("force_cutoff", False)): + self._batched_enforce_cutoff(template, cfg) + if bool(self.disk.hard_constraints.get("force_circular_mask", False)): + self._batched_enforce_circular_mask(template, cfg) + if bool(self.disk.hard_constraints.get("force_shrinkage", False)): + template.sub_(float(cfg.get("shrinkage_amount", 0.25))) + if bool(self.disk.hard_constraints.get("force_positive", False)): + template.clamp_(min=0.0) + if not disk_intensity_frozen and intensity is not None: + intensity.clamp_(min=0.0) + if bool(self.disk.hard_constraints.get("force_norm", False)): + self._batched_enforce_norm(template) + + if ( + self.lat is not None + and "lat.i0_raw" not in skip_keys + and bool(self.lat.hard_constraints.get("force_positive_intensity", False)) + ): + i0 = stacked.get("lat.i0_raw") + if i0 is not None: + i0.clamp_(min=0.0) + + @staticmethod + def _clamp_bounds_inplace(t: torch.Tensor, lo: float | None, hi: float | None) -> None: + if lo is None and hi is None: + return + if lo is None: + t.clamp_(max=float(hi)) # type: ignore[arg-type] + elif hi is None: + t.clamp_(min=float(lo)) + else: + t.clamp_(min=float(lo), max=float(hi)) + + @staticmethod + def _batched_center_disk(template_b: torch.Tensor) -> None: + # template_b: (B, H_t, W_t) + B, h, w = template_b.shape + weights = template_b.clamp(min=0.0) + mass = weights.sum(dim=(1, 2)) + if not torch.any(mass > 1e-12): + return + rr = torch.arange(h, device=template_b.device, dtype=template_b.dtype).view(1, h, 1) + cc = torch.arange(w, device=template_b.device, dtype=template_b.dtype).view(1, 1, w) + safe_mass = mass.clamp(min=1e-12) + com_r = (weights * rr).sum(dim=(1, 2)) / safe_mass + com_c = (weights * cc).sum(dim=(1, 2)) / safe_mass + target_r = (h - 1) * 0.5 + target_c = (w - 1) * 0.5 + shift_r = target_r - com_r # (B,) + shift_c = target_c - com_c + denom_h = max(h - 1, 1) + denom_w = max(w - 1, 1) + ty = -2.0 * shift_r / float(denom_h) + tx = -2.0 * shift_c / float(denom_w) + zeros = torch.zeros_like(tx) + ones = torch.ones_like(tx) + theta = torch.stack( + [ + torch.stack([ones, zeros, tx], dim=1), + torch.stack([zeros, ones, ty], dim=1), + ], + dim=1, + ) # (B, 2, 3) + src = template_b.unsqueeze(1) # (B, 1, H, W) + grid = F.affine_grid(theta, [B, 1, h, w], align_corners=True) + shifted = F.grid_sample(src, grid, mode="bilinear", padding_mode="zeros", align_corners=True)[:, 0] + # Only shift samples with nonzero mass; others left as-is + do_shift = (mass > 1e-12).view(B, 1, 1) + template_b.copy_(torch.where(do_shift, shifted, template_b)) + + @staticmethod + def _batched_enforce_norm(template_b: torch.Tensor) -> None: + mins = template_b.amin(dim=(1, 2), keepdim=True) + template_b.sub_(mins) + maxs = template_b.amax(dim=(1, 2), keepdim=True).clamp(min=1e-12) + template_b.div_(maxs) + + @staticmethod + def _batched_enforce_cutoff(template_b: torch.Tensor, cfg: dict[str, Any]) -> None: + thresh_ratio = float(cfg.get("hard_cutoff_threshold", 0.35)) + maxs = template_b.amax(dim=(1, 2), keepdim=True) + thresh = maxs * thresh_ratio + mask = template_b <= thresh + template_b.masked_fill_(mask, 0.0) + + @staticmethod + def _batched_enforce_circular_mask(template_b: torch.Tensor, cfg: dict[str, Any]) -> None: + B, h, w = template_b.shape + radius = (min(h, w) / 2.0) * float(cfg.get("circular_mask_radius_fraction", 0.95)) + r = torch.arange(-h / 2, h / 2, device=template_b.device, dtype=template_b.dtype) + c = torch.arange(-w / 2, w / 2, device=template_b.device, dtype=template_b.dtype) + rr, cc = torch.meshgrid(r, c, indexing="ij") + circle = torch.sqrt(rr * rr + cc * cc) + if bool(cfg.get("soft_circular_mask", False)): + sharpness = float(cfg.get("circular_mask_sharpness", 0)) + mask2d = torch.sigmoid(sharpness * (radius - circle)) + else: + mask2d = (circle <= radius).to(dtype=template_b.dtype) + template_b.mul_(mask2d.view(1, h, w)) + + def build_sample_state_dict( + self, + init_state: dict[str, torch.Tensor], + stacked: dict[str, torch.Tensor], + b: int, + ) -> dict[str, torch.Tensor]: + """Extract sample b from stacked params, handling frozen params.""" + out = {k: v.detach().clone() for k, v in init_state.items()} + + # Origin: top-level key plus any 'components.X.origin.coords' or + # 'components.X.disk.origin.coords' that PyTorch state_dict registers + if "origin.coords" in stacked: + origin_val = stacked["origin.coords"][b].detach().clone() + else: + # Origin was frozen - use the init value + origin_val = init_state.get("origin.coords") + if origin_val is not None: + origin_val = origin_val.detach().clone() + + if origin_val is not None: + for key in list(out.keys()): + if key == "origin.coords" or key.endswith(".origin.coords"): + out[key] = origin_val.clone() + + # DiskTemplate (shared with lat.disk) + if self.disk is not None and self.disk_idx is not None: + if "disk.template_raw" in stacked: + t_val = stacked["disk.template_raw"][b].detach().clone() + else: + # Template was frozen + t_val = init_state.get(f"components.{self.disk_idx}.template_raw") + if t_val is not None: + t_val = t_val.detach().clone() + + if "disk.intensity_raw" in stacked: + i_val = stacked["disk.intensity_raw"][b].detach().clone() + else: + # Intensity was frozen + i_val = init_state.get(f"components.{self.disk_idx}.intensity_raw") + if i_val is not None: + i_val = i_val.detach().clone() + + if t_val is not None: + for key in list(out.keys()): + if key.endswith(".template_raw"): + out[key] = t_val.clone() + + if i_val is not None: + for key in ( + f"components.{self.disk_idx}.intensity_raw", + f"components.{self.lat_idx}.disk.intensity_raw" if self.lat_idx is not None else None, + ): + if key is not None and key in out: + out[key] = i_val.clone() + + # DCBackground + if self.dcbg is not None and self.dcbg_idx is not None: + if "dcbg.intensity_raw" in stacked: + out[f"components.{self.dcbg_idx}.intensity_raw"] = stacked["dcbg.intensity_raw"][b].detach().clone() + else: + # Was frozen - keep init value (already in out) + pass + + # GaussianBackground + if self.gaussbg is not None and self.gaussbg_idx is not None: + if "gaussbg.sigma_raw" in stacked: + out[f"components.{self.gaussbg_idx}.sigma_raw"] = stacked["gaussbg.sigma_raw"][b].detach().clone() + if "gaussbg.intensity_raw" in stacked: + out[f"components.{self.gaussbg_idx}.intensity_raw"] = stacked["gaussbg.intensity_raw"][b].detach().clone() + + # SyntheticDiskLattice scalar + tensor params + if self.lat is not None and self.lat_idx is not None: + for attr in ("u_row", "u_col", "v_row", "v_col", "i0_raw", "ir", "ic", "irr", "icc", "irc"): + key = f"components.{self.lat_idx}.{attr}" + stacked_key = f"lat.{attr}" + if key in out and stacked_key in stacked: + out[key] = stacked[stacked_key][b].detach().clone() + # else: was frozen, keep init value + + return out + + def set_scheduler_params(self, scheduler_params: Any) -> None: + """Configure per-key scheduler specs.""" + if scheduler_params is None: + return + if not isinstance(scheduler_params, dict): + return + + def _normalize(spec: dict[str, Any]) -> dict[str, Any]: + out = dict(spec) + t = str(out.pop("type", out.pop("name", "none"))).lower() + if t in ("cosine", "cosineannealing"): + t = "cosine_annealing" + if t == "cosine_annealing": + if "t_max" in out: + out["T_max"] = out.pop("t_max") + if "T_max" in out and out["T_max"] is not None: + out["T_max"] = int(float(out["T_max"])) + if "eta_min" in out: + out["eta_min"] = float(out["eta_min"]) + elif t == "exponential": + if "gamma" in out: + out["gamma"] = float(out["gamma"]) + elif t == "linear": + for k in ("start_factor", "end_factor", "total_iters"): + if k in out and out[k] is not None: + out[k] = float(out[k]) if k != "total_iters" else int(float(out[k])) + out["type"] = t + return out + + if "type" in scheduler_params or "name" in scheduler_params: + spec = _normalize(cast(dict[str, Any], scheduler_params)) + for k in self.scheduler_specs: + self.scheduler_specs[k] = dict(spec) + return + + for comp_name, spec in scheduler_params.items(): + if not isinstance(spec, dict): + continue + norm = _normalize(spec) + param_keys = self.component_keys.get(str(comp_name), []) + for k in param_keys: + if k in self.scheduler_specs: + self.scheduler_specs[k] = dict(norm) + + def lr_at_step(self, step: int, n_steps: int) -> dict[str, float]: + """Evaluate the analytic LR schedule for each stacked-param key at step.""" + out: dict[str, float] = {} + for key, base_lr in self.lrs.items(): + spec = self.scheduler_specs.get(key, {"type": "none"}) + t = str(spec.get("type", "none")).lower() + if t == "cosine_annealing": + T_max = int(spec.get("T_max") or n_steps) + if T_max < 1: + T_max = 1 + eta_min = float(spec.get("eta_min", 1e-7)) + s = max(step - 1, 0) + import math + lr = eta_min + 0.5 * (base_lr - eta_min) * (1.0 + math.cos(math.pi * s / T_max)) + out[key] = float(lr) + elif t == "exponential": + gamma = float(spec.get("gamma", 1.0)) + out[key] = float(base_lr * (gamma ** max(step - 1, 0))) + elif t == "linear": + start = float(spec.get("start_factor", 1.0 / 3.0)) + end = float(spec.get("end_factor", 1.0)) + total = int(spec.get("total_iters", n_steps) or n_steps) + if total < 1: + total = 1 + s = min(max(step - 1, 0), total) + factor = start + (end - start) * (s / total) + out[key] = float(base_lr * factor) + else: + out[key] = float(base_lr) + return out + + def batched_constraint_loss( + self, + ctx: RenderContext, + stacked: dict[str, torch.Tensor], + ) -> torch.Tensor: + """Per-sample soft-constraint loss summed across components.""" + B = next(iter(stacked.values())).shape[0] if stacked else 1 + out = torch.zeros(B, device=ctx.device, dtype=ctx.dtype) + if self.disk is not None: + template_b = stacked.get("disk.template_raw") + if template_b is not None: + out = out + self.disk.constraint_loss_batched(ctx, template_raw_b=template_b) + return out + + def resolve_component_keys(self, components: Any) -> list[str]: + """Resolve a name/list of component names to the stacked-param keys they own.""" + if components is None: + return [] + if isinstance(components, str): + components = [components] + keys: list[str] = [] + for name in components: + ks = self.component_keys.get(str(name)) + if ks is None: + raise KeyError( + f"Unknown component name '{name}' for batched fit. " + f"Known: {sorted(self.component_keys)}" + ) + for k in ks: + if k not in keys: + keys.append(k) + return keys \ No newline at end of file diff --git a/src/quantem/diffraction/model_fitting_visualizations.py b/src/quantem/diffraction/model_fitting_visualizations.py index 6dfb4f62..39de26b5 100644 --- a/src/quantem/diffraction/model_fitting_visualizations.py +++ b/src/quantem/diffraction/model_fitting_visualizations.py @@ -73,7 +73,12 @@ def _plot_overlays( ax.plot(disk_centers_rc[:, 1], disk_centers_rc[:, 0], **kw_disks) def plot_losses( - self, figax: tuple[Any, Any] | None = None, plot_lrs: bool = True + self, + figax: tuple[Any, Any] | None = None, + plot_lrs: bool = True, + plot_individual: bool = False, + individual_row: int = 0, + individual_col: int = 0, ) -> tuple[Any, Any]: md = cast("ModelDiffraction", self) colors = config.get("viz.colors.set") @@ -84,9 +89,12 @@ def plot_losses( fig, ax = plt.subplots() else: fig, ax = figax - - mean_hist = md.fit_history.get("mean") - losses = np.asarray([] if mean_hist is None else mean_hist.losses, dtype=np.float64) + if plot_individual: + mean_hist = md.fit_history.get(f"individual_{individual_row}_{individual_col}") + losses = np.asarray([] if mean_hist is None else mean_hist.losses, dtype=np.float64) + else: + mean_hist = md.fit_history.get("mean") + losses = np.asarray([] if mean_hist is None else mean_hist.losses, dtype=np.float64) if losses.size == 0: ax.text( 0.5, @@ -141,7 +149,10 @@ def plot_losses( def visualize( self, *, - power: float = 0.25, + individual_loss: bool = False, + pattern_row: int = 0, + pattern_col: int = 0, + power: float = 1, cbar: bool = False, axsize: tuple[int, int] = (6, 6), overlay: bool = True, @@ -193,14 +204,21 @@ def visualize( md.preprocess() if md.image_ref is None or md.model is None or md.ctx is None: raise RuntimeError("Call .define_model(...) first.") + if md.dataset.shape[0] <= pattern_row or md.dataset.shape[1] <= pattern_col: + raise ValueError("individual row or column outside bounds of dataset") fig = plt.figure(figsize=(12, 7)) gs = gridspec.GridSpec(2, 1, height_ratios=[1, 2], hspace=0.3) ax_top = fig.add_subplot(gs[0]) - md.plot_losses(figax=(fig, ax_top), plot_lrs=True) + md.plot_losses(figax=(fig, ax_top), plot_lrs=True, plot_individual=individual_loss,individual_row=pattern_row,individual_col=pattern_col) - ref = np.asarray(md.image_ref, dtype=np.float32) - pred = md.render_current + if individual_loss: + pred = md.render_individual_pattern(row=pattern_row, col=pattern_col) + ref = np.asarray(md.dataset.array[pattern_row, pattern_col], dtype=np.float32) + else: + ref = np.asarray(md.image_ref, dtype=np.float32) + pred = md.render_current + refp = ref if power == 1.0 else np.maximum(ref, 0.0) ** float(power) predp = pred if power == 1.0 else np.maximum(pred, 0.0) ** float(power) vmin = float(min(refp.min(), predp.min())) @@ -239,6 +257,8 @@ def visualize( ) mean_hist = md.fit_history.get("mean") + if individual_loss: + mean_hist = md.fit_history.get(f"individual_{pattern_row}_{pattern_row}") if mean_hist is not None and len(mean_hist.losses) > 0: fig.suptitle( f"Final loss: {mean_hist.losses[-1]:.3e} | Iters: {len(mean_hist.losses)}", @@ -248,9 +268,13 @@ def visualize( plt.show() return fig, axs - def plot_mean_model( + def plot_model( self, *, + plot_mean_model: bool = False, + plot_individual_model: bool = False, + pattern_row: int = 0, + pattern_col: int= 0, power: float = 0.25, returnfig: bool = False, axsize: tuple[int, int] = (6, 6), @@ -260,10 +284,10 @@ def plot_mean_model( overlay_on: Literal["model", "both"] = "model", origin_marker_kwargs: dict[str, Any] | None = None, disk_marker_kwargs: dict[str, Any] | None = None, - **_: Any, + **kwargs, ) -> tuple[Any, Any] | None: """ - Plot reference and model mean diffraction images. + Plot reference and indiviual diffraction images. Parameters ---------- @@ -305,24 +329,34 @@ def plot_mean_model( md.preprocess() if md.image_ref is None or md.model is None or md.ctx is None: raise RuntimeError("Call .define_model(...) first.") - - ref = np.asarray(md.image_ref, dtype=np.float32) - pred = md.render_current + if plot_mean_model and plot_individual_model: + raise RuntimeError("can only plot mean or plot individual, not both") + if plot_individual_model: + pred = md.render_individual_pattern(pattern_row, pattern_col) + ref = np.asarray(md.dataset.array[pattern_row, pattern_col], dtype=np.float32) + elif plot_mean_model: + ref = np.asarray(md.image_ref, dtype=np.float32) + pred = md.render_mean_refined + else: + ref = np.asarray(md.image_ref, dtype=np.float32) + pred = md.render_current refp = ref if power == 1.0 else np.maximum(ref, 0.0) ** float(power) predp = pred if power == 1.0 else np.maximum(pred, 0.0) ** float(power) vmin = float(min(refp.min(), predp.min())) vmax = float(max(refp.max(), predp.max())) + t1 = kwargs.pop("title", "") fig, ax = show_2d( [refp, predp], - title=["image_ref", "model"], + title=[t1 + " image_ref", t1 + " model"], cmap=config.get("viz.cmap"), - cbar=False, + cbar=True, returnfig=True, axsize=axsize, vmin=vmin, vmax=vmax, + **kwargs, ) if overlay: if overlay_on not in ("model", "both"): diff --git a/src/quantem/diffraction/polar.py b/src/quantem/diffraction/polar.py new file mode 100644 index 00000000..7e87eff3 --- /dev/null +++ b/src/quantem/diffraction/polar.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, List, Union + +import matplotlib.pyplot as plt +import numpy as np +from numpy.typing import NDArray + +from quantem.core.datastructures.dataset2d import Dataset2d +from quantem.core.datastructures.dataset3d import Dataset3d +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.datastructures.polar4dstem import Polar4dstem +from quantem.core.io.serialize import AutoSerialize +from quantem.core.utils.validators import ensure_valid_array + + +class RDF(AutoSerialize): + """ + Radial distribution / fluctuation electron microscopy analysis helper. + + This class wraps a 4D-STEM (or 2D diffraction) dataset and stores a + polar-transformed representation as a Polar4dstem instance in `self.polar`. + Analysis methods (radial statistics, PDF, FEM, clustering, etc.) are + provided as stubs for now and will be implemented in future revisions. + """ + + _token = object() + + def __init__( + self, + polar: Polar4dstem, + input_data: Any | None = None, + _token: object | None = None, + ): + if _token is not self._token: + raise RuntimeError( + "Use RadialDistributionFunction.from_data() to instantiate this class." + ) + + super().__init__() + self.polar = polar + self.input_data = input_data + + # Placeholders for analysis results (to be populated by future methods) + self.radial_mean: NDArray | None = None + self.radial_var: NDArray | None = None + self.radial_var_norm: NDArray | None = None + + self.pdf_r: NDArray | None = None + self.pdf_reduced: NDArray | None = None + self.pdf: NDArray | None = None + + self.Sk: NDArray | None = None + self.fk: NDArray | None = None + self.bg: NDArray | None = None + self.offset: float | None = None + self.Sk_mask: NDArray | None = None + + # ------------------------------------------------------------------ + # Constructors + # ------------------------------------------------------------------ + @classmethod + def from_data( + cls, + data: Union[NDArray, Dataset2d, Dataset3d, Dataset4dstem, Polar4dstem], + *, + origin_row: float | None = None, + origin_col: float | None = None, + ellipse_params: tuple[float, float, float] | None = None, + num_annular_bins: int = 180, + radial_min: float = 0.0, + radial_max: float | None = None, + radial_step: float = 1.0, + two_fold_rotation_symmetry: bool = False, + ) -> "RadialDistributionFunction": + """ + Create a RadialDistributionFunction object from various input types. + + Parameters + ---------- + data + Supported inputs: + - 2D numpy array (single diffraction pattern) + - 4D numpy array (scan_y, scan_x, ky, kx) + - Dataset2d + - Dataset4dstem + - Polar4dstem + origin_row, origin_col + Diffraction-space origin (in pixels). If None, defaults to the + central pixel of the diffraction pattern. + Other parameters + Passed through to Dataset4dstem.polar_transform when needed. + """ + # Polar input: use directly + if isinstance(data, Polar4dstem): + polar = data + return cls(polar=polar, input_data=data, _token=cls._token) + + # Dataset4dstem input: polar-transform it + if isinstance(data, Dataset4dstem): + scan_y, scan_x, ny, nx = data.array.shape + if origin_row is None: + origin_row = (ny - 1) / 2.0 + if origin_col is None: + origin_col = (nx - 1) / 2.0 + + polar = data.polar_transform( + origin_row=origin_row, + origin_col=origin_col, + ellipse_params=ellipse_params, + num_annular_bins=num_annular_bins, + radial_min=radial_min, + radial_max=radial_max, + radial_step=radial_step, + two_fold_rotation_symmetry=two_fold_rotation_symmetry, + ) + return cls(polar=polar, input_data=data, _token=cls._token) + + # Dataset2d input: wrap as a trivial 4D-STEM (1x1 scan) then polar-transform + if isinstance(data, Dataset2d): + arr2d = data.array + if arr2d.ndim != 2: + raise ValueError("Dataset2d for RDF must be 2D.") + arr4 = arr2d[None, None, ...] # (1, 1, ky, kx) + + ds4 = Dataset4dstem.from_array( + array=arr4, + name=f"{data.name}_as4dstem" if getattr(data, "name", None) else "rdf_4dstem_from_2d", + origin=np.concatenate( + [np.zeros(2, dtype=float), np.asarray(data.origin, dtype=float)] + ), + sampling=np.concatenate( + [np.ones(2, dtype=float), np.asarray(data.sampling, dtype=float)] + ), + units=["pixels", "pixels"] + list(data.units), + signal_units=data.signal_units, + ) + ny, nx = ds4.array.shape[-2:] + if origin_row is None: + origin_row = (ny - 1) / 2.0 + if origin_col is None: + origin_col = (nx - 1) / 2.0 + + polar = ds4.polar_transform( + origin_row=origin_row, + origin_col=origin_col, + ellipse_params=ellipse_params, + num_annular_bins=num_annular_bins, + radial_min=radial_min, + radial_max=radial_max, + radial_step=radial_step, + two_fold_rotation_symmetry=two_fold_rotation_symmetry, + ) + return cls(polar=polar, input_data=data, _token=cls._token) + + # Dataset3d input: not yet specified how to interpret + if isinstance(data, Dataset3d): + raise NotImplementedError( + "RadialDistributionFunction.from_data does not yet support Dataset3d inputs." + ) + + # Numpy array input + arr = ensure_valid_array(data) + if arr.ndim == 2: + ds2 = Dataset2d.from_array(arr, name="rdf_input_2d") + return cls.from_data( + ds2, + origin_row=origin_row, + origin_col=origin_col, + ellipse_params=ellipse_params, + num_annular_bins=num_annular_bins, + radial_min=radial_min, + radial_max=radial_max, + radial_step=radial_step, + two_fold_rotation_symmetry=two_fold_rotation_symmetry, + ) + elif arr.ndim == 4: + ds4 = Dataset4dstem.from_array(arr, name="rdf_input_4dstem") + return cls.from_data( + ds4, + origin_row=origin_row, + origin_col=origin_col, + ellipse_params=ellipse_params, + num_annular_bins=num_annular_bins, + radial_min=radial_min, + radial_max=radial_max, + radial_step=radial_step, + two_fold_rotation_symmetry=two_fold_rotation_symmetry, + ) + else: + raise ValueError( + "RadialDistributionFunction.from_data only supports 2D or 4D arrays." + ) + + # ------------------------------------------------------------------ + # Convenience accessors + # ------------------------------------------------------------------ + @property + def qq(self) -> Any: + """ + Scattering vector coordinate array along the radial dimension of `self.polar`, + in physical units (using Polar4dstem.sampling and origin). + """ + # Polar4dstem dims: (scan_y, scan_x, phi, r) + # radial axis is 3 + return self.polar.coords_units(3) + + @property + def radial_bins(self) -> Any: + """ + Radial bin centers in pixel units (convenience alias). + """ + return self.polar.coords(3) + + # ------------------------------------------------------------------ + # Analysis method stubs (py4DSTEM-style API) + # ------------------------------------------------------------------ + def calculate_radial_statistics( + self, + mask_realspace: NDArray | None = None, + plot_results_mean: bool = False, + plot_results_var: bool = False, + figsize: tuple[float, float] = (8, 4), + returnval: bool = False, + returnfig: bool = False, + progress_bar: bool = True, + ): + """ + Stub for radial statistics (FEM-style) calculation on the polar data. + + Intended to compute radial mean, variance, and normalized variance + from self.polar. Not implemented yet. + """ + raise NotImplementedError("calculate_radial_statistics is not implemented yet.") + + def plot_radial_mean( + self, + log_x: bool = False, + log_y: bool = False, + figsize: tuple[float, float] = (8, 4), + returnfig: bool = False, + ): + """ + Stub for plotting radial mean intensity vs scattering vector. + """ + raise NotImplementedError("plot_radial_mean is not implemented yet.") + + def plot_radial_var_norm( + self, + figsize: tuple[float, float] = (8, 4), + returnfig: bool = False, + ): + """ + Stub for plotting normalized radial variance vs scattering vector. + """ + raise NotImplementedError("plot_radial_var_norm is not implemented yet.") + + def calculate_pair_dist_function( + self, + k_min: float = 0.05, + k_max: float | None = None, + k_width: float = 0.25, + k_lowpass: float | None = None, + k_highpass: float | None = None, + r_min: float = 0.0, + r_max: float = 20.0, + r_step: float = 0.02, + damp_origin_fluctuations: bool = True, + enforce_positivity: bool = True, + density: float | None = None, + plot_background_fits: bool = False, + plot_sf_estimate: bool = False, + plot_reduced_pdf: bool = True, + plot_pdf: bool = False, + figsize: tuple[float, float] = (8, 4), + maxfev: int | None = None, + returnval: bool = False, + returnfig: bool = False, + ): + """ + Stub for pair distribution function (PDF) calculation from radial statistics. + + Intended to estimate S(k), background, and transform to real-space g(r)/G(r). + """ + raise NotImplementedError("calculate_pair_dist_function is not implemented yet.") + + def plot_background_fits( + self, + figsize: tuple[float, float] = (8, 4), + returnfig: bool = False, + ): + """ + Stub for plotting background fit vs radial mean intensity. + """ + raise NotImplementedError("plot_background_fits is not implemented yet.") + + def plot_sf_estimate( + self, + figsize: tuple[float, float] = (8, 4), + returnfig: bool = False, + ): + """ + Stub for plotting reduced structure factor S(k). + """ + raise NotImplementedError("plot_sf_estimate is not implemented yet.") + + def plot_reduced_pdf( + self, + figsize: tuple[float, float] = (8, 4), + returnfig: bool = False, + ): + """ + Stub for plotting reduced PDF g(r). + """ + raise NotImplementedError("plot_reduced_pdf is not implemented yet.") + + def plot_pdf( + self, + figsize: tuple[float, float] = (8, 4), + returnfig: bool = False, + ): + """ + Stub for plotting full PDF G(r). + """ + raise NotImplementedError("plot_pdf is not implemented yet.") diff --git a/src/quantem/diffraction/strain.py b/src/quantem/diffraction/strain.py new file mode 100644 index 00000000..637fb4d8 --- /dev/null +++ b/src/quantem/diffraction/strain.py @@ -0,0 +1,865 @@ +from __future__ import annotations + +import warnings + +import numpy as np +from numpy.lib.stride_tricks import sliding_window_view + +from quantem.core.datastructures.dataset2d import Dataset2d +from quantem.core.io.serialize import AutoSerialize +from quantem.diffraction.strain_visualization import ( + plot_strain_panels, + plot_strain_precision_histogram, +) + + +class StrainMap(AutoSerialize): + """Strain tensor maps fit from per-position lattice vectors. + + Stores the reference-frame strain components ``e_rr`` (row), ``e_cc`` (col), + ``e_rc`` (shear), and ``phi`` (infinitesimal rotation). The reference lattice + is the median of the fitted ``g_u``/``g_v`` over a mask/ROI; the strain tensor + is recomputed by :meth:`update_reference`. + + Two measurement modalities are supported and give identical strain for the same + deformation, so correlation and cepstral maps can be compared directly: + reciprocal-space Bragg vectors (``real_space=False``, nanobeam correlation) and + real-space cepstral/autocorrelation vectors (``real_space=True``). + + Parameters + ---------- + u_array : np.ndarray + Per-position first lattice vector, shape ``(scan_row, scan_col, 2)``. + v_array : np.ndarray + Per-position second lattice vector, shape ``(scan_row, scan_col, 2)``. + ds_shape : tuple of int + Shape of the parent scan grid, used to size the strain maps. + real_space : bool + ``False`` for reciprocal-space (Bragg/correlation) lattice vectors; ``True`` + for real-space (cepstral autocorrelation / DPC) vectors. Both modalities are + arranged to yield matching strain (see :func:`_strain_tensor`). + u_ref : np.ndarray, optional + Fixed reference for ``u``; if omitted the median over the mask/ROI is used. + A value supplied here persists across re-fits. + v_ref : np.ndarray, optional + Fixed reference for ``v``; if omitted the median over the mask/ROI is used. + A value supplied here persists across re-fits. + mask : np.ndarray, optional + ``(scan_row, scan_col)`` weighting/ROI mask; defaults to all ones (the full + scan). Normalized to ``[0, 1]`` on assignment. + ds_sampling : float, optional + Real-space scan sampling (step size); defaults to ``1.0``. + ds_units : str, optional + Units for ``ds_sampling``; defaults to ``"pixels"``. + """ + + mask: np.ndarray | None = None + real_space: bool = False + + e_rr: Dataset2d + e_cc: Dataset2d + e_rc: Dataset2d + phi: Dataset2d + + u_ref: np.ndarray | None = None + v_ref: np.ndarray | None = None + u_array: np.ndarray + v_array: np.ndarray + + ds_sampling: float = 1.0 + ds_units: str = "pixels" + ds_shape: tuple[int, ...] + + def __init__( + self, + u_array: np.ndarray, + v_array: np.ndarray, + ds_shape: tuple[int, ...], + real_space: bool, + u_ref: np.ndarray | None = None, + v_ref: np.ndarray | None = None, + mask: np.ndarray | None = None, + ds_sampling: float | None = None, + ds_units: str | None = None, + ): + super().__init__() + self.u_array = u_array + self.v_array = v_array + + self.ds_shape = ds_shape + self.real_space = real_space + + self.ds_sampling = 1.0 if ds_sampling is None else ds_sampling + self.ds_units = "pixels" if ds_units is None else ds_units + + # Per-position weighting / ROI in [0, 1]. The mask producers + # (BraggVectors.fit_lattice, StrainMapAutocorrelation.create_mask) already emit a + # [0, 1] weight, so a well-formed mask is taken as-is: re-normalizing it here + # would collide with that scaling -- a near-constant mask (e.g. the radial + # cepstral weight) would be squashed to ~0 and blank the strain display. Only a + # mask that falls outside [0, 1] (e.g. a raw-intensity ROI) is rescaled, and a + # constant / empty / all-NaN mask falls back to uniform full weight. + m = np.ones(ds_shape[:2], dtype=float) if mask is None else np.asarray(mask, dtype=float) + m_lo = np.nanmin(m) + m_hi = np.nanmax(m) + if not (np.isfinite(m_lo) and np.isfinite(m_hi)) or m_hi <= m_lo: + m = np.ones_like(m) + elif m_lo < 0.0 or m_hi > 1.0: + m = (m - m_lo) / (m_hi - m_lo) + self.mask = m + + # user-supplied reference vectors persist across re-fits (None = use median) + self._u_ref_fixed = None if u_ref is None else np.asarray(u_ref, dtype=float) + self._v_ref_fixed = None if v_ref is None else np.asarray(v_ref, dtype=float) + self.u_ref = None + self.v_ref = None + + self.update_reference() + + # ---- main methods ---- + + def update_reference( + self, + strain_mask: np.ndarray | None = None, + u_ref: np.ndarray | None = None, + v_ref: np.ndarray | None = None, + plot_strain_roi: bool = False, + **plot_kwargs, + ) -> "StrainMap": + """(Re)compute the reference lattice and strain tensor maps. + + Reference precedence: explicit ``u_ref``/``v_ref`` argument > vectors fixed at + construction > median over ``strain_mask`` (if given) else over ``self.mask`` + else the global median. + + Parameters + ---------- + strain_mask : np.ndarray, optional + ``(scan_row, scan_col)`` ROI selecting the positions used to compute the + median reference lattice. If omitted, ``self.mask`` (else the global + median) is used. + u_ref : np.ndarray, optional + Explicit reference for ``u``; overrides both the construction-time fixed + value and the median. + v_ref : np.ndarray, optional + Explicit reference for ``v``; overrides both the construction-time fixed + value and the median. + plot_strain_roi : bool, default=False + If ``True``, show the recomputed strain via :meth:`plot_strain_roi` + (color-scaled to the ROI) so the chosen reference region can be checked + for flatness. + **plot_kwargs + Forwarded to :meth:`plot_strain_roi` when ``plot_strain_roi=True``. + + Returns + ------- + StrainMap + ``self``, with the reference lattice and strain maps recomputed. + """ + u_med, v_med = _reference_lattice(self.u_array, self.v_array, self.mask, strain_mask) + + if u_ref is not None: + self.u_ref = np.asarray(u_ref, dtype=float) + elif self._u_ref_fixed is not None: + self.u_ref = self._u_ref_fixed + else: + self.u_ref = u_med + + if v_ref is not None: + self.v_ref = np.asarray(v_ref, dtype=float) + elif self._v_ref_fixed is not None: + self.v_ref = self._v_ref_fixed + else: + self.v_ref = v_med + + e_rr, e_cc, e_rc, phi = _strain_tensor( + self.u_array, self.v_array, self.u_ref, self.v_ref, self.real_space + ) + self.e_rr = Dataset2d.from_array(e_rr, name="strain e_rr", signal_units="fractional") + self.e_cc = Dataset2d.from_array(e_cc, name="strain e_cc", signal_units="fractional") + self.e_rc = Dataset2d.from_array(e_rc, name="strain e_rc", signal_units="fractional") + self.phi = Dataset2d.from_array(phi, name="strain rotation", signal_units="radians") + + if plot_strain_roi: + self.plot_strain_roi(strain_mask=strain_mask, **plot_kwargs) + return self + + def rotate_strain( + self, rotation_angle: float = 0.0 + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Tensor-rotate the strain into a frame rotated by ``rotation_angle`` (degrees). + + The rotation field ``phi`` is invariant under frame rotation and is not + transformed. + + Parameters + ---------- + rotation_angle : float, default=0.0 + Frame rotation angle, in degrees. + + Returns + ------- + tuple of np.ndarray + ``(e_uu, e_vv, e_uv)`` strain components in the rotated frame. + """ + return _rotate_strain_tensor( + self.e_rr.array, self.e_cc.array, self.e_rc.array, rotation_angle + ) + + def plot_strain_roi( + self, + strain_mask: np.ndarray | None = None, + plot_rotation: bool = True, + cmap_strain: str = "RdBu_r", + cmap_rotation: str = "PiYG", + layout: str = "horizontal", + figsize: tuple[float, float] | None = None, + **kwargs, + ): + """Plot the strain in the raw row/col reference frame, color-scaled to the ROI. + + The color range is symmetric about zero and set by the largest absolute + strain (and rotation) *inside the reference ROI* — ``strain_mask`` if given, + else ``self.mask`` — so a well-chosen, strain-free reference region reads as + flat (near mid-color) and any residual gradient or tilt stands out. The ROI + itself is drawn in color while everything outside it is shown in greyscale, + so the chosen reference region is obvious at a glance. Unlike + :meth:`plot_strain`, no display rotation is applied: the panels show the raw + ``e_rr``/``e_cc``/``e_rc`` that :meth:`update_reference` just computed. + + Parameters + ---------- + strain_mask : np.ndarray, optional + ROI defining the color range (and the reference region). If omitted, + ``self.mask`` is used. + plot_rotation : bool, default=True + Whether to include the rotation (``phi``) panel. + cmap_strain : str, default="RdBu_r" + Colormap for the strain panels. + cmap_rotation : str, default="PiYG" + Colormap for the rotation panel. + layout : {"horizontal", "vertical"}, default="horizontal" + Panel arrangement. + figsize : tuple of float, optional + Figure size in inches; if omitted it is derived from the layout. + **kwargs + Forwarded to + :func:`~quantem.diffraction.strain_visualization.plot_strain_panels`. + + Returns + ------- + tuple + ``(fig, ax)`` from :func:`plot_strain_panels`. + """ + roi_src = self.mask if strain_mask is None else strain_mask + e_rr, e_cc, e_rc, phi = ( + self.e_rr.array, + self.e_cc.array, + self.e_rc.array, + self.phi.array, + ) + + inside = np.asarray(roi_src) > 0 if roi_src is not None else np.ones(e_rr.shape, bool) + if not inside.any(): + inside = np.ones(e_rr.shape, bool) + + strain_stack = np.stack([e_rr[inside], e_cc[inside], e_rc[inside]]) + smax = float(np.nanmax(np.abs(strain_stack))) * 100.0 + rmax = float(np.rad2deg(np.nanmax(np.abs(phi[inside])))) + smax = smax if smax > 0 else 1e-6 + rmax = rmax if rmax > 0 else 1e-6 + + return plot_strain_panels( + e_rr, + e_cc, + e_rc, + phi, + self.mask, + self.u_ref, + self.v_ref, + self.ds_shape, + ds_sampling=self.ds_sampling, + ds_units=self.ds_units, + strain_range_percent=(-smax, smax), + rotation_range_degrees=(-rmax, rmax), + roi=inside, + plot_rotation=plot_rotation, + cmap_strain=cmap_strain, + cmap_rotation=cmap_rotation, + layout=layout, + figsize=figsize, + panel_titles=( + r"$\epsilon_{rr}$ $\updownarrow$", + r"$\epsilon_{cc}$ $\leftrightarrow$", + r"$\epsilon_{rc}$ $\nwarrow\!\!\!\!\!\!\!\!\!\:\searrow$", + ), + **kwargs, + ) + + def plot_strain( + self, + rotation_angle: float = 20.0, + strain_range_percent: tuple[float, float] = (-3.0, 3.0), + rotation_range_degrees: tuple[float, float] = (-2.0, 2.0), + mask_range: tuple[float, float] = (0.0, 1.0), + plot_rotation: bool = True, + plot_gvecs: bool = False, + plot_scalebar: bool = False, + cmap_strain: str = "RdBu_r", + cmap_rotation: str = "PiYG", + layout: str = "horizontal", + figsize: tuple[float, float] | None = None, + **kwargs, + ): + """Plot the strain (rotated into the display frame) and rotation panels. + + Parameters + ---------- + rotation_angle : float, default=20.0 + Angle (degrees) by which the strain tensor is rotated into the display + frame before plotting. + strain_range_percent : tuple of float, default=(-3.0, 3.0) + Symmetric color range for the strain panels, in percent. + rotation_range_degrees : tuple of float, default=(-2.0, 2.0) + Symmetric color range for the rotation panel, in degrees. + mask_range : tuple of float, default=(0.0, 1.0) + ``(low, high)`` window remapping the mask brightness: positions with + mask ``>= high`` are shown at full color, ``<= low`` are black, and + values between ramp linearly from black to full. The default leaves the + normalized mask unchanged. + plot_rotation : bool, default=True + Whether to include the rotation (``phi``) panel. + plot_gvecs : bool, default=False + Whether to overlay the reference lattice vectors. + plot_scalebar : bool, default=False + Whether to draw a real-space scale bar. + cmap_strain : str, default="RdBu_r" + Colormap for the strain panels. + cmap_rotation : str, default="PiYG" + Colormap for the rotation panel. + layout : {"horizontal", "vertical"}, default="horizontal" + Panel arrangement. + figsize : tuple of float, optional + Figure size in inches; if omitted it is derived from the layout. + **kwargs + Forwarded to + :func:`~quantem.diffraction.strain_visualization.plot_strain_panels`. + + Returns + ------- + tuple + ``(fig, ax)`` from :func:`plot_strain_panels`. + """ + e_uu, e_vv, e_uv = self.rotate_strain(rotation_angle) + return plot_strain_panels( + e_uu, + e_vv, + e_uv, + self.phi.array, + self.mask, + self.u_ref, + self.v_ref, + self.ds_shape, + ds_sampling=self.ds_sampling, + ds_units=self.ds_units, + strain_range_percent=strain_range_percent, + rotation_range_degrees=rotation_range_degrees, + mask_range=mask_range, + plot_rotation=plot_rotation, + plot_gvecs=plot_gvecs, + plot_scalebar=plot_scalebar, + cmap_strain=cmap_strain, + cmap_rotation=cmap_rotation, + layout=layout, + figsize=figsize, + **kwargs, + ) + + def estimate_strain_precision( + self, + mask_range: tuple[float, float] = (0.0, 1.0), + rotation_angle: float = 0.0, + window: int = 5, + mask_threshold: float = 0.5, + min_neighbors: int = 3, + component: str = "combined", + bins: int = 50, + bounds: tuple[float, float] | None = None, + plot: bool = True, + returnfig: bool = False, + ): + """Estimate strain *precision* (random scatter) from local median deviations. + + This measures repeatability, not accuracy. Without a ground truth (e.g. a + simulation) it cannot detect systematic error — only how far each position + scatters from its local neighborhood. For every position the deviation from + the median of its surrounding well-indexed neighbors is + + ``error(r, c) = | strain(r, c) - median( strain over neighbors with + scaled mask > mask_threshold ) |`` + + computed for each tensor component (the center position is excluded from its + own median). The three strain components are reduced to one rotation-invariant + number via the Frobenius norm of the symmetric strain-tensor deviation, + + ``combined = sqrt(d_uu**2 + d_vv**2 + 2*d_uv**2)``, + + (equivalently the root-sum-square of the principal-strain deviations) so a + single strain precision can be quoted and compared between datasets. Rotation + precision is reported separately, not folded into ``combined``. + + Each component's precision is summarized by the mask-weighted **median** of its + per-position deviations — the center of the histogram bulk. The median is used + (not the mean or RMS) because a handful of bad-fit pixels form a heavy tail that + would drag a second moment far to the right of where the distribution actually + sits, leaving the reported number disconnected from the histogram; the median + ignores that tail. A weighted histogram of the chosen component is shown, marked + with its median. + + Parameters + ---------- + mask_range : tuple of float, default=(0.0, 1.0) + ``(low, high)`` window remapping :attr:`mask` to ``[0, 1]`` (same + convention as :meth:`plot_strain`); the remapped mask both selects which + positions are trusted (``> mask_threshold`` -- used as neighbors *and* as + the set the precision is computed over) and weights the histogram and the + median. + rotation_angle : float, default=0.0 + Frame rotation (degrees) applied before measuring per-component precision, + matching :meth:`plot_strain`. ``0`` reports the raw row/col frame + (``e_uu == e_rr`` ...). The combined number is rotation-invariant. + window : int, default=5 + Odd edge length (px) of the neighborhood bounding box; the footprint is + the inscribed disk of radius ``window / 2`` (3 -> 8 neighbors, 5 -> 20, + 7 -> 36). A pure linear strain ramp cancels in the (symmetric) median, so + larger windows mostly just steady the median — at the cost of blurring + *curved* strain and biasing the masked edges. ``5`` roughly halves the + noise-floor over-estimate of ``3`` (~9% -> ~4%) while staying local. + mask_threshold : float, default=0.5 + A position is trusted only if its scaled mask exceeds this value. Trusted + positions are the ones used as local-median neighbors *and* the ones whose + deviations enter the reported median and histogram; sub-threshold positions + are excluded from both (not merely down-weighted), so a poorly-indexed + pixel cannot leak its scatter into the precision. + min_neighbors : int, default=3 + Minimum number of valid neighbors required; positions with fewer get no + precision estimate (``nan``, dropped from the statistics). + component : {"combined","e_uu","e_vv","e_uv","rotation"}, default="combined" + Which error distribution to histogram. + bins : int, default=50 + Number of histogram bins, or a sequence of explicit bin edges. With a + bin *count* and no ``bounds``, the range defaults to ``[0, weighted 99th + percentile]`` of the trusted deviations -- robust to the heavy outlier + tail, which otherwise sets the range to its max and crushes the bulk into + the first bin. Passing explicit edges (or ``bounds``) overrides this. + bounds : tuple of float, optional + ``(low, high)`` histogram range in display units (percent for strain, + degrees for rotation). Fix it to compare datasets on the same axis, or to + see the full tail. Values outside the range are left out of the bars (no + overflow spike); the median is computed from all trusted positions + regardless, and ``out_of_range_fraction`` records how much was off-range. + plot : bool, default=True + If ``True``, draw the weighted precision histogram. + returnfig : bool, default=False + If ``True``, return ``(fig, ax)`` instead of the results dict. + + Returns + ------- + dict or tuple + A results dict with the ``precision`` (mask-weighted median local + deviation) per component and ``combined`` (strain in percent, rotation in + degrees), the normalized ``counts`` and ``edges`` of the histogrammed + ``component`` (and ``counts_raw``, the weighted bin sums), + ``out_of_range_fraction`` (weighted mass outside the histogram range, + excluded from the bars), and the chosen settings; or ``(fig, ax)`` when + ``returnfig=True``. + """ + if window < 3 or window % 2 == 0: + raise ValueError("window must be an odd integer >= 3.") + valid_components = ("combined", "e_uu", "e_vv", "e_uv", "rotation") + if component not in valid_components: + raise ValueError(f"component must be one of {valid_components}.") + + # number of neighbors in the circular footprint (matches _local_masked_median) + p = window // 2 + oy, ox = np.ogrid[-p : p + 1, -p : p + 1] + n_neighbors = int(np.sum((oy ** 2 + ox ** 2) <= (window / 2.0) ** 2) - 1) + + # per-component fields in the (optionally rotated) display frame; phi is + # rotation-invariant and is carried through unchanged + e_uu, e_vv, e_uv = self.rotate_strain(rotation_angle) + fields = {"e_uu": e_uu, "e_vv": e_vv, "e_uv": e_uv, "rotation": self.phi.array} + + # remap the mask exactly as plot_strain does, then use it both to select + # neighbors (> mask_threshold) and to weight the histogram / mean + low, high = float(mask_range[0]), float(mask_range[1]) + m = np.asarray(self.mask, dtype=float) + if high > low: + scaled = np.clip((m - low) / (high - low), 0.0, 1.0) + else: + scaled = (m >= high).astype(float) + valid = scaled > float(mask_threshold) + + # per-component local-median deviation, native units (fractional / radians) + dev = { + name: np.abs(field - _local_masked_median(field, valid, window, min_neighbors)) + for name, field in fields.items() + } + # single rotation-invariant number: Frobenius norm of the symmetric + # strain-tensor deviation (== root-sum-square of the principal-strain + # deviations). Rotation is reported separately, not folded in: in nanobeam + # data it is partly a systematic (tilt/descan) and would mix radians into a + # percent figure. + dev["combined"] = np.sqrt( + dev["e_uu"] ** 2 + dev["e_vv"] ** 2 + 2.0 * dev["e_uv"] ** 2 + ) + + # display-unit scaling: strain -> percent, rotation -> degrees + scale = { + "e_uu": 100.0, + "e_vv": 100.0, + "e_uv": 100.0, + "rotation": float(np.rad2deg(1.0)), + "combined": 100.0, + } + + # Precision = the weighted MEDIAN of each per-position deviation distribution, + # in display units. Restricted to trusted positions (valid == scaled > + # mask_threshold, the SAME set used to pick neighbors) and mask-weighted within + # it -- otherwise sub-threshold junk pixels, already excluded as neighbors, + # would leak in. The median sits at the center of the histogram bulk and is + # immune to the heavy outlier tail that a mean / RMS would chase out to the + # right (a few bad-fit pixels dominate a second moment but not the median). + def _weighted_median(err_native: np.ndarray, factor: float) -> float: + e = err_native * factor + use = np.isfinite(e) & valid + return _weighted_quantile(e[use], scaled[use], 0.5) + + precision = {name: _weighted_median(dev[name], scale[name]) for name in scale} + + # weighted histogram of the chosen component, over the same trusted positions + # as the median above. This is purely a picture of the common error values, so + # anything beyond the bin range is left OUT of the bars -- no overflow spike at + # the edge to crush the bulk. Nothing is lost: the median is computed from all + # trusted positions regardless. Bars are normalized by the total trusted + # weight, so each bar is the true fraction of all trusted positions and the + # off-range mass simply isn't drawn (the bars sum to 1 - out_of_range_fraction). + e = dev[component] * scale[component] + use = np.isfinite(e) & valid # trusted positions only, consistent with median + e_f = e[use] + w_f = scaled[use] + # Default histogram range: a robust weighted upper percentile, NOT the raw + # max. A handful of bad-fit positions can reach tens of percent; used as the + # range they crush the entire bulk into the first bin and leave the rest of + # the axis empty (a spurious "spike at 0" plus a far outlier spike). Capping + # at the weighted 99th percentile keeps the common error values readable; the + # few positions past it spill into out_of_range_fraction (reported, not + # drawn). An explicit `bounds`, or passing bin EDGES as `bins`, overrides it. + if bounds is None and np.ndim(bins) == 0 and e_f.size and float(w_f.sum()) > 0: + hi_default = _weighted_quantile(e_f, w_f, 0.99) + if np.isfinite(hi_default) and hi_default > 0: + bounds = (0.0, hi_default) + edges = np.histogram_bin_edges(e_f, bins=bins, range=bounds) + lo, hi = float(edges[0]), float(edges[-1]) + wtot = float(w_f.sum()) + frac_below = float(w_f[e_f < lo].sum()) / wtot if wtot > 0 else 0.0 + frac_above = float(w_f[e_f > hi].sum()) / wtot if wtot > 0 else 0.0 + out_of_range_fraction = frac_below + frac_above + counts_raw, edges = np.histogram(e_f, bins=edges, weights=w_f) + counts = counts_raw / wtot if wtot > 0 else counts_raw + + unit = "°" if component == "rotation" else "%" + result = { + "precision": precision, + "component": component, + "unit": unit, + "counts": counts, + "counts_raw": counts_raw, + "edges": edges, + "out_of_range_fraction": out_of_range_fraction, + "window": int(window), + "n_neighbors": n_neighbors, + "mask_threshold": float(mask_threshold), + "mask_range": (low, high), + "rotation_angle": float(rotation_angle), + } + + print("Strain precision (median local deviation, mask-weighted)") + print( + f" reference={n_neighbors} neighbors (disk, window={window}) " + f"mask>{mask_threshold:g} min_neighbors={min_neighbors} " + f"rotation_angle={rotation_angle:g} deg" + ) + for name in ("e_uu", "e_vv", "e_uv"): + print(f" {name:<9}: {precision[name]:7.4f} %") + print(f" {'rotation':<9}: {precision['rotation']:7.4f} deg") + print( + f" {'combined':<9}: {precision['combined']:7.4f} % " + "(strain-only Frobenius norm; rotation excluded)" + ) + + if not (plot or returnfig): + return result + + fig, ax = plot_strain_precision_histogram(edges, counts, precision, component, unit) + if returnfig: + return fig, ax + return result + + +# ---- module-level fitting functions ---- + + +def _weighted_quantile(values: np.ndarray, weights: np.ndarray, q: float) -> float: + """Weighted ``q``-quantile of ``values`` (``q`` in ``[0, 1]``); ``nan`` if no weight. + + Uses cumulative-weight interpolation with weights centered on each sorted sample, + so with uniform weights it tracks ``np.quantile``'s linear interpolation and is + robust to a heavy upper tail (the median ignores how far the outliers reach). + """ + values = np.asarray(values, dtype=float) + weights = np.asarray(weights, dtype=float) + total = float(weights.sum()) + if values.size == 0 or total <= 0: + return float("nan") + order = np.argsort(values) + v = values[order] + w = weights[order] + cw = np.cumsum(w) - 0.5 * w + return float(np.interp(q * total, cw, v)) + + +def _local_masked_median( + field: np.ndarray, + valid: np.ndarray, + window: int, + min_neighbors: int, +) -> np.ndarray: + """Median of each position's surrounding neighbors over valid (masked) pixels. + + The center position is excluded ("surrounding" only); a neighbor contributes + only where ``valid`` is True and the field is finite. Neighbors are taken over a + circular (isotropic) footprint of radius ``window / 2`` inscribed in the + ``window`` x ``window`` box — a disk avoids the square's far corners, which + over-weight the diagonals and sample the most strain-different points. Positions + left with fewer than ``min_neighbors`` contributing neighbors return ``nan``. + + Parameters + ---------- + field : np.ndarray + ``(scan_row, scan_col)`` field to take local medians of. + valid : np.ndarray + ``(scan_row, scan_col)`` boolean mask of usable neighbor positions. + window : int + Odd edge length of the bounding box; the footprint is the disk of radius + ``window / 2`` within it (3 -> 8 neighbors, 5 -> 20, 7 -> 36). + min_neighbors : int + Minimum contributing neighbors required, else ``nan``. + + Returns + ------- + np.ndarray + ``(scan_row, scan_col)`` local masked median (``nan`` where undefined). + """ + p = window // 2 + fpad = np.pad(np.asarray(field, dtype=float), p, mode="constant", constant_values=np.nan) + vpad = np.pad(np.asarray(valid, dtype=bool), p, mode="constant", constant_values=False) + + # writable per-position (window, window) neighborhoods + fw = sliding_window_view(fpad, (window, window)).copy() + vw = sliding_window_view(vpad, (window, window)) + fw[~vw] = np.nan + fw[:, :, p, p] = np.nan # exclude the center position from its own median + # restrict the square box to a circular footprint of radius window/2 + oy, ox = np.ogrid[-p : p + 1, -p : p + 1] + outside = (oy ** 2 + ox ** 2) > (window / 2.0) ** 2 + fw[:, :, outside] = np.nan + + flat = fw.reshape(fw.shape[0], fw.shape[1], -1) + count = np.sum(np.isfinite(flat), axis=-1) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + med = np.nanmedian(flat, axis=-1) + med[count < min_neighbors] = np.nan + return med + + +def _reference_lattice( + u_array: np.ndarray, + v_array: np.ndarray, + mask: np.ndarray | None = None, + strain_mask: np.ndarray | None = None, +) -> tuple[np.ndarray, np.ndarray]: + """Weighted-median reference lattice vectors, else the global median. + + The reference is the per-component **weighted median** of the lattice vectors. + Weights come from ``strain_mask`` if given, else the continuous ``mask`` + (the ``[0, 1]`` per-position weight from :meth:`create_mask` / ``fit_lattice``): + strong, well-indexed positions dominate the reference and weak / vacuum / bad-fit + positions are down-weighted. A boolean ROI (weights in ``{0, 1}``) reduces to the + plain median over the selected positions, so an explicit ``strain_mask`` behaves + as before. The weighted median (not ``mask == 1``) is used because a continuous + weight rarely hits *exactly* 1 -- the old exact-equality test collapsed a min-max + normalized mask to its single global-max position and made the reference one + arbitrary pixel. + + Parameters + ---------- + u_array : np.ndarray + Per-position first lattice vector, shape ``(scan_row, scan_col, 2)``. + v_array : np.ndarray + Per-position second lattice vector, shape ``(scan_row, scan_col, 2)``. + mask : np.ndarray, optional + ``(scan_row, scan_col)`` per-position weight in ``[0, 1]``. Used as the median + weights when ``strain_mask`` is not given. + strain_mask : np.ndarray, optional + ``(scan_row, scan_col)`` ROI / weight taking precedence over ``mask``. + + Returns + ------- + tuple of np.ndarray + ``(u_ref, v_ref)``, each a length-2 reference vector. + """ + if strain_mask is not None: + w = np.asarray(strain_mask, dtype=float).reshape(-1) + elif mask is not None: + w = np.asarray(mask, dtype=float).reshape(-1) + else: + w = None + + u_flat = u_array.reshape(-1, 2) + v_flat = v_array.reshape(-1, 2) + + def _wmed(vals: np.ndarray) -> float: + # weighted median over finite, positively-weighted positions; positions + # fit_lattice could not fit are NaN and must be dropped, else the reference + # (and the whole strain map) collapses to NaN. Falls back to the unweighted + # nan-median when no weight is given or none survives. + finite = np.isfinite(vals) + ww = np.ones_like(vals) if w is None else w + use = finite & (ww > 0) + if not use.any(): + return float(np.nanmedian(vals)) if finite.any() else float("nan") + return _weighted_quantile(vals[use], ww[use], 0.5) + + u_ref = np.array((_wmed(u_flat[:, 0]), _wmed(u_flat[:, 1])), dtype=float) + v_ref = np.array((_wmed(v_flat[:, 0]), _wmed(v_flat[:, 1])), dtype=float) + return u_ref, v_ref + + +def _strain_tensor( + u_array: np.ndarray, + v_array: np.ndarray, + u_ref: np.ndarray, + v_ref: np.ndarray, + real_space: bool, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Per-position strain tensor from lattice vectors relative to a reference. + + Two measurement modalities are supported and are arranged to give *identical* + strain for the same physical deformation, so correlation (Bragg) and cepstral + (autocorrelation) maps can be compared directly: + + * ``real_space=False`` -- reciprocal-space lattice vectors (nanobeam Bragg + disks), which contract under tension. The per-position transform is + ``strain_trans = U_ref @ inv(U)``. + * ``real_space=True`` -- real-space lattice vectors (cepstral / Patterson + autocorrelation peaks, or DPC), which expand under tension. The transform is + ``strain_trans = (U @ inv(U_ref)).T``. + + Both expressions evaluate to ``F.T`` (the transpose of the real-space deformation + gradient), so the normal strains, shear, and rotation come out the same + regardless of modality, and the reciprocal-space sign convention (``const = -1``, + which sets the shear/rotation handedness) is shared. + + Parameters + ---------- + u_array : np.ndarray + Per-position first lattice vector, shape ``(scan_row, scan_col, 2)``. + v_array : np.ndarray + Per-position second lattice vector, shape ``(scan_row, scan_col, 2)``. + u_ref : np.ndarray + Reference first lattice vector (length 2). + v_ref : np.ndarray + Reference second lattice vector (length 2). + real_space : bool + ``False`` for reciprocal-space (Bragg/correlation) vectors; ``True`` for + real-space (cepstral autocorrelation / DPC) vectors. Selects the per-position + transform above; both yield matching strain. + + Returns + ------- + tuple of np.ndarray + ``(e_rr, e_cc, e_rc, phi)``, each of shape ``(scan_row, scan_col)``. + """ + scan_r, scan_c = u_array.shape[0], u_array.shape[1] + Uref = np.stack((u_ref, v_ref), axis=1).astype(float) + strain_trans = np.zeros((scan_r, scan_c, 2, 2)) + + # For real-space vectors the reference is inverted once (it is shared by every + # position); a non-finite or singular reference leaves the whole map undefined. + Uref_inv = None + if real_space and np.all(np.isfinite(Uref)) and abs(np.linalg.det(Uref)) >= 1e-12: + Uref_inv = np.linalg.inv(Uref) + + for r in range(scan_r): + for c in range(scan_c): + U = np.stack((u_array[r, c, :], v_array[r, c, :]), axis=1) + # Positions fit_lattice could not fit are NaN; a degenerate (collinear) + # fit is singular. Either way there is no meaningful inverse -- leave the + # strain NaN (masked out downstream) rather than feeding NaN into pinv, + # whose SVD does not converge and raises LinAlgError. + if not np.all(np.isfinite(U)) or abs(np.linalg.det(U)) < 1e-12: + strain_trans[r, c, :, :] = np.nan + continue + if real_space: + # real-space vectors expand under tension: (U @ U_ref^-1).T == F.T + if Uref_inv is None: + strain_trans[r, c, :, :] = np.nan + else: + strain_trans[r, c, :, :] = (U @ Uref_inv).T + else: + # reciprocal-space vectors contract under tension: U_ref @ U^-1 == F.T + strain_trans[r, c, :, :] = Uref @ np.linalg.inv(U) + + # const = -1 is the reciprocal-space (nanobeam) shear/rotation convention. Both + # modalities reduce strain_trans to F.T above, so the convention is shared. + const = -1 + e_rr = strain_trans[:, :, 0, 0] - 1 + e_cc = strain_trans[:, :, 1, 1] - 1 + e_rc = strain_trans[:, :, 1, 0] * 0.5 * const + strain_trans[:, :, 0, 1] * 0.5 * const + phi = strain_trans[:, :, 1, 0] * -0.5 * const + strain_trans[:, :, 0, 1] * 0.5 * const + return e_rr, e_cc, e_rc, phi + + +def _rotate_strain_tensor( + e_rr: np.ndarray, + e_cc: np.ndarray, + e_rc: np.ndarray, + rotation_angle: float, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Rotate a 2D strain tensor by ``rotation_angle`` (degrees). + + Parameters + ---------- + e_rr : np.ndarray + Row-row (normal) strain component. + e_cc : np.ndarray + Column-column (normal) strain component. + e_rc : np.ndarray + Row-column (shear) strain component. + rotation_angle : float + Frame rotation angle, in degrees. + + Returns + ------- + tuple of np.ndarray + ``(e_uu, e_vv, e_uv)`` in the rotated frame. + """ + angle = np.deg2rad(rotation_angle) + c = np.cos(angle) + s = np.sin(angle) + e_uu = e_rr * (c * c) + 2.0 * e_rc * (c * s) + e_cc * (s * s) + e_vv = e_rr * (s * s) - 2.0 * e_rc * (c * s) + e_cc * (c * c) + e_uv = (e_cc - e_rr) * (c * s) + e_rc * (c * c - s * s) + return e_uu, e_vv, e_uv diff --git a/src/quantem/diffraction/strain_autocorrelation.py b/src/quantem/diffraction/strain_autocorrelation.py new file mode 100644 index 00000000..c9d4282a --- /dev/null +++ b/src/quantem/diffraction/strain_autocorrelation.py @@ -0,0 +1,2107 @@ +from __future__ import annotations + +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray +from scipy.ndimage import distance_transform_edt + +from quantem.core.datastructures.dataset2d import Dataset2d +from quantem.core.datastructures.dataset3d import Dataset3d +from quantem.core.datastructures.dataset4d import Dataset4d +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.io.serialize import AutoSerialize +from quantem.core.utils.imaging_utils import dft_upsample, rotate_image +from quantem.core.utils.utils import electron_wavelength_angstrom +from quantem.core.utils.validators import ensure_valid_array +from quantem.core.visualization import ScalebarConfig, show_2d +from quantem.diffraction.strain import StrainMap + + +class StrainMapAutocorrelation(AutoSerialize): + """Cepstral / autocorrelation lattice fitting and strain mapping for 4D-STEM. + + An alternative to correlation-based :class:`~quantem.diffraction.bragg_vectors.BraggVectors` + that needs no disk template: every diffraction pattern is transformed into the + autocorrelation (real-space) domain, where the crystal periodicity shows up as a + lattice of sharp peaks, and those peaks are fit per scan position to recover the + local lattice vectors. Because the transform lives in real space, the peaks track + real-space lattice spacings (they *expand* under tension, opposite to Bragg disks), + so ``real_space=True`` -- :func:`~quantem.diffraction.strain._strain_tensor` then + reduces both this and the reciprocal-space Bragg path to the same deformation + gradient, and the two strain maps can be compared directly. + + Workflow (each step writes state consumed by the next): + + 1. :meth:`diffraction_mask` -- build a soft mask over the detector that suppresses + the bright central beam / vacuum so the transform is dominated by the lattice. + 2. :meth:`preprocess` -- transform every pattern and average the magnitudes into a + mean transform image. Three ``mode`` choices set the intensity scaling applied + before the FFT: ``"linear"`` (Patterson / autocorrelation), ``"log"`` + (cepstrum), or ``"gamma"`` (power law). The detector->scan rotation + (``q_to_r_rotation_ccw_deg`` + ``q_transpose``) is read from the parent dataset + metadata, the single source of truth shared with the DPC/CoM and Bragg + workflows. + 3. :meth:`choose_lattice_vector` -- refine a hand-picked initial ``(u, v)`` basis + against the mean transform, optionally auto-detecting and re-fitting all peaks. + 4. :meth:`fit_lattice_vectors` -- the heavy step: transform and fit the lattice + vectors at every scan position into ``u_array``/``v_array`` of shape + ``(scan_row, scan_col, 2)``. + 5. :meth:`create_mask` -- compute the per-position weight :attr:`mask_weight` + (lattice signal strength) used to weight the reference lattice. + 6. :meth:`calculate_strain_map` -- hand the lattice vectors (and + :attr:`mask_weight`) to a :class:`~quantem.diffraction.strain.StrainMap`. + + Use :meth:`from_dataset` or :meth:`from_array` to construct an instance. + + Parameters + ---------- + dataset : Dataset4dstem + The 4D-STEM dataset to analyze. + input_data : Any, optional + The original object passed to the constructor (a dataset or array), retained + for provenance. + """ + + _token = object() + + # Cepstral / Patterson lattice vectors are measured in the autocorrelation + # (real-space) domain -- the peaks track real-space lattice spacings and expand + # under tension -- so real_space=True. _strain_tensor reduces both this and the + # reciprocal-space Bragg path to F.T, so the two strain maps agree. + real_space: bool = True + + def __init__( + self, + dataset: Dataset4dstem, + input_data: Any | None = None, + _token: object | None = None, + ): + """Private constructor; use :meth:`from_dataset` or :meth:`from_array`. + + Direct instantiation is blocked by the ``_token`` guard. The factory + classmethods are the supported entry points and document their arguments. + + Parameters + ---------- + dataset : Dataset4dstem + The 4D-STEM dataset to analyze. + input_data : Any, optional + The original object passed to a factory (kept for provenance). + _token : object, optional + Internal sentinel; must match the class token or a ``RuntimeError`` is + raised. + """ + if _token is not self._token: + raise RuntimeError( + "Use StrainMapAutocorrelation.from_dataset() or StrainMapAutocorrelation.from_array() to instantiate this class." + ) + # Explicit (two-arg) super() rather than the bare super(): the zero-arg form + # needs a compiler-created __class__ closure cell that is absent when this + # method's source is re-exec'd from a string (Jupyter autoreload), which would + # raise "super(): __class__ cell not found". + super(StrainMapAutocorrelation, self).__init__() + self.dataset = dataset + self.input_data = input_data + self.strain = None + self.metadata: dict[str, Any] = {} + self.transform: Dataset2d | None = None + self.transform_rotated: Dataset2d | None = None + + self.mean_img_peaks: NDArray | None = None + self.mean_img_weights: NDArray | None = None + + self.mask_diffraction = np.ones(self.dataset.array.shape[2:]) + self.mask_diffraction_inv = np.zeros(self.dataset.array.shape[2:]) + + # initial basis from choose_lattice_vector(); per-position fits from + # fit_lattice_vectors(); per-position weight from create_mask(). + self.u: np.ndarray | None = None + self.v: np.ndarray | None = None + self.u_peak_fit: Dataset3d | None = None + self.v_peak_fit: Dataset3d | None = None + self.u_ref: np.ndarray | None = None + self.v_ref: np.ndarray | None = None + self.u_array: np.ndarray | None = None + self.v_array: np.ndarray | None = None + self.mask_weight: np.ndarray | None = None + + @classmethod + def from_dataset(cls, dataset: Dataset4dstem, *, name: str | None = None) -> "StrainMapAutocorrelation": + """Create a cepstral strain workflow bound to a 4D-STEM dataset. + + Parameters + ---------- + dataset : Dataset4dstem + The 4D-STEM dataset to analyze. + name : str, optional + If given, sets ``dataset.name``. + + Returns + ------- + StrainMapAutocorrelation + A new workflow instance bound to ``dataset``. + """ + if not isinstance(dataset, Dataset4dstem): + raise TypeError("StrainMapAutocorrelation.from_dataset expects a Dataset4dstem instance.") + if name is not None: + dataset.name = name + return cls(dataset=dataset, input_data=dataset, _token=cls._token) + + @classmethod + def from_array(cls, array: NDArray, *, name: str = "strain_map_autocorrelation") -> "StrainMapAutocorrelation": + """Create a cepstral strain workflow from a raw 4D array. + + Parameters + ---------- + array : np.ndarray + 4D-STEM data with shape ``(scan_row, scan_col, dp_row, dp_col)``. + name : str, default="strain_map_autocorrelation" + Name for the wrapped :class:`Dataset4dstem`. + + Returns + ------- + StrainMapAutocorrelation + A new workflow instance wrapping the data in a :class:`Dataset4dstem`. + """ + arr = ensure_valid_array(array) + if arr.ndim != 4: + raise ValueError( + "StrainMapAutocorrelation.from_array expects a 4D array with shape (scan_r, scan_c, dp_r, dp_c)." + ) + ds4 = Dataset4dstem.from_array(arr, name=name) + return cls(dataset=ds4, input_data=array, _token=cls._token) + + def diffraction_mask( + self, + threshold=None, + threshold_percentile=50.0, + edge_blend=64.0, + plot_mask=True, + figsize=(8, 4), + ): + """Build a soft detector mask suppressing the central beam and vacuum. + + Pixels of the mean diffraction pattern below ``threshold`` (plus the detector + border) are treated as "outside" the useful signal. The kept region is feathered + with a raised-cosine taper of width ``edge_blend`` (via a Euclidean distance + transform) into :attr:`mask_diffraction` (multiplicative, in ``[0, 1]``), and a + complementary fill :attr:`mask_diffraction_inv` replaces the masked region with a + flat edge intensity. Both are applied to every pattern in :meth:`preprocess` and + :meth:`fit_lattice_vectors` so the transform is dominated by the crystalline + signal rather than the bright unscattered beam. + + Parameters + ---------- + threshold : float, optional + Absolute mean-intensity level below which detector pixels are masked out. If + ``None`` (default), the level is taken from ``threshold_percentile`` of the + mean diffraction pattern, so no manual intensity value is needed. + threshold_percentile : float, default=50.0 + Percentile (``0``-``100``) of the mean diffraction pattern used to set the + masking level when ``threshold`` is ``None``. The default keeps the brighter + half of the detector (disks and central beam) and masks the dimmer vacuum; + raise it to mask more aggressively. Ignored when ``threshold`` is given. + edge_blend : float, default=64.0 + Feather width in pixels of the raised-cosine taper between kept and masked + regions; larger values give a softer transition. + plot_mask : bool, default=True + If ``True``, show the raw mean pattern beside the masked/filled pattern (log + scaled) so the mask can be checked. + figsize : tuple of float, default=(8, 4) + Figure size in inches for the diagnostic plot. + + Returns + ------- + StrainMapAutocorrelation + ``self``, with :attr:`mask_diffraction` and :attr:`mask_diffraction_inv` set. + """ + dp_mean = np.mean(self.dataset.array, axis=(0, 1)) + if threshold is None: + threshold = np.percentile(dp_mean, threshold_percentile) + mask_init = dp_mean < threshold + mask_init[:, 0] = True + mask_init[0, :] = True + mask_init[:, -1] = True + mask_init[-1, :] = True + + self.mask_diffraction = np.sin( + np.clip( + distance_transform_edt(np.logical_not(mask_init)) / edge_blend, + 0.0, + 1.0, + ) + * np.pi + / 2, + ) ** 2 + int_edge = np.min(dp_mean[self.mask_diffraction > 0.99]) + self.mask_diffraction_inv = (1 - self.mask_diffraction) * int_edge + + if plot_mask: + fig, ax = plt.subplots(1, 2, figsize=figsize) + ax[0].imshow( + np.log(np.maximum(dp_mean, np.min(dp_mean[dp_mean > 0]))), + cmap="gray", + ) + ax[1].imshow( + np.log( + dp_mean * self.mask_diffraction + self.mask_diffraction_inv, + ), + cmap="gray", + ) + + return self + + def preprocess( + self, + mode: str = "linear", + q_to_r_rotation_ccw_deg: float | None = None, + q_transpose: bool | None = None, + skip=None, + plot_transform: bool = True, + cropping_factor: float = 0.25, + gamma: float = 0.5, + **plot_kwargs: Any, + ) -> "StrainMapAutocorrelation": + """Transform every pattern into the autocorrelation domain and average it. + + For each diffraction pattern the masked/filled pattern (from + :meth:`diffraction_mask`) is intensity-scaled per ``mode``, Fourier transformed, + and its magnitude accumulated; the mean over all scan positions is stored + (fft-shifted, origin centered) as :attr:`transform`. A display copy rotated by + the detector->scan rotation is stored as :attr:`transform_rotated`. The crystal + periodicity appears as a lattice of peaks about the center, which later steps fit. + + The detector->scan rotation (``q_to_r_rotation_ccw_deg``) and transpose + (``q_transpose``) default to the parent dataset metadata -- the same source used + by the DPC/CoM and :class:`BraggVectors` workflows -- so the strain frame is + consistent across methods; pass them explicitly to override. + + Parameters + ---------- + mode : {"linear", "log", "gamma"}, default="linear" + Intensity scaling applied before the FFT. ``"linear"`` is the Patterson / + autocorrelation (aliases ``"patterson"``, ``"acf"``, ``"autocorrelation"``); + ``"log"`` is the cepstrum, ``log1p(I)`` (aliases ``"cepstrum"``, + ``"cepstral"``); ``"gamma"`` raises intensity to the power ``gamma`` + (aliases ``"power"``, ``"sqrt"``). + q_to_r_rotation_ccw_deg : float, optional + Counter-clockwise detector->scan rotation in degrees for the rotated display + transform. Defaults to ``dataset.metadata["q_to_r_rotation_ccw_deg"]`` if + present, else ``0`` (with a warning). + q_transpose : bool, optional + Whether to transpose the detector axes before rotating the display transform. + Defaults to ``dataset.metadata["q_transpose"]`` if present, else ``False``. + skip : int, optional + If given, subsample the scan by this stride (``array[::skip, ::skip]``) when + building the mean transform -- a fast preview over fewer patterns. + plot_transform : bool, default=True + If ``True``, show the original and rotated mean transforms via + :meth:`plot_transform`. + cropping_factor : float, default=0.25 + Fraction of the transform width/height shown when ``plot_transform`` is + ``True`` (the lattice peaks sit near the center). + gamma : float, default=0.5 + Exponent for ``mode="gamma"`` (ignored otherwise). + **plot_kwargs + Forwarded to :meth:`plot_transform`. + + Returns + ------- + StrainMapAutocorrelation + ``self``, with :attr:`transform` and :attr:`transform_rotated` set. + """ + mode_in = mode.strip().lower() + if mode_in in {"linear", "patterson", "paterson", "acf", "autocorrelation"}: + mode_norm = "linear" + elif mode_in in {"log", "cepstrum", "cepstral"}: + mode_norm = "log" + elif mode_in in {"gamma", "power", "sqrt"}: + mode_norm = "gamma" + else: + raise ValueError( + "mode must be 'linear', 'log', or 'gamma' (aliases: 'patterson'->'linear', 'cepstrum'/'cepstral'->'log')." + ) + + self.metadata["mode"] = mode_norm + if mode_norm == "gamma": + self.metadata["gamma"] = gamma + + qrow_unit = self.dataset.units[2] + qcol_unit = self.dataset.units[3] + + if qrow_unit in {"A", "Å"}: + qrow_sampling_ang = self.dataset.sampling[2] + elif qrow_unit == "mrad": + wavelength = electron_wavelength_angstrom(self.dataset.metadata["energy"]) + qrow_sampling_ang = self.dataset.sampling[2] / 1000.0 / wavelength + else: + qrow_sampling_ang = 1.0 + qrow_unit = "pixels" + + if qcol_unit in {"A", "Å"}: + qcol_sampling_ang = self.dataset.sampling[3] + elif qcol_unit == "mrad": + wavelength = electron_wavelength_angstrom(self.dataset.metadata["energy"]) + qcol_sampling_ang = self.dataset.sampling[3] / 1000.0 / wavelength + else: + qcol_sampling_ang = 1.0 + qcol_unit = "pixels" + + self.metadata["sampling_real"] = np.array( + ( + 1.0 / (qrow_sampling_ang * self.dataset.shape[2]), + 1.0 / (qcol_sampling_ang * self.dataset.shape[3]), + ), + dtype=float, + ) + + if qrow_unit == "pixels" and qcol_unit == "pixels": + self.metadata["real_units"] = "1/pixels" + else: + self.metadata["real_units"] = r"$\mathrm{\AA}$" + + parent_rot = self.dataset.metadata.get("q_to_r_rotation_ccw_deg", None) + parent_tr = self.dataset.metadata.get("q_transpose", None) + + used_parent = False + if q_to_r_rotation_ccw_deg is None and parent_rot is not None: + q_to_r_rotation_ccw_deg = parent_rot + used_parent = True + if q_transpose is None and parent_tr is not None: + q_transpose = parent_tr + used_parent = True + + if used_parent: + import warnings + + warnings.warn( + "StrainMapAutocorrelation.preprocess: using parent Dataset4dstem metadata " + f"(q_to_r_rotation_ccw_deg={q_to_r_rotation_ccw_deg or 0.0}, " + f"q_transpose={q_transpose or False}).", + UserWarning, + ) + + if q_to_r_rotation_ccw_deg is None or q_transpose is None: + import warnings + + q_to_r_rotation_ccw_deg = 0.0 if q_to_r_rotation_ccw_deg is None else q_to_r_rotation_ccw_deg + q_transpose = False if q_transpose is None else q_transpose + warnings.warn( + "StrainMapAutocorrelation.preprocess: setting q_to_r_rotation_ccw_deg=0.0 and q_transpose=False.", + UserWarning, + ) + + self.metadata["q_to_r_rotation_ccw_deg"] = q_to_r_rotation_ccw_deg + self.metadata["q_transpose"] = q_transpose + + arr = self.dataset.array if skip is None else self.dataset.array[::skip, ::skip] + dp = arr * self.mask_diffraction[None, None, :, :] + self.mask_diffraction_inv[None, None, :, :] + + if mode_norm == "linear": + dp_proc = dp + elif mode_norm == "log": + dp_proc = np.log1p(dp) + elif mode_norm == "gamma": + dp_proc = np.power(np.clip(dp, 0.0, None), self.metadata["gamma"]) + else: + raise RuntimeError("Unreachable: normalized mode mapping failed.") + + im = np.mean(np.abs(np.fft.fft2(dp_proc)), axis=(0, 1)) + im = np.fft.fftshift(im) + + self.transform = Dataset2d.from_array( + im, + origin=(im.shape[0] // 2, im.shape[1] // 2), + sampling=(1.0, 1.0), + units=(qrow_unit, qcol_unit), + signal_units="intensity", + ) + + im_plot = self.transform.array + if self.metadata["q_transpose"]: + im_plot = im_plot.T + + self.transform_rotated = Dataset2d.from_array( + rotate_image( + im_plot, + self.metadata["q_to_r_rotation_ccw_deg"], + clockwise=False, + ), + origin=(im.shape[0] // 2, im.shape[1] // 2), + sampling=(1.0, 1.0), + units=(self.metadata["real_units"], self.metadata["real_units"]), + signal_units="intensity", + ) + + if plot_transform: + self.plot_transform(cropping_factor=cropping_factor, **plot_kwargs) + + return self + + def plot_transform( + self, + cropping_factor: float = 0.25, + scalebar_fraction: float = 0.25, + **plot_kwargs: Any, + ): + """Show the original and rotated mean transform images side by side. + + The color range is set from the brightest lattice peak (the global max of the + radially weighted transform) so the central DC peak does not wash out the panel, + and both panels are cropped to the central ``cropping_factor`` window where the + lattice peaks lie. A scale bar is drawn in real-space units. + + Parameters + ---------- + cropping_factor : float, default=0.25 + Fraction of the full transform width/height shown about the center. + scalebar_fraction : float, default=0.25 + Target scale-bar length as a fraction of the cropped view width (snapped to + a "nice" round value). + **plot_kwargs + Forwarded to :func:`~quantem.core.visualization.show_2d` (overriding the + defaults computed here). + + Returns + ------- + tuple + ``(fig, ax)`` from :func:`~quantem.core.visualization.show_2d`. + """ + if self.transform is None or self.transform_rotated is None: + raise ValueError("Run preprocess() first to compute transform images.") + + sampling = np.mean(self.metadata["sampling_real"]) + units = self.metadata.get("real_units", r"$\mathrm{\AA}$") + + W = self.transform.shape[1] + view_w_px = W * cropping_factor + target_units = scalebar_fraction * view_w_px * sampling + sb_len = _nice_length_units(target_units) + + kr = (np.arange(self.transform.shape[0], dtype=float) - self.transform.shape[0] // 2)[:, None] + kc = (np.arange(self.transform.shape[1], dtype=float) - self.transform.shape[1] // 2)[None, :] + qmag = np.sqrt(kr * kr + kc * kc) + im0 = self.transform.array + tmp = im0 * qmag + i0 = np.unravel_index(np.nanargmax(tmp), tmp.shape) + vmin = 0.0 + vmax = im0[i0] + + defaults = dict( + vmin=vmin, + vmax=vmax, + title=("Original Transform", "Rotated Transform"), + scalebar=ScalebarConfig( + sampling=sampling, + units=units, + length=sb_len if sb_len > 0 else None, + ), + ) + defaults.update(plot_kwargs) + + fig, ax = show_2d([self.transform, self.transform_rotated], **defaults) + + for a in _flatten_axes(ax): + _apply_center_crop_limits(a, self.transform.shape, cropping_factor) + + return fig, ax + + def plot_single_transform( + self, + row: int = 0, + col: int = 0, + cropping_factor: float = 0.25, + scalebar_fraction: float = 0.25, + **plot_kwargs: Any, + ): + """Show the transform of a single scan position with its fitted lattice vectors. + + Recomputes the transform for the pattern at ``(row, col)`` using the current + ``mode``, and overlays the per-position fitted ``u``/``v`` vectors (from + :meth:`fit_lattice_vectors`) plus any detected peaks. Useful for inspecting the + fit quality at a specific position. + + Parameters + ---------- + row : int, default=0 + Scan row index of the pattern to transform. + col : int, default=0 + Scan column index of the pattern to transform. + cropping_factor : float, default=0.25 + Fraction of the full transform width/height shown about the center. + scalebar_fraction : float, default=0.25 + Target scale-bar length as a fraction of the cropped view width. + **plot_kwargs + Forwarded to :func:`~quantem.core.visualization.show_2d`. + + Returns + ------- + None + Draws the figure; nothing is returned. + """ + if self.transform is None or self.transform_rotated is None: + raise ValueError("Run preprocess() first to compute transform images.") + if self.u_peak_fit is None or self.v_peak_fit is None: + raise ValueError("Run fit_lattice_vectors() first.") + + sampling = np.mean(self.metadata["sampling_real"]) + units = self.metadata.get("real_units", r"$\mathrm{\AA}$") + + W = self.transform.shape[1] + view_w_px = W * cropping_factor + target_units = scalebar_fraction * view_w_px * sampling + sb_len = _nice_length_units(target_units) + + kr = (np.arange(self.transform.shape[0], dtype=float) - self.transform.shape[0] // 2)[:, None] + kc = (np.arange(self.transform.shape[1], dtype=float) - self.transform.shape[1] // 2)[None, :] + qmag = np.sqrt(kr * kr + kc * kc) + im0 = self.transform.array + tmp = im0 * qmag + i0 = np.unravel_index(np.nanargmax(tmp), tmp.shape) + vmin = 0.0 + vmax = im0[i0] + + defaults = dict( + vmin=vmin, + vmax=vmax, + title=(f"Row={row} Col={col} Transform"), + scalebar=ScalebarConfig( + sampling=sampling, + units=units, + length=sb_len if sb_len > 0 else None, + ), + ) + defaults.update(plot_kwargs) + if row >= 0 and row < self.dataset.array.shape[0] and col >= 0 and col < self.dataset.array.shape[1]: + dp = self.dataset.array[row, col] * self.mask_diffraction + self.mask_diffraction_inv + else: + raise ValueError("row or column value out of bounds") + mode = self.metadata.get("mode", "linear").lower() + if mode == "gamma": + g = self.metadata["gamma"] + + + if mode == "linear": + im = np.fft.fftshift(np.abs(np.fft.fft2(dp))) + elif mode == "log": + im = np.fft.fftshift(np.abs(np.fft.fft2(np.log1p(dp)))) + elif mode == "gamma": + im = np.fft.fftshift(np.abs(np.fft.fft2(np.power(np.clip(dp, 0.0, None), g)))) + else: + raise ValueError("metadata['mode'] must be 'linear', 'log', or 'gamma'") + + + fig, ax = show_2d(im, **defaults) + rot_ccw = self.metadata["q_to_r_rotation_ccw_deg"] + q_transpose = self.metadata["q_transpose"] + + _overlay_lattice_vectors( + ax=ax, + shape=self.transform.shape, + u_rc= self.u_peak_fit.array[row, col, :2], + v_rc=self.v_peak_fit.array[row, col, :2], + rot_ccw_deg=rot_ccw, + q_transpose=q_transpose, + peaks_plot=self.mean_img_peaks, + ) + + for a in _flatten_axes(ax): + _apply_center_crop_limits(a, self.transform.shape, cropping_factor) + + def choose_lattice_vector( + self, + *, + u: tuple[float, float] | NDArray, + v: tuple[float, float] | NDArray, + define_in_rotated: bool = False, + refine_gaussian: bool = True, + refine_dft: bool = False, + refine_all_peaks: bool = False, + refine_radius_px: float = 2.0, + upsample: int = 16, + gaussian_maxfev: int = 100, + threshold_percentile: float = 0.9975, + min_peak_spacing: float = 0, + plot: bool = True, + cropping_factor: float = 0.25, + **plot_kwargs: Any, + ) -> "StrainMapAutocorrelation": + """Refine a hand-picked initial lattice basis against the mean transform. + + Takes an approximate basis ``(u, v)`` -- read off the transform plot by eye -- + and refines each vector to the nearest transform peak, storing the result in + :attr:`u` and :attr:`v`. These seed the per-position fit in + :meth:`fit_lattice_vectors`. Vectors are ``(row, col)`` offsets from the + transform center, in transform pixels. + + Parameters + ---------- + u : tuple of float or np.ndarray + Initial first lattice vector ``(d_row, d_col)`` relative to the center. + v : tuple of float or np.ndarray + Initial second lattice vector ``(d_row, d_col)`` relative to the center. + define_in_rotated : bool, default=False + If ``True``, ``u``/``v`` are given in the rotated (display) frame and are + converted back to the raw detector frame before fitting. + refine_gaussian : bool, default=True + If ``True``, refine each peak by a 2D isotropic Gaussian fit; otherwise use + the parabolic sub-pixel estimate only. + refine_dft : bool, default=False + If ``True``, additionally refine by DFT upsampling (uses ``upsample``). + refine_all_peaks : bool, default=False + If ``True``, auto-detect all peaks above ``threshold_percentile`` and fit + the basis to the full set by weighted least squares (storing the detected + peaks/weights for reuse), rather than refining only ``u`` and ``v``. + refine_radius_px : float, default=2.0 + Half-width in pixels of the window used for sub-pixel/Gaussian refinement. + upsample : int, default=16 + DFT upsampling factor used when ``refine_dft=True``. + gaussian_maxfev : int, default=100 + Maximum function evaluations for the Gaussian fit. + threshold_percentile : float, default=0.9975 + Intensity percentile (0--1) above which local maxima are kept as peaks when + ``refine_all_peaks=True``. + min_peak_spacing : float, default=0 + Minimum spacing in pixels between accepted peaks when + ``refine_all_peaks=True`` (0 disables the spacing filter). + plot : bool, default=True + If ``True``, show the transform with the refined ``u``/``v`` overlaid. + cropping_factor : float, default=0.25 + Fraction of the transform shown about the center when ``plot=True``. + **plot_kwargs + Forwarded to :meth:`plot_transform`. + + Returns + ------- + StrainMapAutocorrelation + ``self``, with :attr:`u` and :attr:`v` set (and the detected peaks/weights + when ``refine_all_peaks=True``). + """ + if self.transform is None or self.transform_rotated is None: + raise ValueError("Run preprocess() first to compute transform images.") + + u_rc = np.asarray(u, dtype=float).reshape(2) + v_rc = np.asarray(v, dtype=float).reshape(2) + + rot_ccw = self.metadata["q_to_r_rotation_ccw_deg"] + q_transpose = self.metadata["q_transpose"] + + if define_in_rotated: + u_rc = _display_vec_to_raw(u_rc, rotation_ccw_deg=rot_ccw, transpose=q_transpose) + v_rc = _display_vec_to_raw(v_rc, rotation_ccw_deg=rot_ccw, transpose=q_transpose) + + u_fit_abs, v_fit_abs, peaks, weights = _refine_lattice_vectors( + self.transform.array, + u_rc=u_rc, + v_rc=v_rc, + radius_px=refine_radius_px, + refine_gaussian=refine_gaussian, + refine_dft=refine_dft, + refine_all_peaks=refine_all_peaks, + peaks=None, + weights=None, + upsample=upsample, + maxfev=gaussian_maxfev, + threshold_percentile=threshold_percentile, + min_peak_spacing = min_peak_spacing, + ) + + self.u = u_fit_abs[:2] + self.v = v_fit_abs[:2] + if refine_all_peaks: + self.mean_img_peaks = peaks + self.mean_img_weights = weights + + self.metadata["choose_define_in_rotated"] = define_in_rotated + self.metadata["choose_refine_gaussian"] = refine_gaussian + self.metadata["choose_refine_dft"] = refine_dft + self.metadata["choose_refine_all_peaks"] = refine_all_peaks + self.metadata["choose_refine_radius_px"] = refine_radius_px + self.metadata["choose_upsample"] = upsample + self.metadata["choose_gaussian_maxfev"] = gaussian_maxfev + self.metadata["choose_threshold_percentile"] = threshold_percentile + + if plot: + fig, ax = self.plot_transform(cropping_factor=cropping_factor, **plot_kwargs) + _overlay_lattice_vectors( + ax=ax, + shape=self.transform.shape, + u_rc=self.u, + v_rc=self.v, + rot_ccw_deg=rot_ccw, + q_transpose=q_transpose, + peaks_plot=self.mean_img_peaks, + ) + return self + + return self + + def fit_lattice_vectors( + self, + refine_gaussian: bool = True, + refine_dft: bool = False, + refine_all_peaks: bool = False, + refine_radius_px: float = 2.0, + upsample: int = 16, + gaussian_maxfev: int = 100, + progressbar: bool = True, + device: str = "cpu", + ) -> "StrainMapAutocorrelation": + """Fit the lattice vectors at every scan position (the heavy step). + + For each scan position the masked/filled pattern is transformed (same ``mode`` + as :meth:`preprocess`) and the lattice basis is refined from the + :meth:`choose_lattice_vector` seed ``(self.u, self.v)``. The fitted vectors are + written to :attr:`u_array`/:attr:`v_array` (shape ``(scan_row, scan_col, 2)``, + row/col components) for :meth:`calculate_strain_map`; the full fit records + (position, amplitude, width, background) are kept in :attr:`u_peak_fit`/ + :attr:`v_peak_fit` (shape ``(scan_row, scan_col, 5)``). + + Whenever ``refine_dft=False`` (including ``refine_all_peaks=True``) the transforms + and the isotropic-Gaussian peak fits are batched across scan positions on + ``device`` (the analogue of the correlation pipeline's + :func:`~quantem.diffraction.disk_detection.detect_disks_batch`): the Gaussian fit + is a vectorized Levenberg-Marquardt solve rather than a per-position + ``scipy.optimize.curve_fit``. For ``refine_all_peaks=True`` every detected peak is + batch-refined across the stack and the basis is solved per position by weighted + least squares -- avoiding the ``n_positions x n_peaks`` ``curve_fit`` calls of the + old loop. This removes the dominant per-call overhead and runs far faster (and + faster still on a GPU). The batched Levenberg-Marquardt fit reproduces the + per-position ``scipy.optimize.curve_fit`` to ~1e-6 px on clean, well-isolated + peaks; on noisy, non-Gaussian cepstral peaks the bounded-trf and LM optima can + land a few hundredths of a pixel apart (amplitude-preserving and well below + strain-relevant precision). Only ``refine_dft=True`` falls back to the + per-position path. + + Parameters + ---------- + refine_gaussian : bool, default=True + If ``True``, refine each peak with a 2D isotropic Gaussian fit; otherwise + use the parabolic sub-pixel estimate only. + refine_dft : bool, default=False + If ``True``, additionally refine by DFT upsampling (uses ``upsample``). + refine_all_peaks : bool, default=False + If ``True``, fit the basis to all peaks detected in + :meth:`choose_lattice_vector` (which must have been called with + ``refine_all_peaks=True``) rather than just ``u`` and ``v``. + refine_radius_px : float, default=2.0 + Half-width in pixels of the window used for sub-pixel/Gaussian refinement. + upsample : int, default=16 + DFT upsampling factor used when ``refine_dft=True``. + gaussian_maxfev : int, default=100 + Maximum function evaluations for each Gaussian fit. + progressbar : bool, default=True + If ``True``, show a tqdm progress bar over the scan positions. + + Returns + ------- + StrainMapAutocorrelation + ``self``, with :attr:`u_array`, :attr:`v_array`, :attr:`u_peak_fit`, and + :attr:`v_peak_fit` set. + """ + if self.u is None or self.v is None: + raise ValueError("Run choose_lattice_vector() first to set initial lattice vectors (self.u, self.v).") + if refine_all_peaks: + if self.mean_img_peaks is None or self.mean_img_weights is None: + raise ValueError("Run choose_lattice_vector() with refine_all_peaks=True to determine which peaks to fit") + + scan_r = self.dataset.shape[0] + scan_c = self.dataset.shape[1] + + self.u_peak_fit = Dataset3d.from_shape( + (scan_r, scan_c, 5), + name="u_peak_fit", + signal_units="mixed", + ) + self.v_peak_fit = Dataset3d.from_shape( + (scan_r, scan_c, 5), + name="v_peak_fit", + signal_units="mixed", + ) + + self.u_array = np.zeros((scan_r, scan_c, 2)) + self.v_array = np.zeros((scan_r, scan_c, 2)) + + u0 = np.asarray(self.u, dtype=float).reshape(2) + v0 = np.asarray(self.v, dtype=float).reshape(2) + + if not refine_dft: + # Fast path: batch the transforms and isotropic-Gaussian fits across scan + # positions on `device` (per-position scipy.optimize.curve_fit -> vectorized + # LM). Covers both the 2-vector fit and the all-peaks basis fit; matches the + # per-position result to ~1e-6 px on clean peaks (a few 0.01 px on noisy, + # non-Gaussian peaks -- optimizer variance). Only DFT upsampling still falls back. + self._fit_lattice_vectors_batched( + u0=u0, + v0=v0, + refine_gaussian=refine_gaussian, + refine_radius_px=refine_radius_px, + device=device, + progressbar=progressbar, + refine_all_peaks=refine_all_peaks, + peaks=self.mean_img_peaks if refine_all_peaks else None, + weights=self.mean_img_weights if refine_all_peaks else None, + ) + else: + # Per-position fallback for DFT upsampling (refine_dft=True). + mode = self.metadata.get("mode", "linear").lower() + if mode == "gamma": + g = self.metadata["gamma"] + + it = np.ndindex(scan_r, scan_c) + if progressbar: + try: + from tqdm.auto import tqdm # type: ignore + + it = tqdm(it, total=scan_r * scan_c, desc="fit_lattice_vectors", leave=True) + except Exception: + pass + + for r, c in it: + dp = self.dataset.array[r, c] * self.mask_diffraction + self.mask_diffraction_inv + + if mode == "linear": + im = np.fft.fftshift(np.abs(np.fft.fft2(dp))) + elif mode == "log": + im = np.fft.fftshift(np.abs(np.fft.fft2(np.log1p(dp)))) + elif mode == "gamma": + im = np.fft.fftshift(np.abs(np.fft.fft2(np.power(np.clip(dp, 0.0, None), g)))) + else: + raise ValueError("metadata['mode'] must be 'linear', 'log', or 'gamma'") + + u_fit_abs, v_fit_abs, _, _ = _refine_lattice_vectors( + im, + u_rc=u0, + v_rc=v0, + radius_px=refine_radius_px, + refine_gaussian=refine_gaussian, + refine_dft=refine_dft, + refine_all_peaks=refine_all_peaks, + peaks=self.mean_img_peaks, + weights=self.mean_img_weights, + upsample=upsample, + maxfev=gaussian_maxfev, + ) + + self.u_peak_fit.array[r, c, :] = u_fit_abs + self.v_peak_fit.array[r, c, :] = v_fit_abs + + self.u_array[r, c, 0] = u_fit_abs[0] + self.u_array[r, c, 1] = u_fit_abs[1] + self.v_array[r, c, 0] = v_fit_abs[0] + self.v_array[r, c, 1] = v_fit_abs[1] + + self.metadata["fit_refine_gaussian"] = refine_gaussian + self.metadata["fit_refine_dft"] = refine_dft + self.metadata["fit_refine_all_peaks"] = refine_all_peaks + self.metadata["fit_refine_radius_px"] = refine_radius_px + self.metadata["fit_upsample"] = upsample + self.metadata["fit_gaussian_maxfev"] = gaussian_maxfev + + # Populate the per-position weight directly from the fitted peak amplitudes, + # mirroring the correlation pipeline where BraggVectors.fit_lattice emits + # mask_weight. create_mask() is therefore optional -- call it only to plot the + # weight map, recompute, or switch to the radial estimator. + self.mask_weight = self._amplitude_mask_weight() + + return self + + def _fit_lattice_vectors_batched( + self, + *, + u0: NDArray, + v0: NDArray, + refine_gaussian: bool, + refine_radius_px: float, + device: str, + progressbar: bool, + refine_all_peaks: bool = False, + peaks: NDArray | None = None, + weights: NDArray | None = None, + ) -> None: + """Batched torch implementation of the non-DFT :meth:`fit_lattice_vectors` path. + + The transform (mask/fill -> ``mode`` -> ``|FFT|`` -> fftshift) and the + parabolic/isotropic-Gaussian peak refinement are evaluated for a stack of scan + positions at once on ``device``, the analogue of the correlation pipeline's + :func:`~quantem.diffraction.disk_detection.detect_disks_batch`. This reproduces + :func:`_refine_lattice_vectors` (with ``refine_dft=False``) -- to ~1e-6 px on + clean peaks, within a few 0.01 px on noisy non-Gaussian peaks (bounded-trf vs LM + optimizer variance) -- while removing the per-position + ``scipy.optimize.curve_fit`` overhead. Results are + written into :attr:`u_array`/:attr:`v_array` and :attr:`u_peak_fit`/ + :attr:`v_peak_fit`. + + With ``refine_all_peaks=False`` only the two seed vectors ``u0``/``v0`` are + refined per position. With ``refine_all_peaks=True`` every peak in ``peaks`` (the + basis detected by :meth:`choose_lattice_vector`, weighted by ``weights``) is + batch-refined across the stack and the basis is recovered per position by the + same intensity-weighted least-squares fit as the per-position all-peaks branch -- + so that path is batched too rather than looping ``scipy.optimize.curve_fit`` over + ``n_positions x n_peaks``. In both cases the amplitude/width/background stored in + :attr:`u_peak_fit`/:attr:`v_peak_fit` (columns 2-4, the source of the mask weight) + come from the single-peak refinement at ``u0``/``v0`` -- the background-subtracted + Gaussian height. The all-peaks least squares only improves the u/v *positions* + (columns 0-1); the raw cepstral value at those positions rides the + central-autocorrelation pedestal and would invert the mask (bright in vacuum), so + it is not used for the weight. + """ + scan_r, scan_c, H, W = self.dataset.array.shape + n_pos = scan_r * scan_c + rcent, ccent = H // 2, W // 2 + + mode = self.metadata.get("mode", "linear").lower() + if mode not in ("linear", "log", "gamma"): + raise ValueError("metadata['mode'] must be 'linear', 'log', or 'gamma'") + gamma = float(self.metadata["gamma"]) if mode == "gamma" else None + + dev = torch.device(device) + mask = torch.as_tensor(self.mask_diffraction, dtype=torch.float64, device=dev) + mask_inv = torch.as_tensor(self.mask_diffraction_inv, dtype=torch.float64, device=dev) + + if refine_all_peaks: + # Precompute the fixed pieces of the per-position all-peaks fit (these do not + # change across scan positions): the detected peak seeds, their normalized + # weights, and the seed basis used to assign integer (h, k) indices. + peaks_arr = np.asarray(peaks, dtype=float).reshape(-1, 2) + w = np.asarray(weights, dtype=float).reshape(-1) + wsum = float(w.sum()) + w_norm = w / wsum if wsum != 0 else np.full(w.shape, 1.0 / max(1, w.size)) + sqrt_w = np.sqrt(w_norm)[:, None] + A_seed = np.column_stack((u0, v0)) # (2, 2) + n_pk = peaks_arr.shape[0] + + # Match the correlation chunking heuristic (see BraggVectors._detect_positions). + batch_size = int(min(1024, max(1, 16_000_000 // (H * W)))) + + starts = range(0, n_pos, batch_size) + if progressbar: + try: + from tqdm.auto import tqdm # type: ignore + + bar = tqdm(total=n_pos, desc="fit_lattice_vectors", leave=True) + except Exception: + bar = None + else: + bar = None + + for start in starts: + stop = min(start + batch_size, n_pos) + idxs = [(idx // scan_c, idx % scan_c) for idx in range(start, stop)] + + dps = torch.stack( + [ + torch.as_tensor( + np.asarray(self.dataset.array[r, c]), dtype=torch.float64, device=dev + ) + for r, c in idxs + ], + dim=0, + ) + dpm = dps * mask + mask_inv + if mode == "linear": + tr = dpm + elif mode == "log": + tr = torch.log1p(dpm) + else: # gamma + tr = dpm.clamp(min=0.0).pow(gamma) + ims = torch.fft.fftshift(torch.fft.fft2(tr).abs(), dim=(-2, -1)) + + if refine_all_peaks: + # Batch-refine each detected peak across the stack (one vectorized LM + # solve per peak), then recover the basis per position with the same + # weighted least squares as _refine_lattice_vectors' all-peaks branch. + pts_all = np.empty((len(idxs), n_pk, 2), dtype=float) + for j in range(n_pk): + rj = _refine_peaks_batched( + ims, peaks_arr[j], radius_px=refine_radius_px, + refine_gaussian=refine_gaussian, + ).cpu().numpy() + pts_all[:, j, :] = rj[:, :2] + # Amplitude/width/background for the mask weight come from the SAME + # single-peak refinement at the u/v seeds as the single-peak branch + # below -- i.e. the background-subtracted Gaussian height. (The all-peaks + # lstsq only improves the u/v *positions*; the raw cepstral value at those + # positions rides the central-autocorrelation pedestal and would invert + # the mask in vacuum.) + u_fit = _refine_peaks_batched( + ims, u0, radius_px=refine_radius_px, refine_gaussian=refine_gaussian + ).cpu().numpy() + v_fit = _refine_peaks_batched( + ims, v0, radius_px=refine_radius_px, refine_gaussian=refine_gaussian + ).cpu().numpy() + for k, (r, c) in enumerate(idxs): + pts = pts_all[k] # (n_pk, 2) row/col offsets from center + ab = np.round(np.linalg.lstsq(A_seed, pts.T, rcond=None)[0]).T + M = np.ones((n_pk, 3)) + M[:, :2] = ab # integer (h, k) indices + constant (center) column + uvc = np.linalg.lstsq(M * sqrt_w, pts * sqrt_w, rcond=None)[0] # (3, 2) + u_ref, v_ref = uvc[0], uvc[1] + self.u_peak_fit.array[r, c, :] = ( + u_ref[0], u_ref[1], u_fit[k, 2], u_fit[k, 3], u_fit[k, 4] + ) + self.v_peak_fit.array[r, c, :] = ( + v_ref[0], v_ref[1], v_fit[k, 2], v_fit[k, 3], v_fit[k, 4] + ) + self.u_array[r, c, 0] = u_ref[0] + self.u_array[r, c, 1] = u_ref[1] + self.v_array[r, c, 0] = v_ref[0] + self.v_array[r, c, 1] = v_ref[1] + else: + u_np = _refine_peaks_batched( + ims, u0, radius_px=refine_radius_px, refine_gaussian=refine_gaussian + ).cpu().numpy() + v_np = _refine_peaks_batched( + ims, v0, radius_px=refine_radius_px, refine_gaussian=refine_gaussian + ).cpu().numpy() + for k, (r, c) in enumerate(idxs): + self.u_peak_fit.array[r, c, :] = u_np[k] + self.v_peak_fit.array[r, c, :] = v_np[k] + self.u_array[r, c, 0] = u_np[k, 0] + self.u_array[r, c, 1] = u_np[k, 1] + self.v_array[r, c, 0] = v_np[k, 0] + self.v_array[r, c, 1] = v_np[k, 1] + + if bar is not None: + bar.update(len(idxs)) + + if bar is not None: + bar.close() + + def _amplitude_mask_weight(self) -> np.ndarray: + """Per-position weight from the fitted u/v peak amplitudes, min-max to ``[0, 1]``. + + The mean of the two fitted lattice-peak amplitudes measures the lattice signal at + each position. It is min-max normalized to ``[0, 1]`` with no contrast windowing; + display contrast is applied later via :meth:`StrainMap.plot_strain`'s + ``mask_range``. Degenerate input (constant / non-finite / amplitudes unavailable) + falls back to uniform full weight. + """ + scan_r = self.dataset.shape[0] + scan_c = self.dataset.shape[1] + if self.u_peak_fit is None or self.v_peak_fit is None: + return np.ones((scan_r, scan_c)) + signal = (self.u_peak_fit.array[:, :, 2] + self.v_peak_fit.array[:, :, 2]) / 2.0 + lo = np.nanmin(signal) + hi = np.nanmax(signal) + if np.isfinite(lo) and np.isfinite(hi) and hi > lo: + return (signal - lo) / (hi - lo) + return np.ones((scan_r, scan_c)) + + def create_mask( + self, + use_radial_method: bool = False, + exclusion_radius_fraction: float = 0.1, + plot: bool = True, + figsize: tuple[float, float] = (5, 4), + ): + """(Re)compute the per-position weight :attr:`mask_weight` from lattice signal. + + Usually **optional**: :meth:`fit_lattice_vectors` already populates + :attr:`mask_weight` from the fitted peak amplitudes (mirroring the correlation + pipeline, where ``BraggVectors.fit_lattice`` emits ``mask_weight``). Call this + only to plot the weight map, to recompute it, or to switch to the radial + estimator. + + Builds a ``(scan_row, scan_col)`` weight in ``[0, 1]`` measuring the lattice + signal at each position, the analogue of :attr:`BraggVectors.mask_weight`. It is + the default reference weighting handed to :meth:`calculate_strain_map`, so strong, + well-fit positions dominate the reference lattice and weak/vacuum positions are + down-weighted. + + The raw signal is min-max normalized to ``[0, 1]`` with **no** contrast windowing + or smoothing. Display contrast is applied via :meth:`StrainMap.plot_strain`'s + ``mask_range`` argument (e.g. ``mask_range=(0.6, 0.8)``) -- weights at/below + ``low`` render black, at/above ``high`` render full color. (Min-max is sensitive + to a few very bright positions, which compress the bulk toward 0; pick + ``mask_range`` to match where the bulk actually sits -- the weight-map plot here + shows it.) + + Parameters + ---------- + use_radial_method : bool, default=False + If ``True``, weight by the total diffracted intensity *outside* a central + disk (radius ``exclusion_radius_fraction`` of the detector width), excluding + the bright unscattered beam. Use a **small** exclusion (~0.1) so the Bragg + disks are kept: a large exclusion keeps only the far-corner diffuse scatter, + which is *anti*-correlated with crystallinity and inverts the contrast. If + ``False`` (default, recommended), weight by the mean fitted peak amplitude + from :meth:`fit_lattice_vectors` (requires it to have been run), identical to + the weight ``fit_lattice_vectors`` stores automatically. + exclusion_radius_fraction : float, default=0.1 + Central-disk radius (fraction of detector width) excluded by the radial + method. + plot : bool, default=True + If ``True``, show the resulting per-position weight map (the analogue of the + correlation pipeline's :meth:`BraggVectors.fit_lattice` plot). + figsize : tuple of float, default=(5, 4) + Figure size in inches for the weight-map plot. + + Returns + ------- + StrainMapAutocorrelation + ``self``, with :attr:`mask_weight` set. + """ + if not isinstance(self.dataset, (Dataset4d, Dataset4dstem)): + raise ValueError("Dataset must be Dataset4d or Dataset4dstem.") + + scan_r = self.dataset.shape[0] + scan_c = self.dataset.shape[1] + + if use_radial_method: + # center / radius are in DETECTOR coordinates (the last two axes) + center_y = self.dataset.shape[-2] / 2.0 + center_x = self.dataset.shape[-1] / 2.0 + y, x = np.ogrid[:self.dataset.shape[-2], :self.dataset.shape[-1]] + radius_map = np.sqrt((x - center_x)**2 + (y - center_y)**2) + exclusion_radius = exclusion_radius_fraction * self.dataset.shape[-1] + outside_mask = radius_map > exclusion_radius + signal = np.empty(shape=(scan_r, scan_c)) + for r in range(scan_r): + for c in range(scan_c): + dp = self.dataset.array[r, c] + signal[r, c] = np.sum(dp[outside_mask]) + # min-max normalization to [0, 1]; constant/degenerate -> ones + lo = np.nanmin(signal) + hi = np.nanmax(signal) + if np.isfinite(lo) and np.isfinite(hi) and hi > lo: + self.mask_weight = (signal - lo) / (hi - lo) + else: + self.mask_weight = np.ones((scan_r, scan_c)) + else: + if self.u_peak_fit is None or self.v_peak_fit is None: + raise RuntimeError("For intensity-based masking, run fit_lattice_vectors() first.") + self.mask_weight = self._amplitude_mask_weight() + + if plot: + fig, ax = plt.subplots(1, 1, figsize=figsize) + handle = ax.imshow(self.mask_weight, cmap="gray", vmin=0.0, vmax=1.0) + ax.set_title("Per-position lattice weight (mask_weight)") + ax.set_xlabel("scan column") + ax.set_ylabel("scan row") + fig.colorbar(handle, ax=ax, fraction=0.046, pad=0.04) + fig.tight_layout() + + return self + + + def calculate_strain_map( + self, + u_ref: np.ndarray | None = None, + v_ref: np.ndarray | None = None, + mask: np.ndarray | None = None, + ) -> StrainMap: + """Build a :class:`StrainMap` from the fitted per-position lattice vectors. + + Mirrors :meth:`BraggVectors.calculate_strain_map` so the downstream strain cells + (``plot_strain``, ``update_reference``, ``estimate_strain_precision``) are + identical for the correlation and cepstral workflows. Because the cepstral + vectors are real-space, ``real_space=True`` is passed and + :func:`~quantem.diffraction.strain._strain_tensor` yields strain matching the + correlation result on the same data. + + Parameters + ---------- + u_ref : np.ndarray, optional + ``(2,)`` reference for the first lattice vector. Defaults to the median over + the scan inside :class:`StrainMap`. + v_ref : np.ndarray, optional + ``(2,)`` reference for the second lattice vector. Defaults to the median over + the scan inside :class:`StrainMap`. + mask : np.ndarray, optional + ``(scan_row, scan_col)`` per-position weighting used when computing the + reference lattice. Defaults to :attr:`mask_weight` from :meth:`create_mask` + (the lattice signal strength), so strong, well-fit positions dominate the + reference. + + Returns + ------- + StrainMap + A strain map initialized from the fitted lattice vectors. + """ + if self.u_array is None or self.v_array is None: + raise ValueError("Run fit_lattice_vectors() before calculate_strain_map().") + if not isinstance(self.dataset, (Dataset4d, Dataset4dstem)): + raise ValueError("Dataset must be Dataset4d or Dataset4dstem.") + + if mask is None: + mask = self.mask_weight + + ds_sampling = float(self.dataset.sampling[0]) + ds_units = str(self.dataset.units[0]) + + return StrainMap( + u_array=self.u_array, + v_array=self.v_array, + ds_shape=tuple(self.dataset.shape), + real_space=self.real_space, + u_ref=u_ref, + v_ref=v_ref, + mask=mask, + ds_sampling=ds_sampling, + ds_units=ds_units, + ) + + def plot_lattice_vectors( + self, + subtract_mean: bool = True, + max_shift: float = 1.0, + cmap: str = "PiYG_r", + axsize: tuple[float, float] | None = None, + figsize: tuple[float, float] | None = None, + **imshow_kwargs: Any, + ): + """Plot the four per-position lattice-vector component maps. + + Shows ``u_row``, ``u_col``, ``v_row``, ``v_col`` from :attr:`u_array`/ + :attr:`v_array` as four panels (optionally with the mean subtracted), on a + shared symmetric color range. Positions whose vectors deviate from the + :meth:`choose_lattice_vector` seed by more than ``max_shift`` are masked out, so + bad fits do not blow up the color scale. A quick diagnostic of fit smoothness + before computing strain. + + Parameters + ---------- + subtract_mean : bool, default=True + If ``True``, subtract the (in-range) mean of each component so deviations + from the average lattice are shown. + max_shift : float, default=1.0 + Maximum allowed deviation (pixels) of a vector from the seed before its + position is masked out of the plot and the color-range statistics. + cmap : str, default="PiYG_r" + Diverging colormap; masked positions are drawn black. + axsize : tuple of float, optional + Per-panel size in inches; defaults to ``(4, 4)`` when ``figsize`` is unset. + figsize : tuple of float, optional + Overall figure size in inches; defaults to ``(4 * axsize[0], axsize[1])``. + **imshow_kwargs + Forwarded to :meth:`matplotlib.axes.Axes.imshow`. + + Returns + ------- + tuple + ``(fig, ax)`` with the four-panel figure. + """ + if self.u_array is None or self.v_array is None: + raise ValueError("Run fit_lattice_vectors() first to compute u_array and v_array.") + if self.u is None or self.v is None: + raise ValueError("Run choose_lattice_vector() first to set self.u and self.v.") + + im0 = self.u_array[:, :, 0] + im1 = self.u_array[:, :, 1] + im2 = self.v_array[:, :, 0] + im3 = self.v_array[:, :, 1] + + du0 = im0 - self.u[0] + du1 = im1 - self.u[1] + dv0 = im2 - self.v[0] + dv1 = im3 - self.v[1] + + max_shift2 = max_shift * max_shift + mu = (du0 * du0 + du1 * du1) <= max_shift2 + mv = (dv0 * dv0 + dv1 * dv1) <= max_shift2 + + if subtract_mean: + if np.any(mu): + im0 = im0 - np.mean(im0[mu]) + im1 = im1 - np.mean(im1[mu]) + else: + im0 = im0 - np.mean(im0) + im1 = im1 - np.mean(im1) + + if np.any(mv): + im2 = im2 - np.mean(im2[mv]) + im3 = im3 - np.mean(im3[mv]) + else: + im2 = im2 - np.mean(im2) + im3 = im3 - np.mean(im3) + + vals = [] + if np.any(mu): + vals.append(np.abs(im0[mu])) + vals.append(np.abs(im1[mu])) + if np.any(mv): + vals.append(np.abs(im2[mv])) + vals.append(np.abs(im3[mv])) + + if vals: + vlim = np.max(np.concatenate(vals)) + else: + vlim = np.max(np.abs(np.stack([im0, im1, im2, im3], axis=0))) + + vmin = -vlim + vmax = vlim + + cm = plt.get_cmap(cmap).copy() + cm.set_bad(color="black") + + m0 = np.ma.array(im0, mask=~mu) + m1 = np.ma.array(im1, mask=~mu) + m2 = np.ma.array(im2, mask=~mv) + m3 = np.ma.array(im3, mask=~mv) + + if axsize is None and figsize is None: + axsize = (4.0, 4.0) + if figsize is None: + figsize = (axsize[0] * 4.0, axsize[1]) + + fig, ax = plt.subplots(1, 4, figsize=figsize) + + ax[0].imshow(m0, cmap=cm, vmin=vmin, vmax=vmax, **imshow_kwargs) + ax[1].imshow(m1, cmap=cm, vmin=vmin, vmax=vmax, **imshow_kwargs) + ax[2].imshow(m2, cmap=cm, vmin=vmin, vmax=vmax, **imshow_kwargs) + ax[3].imshow(m3, cmap=cm, vmin=vmin, vmax=vmax, **imshow_kwargs) + + ax[0].set_title("u_r") + ax[1].set_title("u_c") + ax[2].set_title("v_r") + ax[3].set_title("v_c") + + for a in ax: + a.set_xticks([]) + a.set_yticks([]) + + return fig, ax + + +def _nice_length_units(target: float) -> float: + """Round ``target`` to the nearest "nice" scale-bar length (1/2/5 x 10^n).""" + if not np.isfinite(target) or target <= 0: + return 0.0 + exp = np.floor(np.log10(target)) + base = target / (10.0**exp) + if base < 1.5: + nice = 1.0 + elif base < 3.5: + nice = 2.0 + elif base < 7.5: + nice = 5.0 + else: + nice = 10.0 + return nice * (10.0**exp) + + +def _apply_center_crop_limits(ax: Any, shape: tuple[int, int], cropping_factor: float) -> None: + """Zoom ``ax`` to the central ``cropping_factor`` fraction of a ``shape`` image. + + Preserves the existing y-axis direction (inverted for image coordinates). + """ + if cropping_factor >= 1.0: + return + if not (0.0 < cropping_factor <= 1.0): + raise ValueError("cropping_factor must be in (0, 1].") + + H, W = shape + r0 = H // 2 + c0 = W // 2 + half_h = 0.5 * cropping_factor * H + half_w = 0.5 * cropping_factor * W + + ax.set_xlim(c0 - half_w, c0 + half_w) + + y0, y1 = ax.get_ylim() + if y0 > y1: + ax.set_ylim(r0 + half_h, r0 - half_h) + else: + ax.set_ylim(r0 - half_h, r0 + half_h) + + +def _flatten_axes(ax: Any) -> list[Any]: + """Flatten a matplotlib axes container (array/list/tuple) to a flat list of axes.""" + if isinstance(ax, np.ndarray): + return list(ax.ravel()) + if isinstance(ax, (list, tuple)): + out: list[Any] = [] + for a in ax: + out.extend(_flatten_axes(a)) + return out + return [ax] + + +def _raw_vec_to_display(vec_rc: NDArray, *, rotation_ccw_deg: float, transpose: bool) -> NDArray: + """Map a raw-detector ``(row, col)`` vector into the rotated display frame. + + Applies the optional axis transpose, then a counter-clockwise rotation of + ``rotation_ccw_deg``. Inverse of :func:`_display_vec_to_raw`. + """ + v = np.asarray(vec_rc, dtype=float).reshape(2) + dr, dc = v[0], v[1] + + if transpose: + dr, dc = dc, dr + + theta = np.deg2rad(rotation_ccw_deg) + ct = np.cos(theta) + st = np.sin(theta) + + dr2 = ct * dr - st * dc + dc2 = st * dr + ct * dc + return np.array((dr2, dc2), dtype=float) + + +def _display_vec_to_raw(vec_rc: NDArray, *, rotation_ccw_deg: float, transpose: bool) -> NDArray: + """Map a rotated-display ``(row, col)`` vector back to the raw detector frame. + + Inverse of :func:`_raw_vec_to_display`: undo the rotation, then the transpose. + """ + v = np.asarray(vec_rc, dtype=float).reshape(2) + dr, dc = v[0], v[1] + + theta = np.deg2rad(rotation_ccw_deg) + ct = np.cos(theta) + st = np.sin(theta) + + dr2 = ct * dr + st * dc + dc2 = -st * dr + ct * dc + + if transpose: + dr2, dc2 = dc2, dr2 + + return np.array((dr2, dc2), dtype=float) + + +def _plot_lattice_vectors(ax: Any, center_rc: tuple[float, float], u_rc: NDArray, v_rc: NDArray) -> None: + """Draw the ``u`` (red) and ``v`` (cyan) lattice vectors from ``center_rc`` on ``ax``.""" + r0, c0 = center_rc + + def _draw(vec: NDArray, label: str, color: tuple[float, float, float]) -> None: + dr, dc = vec[0], vec[1] + ax.plot([c0, c0 + dc], [r0, r0 + dr], linewidth=2.75, color=color) + ax.plot([c0 + dc], [r0 + dr], marker="o", markersize=6.0, color=color) + ax.text(c0 + dc, r0 + dr, f" {label}", color=color, fontsize=18, va="center") + + _draw(np.asarray(u_rc, dtype=float).reshape(2), "u", (1.0, 0.0, 0.0)) + _draw(np.asarray(v_rc, dtype=float).reshape(2), "v", (0.0, 0.7, 1.0)) + +def _plot_peaks(ax: Any, center_rc: tuple[float, float], peaks_plot: NDArray) -> None: + """Mark each detected peak (green dot), offset from ``center_rc``, on ``ax``.""" + r0, c0 = center_rc + + def _draw(vec: NDArray, color: tuple[float, float, float]) -> None: + dr, dc = vec[0], vec[1] + ax.plot([c0 + dc], [r0 + dr], marker="o", markersize=6.0, color=color) + + for p in peaks_plot: + _draw(np.asarray(p, dtype=float).reshape(2), (0.0, 1.0, 0.0)) + +def _overlay_lattice_vectors( + *, + ax: Any, + shape: tuple[int, int], + u_rc: NDArray, + v_rc: NDArray, + rot_ccw_deg: float, + q_transpose: bool, + peaks_plot: NDArray | None = None, +) -> None: + """Overlay lattice vectors on the original (and, if present, rotated) transform axes. + + Draws ``u``/``v`` (and any ``peaks_plot``) on the first axis in raw coordinates, and + on the second axis (if any) in the rotated display frame. + """ + axs = _flatten_axes(ax) + if not axs: + return + + H, W = shape + center_rc = (H // 2, W // 2) + + _plot_lattice_vectors(axs[0], center_rc, u_rc, v_rc) + if peaks_plot is not None: + _plot_peaks(axs[0], center_rc, peaks_plot) + + if len(axs) >= 2: + u_disp = _raw_vec_to_display(u_rc, rotation_ccw_deg=rot_ccw_deg, transpose=q_transpose) + v_disp = _raw_vec_to_display(v_rc, rotation_ccw_deg=rot_ccw_deg, transpose=q_transpose) + _plot_lattice_vectors(axs[1], center_rc, u_disp, v_disp) + + +def _parabolic_vertex_delta(v_m1: float, v_0: float, v_p1: float) -> float: + """Sub-pixel vertex offset (in ``[-1, 1]``) of a parabola through three samples. + + Given values at offsets ``-1, 0, +1``, returns the offset of the parabola's extremum + from the center sample; ``0`` for a degenerate (flat/non-finite) fit. + """ + denom = v_m1 - 2.0 * v_0 + v_p1 + if denom == 0 or not np.isfinite(denom): + return 0.0 + delta = 0.5 * (v_m1 - v_p1) / denom + if not np.isfinite(delta): + return 0.0 + return np.clip(delta, -1.0, 1.0) + + +def _parabolic_peak_rc_amp(im: NDArray, r_guess: float, c_guess: float) -> tuple[float, float, float]: + """3-point parabolic sub-pixel peak near ``(r_guess, c_guess)`` and its amplitude. + + Snaps to the brightest pixel in the 3x3 window around the (absolute-coordinate) + guess, refines row and column independently by a parabolic vertex, and returns + ``(r_sub, c_sub, amp)`` -- the sub-pixel peak position and the image value at the + rounded vertex. Shared by the per-position :func:`_refine_lattice_vectors` and the + batched :meth:`StrainMapAutocorrelation._fit_lattice_vectors_batched` so they agree + exactly. + """ + H, W = im.shape + r0 = int(np.clip(int(np.round(r_guess)), 0, H - 1)) + c0 = int(np.clip(int(np.round(c_guess)), 0, W - 1)) + win = im[max(0, r0 - 1) : min(H, r0 + 2), max(0, c0 - 1) : min(W, c0 + 2)] + if win.size == 0: + return r_guess, c_guess, 0.0 + + ir, ic = np.unravel_index(np.argmax(win), win.shape) + r_peak = max(0, r0 - 1) + ir + c_peak = max(0, c0 - 1) + ic + + if 0 < r_peak < H - 1: + col = im[r_peak - 1 : r_peak + 2, c_peak] + dr = _parabolic_vertex_delta(col[0], col[1], col[2]) + else: + dr = 0.0 + + if 0 < c_peak < W - 1: + row = im[r_peak, c_peak - 1 : c_peak + 2] + dc = _parabolic_vertex_delta(row[0], row[1], row[2]) + else: + dc = 0.0 + + r_sub = r_peak + dr + c_sub = c_peak + dc + r_int = int(np.clip(int(np.round(r_sub)), 0, H - 1)) + c_int = int(np.clip(int(np.round(c_sub)), 0, W - 1)) + return r_sub, c_sub, float(im[r_int, c_int]) + + +def _refine_peak_subpixel( + im: NDArray, + *, + r_guess: float, + c_guess: float, + radius_px: float = 2.0, +) -> tuple[float, float]: + """Refine a peak near ``(r_guess, c_guess)`` to sub-pixel ``(row, col)``. + + Finds the brightest pixel in a ``radius_px`` window, then applies independent + parabolic vertex offsets along each axis. + """ + im = np.asarray(im, dtype=float) + H, W = im.shape + + r0 = int(np.clip(int(np.round(r_guess)), 0, H - 1)) + c0 = int(np.clip(int(np.round(c_guess)), 0, W - 1)) + rad = int(max(0, int(np.ceil(radius_px)))) + + r1 = max(0, r0 - rad) + r2 = min(H, r0 + rad + 1) + c1 = max(0, c0 - rad) + c2 = min(W, c0 + rad + 1) + + win = im[r1:r2, c1:c2] + if win.size == 0: + return r_guess, c_guess + + ir, ic = np.unravel_index(np.argmax(win), win.shape) + r_peak = r1 + ir + c_peak = c1 + ic + + if 0 < r_peak < H - 1: + col = im[r_peak - 1 : r_peak + 2, c_peak] + dr = _parabolic_vertex_delta(col[0], col[1], col[2]) + else: + dr = 0.0 + + if 0 < c_peak < W - 1: + row = im[r_peak, c_peak - 1 : c_peak + 2] + dc = _parabolic_vertex_delta(row[0], row[1], row[2]) + else: + dc = 0.0 + + return r_peak + dr, c_peak + dc + + +def _refine_peak_subpixel_dft( + im: NDArray, + *, + r0: float, + c0: float, + upsample: int, +) -> tuple[float, float]: + """Refine a peak location to subpixel precision via local DFT upsampling. + + Uses a matrix-multiply DFT (the Guizar-Sicairos upsampled cross-correlation + trick) to evaluate the image's Fourier interpolant on a fine grid in a small + neighborhood around the initial estimate, then locates the maximum of that + upsampled patch with a 3-point parabolic vertex refinement. This avoids + interpolating the whole image and is accurate to roughly ``1 / upsample`` of a + pixel. + + Parameters + ---------- + im : NDArray + 2D real image (a transform panel) whose peak is being refined. + r0, c0 : float + Initial peak estimate in pixel coordinates (row, column). If ``r0`` or + ``c0`` is a torch tensor it is converted to a Python float. + upsample : int + DFT upsampling factor. Values ``<= 1`` skip refinement and return the + input estimate unchanged; larger values give finer subpixel resolution. + + Returns + ------- + tuple[float, float] + The refined ``(row, column)`` peak location in pixel coordinates. + """ + if upsample <= 1: + return r0, c0 + + im = np.asarray(im, dtype=float) + if torch.is_tensor(r0): + r0 = float(r0.item()) + if torch.is_tensor(c0): + c0 = float(c0.item()) + F = np.fft.fft2(np.fft.fftshift(im)) + + up = upsample + du = int(np.fix(np.ceil(1.5 * up))) + patch = np.abs(dft_upsample(F, up=up, shift=(r0, c0))) + patch = np.asarray(patch, dtype=float) + + i0, j0 = np.unravel_index(np.argmax(patch), patch.shape) + + if 0 < i0 < patch.shape[0] - 1: + col = patch[i0 - 1 : i0 + 2, j0] + di = _parabolic_vertex_delta(col[0], col[1], col[2]) + else: + di = 0.0 + + if 0 < j0 < patch.shape[1] - 1: + row = patch[i0, j0 - 1 : j0 + 2] + dj = _parabolic_vertex_delta(row[0], row[1], row[2]) + else: + dj = 0.0 + M, N = im.shape + dr = ((float(i0) - du + di)) / up + dc = ((float(j0) - du + dj)) / up + + return r0 + dr, c0 + dc + + +def _refine_lattice_vectors( + im: NDArray, + *, + u_rc: NDArray, + v_rc: NDArray, + radius_px: float = 2.0, + refine_gaussian: bool = True, + refine_dft: bool = False, + refine_all_peaks: bool = False, + peaks: NDArray | None = None, + weights: NDArray | None = None, + upsample: int = 16, + maxfev: int = 100, + threshold_percentile: float = 0.9975, + min_peak_spacing: float = 0, +) -> tuple[NDArray, NDArray, NDArray, NDArray]: + """Refine the two lattice vectors of a transform panel to subpixel precision. + + Starting from integer-pixel guesses for the ``u`` and ``v`` lattice vectors + (expressed as row/column offsets from the panel center), this locates the + corresponding autocorrelation/cepstral peaks and refines them with up to + three successively finer stages: a 3-point parabolic vertex estimate, an + isotropic 2D Gaussian least-squares fit, and DFT upsampling. When + ``refine_all_peaks`` is set, all bright peaks are detected and the lattice + basis is recovered by an intensity-weighted least-squares fit to the full + peak lattice instead of refining only the two seed vectors. + + Parameters + ---------- + im : NDArray + 2D real transform panel (Patterson/cepstral image) to fit. + u_rc, v_rc : NDArray + Length-2 initial lattice vectors as ``(row, column)`` offsets relative to + the panel center. + radius_px : float, optional + Half-width (in pixels) of the fitting window used for the Gaussian fit. + Default 2.0. + refine_gaussian : bool, optional + If True (default), refine each peak with an isotropic 2D Gaussian fit. + refine_dft : bool, optional + If True, follow the Gaussian fit with DFT-upsampled refinement + (requires ``upsample > 1``). Default False. + refine_all_peaks : bool, optional + If True, detect every bright peak and solve a weighted least-squares fit + for the lattice basis rather than refining only ``u_rc`` and ``v_rc``. + Default False. + peaks : NDArray or None, optional + Precomputed peak positions (row/col offsets from center) to use when + ``refine_all_peaks`` is set; if None they are detected automatically. + weights : NDArray or None, optional + Precomputed peak weights paired with ``peaks``; if None they are derived + from peak amplitudes. + upsample : int, optional + DFT upsampling factor for ``refine_dft``. Default 16. + maxfev : int, optional + Maximum function evaluations for the Gaussian ``curve_fit``. Default 100. + threshold_percentile : float, optional + Fractional intensity percentile (0-1) used as the peak-detection + threshold when auto-detecting peaks. Default 0.9975. + min_peak_spacing : float, optional + Minimum allowed spacing (in pixels) between detected peaks; ``0`` (default) + disables the spacing filter. + + Returns + ------- + tuple[NDArray, NDArray, NDArray, NDArray] + ``(u_result, v_result, pts, weights)``. ``u_result`` and ``v_result`` are + length-5 arrays ``(row_offset, col_offset, amplitude, sigma, background)`` + giving the refined lattice vectors relative to the panel center. ``pts`` + and ``weights`` are the detected peak positions and their normalized + weights when ``refine_all_peaks`` is True, otherwise both are ``None``. + When ``refine_all_peaks`` is True the position comes from the + intensity-weighted all-peaks least squares, but the amplitude/sigma/background + come from the single-peak refinement at the u/v seed -- the + background-subtracted Gaussian height, not the raw cepstral value, which rides + the central pedestal and inverts the mask in vacuum. + """ + from scipy.optimize import curve_fit + + im = np.asarray(im, dtype=float) + if im.ndim != 2: + raise ValueError("im must be 2D.") + + H, W = im.shape + r_center = H // 2 + c_center = W // 2 + + def _fit_gaussian_isotropic( + *, + r0: float, + c0: float, + radius_px: float, + maxfev: int, + ) -> tuple[float, float, float, float, float]: + rad = int(max(1, int(np.ceil(radius_px)))) + r0i = int(np.clip(int(np.round(r0)), 0, H - 1)) + c0i = int(np.clip(int(np.round(c0)), 0, W - 1)) + + r1 = max(0, r0i - rad) + r2 = min(H, r0i + rad + 1) + c1 = max(0, c0i - rad) + c2 = min(W, c0i + rad + 1) + + win = im[r1:r2, c1:c2] + if win.size == 0: + return r0, c0, 0.0, 0.0, 0.0 + + ir, ic = np.unravel_index(np.argmax(win), win.shape) + r_peak = r1 + ir + c_peak = c1 + ic + + bg0 = np.median(win) + amp0 = win[ir, ic] - bg0 + sig0 = max(0.75, radius_px / 2.0) + + rr = np.arange(r1, r2, dtype=float)[:, None] + cc = np.arange(c1, c2, dtype=float)[None, :] + RR = np.broadcast_to(rr, win.shape) + CC = np.broadcast_to(cc, win.shape) + + def _g2( + coords: tuple[NDArray, NDArray], + row: float, + col: float, + amp: float, + sigma: float, + background: float, + ) -> NDArray: + r, c = coords + sig = np.maximum(sigma, 1e-12) + return background + amp * np.exp(-((r - row) ** 2 + (c - col) ** 2) / (2.0 * sig * sig)) + + p0 = (r_peak, c_peak, max(0.0, amp0), sig0, bg0) + + rlo = r1 - 0.5 + rhi = (r2 - 1) + 0.5 + clo = c1 - 0.5 + chi = (c2 - 1) + 0.5 + + bounds_lo = (rlo, clo, 0.0, 0.25, -np.inf) + bounds_hi = (rhi, chi, np.inf, radius_px * 4.0, np.inf) + + try: + popt, _ = curve_fit( + _g2, + (RR.ravel(), CC.ravel()), + win.ravel(), + p0=p0, + bounds=(bounds_lo, bounds_hi), + maxfev=maxfev, + ) + row, col, amp, sig, bg = popt + if not (np.isfinite(row) and np.isfinite(col) and np.isfinite(amp) and np.isfinite(sig) and np.isfinite(bg)): + return r0, c0, p0[2], 0.0, 0.0 + return row, col, amp, sig, bg + except Exception: + return r0, c0, p0[2], 0.0, 0.0 + + def _refine_one(vec: NDArray) -> NDArray: + vec = np.asarray(vec, dtype=float).reshape(2) + r_guess = r_center + vec[0] + c_guess = c_center + vec[1] + + r_par, c_par, amp_par = _parabolic_peak_rc_amp(im, r_guess, c_guess) + + if refine_gaussian: + r_fit, c_fit, amp, sig, bg = _fit_gaussian_isotropic( + r0=r_par, + c0=c_par, + radius_px=radius_px, + maxfev=maxfev, + ) + else: + r_fit, c_fit, amp, sig, bg = r_par, c_par, amp_par, 0.0, 0.0 + + if refine_dft and upsample > 1: + r_dft, c_dft = _refine_peak_subpixel_dft( + im, + r0=r_fit, + c0=c_fit, + upsample=upsample, + ) + r_fit, c_fit = r_dft, c_dft + + return np.array((r_fit - r_center, c_fit - c_center, amp, sig, bg), dtype=float) + + def _find_initial_peaks_weights(initial_peaks: NDArray) -> tuple[NDArray, NDArray, NDArray]: + if initial_peaks is None: + threshold = np.percentile(im, threshold_percentile*100) + p = np.logical_and.reduce(( + im>np.roll(im,(-1,-1),axis = (0,1)), + im>np.roll(im,(-1, 0),axis = (0,1)), + im>np.roll(im,(-1, 1),axis = (0,1)), + im>np.roll(im,( 0,-1),axis = (0,1)), + im>np.roll(im,( 0, 1),axis = (0,1)), + im>np.roll(im,( 1,-1),axis = (0,1)), + im>np.roll(im,( 1, 0),axis = (0,1)), + im>np.roll(im,( 1, 1),axis = (0,1)), + im>threshold, + ) ) + initial_peaks = np.argwhere(p) + r0 = (r_center, c_center) + initial_peaks = initial_peaks - r0 + if min_peak_spacing > 0: + intensities = im[initial_peaks[:, 0] + r_center, initial_peaks[:, 1] + c_center] + sorted_indices = np.argsort(intensities)[::-1] + initial_peaks_sorted = initial_peaks[sorted_indices] + accepted_peaks = [] + + for peak in initial_peaks_sorted: + if len(accepted_peaks) == 0: + accepted_peaks.append(peak) + else: + distances = np.sqrt(np.sum((np.array(accepted_peaks) - peak)**2, axis=1)) + if np.all(distances >= min_peak_spacing): + accepted_peaks.append(peak) + + initial_peaks = np.array(accepted_peaks) + + peak_sp = np.zeros(initial_peaks.shape) + weights = np.zeros((initial_peaks.shape[0], 1)) + it = 0 + for peak in initial_peaks: + refined_peak = _refine_one(peak) + peak_sp[it, :] = refined_peak[:2] + weights[it] = refined_peak[2] + it += 1 + return peak_sp, weights + + + if refine_all_peaks: + if peaks is None or weights is None: + pts, weights = _find_initial_peaks_weights(None) + else: + pts,_ = _find_initial_peaks_weights(peaks) + + A = np.column_stack((u_rc, v_rc)) + ab0_float = np.linalg.lstsq(A, pts.T, rcond=None)[0] + ab0 = (np.round(ab0_float)).T + + weights /= weights.sum() + A = np.ones((pts.shape[0], 3)) + A[:,:2] = ab0 + pts_weighted = pts * np.sqrt(weights) + A_weighted = A * np.sqrt(weights) + uvr0 = np.linalg.lstsq(A_weighted, pts_weighted, rcond=None)[0] + u_refined = uvr0[0,:] + v_refined = uvr0[1,:] + # Position from the weighted all-peaks lstsq; amplitude/width/background from + # the SAME single-peak refinement at the u/v seeds as the single-peak return + # below -- the background-subtracted Gaussian height (the crystalline order + # parameter), not the raw cepstral value, which rides the central pedestal and + # would invert the mask in vacuum. + u_amp_fit = _refine_one(u_rc) + v_amp_fit = _refine_one(v_rc) + + return np.array((u_refined[0], u_refined[1], u_amp_fit[2], u_amp_fit[3], u_amp_fit[4]), dtype=float), np.array((v_refined[0], v_refined[1], v_amp_fit[2], v_amp_fit[3], v_amp_fit[4]), dtype=float), pts, weights + + return _refine_one(u_rc), _refine_one(v_rc), None, None + + +def _refine_peaks_batched( + ims: torch.Tensor, + vec: NDArray, + *, + radius_px: float, + refine_gaussian: bool, + lm_iters: int = 50, +) -> torch.Tensor: + """Batched, vectorized version of the single-image ``_refine_one`` peak refinement. + + Refines one lattice peak (located near ``center + vec``) for every transform in the + stack ``ims`` of shape ``(B, H, W)`` at once. The parabolic sub-pixel step matches + :func:`_refine_lattice_vectors`' ``_parabolic_peak_rc_amp`` exactly; when + ``refine_gaussian`` is set, the per-position ``scipy.optimize.curve_fit`` of an + isotropic 2D Gaussian is replaced by a batched Levenberg-Marquardt solve of the + identical objective and bounds. On clean, well-isolated peaks the two agree to + ~1e-6 px; on noisy, non-Gaussian peaks (e.g. a sinc-like cepstral peak on the sloped + tail of a bright neighbor) the bounded-trf and LM optima can differ by a few 0.01 px + -- amplitude-preserving and well below strain-relevant precision. The batched solve + removes the dominant per-call overhead, so it runs far faster (and scales onto a GPU + via the tensor ``device``). + + Parameters + ---------- + ims : torch.Tensor + Stack of transformed (fftshifted ``|FFT|``) images, shape ``(B, H, W)``, float. + vec : NDArray + Lattice vector ``(row, col)`` relative to the panel center; the peak is sought + near ``(H // 2 + row, W // 2 + col)``. + radius_px : float + Half-width in pixels of the Gaussian-fit window. + refine_gaussian : bool + If ``True``, refine with the batched isotropic-Gaussian LM fit; otherwise return + the parabolic estimate only (``sigma`` and ``background`` set to 0). + lm_iters : int, default=50 + Number of Levenberg-Marquardt iterations. + + Returns + ------- + torch.Tensor + Shape ``(B, 5)``: ``(row_offset, col_offset, amplitude, sigma, background)`` with + ``row_offset``/``col_offset`` measured relative to the panel center (matching + :func:`_refine_lattice_vectors`). + """ + Bn, Hn, Wn = ims.shape + dev, dt = ims.device, ims.dtype + rcent, ccent = Hn // 2, Wn // 2 + bidx = torch.arange(Bn, device=dev) + + # Clamp the seed into the interior so the 3x3 window below is always in-bounds + # (mirrors the clamping in _parabolic_peak_rc_amp; matters only for edge peaks, + # which are interior in normal use but must not crash the batched solve). + r0 = int(np.clip(round(rcent + float(vec[0])), 1, Hn - 2)) + c0 = int(np.clip(round(ccent + float(vec[1])), 1, Wn - 2)) + + # --- 3x3 argmax around the seed (matches _parabolic_peak_rc_amp) --- + win3 = ims[:, r0 - 1 : r0 + 2, c0 - 1 : c0 + 2].reshape(Bn, -1) + am = win3.argmax(1) + r_peak = r0 - 1 + torch.div(am, 3, rounding_mode="floor") + c_peak = c0 - 1 + (am % 3) + + def _parab(vm1: torch.Tensor, v0_: torch.Tensor, vp1: torch.Tensor) -> torch.Tensor: + denom = vm1 - 2.0 * v0_ + vp1 + delta = 0.5 * (vm1 - vp1) / denom + delta = torch.where(torch.isfinite(delta), delta, torch.zeros_like(delta)) + delta = torch.where(denom == 0, torch.zeros_like(delta), delta) + return delta.clamp(-1.0, 1.0) + + v0_ = ims[bidx, r_peak, c_peak] + dr = _parab(ims[bidx, (r_peak - 1).clamp(0, Hn - 1), c_peak], v0_, ims[bidx, (r_peak + 1).clamp(0, Hn - 1), c_peak]) + dr = torch.where((r_peak > 0) & (r_peak < Hn - 1), dr, torch.zeros_like(dr)) + dc = _parab(ims[bidx, r_peak, (c_peak - 1).clamp(0, Wn - 1)], v0_, ims[bidx, r_peak, (c_peak + 1).clamp(0, Wn - 1)]) + dc = torch.where((c_peak > 0) & (c_peak < Wn - 1), dc, torch.zeros_like(dc)) + r_sub = r_peak.to(dt) + dr + c_sub = c_peak.to(dt) + dc + + if not refine_gaussian: + ri = r_sub.round().long().clamp(0, Hn - 1) + ci = c_sub.round().long().clamp(0, Wn - 1) + amp = ims[bidx, ri, ci] + zeros = torch.zeros(Bn, device=dev, dtype=dt) + return torch.stack([r_sub - rcent, c_sub - ccent, amp, zeros, zeros], 1) + + # --- isotropic Gaussian fit over a (2*rad+1)^2 window, batched LM --- + rad = int(max(1, np.ceil(radius_px))) + P = 2 * rad + 1 + r0i = r_sub.round().long() + c0i = c_sub.round().long() + off = torch.arange(-rad, rad + 1, device=dev) + rows_idx = (r0i[:, None, None] + off[None, :, None]).expand(Bn, P, P) + cols_idx = (c0i[:, None, None] + off[None, None, :]).expand(Bn, P, P) + bexp = bidx[:, None, None].expand(Bn, P, P) + win = ims[bexp, rows_idx.clamp(0, Hn - 1), cols_idx.clamp(0, Wn - 1)] + RR = rows_idx.to(dt).reshape(Bn, -1) + CC = cols_idx.to(dt).reshape(Bn, -1) + y = win.reshape(Bn, -1) + + am2 = y.argmax(1) + r_pk2 = (r0i - rad + torch.div(am2, P, rounding_mode="floor")).to(dt) + c_pk2 = (c0i - rad + (am2 % P)).to(dt) + bg0 = y.median(1).values + amp0 = (y.max(1).values - bg0).clamp(min=0.0) + sig0 = torch.full((Bn,), max(0.75, radius_px / 2.0), device=dev, dtype=dt) + p = torch.stack([r_pk2, c_pk2, amp0, sig0, bg0], 1) + + rlo = (r0i - rad).to(dt) - 0.5 + rhi = (r0i + rad).to(dt) + 0.5 + clo = (c0i - rad).to(dt) - 0.5 + chi = (c0i + rad).to(dt) + 0.5 + sig_hi = radius_px * 4.0 + + lam = torch.full((Bn,), 1e-2, device=dev, dtype=dt) + eye5 = torch.eye(5, device=dev, dtype=dt)[None] + + def _sse(p_: torch.Tensor) -> torch.Tensor: + row, col, amp, sig, bg = [p_[:, i : i + 1] for i in range(5)] + sig = sig.clamp(min=1e-9) + E = torch.exp(-((RR - row) ** 2 + (CC - col) ** 2) / (2.0 * sig * sig)) + return ((bg + amp * E - y) ** 2).sum(1) + + for _ in range(lm_iters): + row, col, amp, sig, bg = [p[:, i : i + 1] for i in range(5)] + sig = sig.clamp(min=1e-9) + d_row = RR - row + d_col = CC - col + E = torch.exp(-(d_row * d_row + d_col * d_col) / (2.0 * sig * sig)) + res = (bg + amp * E) - y + g_row = amp * E * d_row / (sig * sig) + g_col = amp * E * d_col / (sig * sig) + g_amp = E + g_sig = amp * E * ((d_row * d_row + d_col * d_col) / (sig**3)) + g_bg = torch.ones_like(E) + J = torch.stack([g_row, g_col, g_amp, g_sig, g_bg], 2) # (B, N, 5) + JT = J.transpose(1, 2) + JTJ = JT @ J + JTr = JT @ res[..., None] + diag = torch.diagonal(JTJ, dim1=1, dim2=2).clamp(min=1e-12) + A = JTJ + lam[:, None, None] * eye5 * diag[:, None, :] + delta = torch.linalg.solve(A, -JTr)[..., 0] + pn = p + delta + pn[:, 0] = pn[:, 0].clamp(rlo, rhi) + pn[:, 1] = pn[:, 1].clamp(clo, chi) + pn[:, 2] = pn[:, 2].clamp(min=0.0) + pn[:, 3] = pn[:, 3].clamp(0.25, sig_hi) + s_new = _sse(pn) + s_old = res.pow(2).sum(1) + better = s_new < s_old + p = torch.where(better[:, None], pn, p) + lam = torch.where(better, (lam * 0.5).clamp(min=1e-9), (lam * 3.0).clamp(max=1e6)) + + row, col, amp, sig, bg = [p[:, i] for i in range(5)] + ok = torch.isfinite(p).all(1) + row = torch.where(ok, row, r_sub) + col = torch.where(ok, col, c_sub) + return torch.stack([row - rcent, col - ccent, amp, sig, bg], 1) + diff --git a/src/quantem/diffraction/strain_visualization.py b/src/quantem/diffraction/strain_visualization.py new file mode 100644 index 00000000..310084af --- /dev/null +++ b/src/quantem/diffraction/strain_visualization.py @@ -0,0 +1,355 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.cm import ScalarMappable +from matplotlib.colors import Normalize +from matplotlib.patches import FancyArrowPatch +from matplotlib.ticker import FuncFormatter + +from quantem.core.visualization.visualization_utils import ScalebarConfig, add_scalebar_to_ax + + +def plot_strain_panels( + e_uu: np.ndarray, + e_vv: np.ndarray, + e_uv: np.ndarray, + rotation: np.ndarray, + mask: np.ndarray | None, + u_ref: np.ndarray | None, + v_ref: np.ndarray | None, + ds_shape: tuple[int, ...], + ds_sampling: float = 1.0, + ds_units: str = "pixels", + strain_range_percent: tuple[float, float] = (-3.0, 3.0), + rotation_range_degrees: tuple[float, float] = (-2.0, 2.0), + mask_range: tuple[float, float] = (0.0, 1.0), + roi: np.ndarray | None = None, + plot_rotation: bool = True, + plot_gvecs: bool = False, + plot_scalebar: bool = False, + cmap_strain: str = "RdBu_r", + cmap_rotation: str = "PiYG", + layout: str = "horizontal", + figsize: tuple[float, float] | None = None, + panel_titles: tuple[str, str, str] | None = None, + **kwargs, +): + """Render strain (e_uu, e_vv, e_uv) and rotation panels. + + Strain arrays are fractional (multiplied by 100 for display); ``rotation`` is + in radians (converted to degrees for display). ``panel_titles`` overrides the + three strain-panel titles (e.g. to label the raw row/col reference frame). + + The mask modulates panel brightness (black where masked out). ``mask_range`` + ``(low, high)`` remaps it linearly before display: mask values ``>= high`` show + full color, ``<= low`` go black, and values between ramp from black to full. + The default ``(0.0, 1.0)`` leaves the already-normalized mask unchanged. + + When ``roi`` (a boolean ``(scan_row, scan_col)`` array) is given, positions + inside it are drawn in color and positions outside it in greyscale (the same + field, desaturated), so a chosen reference region stands out from its context. + """ + if mask is None: + mask = np.ones(ds_shape[:2]) + + # remap the mask brightness onto the [low, high] window: <= low -> black, + # >= high -> full color, linear between. default (0, 1) is a no-op. + low, high = float(mask_range[0]), float(mask_range[1]) + if high > low: + mask = np.clip((np.asarray(mask, dtype=float) - low) / (high - low), 0.0, 1.0) + else: + mask = (np.asarray(mask, dtype=float) >= high).astype(float) + + if cmap_rotation is None: + cmap_rotation = cmap_strain + + if layout not in ["horizontal", "vertical"]: + raise ValueError("layout must be 'horizontal' or 'vertical'") + + ncols = 4 if plot_rotation else 3 + is_horizontal = layout == "horizontal" + + if figsize is None: + figsize = (8, 3) if is_horizontal else (6, 6) + + if is_horizontal: + fig, ax = plt.subplots(1, ncols, figsize=figsize) + else: + fig, ax = plt.subplots(ncols, 1, figsize=figsize) + + cm_strain = plt.get_cmap(cmap_strain).copy() + cm_strain.set_bad(color="black") + cm_rot = plt.get_cmap(cmap_rotation).copy() + cm_rot.set_bad(color="black") + + euu_pct = e_uu * 100 + evv_pct = e_vv * 100 + euv_pct = e_uv * 100 + rot_deg = np.rad2deg(rotation) + + roi_bool = None if roi is None else np.asarray(roi).astype(bool) + gray_cm = plt.get_cmap("gray").copy() + gray_cm.set_bad(color="black") + + def _roi_compose(norm_vals, color_cm): + """Color the field inside the ROI; show it in greyscale outside the ROI.""" + rgb = color_cm(norm_vals)[:, :, :3] + if roi_bool is None: + return rgb + rgb_gray = gray_cm(norm_vals)[:, :, :3] + return np.where(roi_bool[:, :, np.newaxis], rgb, rgb_gray) + + norm_strain = Normalize(vmin=strain_range_percent[0], vmax=strain_range_percent[1]) + euu_disp = _roi_compose(norm_strain(euu_pct), cm_strain) + evv_disp = _roi_compose(norm_strain(evv_pct), cm_strain) + euv_disp = _roi_compose(norm_strain(euv_pct), cm_strain) + + title_fs = 16 + ax[0].imshow(euu_disp * mask[:, :, np.newaxis]) + ax[1].imshow(evv_disp * mask[:, :, np.newaxis]) + ax[2].imshow(euv_disp * mask[:, :, np.newaxis]) + + if panel_titles is None: + panel_titles = ( + r"$\epsilon_{uu}$ $\updownarrow$", + r"$\epsilon_{vv}$ $\leftrightarrow$", + r"$\epsilon_{uv}$ $\nwarrow\!\!\!\!\!\!\!\!\!\:\searrow$", + ) + ax[0].set_title(panel_titles[0], fontsize=title_fs) + ax[1].set_title(panel_titles[1], fontsize=title_fs) + ax[2].set_title(panel_titles[2], fontsize=title_fs) + + if plot_rotation: + norm_rot = Normalize(vmin=rotation_range_degrees[0], vmax=rotation_range_degrees[1]) + rot_disp = _roi_compose(norm_rot(rot_deg), cm_rot) + ax[3].imshow(rot_disp * mask[:, :, np.newaxis]) + ax[3].set_title(r"Rotation $\circlearrowleft$", fontsize=title_fs) + + for a in ax: + a.set_xticks([]) + a.set_yticks([]) + a.set_facecolor("black") + a.set_aspect("equal") + + if plot_scalebar: + scalebar_kwargs = {} + for key, value in kwargs.items(): + if key.startswith("scalebar_"): + scalebar_key = key[len("scalebar_"):] + scalebar_kwargs[scalebar_key] = value + + scalebar_defaults = { + "sampling": ds_sampling, + "units": ds_units, + "length": None, + "width_px": 1, + "pad_px": 0.5, + "color": "black", + "loc": "lower left", + "fontsize": 12, + "bold": True, + } + scalebar_defaults.update(scalebar_kwargs) + scalebar_config = ScalebarConfig(**scalebar_defaults) + add_scalebar_to_ax( + ax[0], + array_size=int(ds_shape[0]), + sampling=scalebar_config.sampling, + length_units=scalebar_config.length, + units=scalebar_config.units, + width_px=scalebar_config.width_px, + pad_px=scalebar_config.pad_px, + color=scalebar_config.color, + loc=scalebar_config.loc, + fontsize=scalebar_config.fontsize, + bold=scalebar_config.bold, + ) + + cb_size = 0.02 + cb_pad = 0.02 + + def _finalize_layout(): + # set_aspect("equal") only resizes/recenters each panel at draw time, so + # get_position() before a draw returns stale boxes -- placing the colorbars + # and g-vector compass off those boxes then spills them off the figure. + # Settle the layout cheaply first so every box read below is the real one. + try: + fig.draw_without_rendering() + except AttributeError: # matplotlib < 3.5 + fig.canvas.draw() + + if is_horizontal: + # Reserve a bottom band wide enough for the colorbar + its tick labels and + # title (fontsize 16) and a right band for the rotation-panel gap; widen the + # right band when the g-vector compass is drawn in it. These keep the figure + # usable when saved "as is" (no bbox_inches='tight'). + right = 0.78 if plot_gvecs else 0.93 + fig.subplots_adjust(left=0.04, right=right, top=0.88, bottom=0.24, wspace=0.05) + if plot_rotation: + # nudge the rotation panel right for a visual gap from the strain panels; + # 0.03 stays inside the reserved right band so nothing is clipped. + pos3 = ax[3].get_position() + ax[3].set_position([pos3.x0 + 0.03, pos3.y0, pos3.width, pos3.height]) + _finalize_layout() + + cb_orientation = "horizontal" + b0 = ax[0].get_position() + b2 = ax[2].get_position() + cb_y = b2.y0 - cb_pad - cb_size + strain_cb_pos = [b0.x0, cb_y, b2.x1 - b0.x0, cb_size] + + if plot_rotation: + b3 = ax[3].get_position() + rot_cb_pos = [b3.x0, cb_y, b3.x1 - b3.x0, cb_size] + last_pos = b3 + else: + rot_cb_pos = None + last_pos = b2 + + else: + # Top band for the panel titles, right band for the vertical colorbars + labels. + fig.subplots_adjust(left=0.04, right=0.80, top=0.92, bottom=0.06, hspace=0.15) + _finalize_layout() + + cb_orientation = "vertical" + b0 = ax[0].get_position() + b2 = ax[2].get_position() + strain_cb_pos = [b0.x1 + cb_pad, b2.y0, cb_size, b0.y1 - b2.y0] + + if plot_rotation: + b3 = ax[3].get_position() + rot_cb_pos = [b0.x1 + cb_pad, b3.y0, cb_size, b3.y1 - b3.y0] + last_pos = b3 + else: + rot_cb_pos = None + last_pos = b2 + + cax1 = fig.add_axes(strain_cb_pos) + sm_strain = ScalarMappable(norm=norm_strain, cmap=cm_strain) + cbar1 = fig.colorbar(sm_strain, cax=cax1, orientation=cb_orientation) + cbar1.set_label("Strain", fontsize=title_fs) + cbar1.formatter = FuncFormatter(lambda v, _pos: f"{v:g}%") + cbar1.update_ticks() + cbar1.ax.tick_params(labelsize=12) + + if plot_rotation and rot_cb_pos is not None: + cax2 = fig.add_axes(rot_cb_pos) + sm_rot = ScalarMappable(norm=norm_rot, cmap=cm_rot) + cbar2 = fig.colorbar(sm_rot, cax=cax2, orientation=cb_orientation) + cbar2.set_label("Rotation", fontsize=title_fs) + cbar2.formatter = FuncFormatter(lambda v, _pos: f"{v:g}°") + cbar2.update_ticks() + cbar2.ax.tick_params(labelsize=12) + + if plot_gvecs: + if u_ref is None or v_ref is None: + print("Warning: u_ref and v_ref not found. Call fit_strain() first.") + return fig, ax + + # The compass goes in the reserved margin beside the last panel; clamp its + # right edge to 0.99 so it never spills off the figure when saved "as is". + if is_horizontal: + ref_left = last_pos.x1 + 0.005 + ref_width = min(last_pos.width, 0.99 - ref_left) + ref_ax = fig.add_axes([ref_left, last_pos.y0, ref_width, last_pos.height]) + else: + ref_left = min(last_pos.x1 + 0.18, 0.74) + ref_width = min(last_pos.width, 0.99 - ref_left) + ref_ax = fig.add_axes([ref_left, last_pos.y0, ref_width, last_pos.height]) + + ref_ax.set_xlim(-1.5, 1.5) + ref_ax.set_ylim(-1.5, 1.5) + ref_ax.set_aspect("equal") + ref_ax.axis("off") + u_norm = u_ref / np.linalg.norm(u_ref) + v_norm = v_ref / np.linalg.norm(v_ref) + + u_row, u_col = u_norm + v_row, v_col = v_norm + arrow_props_ref = dict(arrowstyle="->", lw=3, mutation_scale=25) + + u_arrow = FancyArrowPatch( + (0, 0), (u_col, -u_row), + color="darkred", **arrow_props_ref + ) + ref_ax.add_patch(u_arrow) + + v_arrow = FancyArrowPatch( + (0, 0), (v_col, -v_row), + color="darkblue", **arrow_props_ref + ) + ref_ax.add_patch(v_arrow) + ref_ax.text(u_col * 1.3, -u_row * 1.3, r"$\mathbf{g}_{1}$", + fontsize=14, fontweight="bold", color="darkred", + ha="center", va="center") + + ref_ax.text(v_col * 1.3, -v_row * 1.3, r"$\mathbf{g}_{2}$", + fontsize=14, fontweight="bold", color="darkblue", + ha="center", va="center") + + return fig, ax + + +def plot_strain_precision_histogram( + edges: np.ndarray, + counts: np.ndarray, + precision: dict[str, float], + component: str, + unit: str, + *, + figsize: tuple[float, float] = (6.0, 4.0), +): + """Weighted histogram of the local-deviation strain precision. + + ``edges``/``counts`` describe the (mask-weighted, normalized) distribution of the + chosen ``component`` deviation in display units (``unit``). ``precision`` is the + weighted-median local deviation per component (used for the annotation box); the + plotted component's median is marked with a solid line. + """ + fig, ax = plt.subplots(figsize=figsize) + edges = np.asarray(edges, dtype=float) + counts = np.asarray(counts, dtype=float) + centers = 0.5 * (edges[:-1] + edges[1:]) + widths = np.diff(edges) + + ax.bar(centers, counts, width=widths, align="center", + color="#4C72B0", edgecolor="white", linewidth=0.3) + + median_value = precision[component] + if np.isfinite(median_value): + ax.axvline(median_value, color="crimson", ls="-", lw=2) + # label the line inline -- a legend box here would sit on top of the info box. + # Put the text on whichever side of the line keeps it clear of the right box. + span = float(edges[-1] - edges[0]) + on_right = span > 0 and (median_value - edges[0]) / span > 0.5 + ax.annotate( + f"median = {median_value:.3g} {unit}", + xy=(median_value, 0.96), xycoords=("data", "axes fraction"), + xytext=(-6 if on_right else 6, 0), textcoords="offset points", + ha="right" if on_right else "left", va="top", + color="crimson", fontsize=9, + ) + + label = "combined" if component == "combined" else component + ax.set_xlabel(f"{label} deviation ({unit})", fontsize=12) + ax.set_ylabel("weighted fraction", fontsize=12) + ax.set_title("Strain precision (median local deviation)", fontsize=13) + ax.tick_params(labelsize=10) + + annotation = "\n".join( + [ + r"median:", + rf" $\epsilon_{{uu}}$: {precision['e_uu']:.3g} %", + rf" $\epsilon_{{vv}}$: {precision['e_vv']:.3g} %", + rf" $\epsilon_{{uv}}$: {precision['e_uv']:.3g} %", + rf" rotation: {precision['rotation']:.3g} °", + rf" combined: {precision['combined']:.3g} %", + ] + ) + ax.text(0.97, 0.97, annotation, transform=ax.transAxes, ha="right", va="top", + fontsize=9, family="monospace", + bbox=dict(boxstyle="round", fc="white", ec="0.7", alpha=0.9)) + + fig.tight_layout() + return fig, ax diff --git a/tests/core/utils/test_imaging_utils.py b/tests/core/utils/test_imaging_utils.py new file mode 100644 index 00000000..bc3989d6 --- /dev/null +++ b/tests/core/utils/test_imaging_utils.py @@ -0,0 +1,127 @@ +""" +Tests for imaging utilities in quantem.core.utils.imaging_utils +""" + +import numpy as np +from scipy.ndimage import gaussian_filter +import pytest + +torch = pytest.importorskip("torch") + +from quantem.core.utils.imaging_utils import cross_correlation_shift, cross_correlation_shift_torch, weighted_cross_correlation_shift + + +@pytest.fixture +def spot_image(): + + im = np.zeros((64, 64), dtype=np.float64) + im[32, 32] = 1.0 + im = gaussian_filter(im, 2.0) + im /= np.max(im) + return im + + +def _fourier_shift_numpy(im: np.ndarray, shift_rc: tuple[float, float]) -> np.ndarray: + dr, dc = shift_rc + kr = np.fft.fftfreq(im.shape[0])[:, None] + kc = np.fft.fftfreq(im.shape[1])[None, :] + F = np.fft.fft2(im) + phase = np.exp(-2j * np.pi * (kr * dr + kc * dc)) + return np.fft.ifft2(F * phase).real + + +def _wrap_shift_rc(shift_rc: tuple[float, float], shape: tuple[int, int]) -> tuple[float, float]: + dr, dc = shift_rc + M, N = shape + dr = ((dr + M / 2) % M) - M / 2 + dc = ((dc + N / 2) % N) - N / 2 + return float(dr), float(dc) + + +@pytest.mark.parametrize( + "shift_true, upsample_factor, atol", + [ + ((5.0, -3.0), 1000, 1e-3), + ((-7.123, 1.789), 1000, 1e-3), + ], +) +def test_cross_correlation_shift_numpy_matches_expected(spot_image, shift_true, upsample_factor, atol): + im_ref = spot_image + im = _fourier_shift_numpy(im_ref, shift_true) + expected = _wrap_shift_rc((-shift_true[0], -shift_true[1]), im_ref.shape) + + meas = cross_correlation_shift(im_ref, im, upsample_factor=upsample_factor) + assert meas[0] == pytest.approx(expected[0], abs=atol) + assert meas[1] == pytest.approx(expected[1], abs=atol) + + +@pytest.mark.parametrize( + "shift_true, upsample_factor, atol", + [ + ((5.0, -3.0), 1000, 1e-3), + ((-7.123, 1.789), 1000, 1e-3), + ], +) +def test_cross_correlation_shift_torch_matches_expected(spot_image, shift_true, upsample_factor, atol): + im_ref = spot_image + im = _fourier_shift_numpy(im_ref, shift_true) + expected = _wrap_shift_rc((-shift_true[0], -shift_true[1]), im_ref.shape) + + t_ref = torch.from_numpy(im_ref) + t_im = torch.from_numpy(im) + meas = cross_correlation_shift_torch(t_ref, t_im, upsample_factor=upsample_factor).cpu().numpy() + + assert float(meas[0]) == pytest.approx(expected[0], abs=atol) + assert float(meas[1]) == pytest.approx(expected[1], abs=atol) + +import numpy as np +import pytest + +from quantem.core.utils.imaging_utils import weighted_cross_correlation_shift + + +@pytest.fixture +def peak_grid_images(): + im_ref = np.zeros((80, 80), dtype=float) + im = np.zeros_like(im_ref) + + r_ref = np.array([17, 27, 37, 47], dtype=int) + r_im = np.array([27, 37, 47, 57], dtype=int) # shifted +10 rows + c = np.array([17, 27, 37, 47], dtype=int) + + for rr in r_ref: + for cc in c: + im_ref[rr, cc] = 1.0 + + for rr in r_im: + for cc in c: + im[rr, cc] = 1.0 + + im_ref[37,27] = 3.0 + im[27,27] = 3.0 + + im_ref = gaussian_filter(im_ref,1.0) + im = gaussian_filter(im,1.0) + + # Smooth wrapped radial weight centered at 0 shift + M, N = im_ref.shape + fr = np.fft.fftfreq(M) * M + fc = np.fft.fftfreq(N) * N + dr2 = fr[:, None] ** 2 + fc[None, :] ** 2 + + sigma = 3.0 + weight = np.exp(dr2 / (-2.0*sigma**2)) + + return im_ref, im, weight + + +def test_weighted_cross_correlation_shift_unweighted_prefers_full_overlap(peak_grid_images): + im_ref, im, weight = peak_grid_images + shift = weighted_cross_correlation_shift(im_ref, im, upsample_factor=1000) + assert np.allclose(shift, (-10.0, 0.0), atol=1e-3) + + +def test_weighted_cross_correlation_shift_weighted_prefers_near_zero(peak_grid_images): + im_ref, im, weight = peak_grid_images + shift = weighted_cross_correlation_shift(im_ref, im, weight_real=weight, upsample_factor=1000) + assert np.allclose(shift, (0.0, 0.0), atol=1e-3)