Source code for odl.tomo.operators.ray_trafo

# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Ray transforms."""

from __future__ import absolute_import, division, print_function

from collections import OrderedDict

import numpy as np

from odl.discr import DiscretizedSpace
from odl.operator import Operator
from odl.space.weighting import ConstWeighting
from odl.tomo.backends import (
    ASTRA_AVAILABLE, ASTRA_CUDA_AVAILABLE, SKIMAGE_AVAILABLE)
from odl.tomo.backends.astra_cpu import AstraCpuImpl
from odl.tomo.backends.astra_cuda import AstraCudaImpl
from odl.tomo.backends.skimage_radon import SkImageImpl
from odl.tomo.geometry import Geometry
from odl.util import is_string

# RAY_TRAFO_IMPLS are used by `RayTransform` when no `impl` is given.
# The last inserted implementation has highest priority.
RAY_TRAFO_IMPLS = OrderedDict()
if SKIMAGE_AVAILABLE:
    RAY_TRAFO_IMPLS['skimage'] = SkImageImpl
if ASTRA_AVAILABLE:
    RAY_TRAFO_IMPLS['astra_cpu'] = AstraCpuImpl
if ASTRA_CUDA_AVAILABLE:
    RAY_TRAFO_IMPLS['astra_cuda'] = AstraCudaImpl

__all__ = ('RayTransform',)


