Skip to content

Commit 198a79a

Browse files
committed
ENH: Add rtkspectralrooster python application
1 parent 9df7de2 commit 198a79a

File tree

7 files changed

+371
-0
lines changed

7 files changed

+371
-0
lines changed
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
#!/usr/bin/env python
2+
import argparse
3+
import itk
4+
import numpy as np
5+
from itk import RTK as rtk
6+
7+
8+
def build_parser():
9+
parser = rtk.RTKArgumentParser(
10+
description=(
11+
"Reconstructs a 3D + material vector volume from a vector projection stack,"
12+
" alternating between conjugate gradient optimization and regularization."
13+
)
14+
)
15+
16+
parser.add_argument(
17+
"--geometry", "-g", help="XML geometry file name", required=True
18+
)
19+
parser.add_argument("--output", "-o", help="Output file name", required=True)
20+
parser.add_argument(
21+
"--niter", "-n", help="Number of main loop iterations", type=int, default=5
22+
)
23+
parser.add_argument(
24+
"--cgiter",
25+
help="Number of conjugate gradient nested iterations",
26+
type=int,
27+
default=4,
28+
)
29+
parser.add_argument(
30+
"--cudacg",
31+
help="Perform conjugate gradient calculations on GPU",
32+
action="store_true",
33+
)
34+
parser.add_argument("--input", "-i", help="Input volume (materials) file")
35+
parser.add_argument(
36+
"--projection", "-p", help="Vector projections file", required=True
37+
)
38+
parser.add_argument(
39+
"--nodisplaced",
40+
help="Disable the displaced detector filter",
41+
action="store_true",
42+
)
43+
44+
# Regularization
45+
parser.add_argument(
46+
"--nopositivity", help="Do not enforce positivity", action="store_true"
47+
)
48+
parser.add_argument("--tviter", help="TV iterations", type=int, default=10)
49+
parser.add_argument(
50+
"--gamma_space", help="Spatial TV regularization parameter", type=float
51+
)
52+
parser.add_argument("--threshold", help="Wavelets soft threshold", type=float)
53+
parser.add_argument("--order", help="Wavelets order", type=int, default=5)
54+
parser.add_argument("--levels", help="Wavelets levels", type=int, default=3)
55+
parser.add_argument(
56+
"--gamma_time", help="Temporal TV regularization parameter", type=float
57+
)
58+
parser.add_argument(
59+
"--lambda_time", help="Temporal L0 regularization parameter", type=float
60+
)
61+
parser.add_argument("--l0iter", help="L0 iterations", type=int, default=5)
62+
parser.add_argument("--gamma_tnv", help="TNV regularization parameter", type=float)
63+
64+
# Projector choices
65+
rtk.add_rtkprojectors_group(parser)
66+
rtk.add_rtk3Doutputimage_group(parser)
67+
return parser
68+
69+
70+
def process(args_info: argparse.Namespace):
71+
PixelValueType = itk.F
72+
Dimension = 3
73+
74+
DecomposedProjectionType = itk.VectorImage[PixelValueType, Dimension]
75+
MaterialsVolumeType = itk.VectorImage[PixelValueType, Dimension]
76+
VolumeSeriesType = itk.Image[PixelValueType, Dimension + 1]
77+
ProjectionStackType = itk.Image[PixelValueType, Dimension]
78+
79+
# Projections reader
80+
if args_info.verbose:
81+
print(f"Reading decomposed projections from {args_info.projection}...")
82+
proj_reader = itk.ImageFileReader[DecomposedProjectionType].New()
83+
proj_reader.SetFileName(args_info.projection)
84+
proj_reader.Update()
85+
decomposedProjection = proj_reader.GetOutput()
86+
87+
NumberOfMaterials = decomposedProjection.GetVectorLength()
88+
89+
# Geometry
90+
if args_info.verbose:
91+
print(f"Reading geometry from {args_info.geometry}...")
92+
geometry = rtk.read_geometry(args_info.geometry)
93+
94+
# Create 4D input. Fill it either with an existing materials volume
95+
# read from a file or a blank image
96+
vecVol2VolSeries = rtk.VectorImageToImageFilter[
97+
MaterialsVolumeType, VolumeSeriesType
98+
].New()
99+
100+
if args_info.input is not None:
101+
if args_info.like is not None:
102+
print("WARNING: --like ignored since --input was given")
103+
if args_info.verbose:
104+
print(f"Reading input volume {args_info.input}...")
105+
reference = itk.imread(args_info.input)
106+
vecVol2VolSeries.SetInput(reference)
107+
vecVol2VolSeries.Update()
108+
input = vecVol2VolSeries.GetOutput()
109+
elif args_info.like is not None:
110+
if args_info.verbose:
111+
print(f"Reading reference volume {args_info.like} to infer geometry...")
112+
reference = itk.imread(args_info.like)
113+
vecVol2VolSeries.SetInput(reference)
114+
vecVol2VolSeries.UpdateOutputInformation()
115+
constantImageSource = rtk.ConstantImageSource[VolumeSeriesType].New()
116+
constantImageSource.SetInformationFromImage(vecVol2VolSeries.GetOutput())
117+
constantImageSource.Update()
118+
input = constantImageSource.GetOutput()
119+
else:
120+
# Create new empty volume
121+
constantImageSource = rtk.ConstantImageSource[VolumeSeriesType].New()
122+
123+
imageSize = itk.Size[4]()
124+
imageSize.Fill(int(args_info.size[0]))
125+
for i in range(min(len(args_info.size), Dimension)):
126+
imageSize[i] = int(args_info.size[i])
127+
imageSize[Dimension] = NumberOfMaterials
128+
129+
imageSpacing = itk.Vector[itk.D, 4]()
130+
imageSpacing.Fill(float(args_info.spacing[0]))
131+
for i in range(min(len(args_info.spacing), Dimension)):
132+
imageSpacing[i] = float(args_info.spacing[i])
133+
imageSpacing[Dimension] = 1.0
134+
135+
imageOrigin = itk.Point[itk.D, 4]()
136+
for i in range(Dimension):
137+
imageOrigin[i] = imageSpacing[i] * (int(imageSize[i]) - 1) * -0.5
138+
if args_info.origin is not None:
139+
for i in range(min(len(args_info.origin), Dimension)):
140+
imageOrigin[i] = float(args_info.origin[i])
141+
imageOrigin[Dimension] = 0.0
142+
143+
imageDirection = itk.Matrix[itk.D, 4, 4]()
144+
imageDirection.SetIdentity()
145+
if args_info.direction is not None:
146+
for i in range(Dimension):
147+
for j in range(Dimension):
148+
imageDirection[i][j] = float(args_info.direction[i * Dimension + j])
149+
150+
constantImageSource.SetOrigin(imageOrigin)
151+
constantImageSource.SetSpacing(imageSpacing)
152+
constantImageSource.SetDirection(imageDirection)
153+
constantImageSource.SetSize(imageSize)
154+
constantImageSource.SetConstant(0.0)
155+
constantImageSource.Update()
156+
input = constantImageSource.GetOutput()
157+
158+
# Duplicate geometry and transform the N M-vector projections into N*M scalar projections
159+
# Each material will occupy one frame of the 4D reconstruction, therefore all projections
160+
# of one material need to have the same phase.
161+
# Note : the 4D CG filter is optimized when projections with identical phases are packed together
162+
163+
# Geometry
164+
initialNumberOfProjections = int(
165+
decomposedProjection.GetLargestPossibleRegion().GetSize()[Dimension - 1]
166+
)
167+
for material in range(1, NumberOfMaterials):
168+
for proj in range(initialNumberOfProjections):
169+
geometry.AddProjectionInRadians(
170+
geometry.GetSourceToIsocenterDistances()[proj],
171+
geometry.GetSourceToDetectorDistances()[proj],
172+
geometry.GetGantryAngles()[proj],
173+
geometry.GetProjectionOffsetsX()[proj],
174+
geometry.GetProjectionOffsetsY()[proj],
175+
geometry.GetOutOfPlaneAngles()[proj],
176+
geometry.GetInPlaneAngles()[proj],
177+
geometry.GetSourceOffsetsX()[proj],
178+
geometry.GetSourceOffsetsY()[proj],
179+
)
180+
geometry.SetCollimationOfLastProjection(
181+
geometry.GetCollimationUInf()[proj],
182+
geometry.GetCollimationUSup()[proj],
183+
geometry.GetCollimationVInf()[proj],
184+
geometry.GetCollimationVSup()[proj],
185+
)
186+
187+
# Signal
188+
fakeSignal = []
189+
for material in range(NumberOfMaterials):
190+
v = round(float(material) / float(NumberOfMaterials) * 1000.0) / 1000.0
191+
for proj in range(initialNumberOfProjections):
192+
fakeSignal.append(v)
193+
194+
# Projections
195+
vproj2proj = rtk.VectorImageToImageFilter[
196+
DecomposedProjectionType, ProjectionStackType
197+
].New()
198+
vproj2proj.SetInput(decomposedProjection)
199+
vproj2proj.Update()
200+
201+
# Release the memory holding the stack of original projections
202+
decomposedProjection.ReleaseData()
203+
204+
# Compute the interpolation weights
205+
signalToInterpolationWeights = rtk.SignalToInterpolationWeights.New()
206+
signalToInterpolationWeights.SetSignal(fakeSignal)
207+
signalToInterpolationWeights.SetNumberOfReconstructedFrames(NumberOfMaterials)
208+
signalToInterpolationWeights.Update()
209+
210+
# Set the forward and back projection filters to be used
211+
# Instantiate ROOSTER with CUDA image types if available, otherwise CPU types
212+
if hasattr(itk, "CudaImage"):
213+
cudaVolumeSeriesType = itk.CudaImage[PixelValueType, Dimension + 1]
214+
cudaProjectionStackType = itk.CudaImage[PixelValueType, Dimension]
215+
rooster = rtk.FourDROOSTERConeBeamReconstructionFilter[
216+
cudaVolumeSeriesType, cudaProjectionStackType
217+
].New()
218+
rooster.SetInputVolumeSeries(itk.cuda_image_from_image(input))
219+
rooster.SetInputProjectionStack(
220+
itk.cuda_image_from_image(vproj2proj.GetOutput())
221+
)
222+
else:
223+
rooster = rtk.FourDROOSTERConeBeamReconstructionFilter[
224+
VolumeSeriesType, ProjectionStackType
225+
].New()
226+
rooster.SetInputVolumeSeries(input)
227+
rooster.SetInputProjectionStack(vproj2proj.GetOutput())
228+
229+
# Configure projectors from args
230+
rtk.SetForwardProjectionFromArgParse(args_info, rooster)
231+
rtk.SetBackProjectionFromArgParse(args_info, rooster)
232+
233+
rooster.SetCG_iterations(args_info.cgiter)
234+
rooster.SetMainLoop_iterations(args_info.niter)
235+
rooster.SetCudaConjugateGradient(args_info.cudacg)
236+
rooster.SetDisableDisplacedDetectorFilter(args_info.nodisplaced)
237+
238+
rooster.SetGeometry(geometry)
239+
rooster.SetWeights(signalToInterpolationWeights.GetOutput())
240+
rooster.SetSignal(fakeSignal)
241+
242+
# For each optional regularization step, set whether or not
243+
# it should be performed, and provide the necessary inputs
244+
245+
# Positivity
246+
rooster.SetPerformPositivity(not args_info.nopositivity)
247+
248+
# No motion mask is used, since there is no motion
249+
rooster.SetPerformMotionMask(False)
250+
251+
# Spatial TV
252+
if args_info.gamma_space is not None:
253+
rooster.SetGammaTVSpace(args_info.gamma_space)
254+
rooster.SetTV_iterations(args_info.tviter)
255+
rooster.SetPerformTVSpatialDenoising(True)
256+
else:
257+
rooster.SetPerformTVSpatialDenoising(False)
258+
259+
# Spatial wavelets
260+
if args_info.threshold is not None:
261+
rooster.SetSoftThresholdWavelets(args_info.threshold)
262+
rooster.SetOrder(args_info.order)
263+
rooster.SetNumberOfLevels(args_info.levels)
264+
rooster.SetPerformWaveletsSpatialDenoising(True)
265+
else:
266+
rooster.SetPerformWaveletsSpatialDenoising(False)
267+
268+
# Temporal TV
269+
if args_info.gamma_time is not None:
270+
rooster.SetGammaTVTime(args_info.gamma_time)
271+
rooster.SetTV_iterations(args_info.tviter)
272+
rooster.SetPerformTVTemporalDenoising(True)
273+
else:
274+
rooster.SetPerformTVTemporalDenoising(False)
275+
276+
# Temporal L0
277+
if args_info.lambda_time is not None:
278+
rooster.SetLambdaL0Time(args_info.lambda_time)
279+
rooster.SetL0_iterations(args_info.l0iter)
280+
rooster.SetPerformL0TemporalDenoising(True)
281+
else:
282+
rooster.SetPerformL0TemporalDenoising(False)
283+
284+
# Total nuclear variation
285+
if args_info.gamma_tnv is not None:
286+
rooster.SetGammaTNV(args_info.gamma_tnv)
287+
rooster.SetTV_iterations(args_info.tviter)
288+
rooster.SetPerformTNVDenoising(True)
289+
else:
290+
rooster.SetPerformTNVDenoising(False)
291+
292+
if args_info.verbose:
293+
print("Running ROOSTER reconstruction...")
294+
rooster.Update()
295+
296+
# Convert 4D volume series (itk.Image[...,4]) to a 3D VectorImage using NumPy
297+
vol4d = rooster.GetOutput()
298+
# Extract numpy array from ITK image. For a 4D image the returned array
299+
# shape is typically (t, z, y, x). We want a 3D vector image with shape
300+
# (z, y, x, components).
301+
arr4d = itk.GetArrayFromImage(vol4d)
302+
if arr4d.ndim != 4:
303+
raise RuntimeError(
304+
f"Expected 4D array from ROOSTER output, got shape {arr4d.shape}"
305+
)
306+
307+
# Detect which axis corresponds to the components (materials)
308+
if arr4d.shape[0] == NumberOfMaterials:
309+
# array is (components, z, y, x) -> transpose to (z, y, x, components)
310+
arr_vec = np.transpose(arr4d, (1, 2, 3, 0))
311+
elif arr4d.shape[-1] == NumberOfMaterials:
312+
# already (z, y, x, components)
313+
arr_vec = arr4d
314+
else:
315+
# Fallback: try to move the axis with length NumberOfMaterials to last
316+
comp_axis = None
317+
for ax in range(4):
318+
if arr4d.shape[ax] == NumberOfMaterials:
319+
comp_axis = ax
320+
break
321+
if comp_axis is None:
322+
raise RuntimeError(
323+
"Cannot locate materials/components axis in ROOSTER output array"
324+
)
325+
# move components axis to last
326+
order = [i for i in range(4) if i != comp_axis] + [comp_axis]
327+
arr_vec = np.transpose(arr4d, tuple(order))
328+
329+
# Create an itk.VectorImage from the numpy array
330+
vec_img = itk.image_from_array(arr_vec, is_vector=True)
331+
332+
# Preserve spacing/origin/direction for the spatial 3D axes
333+
spacing4 = vol4d.GetSpacing()
334+
origin4 = vol4d.GetOrigin()
335+
direction4 = vol4d.GetDirection()
336+
vec_img.SetSpacing(tuple(spacing4[0:3]))
337+
vec_img.SetOrigin(tuple(origin4[0:3]))
338+
# Build 3x3 direction matrix
339+
dir3 = itk.Matrix[itk.D, 3, 3]()
340+
for i in range(3):
341+
for j in range(3):
342+
dir3[i][j] = direction4[i][j]
343+
vec_img.SetDirection(dir3)
344+
345+
# Write
346+
if args_info.verbose:
347+
print(f"Writing output to {args_info.output}...")
348+
writer = itk.ImageFileWriter[MaterialsVolumeType].New()
349+
writer.SetFileName(args_info.output)
350+
writer.SetInput(vec_img)
351+
writer.Update()
352+
353+
354+
def main(argv=None):
355+
parser = build_parser()
356+
args_info = parser.parse_args(argv)
357+
process(args_info)
358+
359+
360+
if __name__ == "__main__":
361+
main()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ rtkprojectshepploganphantom= "itk.rtkprojectshepploganphantom:main"
6868
rtkshowgeometry = "itk.rtkshowgeometry:main"
6969
rtksart = "itk.rtksart:main"
7070
rtksimulatedgeometry = "itk.rtksimulatedgeometry:main"
71+
rtkspectralrooster = "itk.rtkspectralrooster:main"
7172
rtksubselect = "itk.rtksubselect:main"
7273
rtktotalvariationdenoising = "itk.rtktotalvariationdenoising:main"
7374
rtkvarianobigeometry = "itk.rtkvarianobigeometry:main"

