Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9cabc82
Initial plan
Copilot Mar 18, 2026
7739850
initial commit for model based refinment
cophus Feb 16, 2026
6e90b3a
fitting model working!
cophus Feb 17, 2026
f63af0c
adding docstrings
cophus Feb 17, 2026
24c47b4
fixing CoM behaviour
cophus Feb 20, 2026
984637c
updates
cophus Feb 20, 2026
1a4c121
slight tweaks
cophus Feb 22, 2026
9ff42ca
dataset4dstem.from_file consistent with read_4dstem args
arthurmccray Feb 25, 2026
207ba98
refactor in progress, base and rendering largely done, working not pe…
arthurmccray Feb 26, 2026
9ba0541
adding state saving to ModelDiffraction
arthurmccray Feb 27, 2026
fb8aae4
switching to state_dict saving
arthurmccray Feb 27, 2026
f00f255
first version of FitBase
arthurmccray Feb 27, 2026
22227c4
cleaning up FitBase and ModelDiffraction
arthurmccray Feb 28, 2026
f91142f
moving more stuff to FitBase
arthurmccray Feb 28, 2026
5e08803
reorganizing classes -- no functional change
arthurmccray Feb 28, 2026
c604a9f
splitting off ModelDiffractionVisualizations into separate file
arthurmccray Feb 28, 2026
f48e691
adding hard constraints like force_center for DiskTemplate
arthurmccray Mar 2, 2026
fad203b
adding docstrings and cleaning
arthurmccray Mar 2, 2026
fd34ae5
adding visualizations and overlays
arthurmccray Mar 3, 2026
38cb7bd
adding turning on/off individual components and parameters
arthurmccray Mar 3, 2026
fb4315d
fixing center disk duplication and a couple viz bugs
arthurmccray Mar 3, 2026
cf7f801
adding back parameter bounds
arthurmccray Mar 3, 2026
3f54795
updating colormaps
arthurmccray Mar 3, 2026
f12f53e
more consistent hard constraints of ranges
arthurmccray Mar 3, 2026
fa8ca66
cleaning up naming of Components
arthurmccray Mar 3, 2026
46e83cd
Fix docstring type hint for file_type parameter in from_file
Copilot Mar 18, 2026
8cc3389
Merge pull request #1 to make create-clean-fittings-branch
arthurmccray Mar 18, 2026
4b15483
Merge pull request #193 from arthurmccray/fitting_models_clean
arthurmccray Mar 18, 2026
fbf631c
Merge branch 'dev' into fitting_models_clean
arthurmccray Mar 25, 2026
e7a94ef
updating types for OptimizerParams and SchedulerParams
arthurmccray Mar 25, 2026
220fc01
Merge remote-tracking branch 'upstream/dev' into HEAD
Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/quantem/core/datastructures/dataset4dstem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os import PathLike
from typing import Any, Self

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -97,15 +98,15 @@ def __init__(
self._virtual_detectors = {} # Store detector information for regeneration

@classmethod
def from_file(cls, file_path: str, file_type: str) -> "Dataset4dstem":
def from_file(cls, file_path: str | PathLike, file_type: str | None = None) -> "Dataset4dstem":
"""
Create a new Dataset4dstem from a file.

Parameters
----------
file_path : str
file_path : str | PathLike
Path to the data file
file_type : str
file_type : str | None
The type of file reader needed. See rosettasciio for supported formats
https://hyperspy.org/rosettasciio/supported_formats/index.html

Expand Down
21 changes: 21 additions & 0 deletions src/quantem/core/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from quantem.core.fitting.background import DCBackground as DCBackground
from quantem.core.fitting.background import GaussianBackground as GaussianBackground
from quantem.core.fitting.base import Component as Component
from quantem.core.fitting.base import Model as Model
from quantem.core.fitting.base import ModelContext as ModelContext
from quantem.core.fitting.base import OriginND as OriginND
from quantem.core.fitting.base import Parameter as Parameter
from quantem.core.fitting.diffraction import DiskTemplate as DiskTemplate
from quantem.core.fitting.diffraction import SyntheticDiskLattice as SyntheticDiskLattice

__all__ = [
"Component",
"DCBackground",
"DiskTemplate",
"GaussianBackground",
"Model",
"ModelContext",
"OriginND",
"Parameter",
"SyntheticDiskLattice",
]
112 changes: 112 additions & 0 deletions src/quantem/core/fitting/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import annotations

from typing import Any, Sequence

import torch
from torch import nn

from quantem.core.fitting.base import OriginND, RenderComponent, RenderContext


class DCBackground(RenderComponent):
def __init__(
self,
*,
intensity: float | int | Sequence[float | int | None] = 0.0,
name: str = "dc_background",
constraint_params: dict[str, Any] | None = None,
):
"""
Build a constant background component.

Notes
-----
Validity is enforced via hard constraints/parameter bounds. Forward
intentionally avoids hard clamps for gradient flow.
"""
super().__init__()
self.name = str(name)
intensity_init, intensity_lo, intensity_hi = self.parse_bounded_init(
intensity, name="intensity"
)
self.intensity_raw = nn.Parameter(torch.tensor(intensity_init, dtype=torch.float32))
bounded_lo = 0.0 if intensity_lo is None else max(float(intensity_lo), 0.0)
self.register_parameter_bounds("intensity_raw", bounded_lo, intensity_hi)
if constraint_params is not None:
self.apply_constraint_params(constraint_params, strict=True)
self._enforce_parameter_bounds()

def forward(self, ctx: RenderContext) -> torch.Tensor:
"""
Render constant background from raw trainable intensity.

Notes
-----
Validity is enforced via hard constraints/parameter bounds, not via
forward-time hard clamps.
"""
inten = self.intensity_raw.to(device=ctx.device, dtype=ctx.dtype)
return torch.ones(ctx.shape, device=ctx.device, dtype=ctx.dtype) * inten


class GaussianBackground(RenderComponent): # TODO this should be N dimensional by default
def __init__(
self,
*,
sigma: float | int | Sequence[float | int | None] = (40.0, 5.0, None),
intensity: float | int | Sequence[float | int | None] = 0.0,
origin: OriginND | None = None,
origin_key: str = "origin",
name: str = "gaussian_background",
constraint_params: dict[str, Any] | None = None,
):
"""
Build a Gaussian background component centered at origin.

Notes
-----
``sigma_raw`` and ``intensity_raw`` validity is enforced via hard
constraints/parameter bounds. Forward intentionally avoids hard clamps
for gradient flow.
"""
super().__init__()
self.name = str(name)
self.origin = origin
self.origin_key = str(origin_key)
sigma_init, sigma_lo, sigma_hi = self.parse_bounded_init(sigma, name="sigma")
intensity_init, intensity_lo, intensity_hi = self.parse_bounded_init(
intensity, name="intensity"
)
self.sigma_raw = nn.Parameter(torch.tensor(sigma_init, dtype=torch.float32))
sigma_bounded_lo = 1e-6 if sigma_lo is None else max(float(sigma_lo), 1e-6)
self.register_parameter_bounds("sigma_raw", sigma_bounded_lo, sigma_hi)
self.intensity_raw = nn.Parameter(torch.tensor(intensity_init, dtype=torch.float32))
intensity_bounded_lo = 0.0 if intensity_lo is None else max(float(intensity_lo), 0.0)
self.register_parameter_bounds("intensity_raw", intensity_bounded_lo, intensity_hi)
if constraint_params is not None:
self.apply_constraint_params(constraint_params, strict=True)
self._enforce_parameter_bounds()

def set_origin(self, origin: OriginND) -> None:
self.origin = origin

def forward(self, ctx: RenderContext) -> torch.Tensor:
"""
Render Gaussian background from raw trainable parameters.

Notes
-----
Validity is enforced via hard constraints/parameter bounds, not via
forward-time hard clamps.
"""
if self.origin is None:
raise RuntimeError("GaussianBackground requires an OriginND instance.")

rr = torch.arange(ctx.shape[0], device=ctx.device, dtype=ctx.dtype)[:, None]
cc = torch.arange(ctx.shape[1], device=ctx.device, dtype=ctx.dtype)[None, :]
r0, c0 = self.origin.coords[0], self.origin.coords[1]

sigma = self.sigma_raw.to(device=ctx.device, dtype=ctx.dtype)
inten = self.intensity_raw.to(device=ctx.device, dtype=ctx.dtype)
r2 = (rr - r0) ** 2 + (cc - c0) ** 2
return inten * torch.exp(-0.5 * r2 / (sigma * sigma))
Loading