Source code for odl.tomo.backends.astra_cuda

# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Backend for ASTRA using CUDA."""

from __future__ import absolute_import, division, print_function

import warnings
from multiprocessing import Lock

import numpy as np
from packaging.version import parse as parse_version

from odl.discr import DiscretizedSpace
from odl.tomo.backends.astra_setup import (
    ASTRA_VERSION, astra_algorithm, astra_data, astra_projection_geometry,
    astra_projector, astra_supports, astra_versions_supporting,
    astra_volume_geometry)
from odl.tomo.backends.util import _add_default_complex_impl
from odl.tomo.geometry import (
    ConeBeamGeometry, FanBeamGeometry, Geometry, Parallel2dGeometry,
    Parallel3dAxisGeometry)

try:
    import astra

    ASTRA_CUDA_AVAILABLE = astra.astra.use_cuda()
except ImportError:
    ASTRA_CUDA_AVAILABLE = False

__all__ = (
    'ASTRA_CUDA_AVAILABLE',
)


[docs]class AstraCudaImpl: """`RayTransform` implementation for CUDA algorithms in ASTRA.""" algo_forward_id = None algo_backward_id = None vol_id = None sino_id = None proj_id = None
[docs] def __init__(self, geometry, vol_space, proj_space): """Initialize a new instance. Parameters ---------- geometry : `Geometry` Geometry defining the tomographic setup. vol_space : `DiscretizedSpace` Reconstruction space, the space of the images to be forward projected. proj_space : `DiscretizedSpace` Projection space, the space of the result. """ if not isinstance(geometry, Geometry): raise TypeError( '`geometry` must be a `Geometry` instance, got {!r}' ''.format(geometry) ) if not isinstance(vol_space, DiscretizedSpace): raise TypeError( '`vol_space` must be a `DiscretizedSpace` instance, got {!r}' ''.format(vol_space) ) if not isinstance(proj_space, DiscretizedSpace): raise TypeError( '`proj_space` must be a `DiscretizedSpace` instance, got {!r}' ''.format(proj_space) ) # Print a warning if the detector midpoint normal vector at any # angle is perpendicular to the geometry axis in parallel 3d # single-axis geometry -- this is broken in some ASTRA versions if ( isinstance(geometry, Parallel3dAxisGeometry) and not astra_supports('par3d_det_mid_pt_perp_to_axis') ): req_ver = astra_versions_supporting( 'par3d_det_mid_pt_perp_to_axis' ) axis = geometry.axis mid_pt = geometry.det_params.mid_pt for i, angle in enumerate(geometry.angles): if abs( np.dot(axis, geometry.det_to_src(angle, mid_pt)) ) < 1e-4: warnings.warn( 'angle {}: detector midpoint normal {} is ' 'perpendicular to the geometry axis {} in ' '`Parallel3dAxisGeometry`; this is broken in ' 'ASTRA {}, please upgrade to ASTRA {}' ''.format(i, geometry.det_to_src(angle, mid_pt), axis, ASTRA_VERSION, req_ver), RuntimeWarning) break self.geometry = geometry self._vol_space = vol_space self._proj_space = proj_space self.create_ids() # ASTRA projectors are not thread-safe, thus we need to lock manually self._mutex = Lock()
@property def vol_space(self): return self._vol_space @property def proj_space(self): return self._proj_space
[docs] def create_ids(self): """Create ASTRA objects.""" # Create input and output arrays if self.geometry.motion_partition.ndim == 1: motion_shape = self.geometry.motion_partition.shape else: # Need to flatten 2- or 3-dimensional angles into one axis motion_shape = (np.prod(self.geometry.motion_partition.shape),) proj_shape = motion_shape + self.geometry.det_partition.shape proj_ndim = len(proj_shape) if proj_ndim == 2: astra_proj_shape = proj_shape astra_vol_shape = self.vol_space.shape elif proj_ndim == 3: # The `u` and `v` axes of the projection data are swapped, # see explanation in `astra_*_3d_geom_to_vec`. astra_proj_shape = (proj_shape[1], proj_shape[0], proj_shape[2]) astra_vol_shape = self.vol_space.shape self.vol_array = np.empty(astra_vol_shape, dtype='float32', order='C') self.proj_array = np.empty(astra_proj_shape, dtype='float32', order='C') # Create ASTRA data structures vol_geom = astra_volume_geometry(self.vol_space) proj_geom = astra_projection_geometry(self.geometry) self.vol_id = astra_data( vol_geom, datatype='volume', ndim=self.vol_space.ndim, data=self.vol_array, allow_copy=False, ) proj_type = 'cuda' if proj_ndim == 2 else 'cuda3d' self.proj_id = astra_projector( proj_type, vol_geom, proj_geom, proj_ndim ) self.sino_id = astra_data( proj_geom, datatype='projection', ndim=proj_ndim, data=self.proj_array, allow_copy=False, ) # Create algorithm self.algo_forward_id = astra_algorithm( 'forward', proj_ndim, self.vol_id, self.sino_id, self.proj_id, impl='cuda', ) # Create algorithm self.algo_backward_id = astra_algorithm( 'backward', proj_ndim, self.vol_id, self.sino_id, self.proj_id, impl='cuda', )
[docs] @_add_default_complex_impl def call_forward(self, x, out=None, **kwargs): return self._call_forward_real(x, out, **kwargs)
def _call_forward_real(self, vol_data, out=None, **kwargs): """Run an ASTRA forward projection on the given data using the GPU. Parameters ---------- vol_data : ``vol_space.real_space`` element Volume data to which the projector is applied. Although ``vol_space`` may be complex, this element needs to be real. out : ``proj_space`` element, optional Element of the projection space to which the result is written. If ``None``, an element in `proj_space` is created. Returns ------- out : ``proj_space`` element Projection data resulting from the application of the projector. If ``out`` was provided, the returned object is a reference to it. """ with self._mutex: assert vol_data in self.vol_space.real_space if out is not None: assert out in self.proj_space else: out = self.proj_space.element() # Copy data to GPU memory if self.geometry.ndim == 2: astra.data2d.store(self.vol_id, vol_data.asarray()) elif self.geometry.ndim == 3: astra.data3d.store(self.vol_id, vol_data.asarray()) else: raise RuntimeError('unknown ndim') # Run algorithm astra.algorithm.run(self.algo_forward_id) # Copy result to host if self.geometry.ndim == 2: out[:] = self.proj_array elif self.geometry.ndim == 3: out[:] = np.swapaxes(self.proj_array, 0, 1).reshape( self.proj_space.shape) # Fix scaling to weight by pixel size if ( isinstance(self.geometry, Parallel2dGeometry) and parse_version(ASTRA_VERSION) < parse_version('1.9.9.dev') ): # parallel2d scales with pixel stride out *= 1 / float(self.geometry.det_partition.cell_volume) return out
[docs] @_add_default_complex_impl def call_backward(self, x, out=None, **kwargs): return self._call_backward_real(x, out, **kwargs)
def _call_backward_real(self, proj_data, out=None, **kwargs): """Run an ASTRA back-projection on the given data using the GPU. Parameters ---------- proj_data : ``proj_space.real_space`` element Projection data to which the back-projector is applied. Although ``proj_space`` may be complex, this element needs to be real. out : ``vol_space`` element, optional Element of the reconstruction space to which the result is written. If ``None``, an element in ``vol_space`` is created. Returns ------- out : ``vol_space`` element Reconstruction data resulting from the application of the back-projector. If ``out`` was provided, the returned object is a reference to it. """ with self._mutex: assert proj_data in self.proj_space.real_space if out is not None: assert out in self.vol_space else: out = self.vol_space.element() # Copy data to GPU memory if self.geometry.ndim == 2: astra.data2d.store(self.sino_id, proj_data.asarray()) elif self.geometry.ndim == 3: shape = (-1,) + self.geometry.det_partition.shape reshaped_proj_data = proj_data.asarray().reshape(shape) swapped_proj_data = np.ascontiguousarray( np.swapaxes(reshaped_proj_data, 0, 1) ) astra.data3d.store(self.sino_id, swapped_proj_data) # Run algorithm astra.algorithm.run(self.algo_backward_id) # Copy result to CPU memory out[:] = self.vol_array # Fix scaling to weight by pixel/voxel size out *= astra_cuda_bp_scaling_factor( self.proj_space, self.vol_space, self.geometry ) return out def __del__(self): """Delete ASTRA objects.""" if self.geometry.ndim == 2: adata, aproj = astra.data2d, astra.projector else: adata, aproj = astra.data3d, astra.projector3d if self.algo_forward_id is not None: astra.algorithm.delete(self.algo_forward_id) self.algo_forward_id = None if self.algo_backward_id is not None: astra.algorithm.delete(self.algo_backward_id) self.algo_backward_id = None if self.vol_id is not None: adata.delete(self.vol_id) self.vol_id = None if self.sino_id is not None: adata.delete(self.sino_id) self.sino_id = None if self.proj_id is not None: aproj.delete(self.proj_id) self.proj_id = None
[docs]def astra_cuda_bp_scaling_factor(proj_space, vol_space, geometry): """Volume scaling accounting for differing adjoint definitions. ASTRA defines the adjoint operator in terms of a fully discrete setting (transposed "projection matrix") without any relation to physical dimensions, which makes a re-scaling necessary to translate it to spaces with physical dimensions. Behavior of ASTRA changes slightly between versions, so we keep track of it and adapt the scaling accordingly. """ # Angular integration weighting factor # angle interval weight by approximate cell volume angle_extent = geometry.motion_partition.extent num_angles = geometry.motion_partition.shape # TODO: this gives the wrong factor for Parallel3dEulerGeometry with # 2 angles scaling_factor = (angle_extent / num_angles).prod() # Correct in case of non-weighted spaces proj_extent = float(proj_space.partition.extent.prod()) proj_size = float(proj_space.partition.size) proj_weighting = proj_extent / proj_size scaling_factor *= ( proj_space.weighting.const / proj_weighting ) scaling_factor /= ( vol_space.weighting.const / vol_space.cell_volume ) if parse_version(ASTRA_VERSION) < parse_version('1.8rc1'): # Scaling for the old, pre-1.8 behaviour if isinstance(geometry, Parallel2dGeometry): # Scales with 1 / cell_volume scaling_factor *= float(vol_space.cell_volume) elif (isinstance(geometry, FanBeamGeometry) and geometry.det_curvature_radius is None): # Scales with 1 / cell_volume scaling_factor *= float(vol_space.cell_volume) # Additional magnification correction src_radius = geometry.src_radius det_radius = geometry.det_radius scaling_factor *= ((src_radius + det_radius) / src_radius) elif isinstance(geometry, Parallel3dAxisGeometry): # Scales with voxel stride # In 1.7, only cubic voxels are supported voxel_stride = vol_space.cell_sides[0] scaling_factor /= float(voxel_stride) elif (isinstance(geometry, ConeBeamGeometry) and geometry.det_curvature_radius is None): # Scales with 1 / cell_volume # In 1.7, only cubic voxels are supported voxel_stride = vol_space.cell_sides[0] scaling_factor /= float(voxel_stride) # Magnification correction src_radius = geometry.src_radius det_radius = geometry.det_radius scaling_factor *= ((src_radius + det_radius) / src_radius) ** 2 elif parse_version(ASTRA_VERSION) < parse_version('1.9.0dev'): # Scaling for the 1.8.x releases if isinstance(geometry, Parallel2dGeometry): # Scales with 1 / cell_volume scaling_factor *= float(vol_space.cell_volume) elif (isinstance(geometry, FanBeamGeometry) and geometry.det_curvature_radius is None): # Scales with 1 / cell_volume scaling_factor *= float(vol_space.cell_volume) # Magnification correction src_radius = geometry.src_radius det_radius = geometry.det_radius scaling_factor *= ((src_radius + det_radius) / src_radius) elif isinstance(geometry, Parallel3dAxisGeometry): # Scales with cell volume # currently only square voxels are supported scaling_factor /= vol_space.cell_volume elif (isinstance(geometry, ConeBeamGeometry) and geometry.det_curvature_radius is None): # Scales with cell volume scaling_factor /= vol_space.cell_volume # Magnification correction (scaling = 1 / magnification ** 2) src_radius = geometry.src_radius det_radius = geometry.det_radius scaling_factor *= ((src_radius + det_radius) / src_radius) ** 2 # Correction for scaled 1/r^2 factor in ASTRA's density weighting. # This compensates for scaled voxels and pixels, as well as a # missing factor src_radius ** 2 in the ASTRA BP with # density weighting. det_px_area = geometry.det_partition.cell_volume scaling_factor *= ( src_radius ** 2 * det_px_area ** 2 / vol_space.cell_volume ** 2 ) elif parse_version(ASTRA_VERSION) < parse_version('1.9.9.dev'): # Scaling for intermediate dev releases between 1.8.3 and 1.9.9.dev if isinstance(geometry, Parallel2dGeometry): # Scales with 1 / cell_volume scaling_factor *= float(vol_space.cell_volume) elif (isinstance(geometry, FanBeamGeometry) and geometry.det_curvature_radius is None): # Scales with 1 / cell_volume scaling_factor *= float(vol_space.cell_volume) # Magnification correction src_radius = geometry.src_radius det_radius = geometry.det_radius scaling_factor *= ((src_radius + det_radius) / src_radius) elif isinstance(geometry, Parallel3dAxisGeometry): # Scales with cell volume # currently only square voxels are supported scaling_factor /= vol_space.cell_volume elif (isinstance(geometry, ConeBeamGeometry) and geometry.det_curvature_radius is None): # Scales with cell volume scaling_factor /= vol_space.cell_volume # Magnification correction (scaling = 1 / magnification ** 2) src_radius = geometry.src_radius det_radius = geometry.det_radius scaling_factor *= ((src_radius + det_radius) / src_radius) ** 2 # Correction for scaled 1/r^2 factor in ASTRA's density weighting. # This compensates for scaled voxels and pixels, as well as a # missing factor src_radius ** 2 in the ASTRA BP with # density weighting. det_px_area = geometry.det_partition.cell_volume scaling_factor *= (src_radius ** 2 * det_px_area ** 2) else: # Scaling for versions since 1.9.9.dev scaling_factor /= float(vol_space.cell_volume) scaling_factor *= float(geometry.det_partition.cell_volume) return scaling_factor
if __name__ == '__main__': from odl.util.testutils import run_doctests run_doctests()