Skip to content

Commit 1af18a0

Browse files
committed
improve nnunet inference for large files
1 parent c42cb25 commit 1af18a0

File tree

4 files changed

+330
-218
lines changed

4 files changed

+330
-218
lines changed

TPTBox/segmentation/TotalVibeSeg/inference_nnunet.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,13 @@ def run_inference_on_file(
5757
padd: int = 0,
5858
ddevice: Literal["cpu", "cuda", "mps"] = "cuda",
5959
_model_path=None,
60+
step_size=0.5,
6061
) -> tuple[Image_Reference, np.ndarray | None]:
6162
global model_path # noqa: PLW0603
6263
if _model_path is not None:
6364
_model_path = Path(_model_path)
6465
model_path = _model_path / "nnUNet_results"
65-
assert model_path.exists(), _model_path
66+
assert model_path.exists(), model_path
6667
if out_file is not None and Path(out_file).exists() and not override:
6768
return out_file, None
6869

@@ -78,7 +79,7 @@ def run_inference_on_file(
7879
nnunet_path = next(next(iter(model_path.glob(f"*{idx:03}*"))).glob("*__nnUNetPlans*"))
7980
folds = [int(f.name.split("fold_")[-1]) for f in nnunet_path.glob("fold*")]
8081
if max_folds is not None:
81-
folds = folds[:max_folds]
82+
folds = max_folds if isinstance(max_folds, list) else folds[:max_folds]
8283

8384
# if idx in _unets:
8485
# nnunet = _unets[idx]
@@ -90,6 +91,7 @@ def run_inference_on_file(
9091
use_folds=tuple(folds) if len(folds) != 5 else None,
9192
gpu=gpu,
9293
ddevice=ddevice,
94+
step_size=step_size,
9395
)
9496

9597
# _unets[idx] = nnunet
@@ -118,6 +120,7 @@ def run_inference_on_file(
118120
if zoom is not None:
119121
input_nii = [i.rescale_(zoom, mode=mode) for i in input_nii]
120122
input_nii = [squash_so_it_fits_in_float16(i) for i in input_nii]
123+
121124
if crop:
122125
crop = input_nii[0].compute_crop(minimum=20)
123126
input_nii = [i.apply_crop(crop) for i in input_nii]
@@ -158,8 +161,13 @@ def run_total_seg(
158161
fill_holes=False,
159162
crop=False,
160163
max_folds: int | None = None,
164+
_model_path=None,
165+
step_size=0.5,
161166
**_kargs,
162167
):
168+
global model_path
169+
if _model_path is not None:
170+
model_path = _model_path
163171
if dataset_id is None:
164172
for idx in known_idx:
165173
download_weights(idx)
@@ -210,4 +218,5 @@ def run_total_seg(
210218
fill_holes=fill_holes,
211219
crop=crop,
212220
max_folds=max_folds,
221+
step_size=step_size,
213222
)[0]

TPTBox/segmentation/nnUnet_utils/export_prediction.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,35 @@ def convert_predicted_logits_to_segmentation_with_correct_shape(
2020
properties_dict: dict,
2121
return_probabilities: bool = False,
2222
num_threads_torch: int = 8,
23+
rescale=True,
2324
):
2425
old_threads = torch.get_num_threads()
2526
torch.set_num_threads(num_threads_torch)
26-
27-
# resample to original shape
28-
current_spacing = (
29-
configuration_manager.spacing
30-
if len(configuration_manager.spacing) == len(properties_dict["shape_after_cropping_and_before_resampling"])
31-
else [properties_dict["spacing"][0], *configuration_manager.spacing]
32-
)
33-
predicted_logits = configuration_manager.resampling_fn_probabilities(
34-
predicted_logits,
35-
properties_dict["shape_after_cropping_and_before_resampling"],
36-
current_spacing,
37-
properties_dict["spacing"],
38-
)
27+
if rescale:
28+
# resample to original shape
29+
current_spacing = (
30+
configuration_manager.spacing
31+
if len(configuration_manager.spacing) == len(properties_dict["shape_after_cropping_and_before_resampling"])
32+
else [properties_dict["spacing"][0], *configuration_manager.spacing]
33+
)
34+
predicted_logits = configuration_manager.resampling_fn_probabilities(
35+
predicted_logits,
36+
properties_dict["shape_after_cropping_and_before_resampling"],
37+
current_spacing,
38+
properties_dict["spacing"],
39+
)
3940
# return value of resampling_fn_probabilities can be ndarray or Tensor but that doesnt matter because
4041
# apply_inference_nonlin will covnert to torch
41-
predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)
42-
del predicted_logits
43-
segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities)
44-
42+
# And this is stupid because convert_probabilities_to_segmentation transforms it back to a numpy...
43+
if label_manager.has_regions:
44+
# Softmax does not change when we use argmax in the next step
45+
predicted_logits = label_manager.apply_inference_nonlin(predicted_logits)
4546
# segmentation may be torch.Tensor but we continue with numpy
46-
if isinstance(segmentation, torch.Tensor):
47-
segmentation = segmentation.cpu().numpy()
47+
if isinstance(predicted_logits, torch.Tensor):
48+
predicted_logits = predicted_logits.cpu().numpy()
4849

