Source code for odl.util.testutils

# Copyright 2014-2019 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Testing utilities."""

from __future__ import absolute_import, division, print_function

import os
import sys
import warnings
from builtins import object
from contextlib import contextmanager
from time import time

import numpy as np
from future.moves.itertools import zip_longest

from odl.util.utility import is_string, run_from_ipython

__all__ = (
    'dtype_ndigits',
    'dtype_tol',
    'all_equal',
    'all_almost_equal',
    'is_subdict',
    'skip_if_no_pyfftw',
    'skip_if_no_pywavelets',
    'simple_fixture',
    'noise_array',
    'noise_element',
    'noise_elements',
    'fail_counter',
    'timer',
    'timeit',
    'ProgressBar',
    'ProgressRange',
    'test',
    'run_doctests',
    'test_file',
)


def _ndigits(a, b, default=None):
    """Return number of expected correct digits comparing ``a`` and ``b``.

    The returned number is the minimum `dtype_ndigits` of the two objects.

    See Also
    --------
    dtype_ndigits
    """
    dtype1 = getattr(a, 'dtype', object)
    dtype2 = getattr(b, 'dtype', object)
    return min(dtype_ndigits(dtype1, default), dtype_ndigits(dtype2, default))


[docs]def dtype_ndigits(dtype, default=None): """Return the number of correct digits expected for a given dtype. This is intended as a somewhat generous default (relative) precision for results of more or less stable computations. Returned numbers: - ``np.float16``: ``1`` - ``np.float32`` or ``np.complex64``: ``3`` - Others: ``default`` if given, otherwise ``5`` See Also -------- dtype_tol : Same precision expressed as tolerance """ small_dtypes = [np.float32, np.complex64] tiny_dtypes = [np.float16] if dtype in tiny_dtypes: return 1 elif dtype in small_dtypes: return 3 else: return default if default is not None else 5
[docs]def dtype_tol(dtype, default=None): """Return a tolerance for a given dtype. This is intended as a somewhat generous default (relative) tolerance for results of more or less stable computations. Returned numbers: - ``np.float16``: ``1e-1`` - ``np.float32`` or ``np.complex64``: ``1e-3`` - Others: ``default`` if given, otherwise ``1e-5`` See Also -------- dtype_ndigits : Same tolerance expressed in number of digits. """ return 10 ** -dtype_ndigits(dtype, default)
[docs]def all_equal(iter1, iter2): """Return ``True`` if all elements in ``a`` and ``b`` are equal.""" # Direct comparison for scalars, tuples or lists try: if iter1 == iter2: return True except ValueError: # Raised by NumPy when comparing arrays pass # Special case for None if iter1 is None and iter2 is None: return True # If one nested iterator is exhausted, go to direct comparison try: it1 = iter(iter1) it2 = iter(iter2) except TypeError: try: return iter1 == iter2 except ValueError: # Raised by NumPy when comparing arrays return False diff_length_sentinel = object() # Compare element by element and return False if the sequences have # different lengths for [ip1, ip2] in zip_longest(it1, it2, fillvalue=diff_length_sentinel): # Verify that none of the lists has ended (then they are not the # same size) if ip1 is diff_length_sentinel or ip2 is diff_length_sentinel: return False if not all_equal(ip1, ip2): return False return True
[docs]def all_almost_equal_array(v1, v2, ndigits): return np.allclose(v1, v2, rtol=10 ** -ndigits, atol=10 ** -ndigits, equal_nan=True)
[docs]def all_almost_equal(iter1, iter2, ndigits=None): """Return ``True`` if all elements in ``a`` and ``b`` are almost equal.""" try: if iter1 is iter2 or iter1 == iter2: return True except ValueError: pass if iter1 is None and iter2 is None: return True if hasattr(iter1, '__array__') and hasattr(iter2, '__array__'): # Only get default ndigits if comparing arrays, need to keep `None` # otherwise for recursive calls. if ndigits is None: ndigits = _ndigits(iter1, iter2, None) return all_almost_equal_array(iter1, iter2, ndigits) try: it1 = iter(iter1) it2 = iter(iter2) except TypeError: if ndigits is None: ndigits = _ndigits(iter1, iter2, None) return np.isclose(iter1, iter2, atol=10 ** -ndigits, rtol=10 ** -ndigits, equal_nan=True) diff_length_sentinel = object() for [ip1, ip2] in zip_longest(it1, it2, fillvalue=diff_length_sentinel): # Verify that none of the lists has ended (then they are not the # same size) if ip1 is diff_length_sentinel or ip2 is diff_length_sentinel: return False if not all_almost_equal(ip1, ip2, ndigits): return False return True
[docs]def is_subdict(subdict, dictionary): """Return ``True`` if all items of ``subdict`` are in ``dictionary``.""" return all(item in dictionary.items() for item in subdict.items())
try: import pytest except ImportError: def identity(*args, **kwargs): if args and callable(args[0]): return args[0] else: return identity skip_if_no_pyfftw = identity skip_if_no_pywavelets = identity else: # Mark decorators for test parameters skip_if_no_pyfftw = pytest.mark.skipif( 'not odl.trafos.PYFFTW_AVAILABLE', reason='pyFFTW not available', ) skip_if_no_pywavelets = pytest.mark.skipif( 'not odl.trafos.PYWT_AVAILABLE', reason='PyWavelets not available', )
[docs]def simple_fixture(name, params, fmt=None): """Helper to create a pytest fixture using only name and params. Parameters ---------- name : str Name of the parameters used for the ``ids`` argument to `pytest.fixture`. params : sequence Values to be taken as parameters in the fixture. They are used as ``params`` argument to `_pytest.fixtures.fixture`. Arguments wrapped in a ``pytest.skipif`` decorator are unwrapped for the generation of the test IDs. fmt : str, optional Use this format string for the generation of the ``ids``. For each value, the id string is generated as :: fmt.format(name=name, value=value) hence the format string must use ``{name}`` and ``{value}``. Default format strings are: - ``" {name}='{value}' "`` for string parameters, - ``" {name}={value} "`` for other types. """ import _pytest if fmt is None: # Use some intelligence to make good format strings fmt_str = " {name}='{value}' " fmt_default = " {name}={value} " ids = [] for p in params: # TODO: other types of decorators? if ( isinstance(p, _pytest.mark.MarkDecorator) and p.name == 'skipif' ): # Unwrap the wrapped object in the decorator if is_string(p.args[1]): ids.append(fmt_str.format(name=name, value=p.args[1])) else: ids.append(fmt_default.format(name=name, value=p.args[1])) else: if is_string(p): ids.append(fmt_str.format(name=name, value=p)) else: ids.append(fmt_default.format(name=name, value=p)) else: # Use provided `fmt` for everything ids = [fmt.format(name=name, value=p) for p in params] wrapper = pytest.fixture(scope='module', ids=ids, params=params) return wrapper(lambda request: request.param)
# Helpers to generate data
[docs]def noise_array(space): """Generate a white noise array that is compatible with ``space``. The array contains white noise with standard deviation 1 in the case of floating point dtypes and uniformly spaced values between -10 and 10 in the case of integer dtypes. For product spaces the method is called recursively for all sub-spaces. Notes ----- This method is intended for internal testing purposes. For more explicit example elements see ``odl.phantoms`` and ``LinearSpaceElement.examples``. Parameters ---------- space : `LinearSpace` Space from which to derive the array data type and size. Returns ------- noise_array : `numpy.ndarray` element Array with white noise such that ``space.element``'s can be created from it. Examples -------- Create single noise array: >>> space = odl.rn(3) >>> array = noise_array(space) See Also -------- noise_element noise_elements odl.set.space.LinearSpace.examples : Examples of elements typical to the space. """ from odl.space import ProductSpace if isinstance(space, ProductSpace): return np.array([noise_array(si) for si in space]) else: if space.dtype == bool: arr = np.random.randint(0, 2, size=space.shape, dtype=bool) elif np.issubdtype(space.dtype, np.unsignedinteger): arr = np.random.randint(0, 10, space.shape) elif np.issubdtype(space.dtype, np.signedinteger): arr = np.random.randint(-10, 10, space.shape) elif np.issubdtype(space.dtype, np.floating): arr = np.random.randn(*space.shape) elif np.issubdtype(space.dtype, np.complexfloating): arr = ( np.random.randn(*space.shape) + 1j * np.random.randn(*space.shape) ) / np.sqrt(2.0) else: raise ValueError('bad dtype {}'.format(space.dtype)) return arr.astype(space.dtype, copy=False)
[docs]def noise_element(space): """Create a white noise element in ``space``. The element contains white noise with standard deviation 1 in the case of floating point dtypes and uniformly spaced values between -10 and 10 in the case of integer dtypes. For product spaces the method is called recursively for all sub-spaces. Notes ----- This method is intended for internal testing purposes. For more explicit example elements see ``odl.phantoms`` and ``LinearSpaceElement.examples``. Parameters ---------- space : `LinearSpace` Space in which to create an element. The `odl.set.space.LinearSpace.element` method of the space needs to accept input of `numpy.ndarray` type. Returns ------- noise_element : ``space`` element Examples -------- Create single noise element: >>> space = odl.rn(3) >>> vector = noise_element(space) See Also -------- noise_array noise_elements odl.set.space.LinearSpace.examples : Examples of elements typical to the space. """ return space.element(noise_array(space))
[docs]def noise_elements(space, n=1): """Create a list of ``n`` noise arrays and elements in ``space``. The arrays contain white noise with standard deviation 1 in the case of floating point dtypes and uniformly spaced values between -10 and 10 in the case of integer dtypes. The returned elements have the same values as the arrays. For product spaces the method is called recursively for all sub-spaces. Notes ----- This method is intended for internal testing purposes. For more explicit example elements see ``odl.phantoms`` and ``LinearSpaceElement.examples``. Parameters ---------- space : `LinearSpace` Space in which to create an element. The `odl.set.space.LinearSpace.element` method of the space needs to accept input of `numpy.ndarray` type. n : int, optional Number of elements to create. Returns ------- arrays : `numpy.ndarray` or tuple of `numpy.ndarray` A single array if ``n == 1``, otherwise a tuple of arrays. elements : ``space`` element or tuple of ``space`` elements A single element if ``n == 1``, otherwise a tuple of elements. Examples -------- Create single noise element: >>> space = odl.rn(3) >>> arr, vector = noise_elements(space) Create multiple noise elements: >>> [arr1, arr2], [vector1, vector2] = noise_elements(space, n=2) See Also -------- noise_array noise_element """ arrs = tuple(noise_array(space) for _ in range(n)) # Make space elements from arrays elems = tuple(space.element(arr.copy()) for arr in arrs) if n == 1: return tuple(arrs + elems) else: return arrs, elems
[docs]@contextmanager def fail_counter(test_name, err_msg=None, logger=print): """Used to count the number of failures of something. Usage:: with fail_counter("my_test") as counter: # Do stuff counter.fail() When done, it prints :: my_test *** FAILED 1 TEST CASE(S) *** """ class _FailCounter(object): def __init__(self): self.num_failed = 0 self.fail_strings = [] def fail(self, string=None): """Add failure with reason as string.""" # TODO: possibly limit number of printed strings self.num_failed += 1 if string is not None: self.fail_strings.append(str(string)) try: counter = _FailCounter() yield counter finally: if counter.num_failed == 0: logger('{:<70}: Completed all test cases.'.format(test_name)) else: print(test_name) for fail_string in counter.fail_strings: print(fail_string) if err_msg is not None: print(err_msg) print('*** FAILED {} TEST CASE(S) ***'.format(counter.num_failed))
[docs]@contextmanager def timer(name=None): """A timer context manager. Usage:: with timer('name'): # Do stuff Prints the time stuff took to execute. """ if name is None: name = "Elapsed" try: tstart = time() yield finally: time_str = '{:.3f}'.format(time() - tstart) print('{:>30s} : {:>10s} '.format(name, time_str))
[docs]def timeit(arg): """A timer decorator. Usage:: @timeit def myfunction(...): ... @timeit('info string') def myfunction(...): ... """ if callable(arg): def timed_function(*args, **kwargs): with timer(str(arg)): return arg(*args, **kwargs) return timed_function else: def _timeit_helper(func): def timed_function(*args, **kwargs): with timer(arg): return func(*args, **kwargs) return timed_function return _timeit_helper
[docs]class ProgressBar(object): """A simple command-line progress bar. Usage: >>> progress = ProgressBar('Reading data', 10) \rReading data: [ ] Starting >>> progress.update(4) #halfway, zero indexing \rReading data: [############### ] 50.0% Multi-indices, from slowest to fastest: >>> progress = ProgressBar('Reading data', 10, 10) \rReading data: [ ] Starting >>> progress.update(9, 8) \rReading data: [############################# ] 99.0% Supports simply calling update, which moves the counter forward: >>> progress = ProgressBar('Reading data', 10, 10) \rReading data: [ ] Starting >>> progress.update() \rReading data: [ ] 1.0% """
[docs] def __init__(self, text='progress', *njobs): """Initialize a new instance.""" self.text = str(text) if len(njobs) == 0: raise ValueError('need to provide at least one job') self.njobs = njobs self.current_progress = 0.0 self.index = 0 self.done = False self.start()
[docs] def start(self): """Print the initial bar.""" sys.stdout.write('\r{0}: [{1:30s}] Starting'.format(self.text, ' ' * 30)) sys.stdout.flush()
[docs] def update(self, *indices): """Update the bar according to ``indices``.""" if indices: if len(indices) != len(self.njobs): raise ValueError('number of indices not correct') self.index = np.ravel_multi_index(indices, self.njobs) + 1 else: self.index += 1 # Find progress as ratio between 0 and 1 # offset by 1 for zero indexing progress = self.index / np.prod(self.njobs) # Write a progressbar and percent if progress < 1.0: # Only update on 0.1% intervals if progress > self.current_progress + 0.001: sys.stdout.write('\r{0}: [{1:30s}] {2:4.1f}% '.format( self.text, '#' * int(30 * progress), 100 * progress)) self.current_progress = progress else: # Special message when done if not self.done: sys.stdout.write('\r{0}: [{1:30s}] Done \n'.format( self.text, '#' * 30)) self.done = True sys.stdout.flush()
[docs]class ProgressRange(object): """Simple range sequence with progress bar output"""
[docs] def __init__(self, text, n): """Initialize a new instance.""" self.current = 0 self.n = n self.bar = ProgressBar(text, n)
def __iter__(self): return self def __next__(self): if self.current < self.n: val = self.current self.current += 1 self.bar.update() return val else: raise StopIteration()
[docs]def test(arguments=None): """Run ODL tests given by arguments.""" try: import pytest except ImportError: raise ImportError( 'ODL tests cannot be run without `pytest` installed.\n' 'Run `$ pip install [--user] odl[testing]` in order to install ' '`pytest`.' ) from .pytest_config import collect_ignore this_dir = os.path.dirname(__file__) odl_root = os.path.abspath(os.path.join(this_dir, os.pardir, os.pardir)) args = ['{root}/odl'.format(root=odl_root)] ignores = ['--ignore={}'.format(file) for file in collect_ignore] args.extend(ignores) if arguments is not None: args.extend(arguments) pytest.main(args)
[docs]def run_doctests(skip_if=False, **kwargs): """Run all doctests in the current module. This function calls ``doctest.testmod()``, by default with the options ``optionflags=doctest.NORMALIZE_WHITESPACE`` and ``extraglobs={'odl': odl, 'np': np}``. This can be changed with keyword arguments. Parameters ---------- skip_if : bool For ``True``, skip the doctests in this module. kwargs : Extra keyword arguments passed on to the ``doctest.testmod`` function. """ from doctest import testmod, NORMALIZE_WHITESPACE, SKIP from packaging.version import parse as parse_version import odl import numpy as np optionflags = kwargs.pop('optionflags', NORMALIZE_WHITESPACE) if skip_if: optionflags |= SKIP extraglobs = kwargs.pop('extraglobs', {'odl': odl, 'np': np}) if run_from_ipython(): try: import spyder except ImportError: pass else: if parse_version(spyder.__version__) < parse_version('3.1.4'): warnings.warn('A bug with IPython and Spyder < 3.1.4 ' 'sometimes causes doctests to fail to run. ' 'Please upgrade Spyder or use another ' 'interpreter if the doctests do not work.', RuntimeWarning) testmod(optionflags=optionflags, extraglobs=extraglobs, **kwargs)
[docs]def test_file(file, args=None): """Run tests in file with proper default arguments.""" try: import pytest except ImportError: raise ImportError('ODL tests cannot be run without `pytest` installed.' '\nRun `$ pip install [--user] odl[testing]` in ' 'order to install `pytest`.') if args is None: args = [] args.extend([str(file.replace('\\', '/')), '-v', '--capture=sys']) pytest.main(args)
if __name__ == '__main__': run_doctests()