Source code for odl.diagnostics.operator

# Copyright 2014-2019 The ODL contributors
# This file is part of ODL.
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at

"""Standardized tests for `Operator`'s."""

from __future__ import absolute_import, division, print_function

from builtins import object

import numpy as np

from odl.diagnostics.examples import samples
from odl.operator import power_method_opnorm
from odl.util.testutils import fail_counter

__all__ = ('OperatorTest',)

[docs]class OperatorTest(object): """Automated tests for `Operator` implementations. This class allows users to automatically test various features of an Operator such as linearity, the adjoint definition and definition of the derivative. """
[docs] def __init__(self, operator, operator_norm=None, verbose=True, tol=1e-5): """Initialize a new instance. Parameters ---------- operator : `Operator` The operator to run tests on operator_norm : float, optional The norm of the operator, used for error estimates. If ``None`` is given, the norm is estimated during initialization. verbose : bool, optional If ``True``, print additional info text. tol : float, optional Tolerance parameter used as a base for the actual tolerance in the tests. Depending on the expected accuracy, the actual tolerance used in a test can be a factor times this number. """ self.operator = operator self.verbose = False if operator_norm is None: self.operator_norm = self.norm() else: self.operator_norm = float(operator_norm) self.verbose = bool(verbose) self.tol = float(tol)
[docs] def log(self, message): """Print message if ``self.verbose == True``.""" if self.verbose: print(message)
[docs] def norm(self): """Estimate the operator norm of the operator. The norm is estimated by calculating ``A(x).norm() / x.norm()`` for some nonzero ``x`` Returns ------- norm : float Estimate of operator norm References ---------- Wikipedia article on `Operator norm <>`_. """ self.log('\n== Calculating operator norm ==\n') operator_norm = max(power_method_opnorm(self.operator, maxiter=2, xstart=x) for name, x in samples(self.operator.domain) if name != 'Zero') self.log('Norm is at least: {}'.format(operator_norm)) self.operator_norm = operator_norm return operator_norm
[docs] def self_adjoint(self): """Verify ``<Ax, y> == <x, Ay>``.""" left_inner_vals = [] right_inner_vals = [] with fail_counter( test_name='Verifying the identity <Ax, y> = <x, Ay>', err_msg='error = |<Ax, y> - <x, Ay>| / ||A|| ||x|| ||y||', logger=self.log ) as counter: for [name_x, x], [name_y, y] in samples(self.operator.domain, self.operator.range): x_norm = x.norm() y_norm = y.norm() l_inner = self.operator(x).inner(y) r_inner = x.inner(self.operator(y)) denom = self.operator_norm * x_norm * y_norm error = 0 if denom == 0 else abs(l_inner - r_inner) / denom if error > self.tol:'x={:25s} y={:25s} : error={:6.5f}' ''.format(name_x, name_y, error)) left_inner_vals.append(l_inner) right_inner_vals.append(r_inner) scale = np.polyfit(left_inner_vals, right_inner_vals, 1)[0] self.log('\nThe adjoint seems to be scaled according to:') self.log('(x, Ay) / (Ax, y) = {}. Should be 1.0'.format(scale))
def _adjoint_definition(self): """Verify ``<Ax, y> == <x, A^* y>``.""" left_inner_vals = [] right_inner_vals = [] with fail_counter( test_name='Verifying the identity <Ax, y> = <x, A^T y>', err_msg='error = |<Ax, y< - <x, A^* y>| / ||A|| ||x|| ||y||', logger=self.log ) as counter: for [name_x, x], [name_y, y] in samples(self.operator.domain, self.operator.range): x_norm = x.norm() y_norm = y.norm() l_inner = self.operator(x).inner(y) r_inner = x.inner(self.operator.adjoint(y)) denom = self.operator_norm * x_norm * y_norm error = 0 if denom == 0 else abs(l_inner - r_inner) / denom if error > self.tol:'x={:25s} y={:25s} : error={:6.5f}' ''.format(name_x, name_y, error)) left_inner_vals.append(l_inner) right_inner_vals.append(r_inner) scale = np.polyfit(left_inner_vals, right_inner_vals, 1)[0] self.log('\nThe adjoint seems to be scaled according to:') self.log('(x, A^T y) / (Ax, y) = {}. Should be 1.0'.format(scale)) def _adjoint_of_adjoint(self): """Verify ``(A^*)^* == A``""" try: self.operator.adjoint.adjoint except AttributeError: print('A^* has no adjoint') return if self.operator.adjoint.adjoint is self.operator: self.log('(A^*)^* == A') return with fail_counter( test_name='\nVerifying the identity Ax = (A^*)^* x', err_msg='error = ||Ax - (A^*)^* x|| / ||A|| ||x||', logger=self.log ) as counter: for [name_x, x] in self.operator.domain.examples: opx = self.operator(x) op_adj_adj_x = self.operator.adjoint.adjoint(x) denom = self.operator_norm * x.norm() if denom == 0: error = 0 else: error = (opx - op_adj_adj_x).norm() / denom if error > self.tol:'x={:25s} : error={:6.5f}' ''.format(name_x, error))
[docs] def adjoint(self): """Verify that `Operator.adjoint` works appropriately. References ---------- Wikipedia article on `Adjoint <>`_. """ try: self.operator.adjoint except NotImplementedError: print('Operator has no adjoint') return self.log('\n== Verifying operator adjoint ==\n') domain_range_ok = True if self.operator.domain != self.operator.adjoint.range: print('*** ERROR: A.domain != A.adjoint.range ***') domain_range_ok = False if self.operator.range != self.operator.adjoint.domain: print('*** ERROR: A.range != A.adjoint.domain ***') domain_range_ok = False if domain_range_ok: self.log('Domain and range of adjoint are OK.') else: print('Domain and range of adjoint are not OK, exiting.') return self._adjoint_definition() self._adjoint_of_adjoint()
def _derivative_convergence(self): """Verify that the derivative is a first-order approximation. The code verifies if ``||A(x+c*p) - A(x) - A'(x)(c*p)|| / c = o(c)`` for ``c --> 0``. """ with fail_counter( test_name='Verifying that derivative is a first-order ' 'approximation', err_msg="error = inf_c ||A(x+c*p)-A(x)-A'(x)(c*p)|| / c", logger=self.log ) as counter: for [name_x, x], [name_dx, dx] in samples(self.operator.domain, self.operator.domain): # Precompute some values deriv = self.operator.derivative(x) derivdx = deriv(dx) opx = self.operator(x) c = 1e-4 # initial step derivative_ok = False minerror = float('inf') while c > 1e-14: exact_step = self.operator(x + dx * c) - opx expected_step = c * derivdx err = (exact_step - expected_step).norm() / c # Need to be slightly more generous here due to possible # numerical instabilities. # TODO: perform more tests to find a good threshold here. if err < 10 * self.tol: derivative_ok = True break else: minerror = min(minerror, err) c /= 10.0 if not derivative_ok:'x={:15s} p={:15s}, error={}' ''.format(name_x, name_dx, minerror))
[docs] def derivative(self): """Verify that `Operator.derivative` works appropriately. The code verifies if ``||A(x+c*p) - A(x) - A'(x)(c*p)|| / c = o(c)`` for ``c --> 0`` using a selection of elements ``x`` and ``p``. References ---------- Wikipedia article on `Derivative <>`_. Wikipedia article on `Frechet derivative <>`_. """ self.log('\n== Verifying operator derivative ==') try: deriv = self.operator.derivative( if not deriv.is_linear: print('Derivative is not a linear operator') return except NotImplementedError: print('Operator has no derivative') return if self.operator.is_linear and deriv is self.operator: self.log('A is linear and A.derivative is A') return self._derivative_convergence()
def _scale_invariance(self): """Verify ``A(c*x) = c * A(x)``.""" with fail_counter( test_name='Verifying homogeneity under scalar multiplication', err_msg='error = ||A(c*x)-c*A(x)|| / |c| ||A|| ||x||', logger=self.log ) as counter: for [name_x, x], [_, scale] in samples(self.operator.domain, self.operator.domain.field): opx = self.operator(x) scaled_opx = self.operator(scale * x) denom = self.operator_norm * scale * x.norm() error = (0 if denom == 0 else (scaled_opx - opx * scale).norm() / denom) if error > self.tol:'x={:25s} scale={:7.2f} error={:6.5f}' ''.format(name_x, scale, error)) def _addition_invariance(self): """Verify ``A(x+y) = A(x) + A(y)``.""" with fail_counter( test_name='Verifying distributivity under vector addition', err_msg='error = ||A(x+y) - A(x) - A(y)|| / ' '||A||(||x|| + ||y||)', logger=self.log ) as counter: for [name_x, x], [name_y, y] in samples(self.operator.domain, self.operator.domain): opx = self.operator(x) opy = self.operator(y) opxy = self.operator(x + y) denom = self.operator_norm * (x.norm() + y.norm()) error = (0 if denom == 0 else (opxy - opx - opy).norm() / denom) if error > self.tol:'x={:25s} y={:25s} error={:6.5f}' ''.format(name_x, name_y, error))
[docs] def linear(self): """Verify that the operator is actually linear.""" if not self.operator.is_linear: print('Operator is not linear') return self.log('\n== Verifying operator linearity ==\n') # Test if zero gives zero result = self.operator( result_norm = result.norm() if result_norm != 0.0: print("||A(0)||={:6.5f}. Should be 0.0000".format(result_norm)) self._scale_invariance() self._addition_invariance()
[docs] def run_tests(self): """Run all tests on this operator.""" print('\n== RUNNING ALL TESTS ==') print('Operator = {}'.format(self.operator)) self.norm() if self.operator.is_linear: self.linear() self.adjoint() else: self.derivative()
def __str__(self): return '{}({})'.format(self.__class__.__name__, self.operator) def __repr__(self): return '{}({!r})'.format(self.__class__.__name__, self.operator)
if __name__ == '__main__': import odl space = odl.uniform_discr([0, 0], [1, 1], [3, 3]) # Linear operator I = odl.IdentityOperator(space) OperatorTest(I, verbose=False).run_tests() # Nonlinear operator op(x) = x**4 op = odl.PowerOperator(space, 4) OperatorTest(op).run_tests()