[docs]class RayTransform(Operator): """Linear X-Ray (Radon) transform operator between L^p spaces."""
[docs] def __init__(self, vol_space, geometry, **kwargs): """Initialize a new instance. Parameters ---------- vol_space : `DiscretizedSpace` Discretized reconstruction space, the domain of the forward operator or the range of the adjoint (back-projection). geometry : `Geometry` Geometry of the transform that contains information about the data structure. Other Parameters ---------------- impl : {`None`, 'astra_cuda', 'astra_cpu', 'skimage'}, optional Implementation back-end for the transform. Supported back-ends: - ``'astra_cuda'``: ASTRA toolbox, using CUDA, 2D or 3D - ``'astra_cpu'``: ASTRA toolbox using CPU, only 2D - ``'skimage'``: scikit-image, only 2D parallel with square reconstruction space. For the default ``None``, the fastest available back-end is used. proj_space : `DiscretizedSpace`, optional Discretized projection (sinogram) space, the range of the forward operator or the domain of the adjoint (back-projection). Default: Inferred from parameters. use_cache : bool, optional If ``True``, data is cached. This gives a significant speed-up at the expense of a notable memory overhead, both on the GPU and on the CPU, since a full volume and a projection dataset are stored. That may be prohibitive in 3D. Default: True kwargs Further keyword arguments passed to the projector backend. Notes ----- The ASTRA backend is faster if data are given with ``dtype='float32'`` and storage order 'C'. Otherwise copies will be needed. """ if not isinstance(vol_space, DiscretizedSpace): raise TypeError( '`vol_space` must be a `DiscretizedSpace` instance, got ' '{!r}'.format(vol_space)) if not isinstance(geometry, Geometry): raise TypeError( '`geometry` must be a `Geometry` instance, got {!r}' ''.format(geometry) ) # Generate or check projection space proj_space = kwargs.pop('proj_space', None) if proj_space is None: dtype = vol_space.dtype if not vol_space.is_weighted: weighting = None elif ( isinstance(vol_space.weighting, ConstWeighting) and np.isclose( vol_space.weighting.const, vol_space.cell_volume ) ): # Approximate cell volume # TODO: find a way to treat angles and detector differently # regarding weighting. While the detector should be uniformly # discretized, the angles do not have to and often are not. # The needed partition property is available since # commit a551190d, but weighting is not adapted yet. # See also issue #286 extent = float(geometry.partition.extent.prod()) size = float(geometry.partition.size) weighting = extent / size else: raise NotImplementedError('unknown weighting of domain') proj_tspace = vol_space.tspace_type( geometry.partition.shape, weighting=weighting, dtype=dtype, ) if geometry.motion_partition.ndim == 0: angle_labels = [] elif geometry.motion_partition.ndim == 1: angle_labels = ['$\\varphi$'] elif geometry.motion_partition.ndim == 2: # TODO: check order angle_labels = ['$\\vartheta$', '$\\varphi$'] elif geometry.motion_partition.ndim == 3: # TODO: check order angle_labels = ['$\\vartheta$', '$\\varphi$', '$\\psi$'] else: angle_labels = None if geometry.det_partition.ndim == 1: det_labels = ['$s$'] elif geometry.det_partition.ndim == 2: det_labels = ['$u$', '$v$'] else: det_labels = None if angle_labels is None or det_labels is None: # Fallback for unknown configuration axis_labels = None else: axis_labels = angle_labels + det_labels proj_space = DiscretizedSpace( geometry.partition, proj_tspace, axis_labels=axis_labels ) else: # proj_space was given, checking some stuff if not isinstance(proj_space, DiscretizedSpace): raise TypeError( '`proj_space` must be a `DiscretizedSpace` instance, ' 'got {!r}'.format(proj_space) ) if proj_space.shape != geometry.partition.shape: raise ValueError( '`proj_space.shape` not equal to `geometry.shape`: ' '{} != {}' ''.format(proj_space.shape, geometry.partition.shape) ) if proj_space.dtype != vol_space.dtype: raise ValueError( '`proj_space.dtype` not equal to `vol_space.dtype`: ' '{} != {}'.format(proj_space.dtype, vol_space.dtype) ) if vol_space.ndim != geometry.ndim: raise ValueError( '`vol_space.ndim` not equal to `geometry.ndim`: ' '{} != {}'.format(vol_space.ndim, geometry.ndim) ) # Cache for input/output arrays of transforms self.use_cache = kwargs.pop('use_cache', True) # Check `impl` impl = kwargs.pop('impl', None) impl_type, self.__cached_impl = self._initialize_impl(impl) self._impl_type = impl_type if is_string(impl): self.__impl = impl.lower() else: self.__impl = impl_type.__name__ self._geometry = geometry # Reserve name for cached properties (used for efficiency reasons) self._adjoint = None # Extra kwargs that can be reused for adjoint etc. These must # be retrieved with `get` instead of `pop` above. self._extra_kwargs = kwargs # Finally, initialize the Operator structure super(RayTransform, self).__init__( domain=vol_space, range=proj_space, linear=True )
@staticmethod def _initialize_impl(impl): """Internal method to verify the validity of the `impl` kwarg.""" impl_instance = None if impl is None: # User didn't specify a backend if not RAY_TRAFO_IMPLS: raise RuntimeError( 'No `RayTransform` back-end available; this requires ' '3rd party packages, please check the install docs.' ) # Select fastest available impl_type = next(reversed(RAY_TRAFO_IMPLS.values())) else: # User did specify `impl` if is_string(impl): if impl.lower() not in RAY_TRAFO_IMPLS.keys(): raise ValueError( 'The {!r} `impl` is not found. This `impl` is either ' 'not supported, it may be misspelled, or external ' 'packages required are not available. Consult ' '`RAY_TRAFO_IMPLS` to find the run-time available ' 'implementations.'.format(impl) ) impl_type = RAY_TRAFO_IMPLS[impl.lower()] elif isinstance(impl, type) or isinstance(impl, object): # User gave the type and leaves instantiation to us forward = getattr(impl, "call_forward", None) backward = getattr(impl, "call_backward", None) if not callable(forward) and not callable(backward): raise TypeError( 'Type {!r} must have a `call_forward()` ' 'and/or `call_backward()`.'.format(impl) ) if isinstance(impl, type): impl_type = impl else: # User gave an object for `impl`, meaning to set the # backend cache to an already initiated object impl_type = type(impl) impl_instance = impl else: raise TypeError( '`impl` {!r} should be a string, or an object or type ' 'having a `call_forward()` and/or `call_backward()`. ' ''.format(type(impl)) ) return impl_type, impl_instance @property def impl(self): """Implementation name string. If a custom ``impl`` was provided this method returns a ``str`` of the type.""" return self.__impl
[docs] def get_impl(self, use_cache=True): """Fetches or instantiates implementation backend for evaluation. Parameters ---------- bool : use_cache If ``True`` returns the cached implementation backend, if it was generated in a previous call (or given with ``__init__``). If ``False`` a new instance of the backend will be generated, freeing up GPU memory and RAM used by the backend. """ # Use impl creation (__cached_impl) when `use_cache` is True if not use_cache or self.__cached_impl is None: # Lazily (re)instantiate the backend self.__cached_impl = self._impl_type( self.geometry, vol_space=self.domain, proj_space=self.range) return self.__cached_impl
[docs] def _call(self, x, out=None, **kwargs): """Forward projection. Parameters ---------- x : DiscretizedSpaceElement A volume. Must be an element of `RayTransform.domain`. out : `RayTransform.range` element, optional Element to which the result of the operator evaluation is written. **kwargs Extra keyword arguments, passed on to the implementation backend. Returns ------- DiscretizedSpaceElement Result of the transform, an element of the range. """ return self.get_impl(self.use_cache).call_forward(x, out, **kwargs)
@property def geometry(self): return self._geometry @property def adjoint(self): """Adjoint of this operator. The adjoint of the `RayTransform` is the linear `RayBackProjection` operator. It uses the same geometry and shares the implementation backend whenever `RayTransform.use_cache` is `True`. Returns ------- adjoint : `RayBackProjection` """ if self._adjoint is None: # bring `self` into scope to prevent shadowing in inline class ray_trafo = self class RayBackProjection(Operator): """Adjoint of the discrete Ray transform between L^p spaces.""" def _call(self, x, out=None, **kwargs): """Backprojection. Parameters ---------- x : DiscretizedSpaceElement A sinogram. Must be an element of `RayTransform.range` (domain of `RayBackProjection`). out : `RayBackProjection.domain` element, optional A volume to which the result of this evaluation is written. **kwargs Extra keyword arguments, passed on to the implementation backend. Returns ------- DiscretizedSpaceElement Result of the transform in the domain of `RayProjection`. """ return ray_trafo.get_impl( ray_trafo.use_cache ).call_backward(x, out, **kwargs) @property def geometry(self): return ray_trafo.geometry @property def adjoint(self): return ray_trafo kwargs = self._extra_kwargs.copy() kwargs['domain'] = self.range self._adjoint = RayBackProjection( range=self.domain, linear=True, **kwargs ) return self._adjoint
if __name__ == '__main__': from odl.util.testutils import run_doctests run_doctests()