50+
segmentation = label_manager.convert_probabilities_to_segmentation(predicted_logits)
51+
del predicted_logits
4952
# put segmentation in bbox (revert cropping)
5053
segmentation_reverted_cropping = np.zeros(
5154
properties_dict["shape_before_cropping"],

TPTBox/segmentation/nnUnet_utils/inference_api.py

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def load_inf_model(
2929
init_threads: bool = True,
3030
allow_non_final: bool = True,
3131
inference_augmentation: bool = False,
32+
use_gaussian=True,
3233
verbose: bool = False,
3334
gpu=None,
3435
) -> nnUNetPredictor:
@@ -69,11 +70,11 @@ def load_inf_model(
6970

7071
predictor = nnUNetPredictor(
7172
tile_step_size=step_size,
72-
use_gaussian=True,
73+
use_gaussian=use_gaussian,
7374
use_mirroring=inference_augmentation, # <- mirroring augmentation!
7475
perform_everything_on_gpu=ddevice != "cpu",
7576
device=device,
76-
verbose=False,
77+
verbose=verbose,
7778
verbose_preprocessing=False,
7879
cuda_id=0 if gpu is None else gpu,
7980
)
@@ -116,6 +117,8 @@ def run_inference(
116117
Returns:
117118
Segmentation (NII), Uncertainty Map (NII), Softmax Logits (numpy arr)
118119
"""
120+
if logits:
121+
raise NotImplementedError("logits=True")
119122
if isinstance(input_nii, str):
120123
assert input_nii.endswith(".nii.gz"), f"input file is not a .nii.gz! Got {input_nii}"
121124
input_nii = NII.load(input_nii, seg=False)
@@ -124,52 +127,80 @@ def run_inference(
124127
if isinstance(input_nii, NII):
125128
input_nii = [input_nii]
126129
orientation = input_nii[0].orientation
127-
zoom = input_nii[0].zoom
128130

129131
img_arrs = []
130132
# Prepare for nnUNet behavior
131133
for i in input_nii:
132134
if reorient_PIR:
133135
i.reorient_()
134-
sitk_nii = sitk_utils.nii_to_sitk(i)
135-
nii_img_converted = sitk.GetArrayFromImage(sitk_nii).astype(np.float16)[np.newaxis, :]
136-
# nii_img_converted = i.get_array()
137-
# nii_img_converted = np.pad(nii_img_converted, pad_width=pad_size, mode="edge")
138-
# nii_img_converted = np.swapaxes(nii_img_converted, 0, 2)[np.newaxis, :].astype(np.float16)
136+
a = i.get_array().astype(np.float16)
137+
nii_img_converted = np.transpose(a, axes=a.ndim - 1 - np.arange(a.ndim))[np.newaxis, :]
139138
img_arrs.append(nii_img_converted)
140139
try:
141140
img = np.vstack(img_arrs)
142141
except Exception:
143-
print([a.shape for a in img_arrs])
142+
print("could not stack images; shapes=", [a.shape for a in img_arrs])
144143
raise
145144
props = {
146-
"sitk_stuff": {
147-
# this saves the sitk geometry information. This part is NOT used by nnU-Net!
148-
"spacing": sitk_nii.GetSpacing(), # type:ignore
149-
"origin": sitk_nii.GetOrigin(), # type:ignore
150-
"direction": sitk_nii.GetDirection(), # type:ignore
151-
},
152-
"spacing": zoom[::-1], # PIR
145+
"spacing": i.zoom[::-1], # PIR
153146
}
154-
out = predictor.predict_single_npy_array(img, props, save_or_return_probabilities=logits)
155-
if logits:
156-
segmentation, _, softmax_logits = out # type: ignore
157-
softmax_logits = np.expand_dims(softmax_logits.astype(np.float16), 0)
158-
# softmax_logits = np.swapaxes(softmax_logits, 0, 3)
159-
# PRI label
160-
# softmax_logits = np.swapaxes(softmax_logits, 1, 2)
161-
else:
162-
segmentation, _ = out # type: ignore
163-
softmax_logits = None
164-
itk_image = sitk.GetImageFromArray(segmentation.astype(np.uint8))
165-
itk_image.SetSpacing(sitk_nii.GetSpacing())
166-
itk_image.SetOrigin(sitk_nii.GetOrigin())
167-
itk_image.SetDirection(sitk_nii.GetDirection())
168-
seg_nii = sitk_utils.sitk_to_nii(itk_image, True)
169-
170-
# segmentation = np.swapaxes(segmentation, 0, 2)
171-
# assert isinstance(segmentation, np.ndarray)
172-
# seg_nii = NII(nib.ni1.Nifti1Image(segmentation, affine=affine, header=header), seg=True)
173-
147+
out = predictor.predict_single_npy_array(img, props, logits=False, rescale=False)
148+
segmentation: np.ndarray = out # type: ignore
149+
softmax_logits = None
150+
segmentation = np.transpose(segmentation.astype(np.uint8), axes=segmentation.ndim - 1 - np.arange(segmentation.ndim))
151+
assert segmentation.shape == input_nii[0].shape
152+
seg_nii = input_nii[0].set_array(segmentation.astype(np.uint8), seg=True)
174153
seg_nii.reorient_(orientation, verbose=False)
175154
return seg_nii, None, softmax_logits
155+
156+
157+
# def predict_single_npy_array(predictor: nnUNetPredictor, img, props, logits, rescale):
158+
# return predictor.predict_single_npy_array(img, props, save_or_return_probabilities=logits, rescale=rescale)
159+
#
160+
# def fun(x):
161+
# return predictor.predict_single_npy_array(x, props, save_or_return_probabilities=False)[0][None]
162+
#
163+
# p = 750 if max_v % 700 > max_v % 800 else 800
164+
# patch_size = tuple(p for _ in img.shape)
165+
# overlap = min(50, max(predictor.configuration_manager.patch_size) // 2)
166+
# print(f"image very large ({img.shape}>1000); use sliding window", f"{patch_size=}", predictor.configuration_manager.patch_size)
167+
#
168+
# return sliding_nd_slices(img, patch_size=patch_size, overlap=overlap, fun=fun)[0], None
169+
170+
171+
def sliding_nd_slices(arr: np.ndarray, patch_size, overlap, fun):
172+
print("sliding window")
173+
step = tuple(p - overlap for p in patch_size)
174+
half_overlap = overlap // 2
175+
shape = arr.shape
176+
177+
# Compute number of steps in each dimension
178+
ranges = [range(0, max(s, 1), st) if s != 1 else [0] for s, st in zip(shape, step)]
179+
result = np.zeros_like(arr)
180+
for starts in np.ndindex(*[len(r) for r in ranges]):
181+
# Compute actual start and end indices for this patch
182+
idx_start = [ranges[dim][i] for dim, i in enumerate(starts)]
183+
idx_start2 = [ranges[dim][i] + half_overlap if ranges[dim][i] != 0 else 0 for dim, i in enumerate(starts)]
184+
idx_start3 = [half_overlap if ranges[dim][i] != 0 else 0 for dim, i in enumerate(starts)]
185+
idx_end = [min(start + size, shape[dim]) for start, size, dim in zip(idx_start, patch_size, range(len(shape)))]
186+
idx_end2 = [
187+
(start + size - half_overlap if start + size < shape[dim] else shape[dim])
188+
for start, size, dim in zip(idx_start, patch_size, range(len(shape)))
189+
]
190+
idx_end3 = [(-half_overlap if a != shape[dim] else None) for a, dim in zip(idx_end2, range(len(shape)))]
191+
192+
slices = tuple(slice(s, e) for s, e in zip(idx_start, idx_end))
193+
slices2 = tuple(slice(s, e) for s, e in zip(idx_start2, idx_end2))
194+
slices3 = tuple(slice(s, e) for s, e in zip(idx_start3, idx_end3))
195+
print("sliding window", slices)
196+
patch = arr[slices]
197+
patch = fun(patch)
198+
result[slices2] = patch[slices3]
199+
return result
200+
201+
202+
# if __name__ == "__main__":
203+
# np.zeros((1, 2243, 472, 622))
204+
# x = sliding_nd_slices()
205+
# max_v=2243, (1, 2243, 472, 622)
206+
# image very large ((1, 2243, 472, 622)>1000); use sliding window patch_size=<generator object predict_single_npy_array.<locals>.<genexpr> at 0x7f89c12dfac0> [160, 192, 192]

0 commit comments

Comments
 (0)