@@ -29,6 +29,7 @@ def load_inf_model(
2929 init_threads : bool = True ,
3030 allow_non_final : bool = True ,
3131 inference_augmentation : bool = False ,
32+ use_gaussian = True ,
3233 verbose : bool = False ,
3334 gpu = None ,
3435) -> nnUNetPredictor :
@@ -69,11 +70,11 @@ def load_inf_model(
6970
7071 predictor = nnUNetPredictor (
7172 tile_step_size = step_size ,
72- use_gaussian = True ,
73+ use_gaussian = use_gaussian ,
7374 use_mirroring = inference_augmentation , # <- mirroring augmentation!
7475 perform_everything_on_gpu = ddevice != "cpu" ,
7576 device = device ,
76- verbose = False ,
77+ verbose = verbose ,
7778 verbose_preprocessing = False ,
7879 cuda_id = 0 if gpu is None else gpu ,
7980 )
@@ -116,6 +117,8 @@ def run_inference(
116117 Returns:
117118 Segmentation (NII), Uncertainty Map (NII), Softmax Logits (numpy arr)
118119 """
120+ if logits :
121+ raise NotImplementedError ("logits=True" )
119122 if isinstance (input_nii , str ):
120123 assert input_nii .endswith (".nii.gz" ), f"input file is not a .nii.gz! Got { input_nii } "
121124 input_nii = NII .load (input_nii , seg = False )
@@ -124,52 +127,80 @@ def run_inference(
124127 if isinstance (input_nii , NII ):
125128 input_nii = [input_nii ]
126129 orientation = input_nii [0 ].orientation
127- zoom = input_nii [0 ].zoom
128130
129131 img_arrs = []
130132 # Prepare for nnUNet behavior
131133 for i in input_nii :
132134 if reorient_PIR :
133135 i .reorient_ ()
134- sitk_nii = sitk_utils .nii_to_sitk (i )
135- nii_img_converted = sitk .GetArrayFromImage (sitk_nii ).astype (np .float16 )[np .newaxis , :]
136- # nii_img_converted = i.get_array()
137- # nii_img_converted = np.pad(nii_img_converted, pad_width=pad_size, mode="edge")
138- # nii_img_converted = np.swapaxes(nii_img_converted, 0, 2)[np.newaxis, :].astype(np.float16)
136+ a = i .get_array ().astype (np .float16 )
137+ nii_img_converted = np .transpose (a , axes = a .ndim - 1 - np .arange (a .ndim ))[np .newaxis , :]
139138 img_arrs .append (nii_img_converted )
140139 try :
141140 img = np .vstack (img_arrs )
142141 except Exception :
143- print ([a .shape for a in img_arrs ])
142+ print ("could not stack images; shapes=" , [a .shape for a in img_arrs ])
144143 raise
145144 props = {
146- "sitk_stuff" : {
147- # this saves the sitk geometry information. This part is NOT used by nnU-Net!
148- "spacing" : sitk_nii .GetSpacing (), # type:ignore
149- "origin" : sitk_nii .GetOrigin (), # type:ignore
150- "direction" : sitk_nii .GetDirection (), # type:ignore
151- },
152- "spacing" : zoom [::- 1 ], # PIR
145+ "spacing" : i .zoom [::- 1 ], # PIR
153146 }
154- out = predictor .predict_single_npy_array (img , props , save_or_return_probabilities = logits )
155- if logits :
156- segmentation , _ , softmax_logits = out # type: ignore
157- softmax_logits = np .expand_dims (softmax_logits .astype (np .float16 ), 0 )
158- # softmax_logits = np.swapaxes(softmax_logits, 0, 3)
159- # PRI label
160- # softmax_logits = np.swapaxes(softmax_logits, 1, 2)
161- else :
162- segmentation , _ = out # type: ignore
163- softmax_logits = None
164- itk_image = sitk .GetImageFromArray (segmentation .astype (np .uint8 ))
165- itk_image .SetSpacing (sitk_nii .GetSpacing ())
166- itk_image .SetOrigin (sitk_nii .GetOrigin ())
167- itk_image .SetDirection (sitk_nii .GetDirection ())
168- seg_nii = sitk_utils .sitk_to_nii (itk_image , True )
169-
170- # segmentation = np.swapaxes(segmentation, 0, 2)
171- # assert isinstance(segmentation, np.ndarray)
172- # seg_nii = NII(nib.ni1.Nifti1Image(segmentation, affine=affine, header=header), seg=True)
173-
147+ out = predictor .predict_single_npy_array (img , props , logits = False , rescale = False )
148+ segmentation : np .ndarray = out # type: ignore
149+ softmax_logits = None
150+ segmentation = np .transpose (segmentation .astype (np .uint8 ), axes = segmentation .ndim - 1 - np .arange (segmentation .ndim ))
151+ assert segmentation .shape == input_nii [0 ].shape
152+ seg_nii = input_nii [0 ].set_array (segmentation .astype (np .uint8 ), seg = True )
174153 seg_nii .reorient_ (orientation , verbose = False )
175154 return seg_nii , None , softmax_logits
155+
156+
157+ # def predict_single_npy_array(predictor: nnUNetPredictor, img, props, logits, rescale):
158+ # return predictor.predict_single_npy_array(img, props, save_or_return_probabilities=logits, rescale=rescale)
159+ #
160+ # def fun(x):
161+ # return predictor.predict_single_npy_array(x, props, save_or_return_probabilities=False)[0][None]
162+ #
163+ # p = 750 if max_v % 700 > max_v % 800 else 800
164+ # patch_size = tuple(p for _ in img.shape)
165+ # overlap = min(50, max(predictor.configuration_manager.patch_size) // 2)
166+ # print(f"image very large ({img.shape}>1000); use sliding window", f"{patch_size=}", predictor.configuration_manager.patch_size)
167+ #
168+ # return sliding_nd_slices(img, patch_size=patch_size, overlap=overlap, fun=fun)[0], None
169+
170+
171+ def sliding_nd_slices (arr : np .ndarray , patch_size , overlap , fun ):
172+ print ("sliding window" )
173+ step = tuple (p - overlap for p in patch_size )
174+ half_overlap = overlap // 2
175+ shape = arr .shape
176+
177+ # Compute number of steps in each dimension
178+ ranges = [range (0 , max (s , 1 ), st ) if s != 1 else [0 ] for s , st in zip (shape , step )]
179+ result = np .zeros_like (arr )
180+ for starts in np .ndindex (* [len (r ) for r in ranges ]):
181+ # Compute actual start and end indices for this patch
182+ idx_start = [ranges [dim ][i ] for dim , i in enumerate (starts )]
183+ idx_start2 = [ranges [dim ][i ] + half_overlap if ranges [dim ][i ] != 0 else 0 for dim , i in enumerate (starts )]
184+ idx_start3 = [half_overlap if ranges [dim ][i ] != 0 else 0 for dim , i in enumerate (starts )]
185+ idx_end = [min (start + size , shape [dim ]) for start , size , dim in zip (idx_start , patch_size , range (len (shape )))]
186+ idx_end2 = [
187+ (start + size - half_overlap if start + size < shape [dim ] else shape [dim ])
188+ for start , size , dim in zip (idx_start , patch_size , range (len (shape )))
189+ ]
190+ idx_end3 = [(- half_overlap if a != shape [dim ] else None ) for a , dim in zip (idx_end2 , range (len (shape )))]
191+
192+ slices = tuple (slice (s , e ) for s , e in zip (idx_start , idx_end ))
193+ slices2 = tuple (slice (s , e ) for s , e in zip (idx_start2 , idx_end2 ))
194+ slices3 = tuple (slice (s , e ) for s , e in zip (idx_start3 , idx_end3 ))
195+ print ("sliding window" , slices )
196+ patch = arr [slices ]
197+ patch = fun (patch )
198+ result [slices2 ] = patch [slices3 ]
199+ return result
200+
201+
202+ # if __name__ == "__main__":
203+ # np.zeros((1, 2243, 472, 622))
204+ # x = sliding_nd_slices()
205+ # max_v=2243, (1, 2243, 472, 622)
206+ # image very large ((1, 2243, 472, 622)>1000); use sliding window patch_size=<generator object predict_single_npy_array.<locals>.<genexpr> at 0x7f89c12dfac0> [160, 192, 192]
0 commit comments