Source code for odl.trafos.wavelet

# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Discrete wavelet transformation on L2 spaces."""

from __future__ import absolute_import, division, print_function

import numpy as np

from odl.discr import DiscretizedSpace
from odl.operator import Operator
from odl.trafos.backends.pywt_bindings import (
    PYWT_AVAILABLE, precompute_raveled_slices, pywt_pad_mode, pywt_wavelet)

__all__ = ('WaveletTransform', 'WaveletTransformInverse')


_SUPPORTED_WAVELET_IMPLS = ()
if PYWT_AVAILABLE:
    _SUPPORTED_WAVELET_IMPLS += ('pywt',)
    import pywt


[docs]class WaveletTransformBase(Operator): """Base class for discrete wavelet transforms. This abstract class is intended to share code between the forward, inverse and adjoint wavelet transforms. """
[docs] def __init__(self, space, wavelet, nlevels, variant, pad_mode='constant', pad_const=0, impl='pywt', axes=None): """Initialize a new instance. Parameters ---------- space : `DiscretizedSpace` Domain of the forward wavelet transform (the "image domain"). In the case of ``variant in ('inverse', 'adjoint')``, this space is the range of the operator. wavelet : string or `pywt.Wavelet` Specification of the wavelet to be used in the transform. If a string is given, it is converted to a `pywt.Wavelet`. Use `pywt.wavelist` to get a list of available wavelets. Possible wavelet families are: ``'haar'``: Haar ``'db'``: Daubechies ``'sym'``: Symlets ``'coif'``: Coiflets ``'bior'``: Biorthogonal ``'rbio'``: Reverse biorthogonal ``'dmey'``: Discrete FIR approximation of the Meyer wavelet variant : {'forward', 'inverse', 'adjoint'} Wavelet transform variant to be created. nlevels : positive int, optional Number of scaling levels to be used in the decomposition. The maximum number of levels can be calculated with `pywt.dwtn_max_level`. Default: Use maximum number of levels. pad_mode : string, optional Method to be used to extend the signal. ``'constant'``: Fill with ``pad_const``. ``'symmetric'``: Reflect at the boundaries, not repeating the outmost values. ``'periodic'``: Fill in values from the other side, keeping the order. ``'order0'``: Extend constantly with the outmost values (ensures continuity). ``'order1'``: Extend with constant slope (ensures continuity of the first derivative). This requires at least 2 values along each axis where padding is applied. ``'pywt_per'``: like ``'periodic'``-padding but gives the smallest possible number of decomposition coefficients. Only available with ``impl='pywt'``, See ``pywt.Modes.modes``. ``'reflect'``: Reflect at the boundary, without repeating the outmost values. ``'antisymmetric'``: Anti-symmetric variant of ``symmetric``. ``'antireflect'``: Anti-symmetric variant of ``reflect``. For reference, the following table compares the naming conventions for the modes in ODL vs. PyWavelets:: ======================= ================== ODL PyWavelets ======================= ================== symmetric symmetric reflect reflect order1 smooth order0 constant constant, pad_const=0 zero periodic periodic pywt_per periodization antisymmetric antisymmetric antireflect antireflect ======================= ================== See `signal extension modes`_ for an illustration of the modes (under the PyWavelets naming conventions). pad_const : float, optional Constant value to use if ``pad_mode == 'constant'``. Ignored otherwise. Constants other than 0 are not supported by the ``pywt`` back-end. impl : {'pywt'}, optional Back-end for the wavelet transform. axes : sequence of ints, optional Axes over which the DWT that created ``coeffs`` was performed. The default value of ``None`` corresponds to all axes. When not all axes are included this is analagous to a batch transform in ``len(axes)`` dimensions looped over the non-transformed axes. In orther words, filtering and decimation does not occur along any axes not in ``axes``. References ---------- .. _signal extension modes: https://pywavelets.readthedocs.io/en/latest/ref/signal-extension-modes.html """ if not isinstance(space, DiscretizedSpace): raise TypeError( '`space` {!r} is not a `DiscretizedSpace` instance' ''.format(space) ) self.__impl, impl_in = str(impl).lower(), impl if self.impl not in _SUPPORTED_WAVELET_IMPLS: raise ValueError("`impl` '{}' not supported".format(impl_in)) if axes is None: axes = tuple(range(space.ndim)) elif np.isscalar(axes): axes = (axes,) elif len(axes) > space.ndim: raise ValueError("too many axes") self.axes = tuple(axes) if nlevels is None: nlevels = pywt.dwtn_max_level(space.shape, wavelet, self.axes) self.__nlevels, nlevels_in = int(nlevels), nlevels if self.nlevels != nlevels_in: raise ValueError('`nlevels` must be integer, got {}' ''.format(nlevels_in)) self.__impl, impl_in = str(impl).lower(), impl if self.impl not in _SUPPORTED_WAVELET_IMPLS: raise ValueError("`impl` '{}' not supported".format(impl_in)) self.__wavelet = getattr(wavelet, 'name', str(wavelet).lower()) self.__pad_mode = str(pad_mode).lower() self.__pad_const = space.field.element(pad_const) if self.impl == 'pywt': self.pywt_pad_mode = pywt_pad_mode(pad_mode, pad_const) self.pywt_wavelet = pywt_wavelet(self.wavelet) # determine coefficient shapes (without running wavedecn) self._coeff_shapes = pywt.wavedecn_shapes( space.shape, wavelet, mode=self.pywt_pad_mode, level=self.nlevels, axes=self.axes) # precompute slices into the (raveled) coeffs self._coeff_slices = precompute_raveled_slices(self._coeff_shapes) coeff_size = pywt.wavedecn_size(self._coeff_shapes) coeff_space = space.tspace_type(coeff_size, dtype=space.dtype) else: raise RuntimeError("bad `impl` '{}'".format(self.impl)) variant, variant_in = str(variant).lower(), variant if variant not in ('forward', 'inverse', 'adjoint'): raise ValueError("`variant` '{}' not understood" "".format(variant_in)) self.__variant = variant if variant == 'forward': super(WaveletTransformBase, self).__init__( domain=space, range=coeff_space, linear=True) else: super(WaveletTransformBase, self).__init__( domain=coeff_space, range=space, linear=True)
@property def impl(self): """Implementation back-end of this wavelet transform.""" return self.__impl @property def nlevels(self): """Number of scaling levels in this wavelet transform.""" return self.__nlevels @property def wavelet(self): """Name of the wavelet used in this wavelet transform.""" return self.__wavelet @property def pad_mode(self): """Padding mode used for extending input beyond its boundary.""" return self.__pad_mode @property def pad_const(self): """Value for extension used in ``'constant'`` padding mode.""" return self.__pad_const @property def is_orthogonal(self): """Whether or not the wavelet basis is orthogonal.""" return self.pywt_wavelet.orthogonal @property def is_biorthogonal(self): """Whether or not the wavelet basis is bi-orthogonal.""" return self.pywt_wavelet.biorthogonal
[docs] def scales(self): """Get the scales of each coefficient. Returns ------- scales : ``range`` element The scale of each coefficient, given by an integer. 0 for the lowest resolution and self.nlevels for the highest. """ if self.impl == 'pywt': if self.__variant == 'forward': discr_space = self.domain wavelet_space = self.range else: discr_space = self.range wavelet_space = self.domain shapes = pywt.wavedecn_shapes(discr_space.shape, self.pywt_wavelet, mode=self.pywt_pad_mode, level=self.nlevels, axes=self.axes) coeff_list = [np.full(shapes[0], 0)] for i in range(1, 1 + len(shapes[1:])): coeff_list.append({k: np.full(shapes[i][k], i) for k in shapes[i].keys()}) coeffs = pywt.ravel_coeffs(coeff_list, axes=self.axes)[0] return wavelet_space.element(coeffs) else: raise RuntimeError("bad `impl` '{}'".format(self.impl))
[docs]class WaveletTransform(WaveletTransformBase): """Discrete wavelet transform between discretized Lp spaces."""
[docs] def __init__(self, domain, wavelet, nlevels=None, pad_mode='constant', pad_const=0, impl='pywt', axes=None): """Initialize a new instance. Parameters ---------- domain : `DiscretizedSpace` Domain of the wavelet transform (the "image domain"). wavelet : string or `pywt.Wavelet` Specification of the wavelet to be used in the transform. If a string is given, it is converted to a `pywt.Wavelet`. Use `pywt.wavelist` to get a list of available wavelets. Possible wavelet families are: ``'haar'``: Haar ``'db'``: Daubechies ``'sym'``: Symlets ``'coif'``: Coiflets ``'bior'``: Biorthogonal ``'rbio'``: Reverse biorthogonal ``'dmey'``: Discrete FIR approximation of the Meyer wavelet nlevels : positive int, optional Number of scaling levels to be used in the decomposition. The maximum number of levels can be calculated with `pywt.dwtn_max_level`. Default: Use maximum number of levels. pad_mode : string, optional Method to be used to extend the signal. ``'constant'``: Fill with ``pad_const``. ``'symmetric'``: Reflect at the boundaries, not repeating the outmost values. ``'periodic'``: Fill in values from the other side, keeping the order. ``'order0'``: Extend constantly with the outmost values (ensures continuity). ``'order1'``: Extend with constant slope (ensures continuity of the first derivative). This requires at least 2 values along each axis where padding is applied. ``'pywt_per'``: like ``'periodic'``-padding but gives the smallest possible number of decomposition coefficients. Only available with ``impl='pywt'``, See ``pywt.Modes.modes``. ``'reflect'``: Reflect at the boundary, without repeating the outmost values. ``'antisymmetric'``: Anti-symmetric variant of ``symmetric``. ``'antireflect'``: Anti-symmetric variant of ``reflect``. For reference, the following table compares the naming conventions for the modes in ODL vs. PyWavelets:: ======================= ================== ODL PyWavelets ======================= ================== symmetric symmetric reflect reflect order1 smooth order0 constant constant, pad_const=0 zero periodic periodic pywt_per periodization antisymmetric antisymmetric antireflect antireflect ======================= ================== See `signal extension modes`_ for an illustration of the modes (under the PyWavelets naming conventions). pad_const : float, optional Constant value to use if ``pad_mode == 'constant'``. Ignored otherwise. Constants other than 0 are not supported by the ``pywt`` back-end. impl : {'pywt'}, optional Backend for the wavelet transform. axes : sequence of ints, optional Axes over which the DWT that created ``coeffs`` was performed. The default value of ``None`` corresponds to all axes. When not all axes are included this is analagous to a batch transform in ``len(axes)`` dimensions looped over the non-transformed axes. In orther words, filtering and decimation does not occur along any axes not in ``axes``. Examples -------- Compute a very simple wavelet transform in a discrete 2D space with 4 sampling points per axis: >>> space = odl.uniform_discr([0, 0], [1, 1], (4, 4)) >>> wavelet_trafo = odl.trafos.WaveletTransform( ... domain=space, nlevels=1, wavelet='haar') >>> wavelet_trafo.is_biorthogonal True >>> data = [[1, 1, 1, 1], ... [0, 0, 0, 0], ... [0, 0, 1, 1], ... [1, 0, 1, 0]] >>> decomp = wavelet_trafo(data) >>> decomp.shape (16,) It is also possible to apply the transform only along a subset of the axes. Here, we apply a 1D wavelet transfrom along axis 0 for each index along axis 1: >>> wavelet_trafo = odl.trafos.WaveletTransform( ... domain=space, nlevels=1, wavelet='haar', axes=(0,)) >>> decomp = wavelet_trafo(data) >>> decomp.shape (16,) In general, the size of the coefficients may exceed the size of the input data when the wavelet is longer than the Haar wavelet. This due to extra coefficients that must be kept for perfect reconstruction. No extra boundary coefficients are needed when the edge mode is ``"pywt_periodic"`` and the size along each transformed axis is a multiple of ``2**nlevels``. >>> space = odl.uniform_discr([0, 0], [1, 1], (16, 16)) >>> space.size 256 >>> wavelet_trafo = odl.trafos.WaveletTransform( ... domain=space, nlevels=2, wavelet='db2', ... pad_mode='pywt_periodic') >>> decomp = wavelet_trafo(np.ones(space.shape)) >>> decomp.shape (256,) >>> wavelet_trafo = odl.trafos.WaveletTransform( ... domain=space, nlevels=2, wavelet='db2', pad_mode='symmetric') >>> decomp = wavelet_trafo(np.ones(space.shape)) >>> decomp.shape (387,) References ---------- .. _signal extension modes: https://pywavelets.readthedocs.io/en/latest/ref/signal-extension-modes.html """ super(WaveletTransform, self).__init__( space=domain, wavelet=wavelet, nlevels=nlevels, variant='forward', pad_mode=pad_mode, pad_const=pad_const, impl=impl, axes=axes)
[docs] def _call(self, x): """Return wavelet transform of ``x``.""" if self.impl == 'pywt': coeffs = pywt.wavedecn( x, wavelet=self.pywt_wavelet, level=self.nlevels, mode=self.pywt_pad_mode, axes=self.axes) return pywt.ravel_coeffs(coeffs, axes=self.axes)[0] else: raise RuntimeError("bad `impl` '{}'".format(self.impl))
@property def adjoint(self): """Adjoint wavelet transform. Returns ------- adjoint : `WaveletTransformInverse` If the transform is orthogonal, the adjoint is the inverse. Raises ------ OpNotImplementedError if `is_orthogonal` is ``False`` """ if self.is_orthogonal: scale = 1 / self.domain.partition.cell_volume return scale * self.inverse else: # TODO: put adjoint here return super(WaveletTransform, self).adjoint @property def inverse(self): """Inverse wavelet transform. Returns ------- inverse : `WaveletTransformInverse` See Also -------- adjoint """ return WaveletTransformInverse( range=self.domain, wavelet=self.pywt_wavelet, nlevels=self.nlevels, pad_mode=self.pad_mode, pad_const=self.pad_const, impl=self.impl, axes=self.axes)
[docs]class WaveletTransformInverse(WaveletTransformBase): """Discrete inverse wavelet trafo between discrete L2 spaces. See Also -------- WaveletTransform """
[docs] def __init__(self, range, wavelet, nlevels=None, pad_mode='constant', pad_const=0, impl='pywt', axes=None): """Initialize a new instance. Parameters ---------- range : `DiscretizedSpace` Domain of the forward wavelet transform (the "image domain"), which is the range of this inverse transform. wavelet : string or `pywt.Wavelet` Specification of the wavelet to be used in the transform. If a string is given, it is converted to a `pywt.Wavelet`. Use `pywt.wavelist` to get a list of available wavelets. Possible wavelet families are: ``'haar'``: Haar ``'db'``: Daubechies ``'sym'``: Symlets ``'coif'``: Coiflets ``'bior'``: Biorthogonal ``'rbio'``: Reverse biorthogonal ``'dmey'``: Discrete FIR approximation of the Meyer wavelet nlevels : positive int, optional Number of scaling levels to be used in the decomposition. The maximum number of levels can be calculated with `pywt.dwtn_max_level`. Default: Use maximum number of levels. pad_mode : string, optional Method to be used to extend the signal. ``'constant'``: Fill with ``pad_const``. ``'symmetric'``: Reflect at the boundaries, not repeating the outmost values. ``'periodic'``: Fill in values from the other side, keeping the order. ``'order0'``: Extend constantly with the outmost values (ensures continuity). ``'order1'``: Extend with constant slope (ensures continuity of the first derivative). This requires at least 2 values along each axis where padding is applied. ``'pywt_per'``: like ``'periodic'``-padding but gives the smallest possible number of decomposition coefficients. Only available with ``impl='pywt'``, See ``pywt.Modes.modes``. ``'reflect'``: Reflect at the boundary, without repeating the outmost values. ``'antisymmetric'``: Anti-symmetric variant of ``symmetric``. ``'antireflect'``: Anti-symmetric variant of ``reflect``. For reference, the following table compares the naming conventions for the modes in ODL vs. PyWavelets:: ======================= ================== ODL PyWavelets ======================= ================== symmetric symmetric reflect reflect order1 smooth order0 constant constant, pad_const=0 zero periodic periodic pywt_per periodization antisymmetric antisymmetric antireflect antireflect ======================= ================== See `signal extension modes`_ for an illustration of the modes (under the PyWavelets naming conventions). pad_const : float, optional Constant value to use if ``pad_mode == 'constant'``. Ignored otherwise. Constants other than 0 are not supported by the ``pywt`` back-end. impl : {'pywt'}, optional Back-end for the wavelet transform. axes : sequence of ints, optional Axes over which the DWT that created ``coeffs`` was performed. The default value of ``None`` corresponds to all axes. When not all axes are included this is analagous to a batch transform in ``len(axes)`` dimensions looped over the non-transformed axes. In orther words, filtering and decimation does not occur along any axes not in ``axes``. Examples -------- Check that the inverse is the actual inverse on a simple example on a discrete 2D space with 4 sampling points per axis: >>> space = odl.uniform_discr([0, 0], [1, 1], (4, 4)) >>> wavelet_trafo = odl.trafos.WaveletTransform( ... domain=space, nlevels=1, wavelet='haar') >>> orig_array = np.array([[1, 1, 1, 1], ... [0, 0, 0, 0], ... [0, 0, 1, 1], ... [1, 0, 1, 0]]) >>> decomp = wavelet_trafo(orig_array) >>> recon = wavelet_trafo.inverse(decomp) >>> np.allclose(recon, orig_array) True References ---------- .. _signal extension modes: https://pywavelets.readthedocs.io/en/latest/ref/signal-extension-modes.html """ super(WaveletTransformInverse, self).__init__( space=range, wavelet=wavelet, variant='inverse', nlevels=nlevels, pad_mode=pad_mode, pad_const=pad_const, impl=impl, axes=axes)
[docs] def _call(self, coeffs): """Return the inverse wavelet transform of ``coeffs``.""" if self.impl == 'pywt': coeffs = pywt.unravel_coeffs(coeffs, coeff_slices=self._coeff_slices, coeff_shapes=self._coeff_shapes, output_format='wavedecn') recon = pywt.waverecn( coeffs, wavelet=self.pywt_wavelet, mode=self.pywt_pad_mode, axes=self.axes) recon_shape = self.range.shape if recon.shape != recon_shape: # If the original shape was odd along any transformed axes it # will have been rounded up to the next even size after the # reconstruction. The extra sample should be discarded. # The underlying reason is decimation by two in reconstruction # must keep ceil(N/2) samples in each band for perfect # reconstruction. Reconstruction then upsamples by two. # When N is odd, (2 * np.ceil(N/2)) != N. recon_slc = [] for i, (n_recon, n_intended) in enumerate(zip(recon.shape, recon_shape)): if n_recon == n_intended + 1: # Upsampling added one entry too much in this axis, # drop last one recon_slc.append(slice(-1)) elif n_recon == n_intended: recon_slc.append(slice(None)) else: raise ValueError( 'in axis {}: expected size {} or {} in ' '`recon_shape`, got {}' ''.format(i, n_recon - 1, n_recon, n_intended)) recon = recon[tuple(recon_slc)] return recon else: raise RuntimeError("bad `impl` '{}'".format(self.impl))
@property def adjoint(self): """Adjoint of this operator. Returns ------- adjoint : `WaveletTransform` If the transform is orthogonal, the adjoint is the inverse. Raises ------ OpNotImplementedError if `is_orthogonal` is ``False`` See Also -------- inverse """ if self.is_orthogonal: scale = self.range.partition.cell_volume return scale * self.inverse else: # TODO: put adjoint here return super(WaveletTransformInverse, self).adjoint @property def inverse(self): """Inverse of this operator. Returns ------- inverse : `WaveletTransform` See Also -------- adjoint """ return WaveletTransform( domain=self.range, wavelet=self.pywt_wavelet, nlevels=self.nlevels, pad_mode=self.pad_mode, pad_const=self.pad_const, impl=self.impl, axes=self.axes)
if __name__ == '__main__': from odl.util.testutils import run_doctests run_doctests(skip_if=not PYWT_AVAILABLE)