Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d1cfb45
Minor type hint updates
santisoler Oct 16, 2025
3b84e30
Add a `callback` argument to `GaussNewtonConjugateGradient`
santisoler Oct 16, 2025
f7b6509
Add a Log protocol
santisoler Oct 17, 2025
175813f
Keep drafting a few ideas
santisoler Oct 17, 2025
45e1944
Merge branch 'main' into minimizer-log
santisoler Oct 17, 2025
20be4e2
Add a `get_minimizer_callback` method to Log
santisoler Oct 17, 2025
2097582
Extend MinimizerLog and add a __rich__ method to it
santisoler Oct 17, 2025
6ba91cd
Merge branch 'main' into minimizer-log
santisoler Oct 17, 2025
d62cc30
Merge branch 'main' into minimizer-log
santisoler Oct 17, 2025
3488c21
Merge branch 'main' into minimizer-log
santisoler Oct 18, 2025
b87bc4f
Keep allocating the Rich table
santisoler Oct 20, 2025
b81e869
Add option to not log the minimzer
santisoler Oct 20, 2025
ab49b00
Make MinimizerResult a child of dict
santisoler Oct 20, 2025
acd1429
Clunky implementation of nested logs
santisoler Oct 20, 2025
ee856ed
Fix bug in how minimizer was creating the MinimizerResult
santisoler Oct 20, 2025
d6c63e5
Move minimizer logs to the Inversion
santisoler Oct 20, 2025
68c697c
Remove panel from minimizer log tables
santisoler Oct 20, 2025
43d1162
Merge branch 'main' into minimizer-log
santisoler Oct 21, 2025
bd63125
Merge branch 'main' into minimizer-log
santisoler Nov 3, 2025
f566b99
Fix style
santisoler Nov 3, 2025
0af847d
Fix log_minimizers when minimizer is not a Minimizer
santisoler Nov 4, 2025
1e5b92d
Merge branch 'main' into minimizer-log
santisoler Nov 4, 2025
b4ec61b
Merge branch 'main' into minimizer-log
santisoler Mar 24, 2026
426c282
Add `cg_iters` to `MinimizerResult` in `GaussNewtonConjugateGradient`
santisoler Mar 24, 2026
731c788
Hide model columns in MinimizerLog table
santisoler Mar 24, 2026
d2b13ad
Move success/failure comment in Inversion to rich renderable
santisoler Mar 25, 2026
f8dcb5c
Remove commented lines
santisoler Mar 25, 2026
53c8876
Merge branch 'main' into minimizer-log
santisoler Mar 25, 2026
3faf383
Fix wording in docstring
santisoler Mar 25, 2026
ab23180
Add TODO comment
santisoler Mar 25, 2026
b248c4d
Simplify return
santisoler Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
453 changes: 310 additions & 143 deletions notebooks/21_dc-resistivity-inversion-w-beta-cooling.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/inversion_ideas/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .conditions import Condition
from .directive import Directive
from .minimizer import Minimizer
from .minimizer import Minimizer, MinimizerResult
from .objective_function import Combo, Objective, Scaled
from .simulation import Simulation

Expand All @@ -13,6 +13,7 @@
"Condition",
"Directive",
"Minimizer",
"MinimizerResult",
"Objective",
"Scaled",
"Simulation",
Expand Down
38 changes: 36 additions & 2 deletions src/inversion_ideas/base/minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,50 @@
"""

from abc import ABC, abstractmethod
from collections.abc import Generator
from collections.abc import Callable, Generator

from ..typing import Model
from .objective_function import Objective


class MinimizerResult(dict):
"""
Dictionary to store results of a single minimization iteration.

This class is a child of ``dict``, but allows to access the values through
attributes.

Notes
-----
Inspired in the :class:`scipy.optimize.OptimizeResult`.
"""

def __getattr__(self, name):
try:
return self[name]
except KeyError as e:
raise AttributeError(name) from e

__setattr__ = dict.__setitem__ # type: ignore[assignment]
__delattr__ = dict.__delitem__ # type: ignore[assignment]

def __dir__(self):
return list(self.keys())


class Minimizer(ABC):
"""
Base class to represent minimizers as generators.
"""

@abstractmethod
def __call__(self, objective: Objective, initial_model: Model) -> Generator[Model]:
def __call__(
self,
objective: Objective,
initial_model: Model,
*,
callback: Callable[[MinimizerResult], None] | None = None,
) -> Generator[Model]:
"""
Minimize objective function.

Expand All @@ -25,6 +56,9 @@ def __call__(self, objective: Objective, initial_model: Model) -> Generator[Mode
Objective function to be minimized.
initial_model : (n_params) array
Initial model used to start the minimization.
callback : callable, optional
Callable that gets called after each iteration.
Takes a :class:`inversion_ideas.base.MinimizerResult` as argument.

Returns
-------
Expand Down
118 changes: 106 additions & 12 deletions src/inversion_ideas/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
"""

import typing
import warnings
from collections.abc import Callable

from rich.console import Group, RenderableType
from rich.live import Live
from rich.spinner import Spinner
from rich.text import Text
from rich.tree import Tree

from inversion_ideas.errors import ConvergenceWarning

from .base import Condition, Directive, Minimizer, Objective
from .inversion_log import InversionLog, InversionLogRich
from .typing import Model
from .inversion_log import InversionLog, InversionLogRich, MinimizerLog
from .typing import Log, Model
from .utils import get_logger


Expand Down Expand Up @@ -40,11 +49,15 @@ class Inversion:
no limit on the total amount of iterations.
cache_models : bool, optional
Whether to cache each model after each iteration.
log : InversionLog or bool, optional
Instance of :class:`InversionLog` to store information about the inversion.
log : Log or bool, optional
Instance of :class:`InversionLog` to store information about the inversion,
or any object that follows the :class:`inversion_ideas.typing.Log` protocol.
If `True`, a default :class:`InversionLog` is going to be used.
If `False`, no log will be assigned to the inversion, and :attr:`Inversion.log`
will be ``None``.
log_minimizers : bool, optional
Whether to log the minimizers or not. Logging minimizers is only possible when
the ``minimizer`` is an instance of :class:`inversion_ideas.base.Minimizer``.
minimizer_kwargs : dict, optional
Extra arguments that will be passed to the ``minimizer`` when called.
"""
Expand All @@ -59,7 +72,8 @@ def __init__(
stopping_criteria: Condition | Callable[[Model], bool],
max_iterations: int | None = None,
cache_models=False,
log: "InversionLog | bool" = True,
log: Log | InversionLog | bool = True,
log_minimizers: bool = True,
minimizer_kwargs: dict | None = None,
):
self.objective_function = objective_function
Expand All @@ -72,6 +86,7 @@ def __init__(
if minimizer_kwargs is None:
minimizer_kwargs = {}
self.minimizer_kwargs = minimizer_kwargs
self._log_minimizers = log_minimizers

# Assign log
if log is False:
Expand All @@ -86,6 +101,11 @@ def __init__(
# Assign model as a copy of the initial model
self.model = initial_model.copy()

# TODO: Support for handling custom callbacks for the minimizer
if log is not None and "callback" in self.minimizer_kwargs:
msg = "Passing a custom callback for the minimizer is not yet supported."
raise NotImplementedError(msg)

def __next__(self):
"""
Run next iteration in the inversion.
Expand All @@ -108,17 +128,21 @@ def __next__(self):

# Check for stopping criteria before trying to run the iteration
if self.stopping_criteria(self.model):
get_logger().info(
get_logger().debug(
"🎉 Inversion successfully finished due to stopping criteria."
)
self._stop_code = 0
raise StopIteration

# Check if maximum number of iterations have been reached
if self.max_iterations is not None and self.counter >= self.max_iterations:
get_logger().info(
msg = (
"⚠️ Inversion finished after reaching maximum number of iterations "
f"({self.max_iterations})."
)
get_logger().debug(msg)
warnings.warn(msg, ConvergenceWarning, stacklevel=2)
self._stop_code = 1
raise StopIteration

# Update stopping criteria (if necessary)
Expand All @@ -137,10 +161,18 @@ def __next__(self):
directive(self.model, self.counter)

# Minimize objective function
# ---------------------------
if isinstance(self.minimizer, Minimizer):
# Keep only the last model of the minimizer iterator
# Generate a new minimizer log for this iteration
minimizer_kwargs = self.minimizer_kwargs.copy()
if self.log is not None and self.log_minimizers:
minimizer_log = MinimizerLog()
self.minimizer_logs.append(minimizer_log)
minimizer_kwargs["callback"] = minimizer_log.update

# Unpack the generator and keep only the last model
*_, model = self.minimizer(
self.objective_function, self.model, **self.minimizer_kwargs
self.objective_function, self.model, **minimizer_kwargs
)
else:
model = self.minimizer(
Expand All @@ -161,8 +193,21 @@ def __next__(self):
return self.model

def __iter__(self):
self._stop_code = None
return self

@property
def stop_code(self) -> int | None:
"""
Code obtained after inversion stopped.

Code 0: the stopping criteria was met.
Code 1: the inversion stopped due to that the maximum number of iterations was
encountered..
Code ``None``: the inversion is still running or hasn't started yet.
"""
return getattr(self, "_stop_code", None)

@property
def counter(self) -> int:
"""
Expand All @@ -185,6 +230,22 @@ def models(self) -> list:
self._models = [self.initial_model]
return self._models

@property
def log_minimizers(self) -> bool:
"""Whether minimizers will be logged or not."""
return self._log_minimizers and isinstance(self.minimizer, Minimizer)

@property
def minimizer_logs(self) -> list[None | MinimizerLog] | None:
"""
Logs of minimizers.
"""
if not self.log_minimizers:
return None
if not hasattr(self, "_minimizer_logs"):
self._minimizer_logs = [None]
return self._minimizer_logs

def run(self, show_log=True) -> Model:
"""
Run the inversion.
Expand All @@ -195,11 +256,44 @@ def run(self, show_log=True) -> Model:
Whether to show the ``log`` (if it's defined) during the inversion.
"""
if show_log and self.log is not None:
if not hasattr(self.log, "live"):
if not isinstance(self.log, RenderableType):
# TODO: Add message
raise NotImplementedError()
with self.log.live() as live:

spinner = Spinner(
name="dots", text="Starting inversion...", style="green", speed=1
)
log = Tree(self.log) if self.log_minimizers else self.log
group = Group(log, spinner)

with Live(group, refresh_per_second=10) as live:
for _ in self:
live.refresh()
if self.log_minimizers:
minimizer_log = self.minimizer_logs[self.counter]
if minimizer_log is not None:
renderable = minimizer_log.__rich__()
renderable.title = (
f"Minimizer log for iteration {self.counter}"
)
log.add(renderable)
spinner.text = f"Running iteration {self.counter + 1}..."

# Finish inversion
group.renderables.pop(-1) # remove spinner
if self.stop_code == 0:
text = Text(
"🎉 Inversion successfully finished due to stopping criteria."
)
elif self.stop_code == 1:
text = Text(
"⚠️ Inversion finished after reaching maximum number of iterations "
f"({self.max_iterations})."
)
else:
text = Text(f"⚠️ Invalid stop code '{self.stop_code}'.")
group.renderables.append(text)
live.refresh()

else:
for _ in self:
pass
Expand Down
Loading
Loading