wrapping/__init_rtk__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"rtkshowgeometry",
5151
"rtksart",
5252
"rtksimulatedgeometry",
53+
"rtkspectralrooster",
5354
"rtksubselect",
5455
"rtktotalvariationdenoising",
5556
"rtkvarianobigeometry",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
itk_wrap_include(itkCSVFileReaderBase.h)
2+
itk_wrap_simple_class("itk::CSVFileReaderBase" itkCSVFileReaderBase)
3+
itk_end_wrap_class()

wrapping/itkImageToImageFilterRTK.wrap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ itk_wrap_class("itk::ImageToImageFilter" POINTER)
139139
itk_wrap_template("I${ITKM_${t}}2I${ITKM_${t}}1" "itk::Image<${ITKT_${t}}, 2>, itk::Image<${ITKT_${t}}, 1>")
140140
endif()
141141
itk_wrap_template("I${ITKM_${t}}3VI${ITKM_${t}}2" "itk::Image<${ITKT_${t}}, 3>, itk::VectorImage<${ITKT_${t}}, 2>")
142+
itk_wrap_template("I${ITKM_${t}}4VI${ITKM_${t}}3" "itk::Image<${ITKT_${t}}, 4>, itk::VectorImage<${ITKT_${t}}, 3>")
142143
itk_wrap_template("VI${ITKM_${t}}2I${ITKM_${t}}3" "itk::VectorImage<${ITKT_${t}}, 2>, itk::Image<${ITKT_${t}}, 3>")
143144
itk_wrap_template("VI${ITKM_${t}}3I${ITKM_${t}}4" "itk::VectorImage<${ITKT_${t}}, 3>, itk::Image<${ITKT_${t}}, 4>")
144145
endforeach()

wrapping/rtkImageToVectorImageFilter.wrap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ itk_wrap_class("rtk::ImageToVectorImageFilter" POINTER)
22
foreach(t ${WRAP_ITK_REAL})
33
itk_wrap_template("I${ITKM_${t}}2VI${ITKM_${t}}2" "itk::Image<${ITKT_${t}}, 2>, itk::VectorImage<${ITKT_${t}}, 2>")
44
itk_wrap_template("I${ITKM_${t}}3VI${ITKM_${t}}2" "itk::Image<${ITKT_${t}}, 3>, itk::VectorImage<${ITKT_${t}}, 2>")
5+
itk_wrap_template("I${ITKM_${t}}4VI${ITKM_${t}}3" "itk::Image<${ITKT_${t}}, 4>, itk::VectorImage<${ITKT_${t}}, 3>")
56
endforeach()
67
itk_end_wrap_class()

0 commit comments

Comments
 (0)