Source code for odl.trafos.backends.pyfftw_bindings

# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Bindings to the ``pyFFTW`` back-end for Fourier transforms.

The `pyFFTW <https://pyfftw.readthedocs.io>`_  package is a Python
wrapper around the well-known `FFTW <http://fftw.org/>`_ library for fast
Fourier transforms.
"""

from __future__ import print_function, division, absolute_import
from multiprocessing import cpu_count
import numpy as np
from packaging.version import parse as parse_version
import warnings

try:
    import pyfftw
    PYFFTW_AVAILABLE = True
except ImportError:
    PYFFTW_AVAILABLE = False
else:
    if parse_version(pyfftw.__version__) < parse_version('0.10.3'):
        warnings.warn('PyFFTW < 0.10.3 is known to cause problems with some '
                      'ODL functionality, see issue #1002.',
                      RuntimeWarning)

from odl.util import (
    is_real_dtype, dtype_repr, complex_dtype, normalized_axes_tuple)

__all__ = ('pyfftw_call', 'PYFFTW_AVAILABLE')


[docs]def pyfftw_call(array_in, array_out, direction='forward', axes=None, halfcomplex=False, **kwargs): """Calculate the DFT with pyfftw. The discrete Fourier (forward) transform calcuates the sum:: f_hat[k] = sum_j( f[j] * exp(-2*pi*1j * j*k/N) ) where the summation is taken over all indices ``j = (j[0], ..., j[d-1])`` in the range ``0 <= j < N`` (component-wise), with ``N`` being the shape of the input array. The output indices ``k`` lie in the same range, except for half-complex transforms, where the last axis ``i`` in ``axes`` is shortened to ``0 <= k[i] < floor(N[i]/2) + 1``. In the backward transform, sign of the the exponential argument is flipped. Parameters ---------- array_in : `numpy.ndarray` Array to be transformed array_out : `numpy.ndarray` Output array storing the transformed values, may be aliased with ``array_in``. direction : {'forward', 'backward'}, optional Direction of the transform axes : int or sequence of ints, optional Dimensions along which to take the transform. ``None`` means using all axes and is equivalent to ``np.arange(ndim)``. halfcomplex : bool, optional If ``True``, calculate only the negative frequency part along the last axis. If ``False``, calculate the full complex FFT. This option can only be used with real input data. Other Parameters ---------------- fftw_plan : ``pyfftw.FFTW``, optional Use this plan instead of calculating a new one. If specified, the options ``planning_effort``, ``planning_timelimit`` and ``threads`` have no effect. planning_effort : str, optional Flag for the amount of effort put into finding an optimal FFTW plan. See the `FFTW doc on planner flags <http://www.fftw.org/fftw3_doc/Planner-Flags.html>`_. Available options: {'estimate', 'measure', 'patient', 'exhaustive'} Default: 'estimate' planning_timelimit : float or ``None``, optional Limit planning time to roughly this many seconds. Default: ``None`` (no limit) threads : int, optional Number of threads to use. Default: Number of CPUs if the number of data points is larger than 4096, else 1. normalise_idft : bool, optional If ``True``, the result of the backward transform is divided by ``1 / N``, where ``N`` is the total number of points in ``array_in[axes]``. This ensures that the IDFT is the true inverse of the forward DFT. Default: ``False`` import_wisdom : filename or file handle, optional File to load FFTW wisdom from. If the file does not exist, it is ignored. export_wisdom : filename or file handle, optional File to append the accumulated FFTW wisdom to Returns ------- fftw_plan : ``pyfftw.FFTW`` The plan object created from the input arguments. It can be reused for transforms of the same size with the same data types. Note that reuse only gives a speedup if the initial plan used a planner flag other than ``'estimate'``. If ``fftw_plan`` was specified, the returned object is a reference to it. Notes ----- * The planning and direction flags can also be specified as capitalized and prepended by ``'FFTW_'``, i.e. in the original FFTW form. * For a ``halfcomplex`` forward transform, the arrays must fulfill ``array_out.shape[axes[-1]] == array_in.shape[axes[-1]] // 2 + 1``, and vice versa for backward transforms. * All planning schemes except ``'estimate'`` require an internal copy of the input array but are often several times faster after the first call (measuring results are cached). Typically, 'measure' is a good compromise. If you cannot afford the copy, use ``'estimate'``. * If a plan is provided via the ``fftw_plan`` parameter, no copy is needed internally. """ import pickle if not array_in.flags.aligned: raise ValueError('input array not aligned') if not array_out.flags.aligned: raise ValueError('output array not aligned') if axes is None: axes = tuple(range(array_in.ndim)) axes = normalized_axes_tuple(axes, array_in.ndim) direction = _flag_pyfftw_to_odl(direction) fftw_plan_in = kwargs.pop('fftw_plan', None) planning_effort = _flag_pyfftw_to_odl( kwargs.pop('planning_effort', 'estimate') ) planning_timelimit = kwargs.pop('planning_timelimit', None) threads = kwargs.pop('threads', None) normalise_idft = kwargs.pop('normalise_idft', False) wimport = kwargs.pop('import_wisdom', '') wexport = kwargs.pop('export_wisdom', '') # Cast input to complex if necessary array_in_copied = False if is_real_dtype(array_in.dtype) and not halfcomplex: # Need to cast array_in to complex dtype array_in = array_in.astype(complex_dtype(array_in.dtype)) array_in_copied = True # Do consistency checks on the arguments _pyfftw_check_args(array_in, array_out, axes, halfcomplex, direction) # Import wisdom if possible if wimport: try: with open(wimport, 'rb') as wfile: wisdom = pickle.load(wfile) except IOError: wisdom = [] except TypeError: # Got file handle wisdom = pickle.load(wimport) if wisdom: pyfftw.import_wisdom(wisdom) # Copy input array if it hasn't been done yet and the planner is likely # to destroy it. If we already have a plan, we don't have to worry. planner_destroys = _pyfftw_destroys_input( [planning_effort], direction, halfcomplex, array_in.ndim) must_copy_array_in = fftw_plan_in is None and planner_destroys if must_copy_array_in and not array_in_copied: plan_arr_in = np.empty_like(array_in) flags = [_flag_odl_to_pyfftw(planning_effort), 'FFTW_DESTROY_INPUT'] else: plan_arr_in = array_in flags = [_flag_odl_to_pyfftw(planning_effort)] if fftw_plan_in is None: if threads is None: if plan_arr_in.size <= 4096: # Trade-off wrt threading overhead threads = 1 else: threads = cpu_count() fftw_plan = pyfftw.FFTW( plan_arr_in, array_out, direction=_flag_odl_to_pyfftw(direction), flags=flags, planning_timelimit=planning_timelimit, threads=threads, axes=axes) else: fftw_plan = fftw_plan_in fftw_plan(array_in, array_out, normalise_idft=normalise_idft) if wexport: try: with open(wexport, 'ab') as wfile: pickle.dump(pyfftw.export_wisdom(), wfile) except TypeError: # Got file handle pickle.dump(pyfftw.export_wisdom(), wexport) return fftw_plan
def _flag_pyfftw_to_odl(flag): return flag.lstrip('FFTW_').lower() def _flag_odl_to_pyfftw(flag): return 'FFTW_' + flag.upper() def _pyfftw_destroys_input(flags, direction, halfcomplex, ndim): """Return ``True`` if FFTW destroys an input array, ``False`` otherwise.""" if any(flag in flags or _flag_pyfftw_to_odl(flag) in flags for flag in ('FFTW_MEASURE', 'FFTW_PATIENT', 'FFTW_EXHAUSTIVE', 'FFTW_DESTROY_INPUT')): return True elif (direction in ('backward', 'FFTW_BACKWARD') and halfcomplex and ndim != 1): return True else: return False def _pyfftw_check_args(arr_in, arr_out, axes, halfcomplex, direction): """Raise an error if anything is not ok with in and out.""" if len(set(axes)) != len(axes): raise ValueError('duplicate axes are not allowed') if direction == 'forward': out_shape = list(arr_in.shape) if halfcomplex: try: out_shape[axes[-1]] = arr_in.shape[axes[-1]] // 2 + 1 except IndexError: raise IndexError('axis index {} out of range for array ' 'with {} axes' ''.format(axes[-1], arr_in.ndim)) if arr_out.shape != tuple(out_shape): raise ValueError('expected output shape {}, got {}' ''.format(tuple(out_shape), arr_out.shape)) if is_real_dtype(arr_in.dtype): out_dtype = complex_dtype(arr_in.dtype) elif halfcomplex: raise ValueError('cannot combine halfcomplex forward transform ' 'with complex input') else: out_dtype = arr_in.dtype if arr_out.dtype != out_dtype: raise ValueError('expected output dtype {}, got {}' ''.format(dtype_repr(out_dtype), dtype_repr(arr_out.dtype))) elif direction == 'backward': in_shape = list(arr_out.shape) if halfcomplex: try: in_shape[axes[-1]] = arr_out.shape[axes[-1]] // 2 + 1 except IndexError as err: raise IndexError('axis index {} out of range for array ' 'with {} axes' ''.format(axes[-1], arr_out.ndim)) if arr_in.shape != tuple(in_shape): raise ValueError('expected input shape {}, got {}' ''.format(tuple(in_shape), arr_in.shape)) if is_real_dtype(arr_out.dtype): in_dtype = complex_dtype(arr_out.dtype) elif halfcomplex: raise ValueError('cannot combine halfcomplex backward transform ' 'with complex output') else: in_dtype = arr_out.dtype if arr_in.dtype != in_dtype: raise ValueError('expected input dtype {}, got {}' ''.format(dtype_repr(in_dtype), dtype_repr(arr_in.dtype))) else: # Shouldn't happen raise RuntimeError if __name__ == '__main__': from odl.util.testutils import run_doctests run_doctests(skip_if=not PYFFTW_AVAILABLE)