|
| 1 | +#!/usr/bin/env python |
| 2 | +import argparse |
| 3 | +import itk |
| 4 | +import numpy as np |
| 5 | +from itk import RTK as rtk |
| 6 | + |
| 7 | + |
| 8 | +def build_parser(): |
| 9 | + parser = rtk.RTKArgumentParser( |
| 10 | + description=( |
| 11 | + "Reconstructs a 3D + material vector volume from a vector projection stack," |
| 12 | + " alternating between conjugate gradient optimization and regularization." |
| 13 | + ) |
| 14 | + ) |
| 15 | + |
| 16 | + parser.add_argument( |
| 17 | + "--geometry", "-g", help="XML geometry file name", required=True |
| 18 | + ) |
| 19 | + parser.add_argument("--output", "-o", help="Output file name", required=True) |
| 20 | + parser.add_argument( |
| 21 | + "--niter", "-n", help="Number of main loop iterations", type=int, default=5 |
| 22 | + ) |
| 23 | + parser.add_argument( |
| 24 | + "--cgiter", |
| 25 | + help="Number of conjugate gradient nested iterations", |
| 26 | + type=int, |
| 27 | + default=4, |
| 28 | + ) |
| 29 | + parser.add_argument( |
| 30 | + "--cudacg", |
| 31 | + help="Perform conjugate gradient calculations on GPU", |
| 32 | + action="store_true", |
| 33 | + ) |
| 34 | + parser.add_argument("--input", "-i", help="Input volume (materials) file") |
| 35 | + parser.add_argument( |
| 36 | + "--projection", "-p", help="Vector projections file", required=True |
| 37 | + ) |
| 38 | + parser.add_argument( |
| 39 | + "--nodisplaced", |
| 40 | + help="Disable the displaced detector filter", |
| 41 | + action="store_true", |
| 42 | + ) |
| 43 | + |
| 44 | + # Regularization |
| 45 | + parser.add_argument( |
| 46 | + "--nopositivity", help="Do not enforce positivity", action="store_true" |
| 47 | + ) |
| 48 | + parser.add_argument("--tviter", help="TV iterations", type=int, default=10) |
| 49 | + parser.add_argument( |
| 50 | + "--gamma_space", help="Spatial TV regularization parameter", type=float |
| 51 | + ) |
| 52 | + parser.add_argument("--threshold", help="Wavelets soft threshold", type=float) |
| 53 | + parser.add_argument("--order", help="Wavelets order", type=int, default=5) |
| 54 | + parser.add_argument("--levels", help="Wavelets levels", type=int, default=3) |
| 55 | + parser.add_argument( |
| 56 | + "--gamma_time", help="Temporal TV regularization parameter", type=float |
| 57 | + ) |
| 58 | + parser.add_argument( |
| 59 | + "--lambda_time", help="Temporal L0 regularization parameter", type=float |
| 60 | + ) |
| 61 | + parser.add_argument("--l0iter", help="L0 iterations", type=int, default=5) |
| 62 | + parser.add_argument("--gamma_tnv", help="TNV regularization parameter", type=float) |
| 63 | + |
| 64 | + # Projector choices |
| 65 | + rtk.add_rtkprojectors_group(parser) |
| 66 | + rtk.add_rtk3Doutputimage_group(parser) |
| 67 | + return parser |
| 68 | + |
| 69 | + |
| 70 | +def process(args_info: argparse.Namespace): |
| 71 | + PixelValueType = itk.F |
| 72 | + Dimension = 3 |
| 73 | + |
| 74 | + DecomposedProjectionType = itk.VectorImage[PixelValueType, Dimension] |
| 75 | + MaterialsVolumeType = itk.VectorImage[PixelValueType, Dimension] |
| 76 | + VolumeSeriesType = itk.Image[PixelValueType, Dimension + 1] |
| 77 | + ProjectionStackType = itk.Image[PixelValueType, Dimension] |
| 78 | + |
| 79 | + # Projections reader |
| 80 | + if args_info.verbose: |
| 81 | + print(f"Reading decomposed projections from {args_info.projection}...") |
| 82 | + proj_reader = itk.ImageFileReader[DecomposedProjectionType].New() |
| 83 | + proj_reader.SetFileName(args_info.projection) |
| 84 | + proj_reader.Update() |
| 85 | + decomposedProjection = proj_reader.GetOutput() |
| 86 | + |
| 87 | + NumberOfMaterials = decomposedProjection.GetVectorLength() |
| 88 | + |
| 89 | + # Geometry |
| 90 | + if args_info.verbose: |
| 91 | + print(f"Reading geometry from {args_info.geometry}...") |
| 92 | + geometry = rtk.read_geometry(args_info.geometry) |
| 93 | + |
| 94 | + # Create 4D input. Fill it either with an existing materials volume |
| 95 | + # read from a file or a blank image |
| 96 | + vecVol2VolSeries = rtk.VectorImageToImageFilter[ |
| 97 | + MaterialsVolumeType, VolumeSeriesType |
| 98 | + ].New() |
| 99 | + |
| 100 | + if args_info.input is not None: |
| 101 | + if args_info.like is not None: |
| 102 | + print("WARNING: --like ignored since --input was given") |
| 103 | + if args_info.verbose: |
| 104 | + print(f"Reading input volume {args_info.input}...") |
| 105 | + reference = itk.imread(args_info.input) |
| 106 | + vecVol2VolSeries.SetInput(reference) |
| 107 | + vecVol2VolSeries.Update() |
| 108 | + input = vecVol2VolSeries.GetOutput() |
| 109 | + elif args_info.like is not None: |
| 110 | + if args_info.verbose: |
| 111 | + print(f"Reading reference volume {args_info.like} to infer geometry...") |
| 112 | + reference = itk.imread(args_info.like) |
| 113 | + vecVol2VolSeries.SetInput(reference) |
| 114 | + vecVol2VolSeries.UpdateOutputInformation() |
| 115 | + constantImageSource = rtk.ConstantImageSource[VolumeSeriesType].New() |
| 116 | + constantImageSource.SetInformationFromImage(vecVol2VolSeries.GetOutput()) |
| 117 | + constantImageSource.Update() |
| 118 | + input = constantImageSource.GetOutput() |
| 119 | + else: |
| 120 | + # Create new empty volume |
| 121 | + constantImageSource = rtk.ConstantImageSource[VolumeSeriesType].New() |
| 122 | + |
| 123 | + imageSize = itk.Size[4]() |
| 124 | + imageSize.Fill(int(args_info.size[0])) |
| 125 | + for i in range(min(len(args_info.size), Dimension)): |
| 126 | + imageSize[i] = int(args_info.size[i]) |
| 127 | + imageSize[Dimension] = NumberOfMaterials |
| 128 | + |
| 129 | + imageSpacing = itk.Vector[itk.D, 4]() |
| 130 | + imageSpacing.Fill(float(args_info.spacing[0])) |
| 131 | + for i in range(min(len(args_info.spacing), Dimension)): |
| 132 | + imageSpacing[i] = float(args_info.spacing[i]) |
| 133 | + imageSpacing[Dimension] = 1.0 |
| 134 | + |
| 135 | + imageOrigin = itk.Point[itk.D, 4]() |
| 136 | + for i in range(Dimension): |
| 137 | + imageOrigin[i] = imageSpacing[i] * (int(imageSize[i]) - 1) * -0.5 |
| 138 | + if args_info.origin is not None: |
| 139 | + for i in range(min(len(args_info.origin), Dimension)): |
| 140 | + imageOrigin[i] = float(args_info.origin[i]) |
| 141 | + imageOrigin[Dimension] = 0.0 |
| 142 | + |
| 143 | + imageDirection = itk.Matrix[itk.D, 4, 4]() |
| 144 | + imageDirection.SetIdentity() |
| 145 | + if args_info.direction is not None: |
| 146 | + for i in range(Dimension): |
| 147 | + for j in range(Dimension): |
| 148 | + imageDirection[i][j] = float(args_info.direction[i * Dimension + j]) |
| 149 | + |
| 150 | + constantImageSource.SetOrigin(imageOrigin) |
| 151 | + constantImageSource.SetSpacing(imageSpacing) |
| 152 | + constantImageSource.SetDirection(imageDirection) |
| 153 | + constantImageSource.SetSize(imageSize) |
| 154 | + constantImageSource.SetConstant(0.0) |
| 155 | + constantImageSource.Update() |
| 156 | + input = constantImageSource.GetOutput() |
| 157 | + |
| 158 | + # Duplicate geometry and transform the N M-vector projections into N*M scalar projections |
| 159 | + # Each material will occupy one frame of the 4D reconstruction, therefore all projections |
| 160 | + # of one material need to have the same phase. |
| 161 | + # Note : the 4D CG filter is optimized when projections with identical phases are packed together |
| 162 | + |
| 163 | + # Geometry |
| 164 | + initialNumberOfProjections = int( |
| 165 | + decomposedProjection.GetLargestPossibleRegion().GetSize()[Dimension - 1] |
| 166 | + ) |
| 167 | + for material in range(1, NumberOfMaterials): |
| 168 | + for proj in range(initialNumberOfProjections): |
| 169 | + geometry.AddProjectionInRadians( |
| 170 | + geometry.GetSourceToIsocenterDistances()[proj], |
| 171 | + geometry.GetSourceToDetectorDistances()[proj], |
| 172 | + geometry.GetGantryAngles()[proj], |
| 173 | + geometry.GetProjectionOffsetsX()[proj], |
| 174 | + geometry.GetProjectionOffsetsY()[proj], |
| 175 | + geometry.GetOutOfPlaneAngles()[proj], |
| 176 | + geometry.GetInPlaneAngles()[proj], |
| 177 | + geometry.GetSourceOffsetsX()[proj], |
| 178 | + geometry.GetSourceOffsetsY()[proj], |
| 179 | + ) |
| 180 | + geometry.SetCollimationOfLastProjection( |
| 181 | + geometry.GetCollimationUInf()[proj], |
| 182 | + geometry.GetCollimationUSup()[proj], |
| 183 | + geometry.GetCollimationVInf()[proj], |
| 184 | + geometry.GetCollimationVSup()[proj], |
| 185 | + ) |
| 186 | + |
| 187 | + # Signal |
| 188 | + fakeSignal = [] |
| 189 | + for material in range(NumberOfMaterials): |
| 190 | + v = round(float(material) / float(NumberOfMaterials) * 1000.0) / 1000.0 |
| 191 | + for proj in range(initialNumberOfProjections): |
| 192 | + fakeSignal.append(v) |
| 193 | + |
| 194 | + # Projections |
| 195 | + vproj2proj = rtk.VectorImageToImageFilter[ |
| 196 | + DecomposedProjectionType, ProjectionStackType |
| 197 | + ].New() |
| 198 | + vproj2proj.SetInput(decomposedProjection) |
| 199 | + vproj2proj.Update() |
| 200 | + |
| 201 | + # Release the memory holding the stack of original projections |
| 202 | + decomposedProjection.ReleaseData() |
| 203 | + |
| 204 | + # Compute the interpolation weights |
| 205 | + signalToInterpolationWeights = rtk.SignalToInterpolationWeights.New() |
| 206 | + signalToInterpolationWeights.SetSignal(fakeSignal) |
| 207 | + signalToInterpolationWeights.SetNumberOfReconstructedFrames(NumberOfMaterials) |
| 208 | + signalToInterpolationWeights.Update() |
| 209 | + |
| 210 | + # Set the forward and back projection filters to be used |
| 211 | + # Instantiate ROOSTER with CUDA image types if available, otherwise CPU types |
| 212 | + if hasattr(itk, "CudaImage"): |
| 213 | + cudaVolumeSeriesType = itk.CudaImage[PixelValueType, Dimension + 1] |
| 214 | + cudaProjectionStackType = itk.CudaImage[PixelValueType, Dimension] |
| 215 | + rooster = rtk.FourDROOSTERConeBeamReconstructionFilter[ |
| 216 | + cudaVolumeSeriesType, cudaProjectionStackType |
| 217 | + ].New() |
| 218 | + rooster.SetInputVolumeSeries(itk.cuda_image_from_image(input)) |
| 219 | + rooster.SetInputProjectionStack( |
| 220 | + itk.cuda_image_from_image(vproj2proj.GetOutput()) |
| 221 | + ) |
| 222 | + else: |
| 223 | + rooster = rtk.FourDROOSTERConeBeamReconstructionFilter[ |
| 224 | + VolumeSeriesType, ProjectionStackType |
| 225 | + ].New() |
| 226 | + rooster.SetInputVolumeSeries(input) |
| 227 | + rooster.SetInputProjectionStack(vproj2proj.GetOutput()) |
| 228 | + |
| 229 | + # Configure projectors from args |
| 230 | + rtk.SetForwardProjectionFromArgParse(args_info, rooster) |
| 231 | + rtk.SetBackProjectionFromArgParse(args_info, rooster) |
| 232 | + |
| 233 | + rooster.SetCG_iterations(args_info.cgiter) |
| 234 | + rooster.SetMainLoop_iterations(args_info.niter) |
| 235 | + rooster.SetCudaConjugateGradient(args_info.cudacg) |
| 236 | + rooster.SetDisableDisplacedDetectorFilter(args_info.nodisplaced) |
| 237 | + |
| 238 | + rooster.SetGeometry(geometry) |
| 239 | + rooster.SetWeights(signalToInterpolationWeights.GetOutput()) |
| 240 | + rooster.SetSignal(fakeSignal) |
| 241 | + |
| 242 | + # For each optional regularization step, set whether or not |
| 243 | + # it should be performed, and provide the necessary inputs |
| 244 | + |
| 245 | + # Positivity |
| 246 | + rooster.SetPerformPositivity(not args_info.nopositivity) |
| 247 | + |
| 248 | + # No motion mask is used, since there is no motion |
| 249 | + rooster.SetPerformMotionMask(False) |
| 250 | + |
| 251 | + # Spatial TV |
| 252 | + if args_info.gamma_space is not None: |
| 253 | + rooster.SetGammaTVSpace(args_info.gamma_space) |
| 254 | + rooster.SetTV_iterations(args_info.tviter) |
| 255 | + rooster.SetPerformTVSpatialDenoising(True) |
| 256 | + else: |
| 257 | + rooster.SetPerformTVSpatialDenoising(False) |
| 258 | + |
| 259 | + # Spatial wavelets |
| 260 | + if args_info.threshold is not None: |
| 261 | + rooster.SetSoftThresholdWavelets(args_info.threshold) |
| 262 | + rooster.SetOrder(args_info.order) |
| 263 | + rooster.SetNumberOfLevels(args_info.levels) |
| 264 | + rooster.SetPerformWaveletsSpatialDenoising(True) |
| 265 | + else: |
| 266 | + rooster.SetPerformWaveletsSpatialDenoising(False) |
| 267 | + |
| 268 | + # Temporal TV |
| 269 | + if args_info.gamma_time is not None: |
| 270 | + rooster.SetGammaTVTime(args_info.gamma_time) |
| 271 | + rooster.SetTV_iterations(args_info.tviter) |
| 272 | + rooster.SetPerformTVTemporalDenoising(True) |
| 273 | + else: |
| 274 | + rooster.SetPerformTVTemporalDenoising(False) |
| 275 | + |
| 276 | + # Temporal L0 |
| 277 | + if args_info.lambda_time is not None: |
| 278 | + rooster.SetLambdaL0Time(args_info.lambda_time) |
| 279 | + rooster.SetL0_iterations(args_info.l0iter) |
| 280 | + rooster.SetPerformL0TemporalDenoising(True) |
| 281 | + else: |
| 282 | + rooster.SetPerformL0TemporalDenoising(False) |
| 283 | + |
| 284 | + # Total nuclear variation |
| 285 | + if args_info.gamma_tnv is not None: |
| 286 | + rooster.SetGammaTNV(args_info.gamma_tnv) |
| 287 | + rooster.SetTV_iterations(args_info.tviter) |
| 288 | + rooster.SetPerformTNVDenoising(True) |
| 289 | + else: |
| 290 | + rooster.SetPerformTNVDenoising(False) |
| 291 | + |
| 292 | + if args_info.verbose: |
| 293 | + print("Running ROOSTER reconstruction...") |
| 294 | + rooster.Update() |
| 295 | + |
| 296 | + # Convert 4D volume series (itk.Image[...,4]) to a 3D VectorImage using NumPy |
| 297 | + vol4d = rooster.GetOutput() |
| 298 | + # Extract numpy array from ITK image. For a 4D image the returned array |
| 299 | + # shape is typically (t, z, y, x). We want a 3D vector image with shape |
| 300 | + # (z, y, x, components). |
| 301 | + arr4d = itk.GetArrayFromImage(vol4d) |
| 302 | + if arr4d.ndim != 4: |
| 303 | + raise RuntimeError( |
| 304 | + f"Expected 4D array from ROOSTER output, got shape {arr4d.shape}" |
| 305 | + ) |
| 306 | + |
| 307 | + # Detect which axis corresponds to the components (materials) |
| 308 | + if arr4d.shape[0] == NumberOfMaterials: |
| 309 | + # array is (components, z, y, x) -> transpose to (z, y, x, components) |
| 310 | + arr_vec = np.transpose(arr4d, (1, 2, 3, 0)) |
| 311 | + elif arr4d.shape[-1] == NumberOfMaterials: |
| 312 | + # already (z, y, x, components) |
| 313 | + arr_vec = arr4d |
| 314 | + else: |
| 315 | + # Fallback: try to move the axis with length NumberOfMaterials to last |
| 316 | + comp_axis = None |
| 317 | + for ax in range(4): |
| 318 | + if arr4d.shape[ax] == NumberOfMaterials: |
| 319 | + comp_axis = ax |
| 320 | + break |
| 321 | + if comp_axis is None: |
| 322 | + raise RuntimeError( |
| 323 | + "Cannot locate materials/components axis in ROOSTER output array" |
| 324 | + ) |
| 325 | + # move components axis to last |
| 326 | + order = [i for i in range(4) if i != comp_axis] + [comp_axis] |
| 327 | + arr_vec = np.transpose(arr4d, tuple(order)) |
| 328 | + |
| 329 | + # Create an itk.VectorImage from the numpy array |
| 330 | + vec_img = itk.image_from_array(arr_vec, is_vector=True) |
| 331 | + |
| 332 | + # Preserve spacing/origin/direction for the spatial 3D axes |
| 333 | + spacing4 = vol4d.GetSpacing() |
| 334 | + origin4 = vol4d.GetOrigin() |
| 335 | + direction4 = vol4d.GetDirection() |
| 336 | + vec_img.SetSpacing(tuple(spacing4[0:3])) |
| 337 | + vec_img.SetOrigin(tuple(origin4[0:3])) |
| 338 | + # Build 3x3 direction matrix |
| 339 | + dir3 = itk.Matrix[itk.D, 3, 3]() |
| 340 | + for i in range(3): |
| 341 | + for j in range(3): |
| 342 | + dir3[i][j] = direction4[i][j] |
| 343 | + vec_img.SetDirection(dir3) |
| 344 | + |
| 345 | + # Write |
| 346 | + if args_info.verbose: |
| 347 | + print(f"Writing output to {args_info.output}...") |
| 348 | + writer = itk.ImageFileWriter[MaterialsVolumeType].New() |
| 349 | + writer.SetFileName(args_info.output) |
| 350 | + writer.SetInput(vec_img) |
| 351 | + writer.Update() |
| 352 | + |
| 353 | + |
| 354 | +def main(argv=None): |
| 355 | + parser = build_parser() |
| 356 | + args_info = parser.parse_args(argv) |
| 357 | + process(args_info) |
| 358 | + |
| 359 | + |
| 360 | +if __name__ == "__main__": |
| 361 | + main() |
0 commit comments