Skip to content

Commit 2362eec

Browse files
authored
Workflow API Improvements (#109)
- Add block option to reconstruct call in workflow api - Add export training data method to workflow api - Fix position offset in pty-chi results - Add more debug logging
1 parent 3880050 commit 2362eec

File tree

15 files changed

+108
-55
lines changed

15 files changed

+108
-55
lines changed

src/ptychodus/api/object.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def layer_spacing_m(self) -> Sequence[float]:
203203
def get_total_thickness_m(self) -> float:
204204
return sum(self._layer_spacing_m)
205205

206+
def __repr__(self) -> str:
207+
return f'{self._array.dtype}{self._array.shape}'
208+
206209

207210
class ObjectFileReader(ABC):
208211
@abstractmethod

src/ptychodus/api/parametric.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ def get_value(self) -> float:
211211

212212
return value
213213

214+
def set_value(self, value: float, *, notify: bool = True) -> None:
215+
super().set_value(float(value), notify=notify)
216+
214217
def set_value_from_string(self, value: str) -> None:
215218
self.set_value(float(value))
216219

src/ptychodus/api/probe.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,19 @@ def __init__(
169169
self._opr_weights = None
170170
elif numpy.issubdtype(opr_weights.dtype, numpy.floating):
171171
if opr_weights.ndim == 2:
172-
if opr_weights.shape[1] == self._array.shape[0]:
172+
num_weights_actual = opr_weights.shape[1]
173+
num_weights_expected = self._array.shape[0]
174+
175+
if num_weights_actual == num_weights_expected:
173176
self._opr_weights = opr_weights
174177
else:
175-
raise ValueError('opr_weights do not match the number of coherent probe modes')
178+
raise ValueError(
179+
(
180+
'inconsistent number of opr weights!'
181+
f' actual={num_weights_actual}'
182+
f' expected={num_weights_expected}'
183+
)
184+
)
176185
else:
177186
raise ValueError('opr_weights must be 2-dimensional ndarray')
178187
else:
@@ -267,6 +276,9 @@ def get_geometry(self) -> ProbeGeometry:
267276
def __len__(self) -> int:
268277
return 1 if self._opr_weights is None else self._opr_weights.shape[0]
269278

279+
def __repr__(self) -> str:
280+
return f'{self._array.dtype}{self._array.shape}'
281+
270282

271283
class ProbeFileReader(ABC):
272284
@abstractmethod

src/ptychodus/api/scan.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def __len__(self) -> int:
8787
def nbytes(self) -> int:
8888
return self._coordinates_m.nbytes
8989

90+
def __repr__(self) -> str:
91+
return f'{self._coordinates_m.dtype}{self._coordinates_m.shape}'
92+
9093

9194
class ScanPointParseError(Exception):
9295
"""raised when the scan file cannot be parsed"""

src/ptychodus/api/workflow.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111

1212
class WorkflowProductAPI(ABC):
13+
@abstractmethod
14+
def get_product_index(self) -> int:
15+
pass
16+
1317
@abstractmethod
1418
def open_scan(self, file_path: Path, *, file_type: str | None = None) -> None:
1519
pass
@@ -41,7 +45,7 @@ def build_object(
4145
pass
4246

4347
@abstractmethod
44-
def reconstruct_local(self) -> WorkflowProductAPI:
48+
def reconstruct_local(self, block: bool = False) -> WorkflowProductAPI:
4549
pass
4650

4751
@abstractmethod
@@ -52,6 +56,10 @@ def reconstruct_remote(self) -> None:
5256
def save_product(self, file_path: Path, *, file_type: str | None = None) -> None:
5357
pass
5458

59+
@abstractmethod
60+
def export_training_data(self, file_path: Path) -> None:
61+
pass
62+
5563

5664
class WorkflowAPI(ABC):
5765
@abstractmethod
@@ -76,6 +84,11 @@ def export_assembled_patterns(self, file_path: Path) -> None:
7684
"""export assembled patterns"""
7785
pass
7886

87+
@abstractmethod
88+
def get_product(self, product_index: int) -> WorkflowProductAPI:
89+
"""returns a product by index"""
90+
pass
91+
7992
@abstractmethod
8093
def open_product(self, file_path: Path, *, file_type: str | None = None) -> WorkflowProductAPI:
8194
"""opens product from file"""
@@ -103,6 +116,10 @@ def save_settings(
103116
) -> None:
104117
pass
105118

119+
@abstractmethod
120+
def set_reconstructor(self, reconstructor_name: str) -> None:
121+
pass
122+
106123

107124
class FileBasedWorkflow(ABC):
108125
@property

src/ptychodus/model/core.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def batch_mode_execute(
231231
input_path: Path,
232232
output_path: Path,
233233
*,
234-
product_file_type: str = 'NPZ',
235234
fluorescence_input_file_path: Path | None = None,
236235
fluorescence_output_file_path: Path | None = None,
237236
) -> int:
@@ -241,39 +240,20 @@ def batch_mode_execute(
241240
self.reconstructor.reconstructor_api.save_model(output_path)
242241
return output.result
243242

244-
input_product_index = self.product.product_api.open_product(
245-
input_path, file_type=product_file_type
246-
)
247-
248-
if input_product_index < 0:
249-
logger.error(f'Failed to open product "{input_path}"!')
250-
return -1
251-
252243
if action.lower() == 'reconstruct':
253-
logger.info('Reconstructing...')
254-
output_product_index = self.reconstructor.reconstructor_api.reconstruct(
255-
input_product_index
256-
)
257-
self.reconstructor.reconstructor_api.process_results(block=True)
258-
logger.info('Reconstruction complete.')
259-
260-
self.product.product_api.save_product(
261-
output_product_index, output_path, file_type=product_file_type
262-
)
244+
input_product_api = self.workflow.workflow_api.open_product(input_path)
245+
output_product_api = input_product_api.reconstruct_local(block=True)
246+
output_product_api.save_product(output_path)
263247

264248
if (
265249
fluorescence_input_file_path is not None
266250
and fluorescence_output_file_path is not None
267251
):
268252
self.fluorescence_core.enhance_fluorescence(
269-
output_product_index,
253+
output_product_api.get_product_index(),
270254
fluorescence_input_file_path,
271255
fluorescence_output_file_path,
272256
)
273-
elif action.lower() == 'prepare_training_data':
274-
self.reconstructor.reconstructor_api.export_training_data(
275-
output_path, input_product_index
276-
)
277257
else:
278258
logger.error(f'Unknown batch mode action "{action}"!')
279259
return -1

src/ptychodus/model/patterns/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,13 @@ def start(self) -> None:
4747
pass
4848

4949
def stop(self) -> None:
50-
self.dataset.finish_loading(block=False)
50+
self.patterns_api.finish_assembling_diffraction_patterns(block=False)
5151

5252
def _update(self, observable: Observable) -> None:
5353
if observable is self._reinit_observable:
5454
self.patterns_api.open_patterns(
5555
file_path=self.pattern_settings.file_path.get_value(),
5656
file_type=self.pattern_settings.file_type.get_value(),
5757
)
58-
self.dataset.start_loading()
58+
self.patterns_api.start_assembling_diffraction_patterns()
59+
self.patterns_api.finish_assembling_diffraction_patterns(block=True)

src/ptychodus/model/product/item_factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def create_from_settings(self) -> ProductRepositoryItem:
136136
probe_item = self._probe_item_factory.create_from_settings(geometry)
137137
object_item = self._object_item_factory.create_from_settings(geometry)
138138

139-
return ProductRepositoryItem(
139+
item = ProductRepositoryItem(
140140
parent=self._repository,
141141
metadata_item=metadata_item,
142142
scan_item=scan_item,
@@ -146,3 +146,5 @@ def create_from_settings(self) -> ProductRepositoryItem:
146146
validator=ProductValidator(self._dataset, scan_item, geometry, probe_item, object_item),
147147
costs=list(),
148148
)
149+
logger.debug(f'Created product from settings: {item.get_name()}')
150+
return item

src/ptychodus/model/ptychi/helper.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -697,16 +697,14 @@ def create_product(
697697

698698
corrected_scan_points: list[ScanPoint] = list()
699699
object_geometry = object_in.get_geometry()
700-
rx_px = object_geometry.width_px / 2
701-
ry_px = object_geometry.height_px / 2
702700

703701
for uncorrected_point, pos_x_px, pos_y_px in zip(
704702
product.positions, position_x_px, position_y_px
705703
):
706704
object_point = ObjectPoint(
707705
index=uncorrected_point.index,
708-
position_x_px=float(pos_x_px + rx_px),
709-
position_y_px=float(pos_y_px + ry_px),
706+
position_x_px=float(pos_x_px),
707+
position_y_px=float(pos_y_px),
710708
)
711709
scan_point = object_geometry.map_object_point_to_scan_point(object_point)
712710
corrected_scan_points.append(scan_point)

src/ptychodus/model/reconstructor/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,7 @@ def train(self, data_path: Path) -> TrainOutput:
155155
logger.warning('Reconstructor is not trainable!')
156156

157157
return result
158+
159+
def set_reconstructor(self, name: str) -> str:
160+
self._reconstructor_chooser.set_current_plugin(name)
161+
return self._reconstructor_chooser.get_current_plugin().simple_name

0 commit comments

Comments
 (0)