Skip to content

Conversation

@michaelxu01
Copy link
Contributor

@michaelxu01 michaelxu01 commented Jan 15, 2026

  • Added line and affine correction
  • added output plot option for scan_update to compare current iteration to initial scan (need to add .h5 load/save for initial scan positions)
  • refactored scan from array to object
  • updated h5 to include scan positions and any metadata in a dictionary

This has been rebased on latest probe aberration merge. Still need to update and test conventional solvers

Example: affine-only updates after 800 iterations for an experiment:
image

@hexane360
Copy link
Owner

Sorry, I think this needs another rebase. LMK if you need help with it. Also, can you fix the tests? It should be as simple as changing the scan access. We should probably also take a look at the conventional engines before merge, shouldn't be too difficult.

# cast = to_real_dtype(sim.object.data.dtype)
xp = get_array_module(sim.scan.data)
update = xp.zeros_like(sim.scan.data, dtype=sim.scan.data.dtype)
for kind, weight in self.constraints.items():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this deterministic? It seems like it could apply the updates in arbitrary order, we may want to add sorted() if it matters

Comment on lines +41 to +45
self.constraints[kind] = getattr(props, kind)
self.total_weight = sum(self.constraints.values())
# self.weight: t.Optional[float]
# self.type: t.Optional[str]
logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values())} with total weight {self.total_weight:.4f}")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the meaning of weights here? Does it make more sense to add weight for each constraint as a relaxation parameter?

Comment on lines +68 to +78
if kind == 'affine':
update += scan_affine(sim.scan.data, state.previous) * weight
# sims.object.data = ## affine deform object
if kind == 'line' and state.row_bins is not None:
update += scan_line(sim.scan.data, state.previous, state.row_bins) * weight
if kind == 'hpf':
pass
if kind == 'lpf':
pass
if kind == 'default':
update += scan_default(sim.scan.data, state.previous) * weight
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be better as a dictionary of update functions. It could also be a match, but I don't remember what our minimum supported version is.


## double check that if position update is off (scan == prev_step), this doesn't break anything
# @partial(jit, donate_argnames=('pos',), cupy_fuse=True)
def scan_default(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be private functions (e.g. _scan_default())

Comment on lines +101 to +109
left = xp.matmul(pos_prev.T, disp_update)
right = xp.matmul(pos_prev.T, pos_prev)
A = xp.matmul(xp.linalg.inv(right), left)
constraint = xp.matmul(pos_prev, A)
#remove the middle shift, keep the middle unchanged
center_ones = xp.ones((1, 1), pos.dtype)
# center[0, 0:2] = xp.average(pos, axis = 0)
center = xp.concatenate([xp.average(pos, axis = 0, keepdims=True), center_ones], axis=1, dtype=pos.dtype)
center_shift = xp.matmul(center, A)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xp.matmul(x, y) should be replacable as x @ y


cost = xp.sum(xp.abs(sim.object.data - 1.0))
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe unnecessary for this PR, but we could probably make n_pos() a method of sim.scan to avoid this repetition

Comment on lines +127 to 130
## FIXME: the scan normalization here - happens before dropnans and scan data flattening, but may alter shape and therefore rows/cols? why is this needed
def _normalize_scan_shape(
patterns: Patterns, state: ReconsState
) -> t.Tuple[Patterns, ReconsState]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically is for scan and patterns from heterogeneous sources, i.e. one from previous state. It's a bit of a hack, but should be possible to adapt.


## FIXME: output to Tuple? importance of array number types

@t.overload
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this API, not sure what would be better. Maybe output ScanState directly? Or maybe better, keep make_raster_scan clean and include the metadata in the hook only

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants