Source code for odl.util.utility

# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Utilities mainly for internal use."""

from __future__ import absolute_import, division, print_function
from future.moves.itertools import zip_longest

import contextlib
from collections import OrderedDict
from contextlib import contextmanager
from itertools import product

import numpy as np

__all__ = (
    'REPR_PRECISION',
    'indent',
    'dedent',
    'npy_printoptions',
    'array_str',
    'dtype_repr',
    'dtype_str',
    'cache_arguments',
    'is_numeric_dtype',
    'is_int_dtype',
    'is_floating_dtype',
    'is_real_dtype',
    'is_real_floating_dtype',
    'is_complex_floating_dtype',
    'real_dtype',
    'complex_dtype',
    'is_string',
    'nd_iterator',
    'conj_exponent',
    'nullcontext',
    'writable_array',
    'signature_string',
    'signature_string_parts',
    'repr_string',
    'attribute_repr_string',
    'method_repr_string',
    'run_from_ipython',
    'npy_random_seed',
    'unique',
)


REPR_PRECISION = 4  # For printing scalars and array entries
TYPE_MAP_R2C = {np.dtype(dtype): np.result_type(dtype, 1j)
                for dtype in np.sctypes['float']}

TYPE_MAP_C2R = {cdt: np.empty(0, dtype=cdt).real.dtype
                for rdt, cdt in TYPE_MAP_R2C.items()}
TYPE_MAP_C2R.update({k: k for k in TYPE_MAP_R2C.keys()})


[docs]def indent(string, indent_str=' '): """Return a copy of ``string`` indented by ``indent_str``. Parameters ---------- string : str Text that should be indented. indent_str : str, optional String to be inserted before each new line. The default is to indent by 4 spaces. Returns ------- indented : str The indented text. Examples -------- >>> text = '''This is line 1. ... Next line. ... And another one.''' >>> print(text) This is line 1. Next line. And another one. >>> print(indent(text)) This is line 1. Next line. And another one. Indenting by random stuff: >>> print(indent(text, indent_str='<->')) <->This is line 1. <->Next line. <->And another one. """ return '\n'.join(indent_str + row for row in string.splitlines())
[docs]def dedent(string, indent_str=' ', max_levels=None): """Revert the effect of indentation. Examples -------- Remove a simple one-level indentation: >>> text = '''<->This is line 1. ... <->Next line. ... <->And another one.''' >>> print(text) <->This is line 1. <->Next line. <->And another one. >>> print(dedent(text, '<->')) This is line 1. Next line. And another one. Multiple levels of indentation: >>> text = '''<->Level 1. ... <-><->Level 2. ... <-><-><->Level 3.''' >>> print(text) <->Level 1. <-><->Level 2. <-><-><->Level 3. >>> print(dedent(text, '<->')) Level 1. <->Level 2. <-><->Level 3. >>> text = '''<-><->Level 2. ... <-><-><->Level 3.''' >>> print(text) <-><->Level 2. <-><-><->Level 3. >>> print(dedent(text, '<->')) Level 2. <->Level 3. >>> print(dedent(text, '<->', max_levels=1)) <->Level 2. <-><->Level 3. """ if len(indent_str) == 0: return string lines = string.splitlines() # Determine common (minimum) number of indentation levels, capped at # `max_levels` if given def num_indents(line): max_num = int(np.ceil(len(line) / len(indent_str))) i = 0 # set for the case the loop is not run (`max_num == 0`) for i in range(max_num): if line.startswith(indent_str): line = line[len(indent_str):] else: break return i num_levels = num_indents(min(lines, key=num_indents)) if max_levels is not None: num_levels = min(num_levels, max_levels) # Dedent dedent_len = num_levels * len(indent_str) return '\n'.join(line[dedent_len:] for line in lines)
[docs]@contextmanager def npy_printoptions(**extra_opts): """Context manager to temporarily set NumPy print options. See Also -------- numpy.get_printoptions numpy.set_printoptions Examples -------- >>> print(np.array([np.nan, 1.00001])) [ nan 1.00001] >>> with npy_printoptions(precision=3): ... print(np.array([np.nan, 1.00001])) [ nan 1.] >>> with npy_printoptions(nanstr='whoah!'): ... print(np.array([np.nan, 1.00001])) [ whoah! 1.00001] """ orig_opts = np.get_printoptions() try: new_opts = orig_opts.copy() new_opts.update(extra_opts) np.set_printoptions(**new_opts) yield finally: np.set_printoptions(**orig_opts)
[docs]def array_str(a, nprint=6): """Stringification of an array. Parameters ---------- a : `array-like` The array to print. nprint : int, optional Maximum number of elements to print per axis in ``a``. For larger arrays, a summary is printed, with ``nprint // 2`` elements on each side and ``...`` in the middle (per axis). Examples -------- Printing 1D arrays: >>> print(array_str(np.arange(4))) [0, 1, 2, 3] >>> print(array_str(np.arange(10))) [0, 1, 2, ..., 7, 8, 9] >>> print(array_str(np.arange(10), nprint=10)) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] For 2D and higher, the ``nprint`` limitation applies per axis: >>> print(array_str(np.arange(24).reshape(4, 6))) [[ 0, 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10, 11], [12, 13, 14, 15, 16, 17], [18, 19, 20, 21, 22, 23]] >>> print(array_str(np.arange(32).reshape(4, 8))) [[ 0, 1, 2, ..., 5, 6, 7], [ 8, 9, 10, ..., 13, 14, 15], [16, 17, 18, ..., 21, 22, 23], [24, 25, 26, ..., 29, 30, 31]] >>> print(array_str(np.arange(32).reshape(8, 4))) [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], ..., [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]] >>> print(array_str(np.arange(64).reshape(8, 8))) [[ 0, 1, 2, ..., 5, 6, 7], [ 8, 9, 10, ..., 13, 14, 15], [16, 17, 18, ..., 21, 22, 23], ..., [40, 41, 42, ..., 45, 46, 47], [48, 49, 50, ..., 53, 54, 55], [56, 57, 58, ..., 61, 62, 63]] Printing of empty arrays and 0D arrays: >>> print(array_str(np.array([]))) # 1D, size=0 [] >>> print(array_str(np.array(1.0))) # 0D, size=1 1.0 Small deviations from round numbers will be suppressed: >>> # 2.0000000000000004 in double precision >>> print(array_str((np.array([2.0]) ** 0.5) ** 2)) [ 2.] """ a = np.asarray(a) max_shape = tuple(n if n < nprint else nprint for n in a.shape) with npy_printoptions(threshold=int(np.prod(max_shape)), edgeitems=nprint // 2, suppress=True): a_str = np.array2string(a, separator=', ') return a_str
[docs]def dtype_repr(dtype): """Stringify ``dtype`` for ``repr`` with default for int and float.""" dtype = np.dtype(dtype) if dtype == np.dtype(int): return "'int'" elif dtype == np.dtype(float): return "'float'" elif dtype == np.dtype(complex): return "'complex'" elif dtype.shape: return "('{}', {})".format(dtype.base, dtype.shape) else: return "'{}'".format(dtype)
[docs]def dtype_str(dtype): """Stringify ``dtype`` for ``str`` with default for int and float.""" dtype = np.dtype(dtype) if dtype == np.dtype(int): return 'int' elif dtype == np.dtype(float): return 'float' elif dtype == np.dtype(complex): return 'complex' elif dtype.shape: return "('{}', {})".format(dtype.base, dtype.shape) else: return '{}'.format(dtype)
[docs]def cache_arguments(function): """Decorate function to cache the result with given arguments. This is equivalent to `functools.lru_cache` with Python 3, and currently does nothing with Python 2 but this may change at some later point. Parameters ---------- function : `callable` Function that should be wrapped. """ try: from functools import lru_cache return lru_cache()(function) except ImportError: return function
@cache_arguments def is_numeric_dtype(dtype): """Return ``True`` if ``dtype`` is a numeric type.""" dtype = np.dtype(dtype) return np.issubsctype(getattr(dtype, 'base', None), np.number) @cache_arguments def is_int_dtype(dtype): """Return ``True`` if ``dtype`` is an integer type.""" dtype = np.dtype(dtype) return np.issubsctype(getattr(dtype, 'base', None), np.integer) @cache_arguments def is_floating_dtype(dtype): """Return ``True`` if ``dtype`` is a floating point type.""" return is_real_floating_dtype(dtype) or is_complex_floating_dtype(dtype) @cache_arguments def is_real_dtype(dtype): """Return ``True`` if ``dtype`` is a real (including integer) type.""" return is_numeric_dtype(dtype) and not is_complex_floating_dtype(dtype) @cache_arguments def is_real_floating_dtype(dtype): """Return ``True`` if ``dtype`` is a real floating point type.""" dtype = np.dtype(dtype) return np.issubsctype(getattr(dtype, 'base', None), np.floating) @cache_arguments def is_complex_floating_dtype(dtype): """Return ``True`` if ``dtype`` is a complex floating point type.""" dtype = np.dtype(dtype) return np.issubsctype(getattr(dtype, 'base', None), np.complexfloating)
[docs]def real_dtype(dtype, default=None): """Return the real counterpart of ``dtype`` if existing. Parameters ---------- dtype : Real or complex floating point data type. It can be given in any way the `numpy.dtype` constructor understands. default : Object to be returned if no real counterpart is found for ``dtype``, except for ``None``, in which case an error is raised. Returns ------- real_dtype : `numpy.dtype` The real counterpart of ``dtype``. Raises ------ ValueError if there is no real counterpart to the given data type and ``default == None``. See Also -------- complex_dtype Examples -------- Convert scalar dtypes: >>> real_dtype(complex) dtype('float64') >>> real_dtype('complex64') dtype('float32') >>> real_dtype(float) dtype('float64') Dtypes with shape are also supported: >>> real_dtype(np.dtype((complex, (3,)))) dtype(('<f8', (3,))) >>> real_dtype(('complex64', (3,))) dtype(('<f4', (3,))) """ dtype, dtype_in = np.dtype(dtype), dtype if is_real_floating_dtype(dtype): return dtype try: real_base_dtype = TYPE_MAP_C2R[dtype.base] except KeyError: if default is not None: return default else: raise ValueError('no real counterpart exists for `dtype` {}' ''.format(dtype_repr(dtype_in))) else: return np.dtype((real_base_dtype, dtype.shape))
[docs]def complex_dtype(dtype, default=None): """Return complex counterpart of ``dtype`` if existing, else ``default``. Parameters ---------- dtype : Real or complex floating point data type. It can be given in any way the `numpy.dtype` constructor understands. default : Object to be returned if no complex counterpart is found for ``dtype``, except for ``None``, in which case an error is raised. Returns ------- complex_dtype : `numpy.dtype` The complex counterpart of ``dtype``. Raises ------ ValueError if there is no complex counterpart to the given data type and ``default == None``. Examples -------- Convert scalar dtypes: >>> complex_dtype(float) dtype('complex128') >>> complex_dtype('float32') dtype('complex64') >>> complex_dtype(complex) dtype('complex128') Dtypes with shape are also supported: >>> complex_dtype(np.dtype((float, (3,)))) dtype(('<c16', (3,))) >>> complex_dtype(('float32', (3,))) dtype(('<c8', (3,))) """ dtype, dtype_in = np.dtype(dtype), dtype if is_complex_floating_dtype(dtype): return dtype try: complex_base_dtype = TYPE_MAP_R2C[dtype.base] except KeyError: if default is not None: return default else: raise ValueError('no complex counterpart exists for `dtype` {}' ''.format(dtype_repr(dtype_in))) else: return np.dtype((complex_base_dtype, dtype.shape))
[docs]def is_string(obj): """Return ``True`` if ``obj`` behaves like a string, ``False`` else.""" try: obj + '' except TypeError: return False else: return True
[docs]def nd_iterator(shape): """Iterator over n-d cube with shape. Parameters ---------- shape : sequence of int The number of points per axis Returns ------- nd_iterator : generator Generator returning tuples of integers of length ``len(shape)``. Examples -------- >>> for pt in nd_iterator([2, 2]): ... print(pt) (0, 0) (0, 1) (1, 0) (1, 1) """ return product(*map(range, shape))
[docs]def conj_exponent(exp): """Conjugate exponent ``exp / (exp - 1)``. Parameters ---------- exp : positive float or inf Exponent for which to calculate the conjugate. Must be at least 1.0. Returns ------- conj : positive float or inf Conjugate exponent. For ``exp=1``, return ``float('inf')``, for ``exp=float('inf')`` return 1. In all other cases, return ``exp / (exp - 1)``. """ if exp == 1.0: return float('inf') elif exp == float('inf'): return 1.0 else: return exp / (exp - 1.0)
[docs]@contextmanager def nullcontext(enter_result=None): """Backport of the Python >=3.7 trivial context manager. See `the Python documentation <https://docs.python.org/3/library/contextlib.html#contextlib.nullcontext>`_ for details. """ try: yield enter_result finally: pass
try: nullcontext = contextlib.nullcontext except AttributeError: pass
[docs]@contextmanager def writable_array(obj, **kwargs): """Context manager that casts obj to a `numpy.array` and saves changes. Parameters ---------- obj : `array-like` Object that should be made available as writable array. It must be valid as input to `numpy.asarray` and needs to support the syntax ``obj[:] = arr``. kwargs : Keyword arguments that should be passed to `numpy.asarray`. Examples -------- Convert list to array and use with numpy: >>> lst = [1, 2, 3] >>> with writable_array(lst) as arr: ... arr *= 2 >>> lst [2, 4, 6] Usage with ODL vectors: >>> space = odl.uniform_discr(0, 1, 3) >>> x = space.element([1, 2, 3]) >>> with writable_array(x) as arr: ... arr += [1, 1, 1] >>> x uniform_discr(0.0, 1.0, 3).element([ 2., 3., 4.]) Additional keyword arguments are passed to `numpy.asarray`: >>> lst = [1, 2, 3] >>> with writable_array(lst, dtype='complex') as arr: ... print(arr) [ 1.+0.j 2.+0.j 3.+0.j] Note that the changes are only saved upon exiting the context manger exits. Before, the input object is unchanged: >>> lst = [1, 2, 3] >>> with writable_array(lst) as arr: ... arr *= 2 ... print(lst) [1, 2, 3] >>> print(lst) [2, 4, 6] """ arr = None try: arr = np.asarray(obj, **kwargs) yield arr finally: if arr is not None: obj[:] = arr
[docs]def signature_string(posargs, optargs, sep=', ', mod='!r'): """Return a stringified signature from given arguments. Parameters ---------- posargs : sequence Positional argument values, always included in the returned string. They appear in the string as (roughly):: sep.join(str(arg) for arg in posargs) optargs : sequence of 3-tuples Optional arguments with names and defaults, given in the form:: [(name1, value1, default1), (name2, value2, default2), ...] Only those parameters that are different from the given default are included as ``name=value`` keyword pairs. **Note:** The comparison is done by using ``if value == default:``, which is not valid for, e.g., NumPy arrays. sep : string or sequence of strings, optional Separator(s) for the argument strings. A provided single string is used for all joining operations. A given sequence must have 3 entries ``pos_sep, opt_sep, part_sep``. The ``pos_sep`` and ``opt_sep`` strings are used for joining the respective sequences of argument strings, and ``part_sep`` joins these two joined strings. mod : string or callable or sequence, optional Format modifier(s) for the argument strings. In its most general form, ``mod`` is a sequence of 2 sequences ``pos_mod, opt_mod`` with ``len(pos_mod) == len(posargs)`` and ``len(opt_mod) == len(optargs)``. Each entry ``m`` in those sequences can be eiter a string, resulting in the following stringification of ``arg``:: arg_fmt = {{{}}}.format(m) arg_str = arg_fmt.format(arg) For a callable ``to_str``, the stringification is simply ``arg_str = to_str(arg)``. The entries ``pos_mod, opt_mod`` of ``mod`` can also be strings or callables instead of sequences, in which case the modifier applies to all corresponding arguments. Finally, if ``mod`` is a string or callable, it is applied to all arguments. The default behavior is to apply the "{!r}" (``repr``) conversion. For floating point scalars, the number of digits printed is determined by the ``precision`` value in NumPy's printing options, which can be temporarily modified with `npy_printoptions`. Returns ------- signature : string Stringification of a signature, typically used in the form:: '{}({})'.format(self.__class__.__name__, signature) Examples -------- Usage with non-trivial entries in both sequences, with a typical use case: >>> posargs = [1, 'hello', None] >>> optargs = [('dtype', 'float32', 'float64')] >>> signature_string(posargs, optargs) "1, 'hello', None, dtype='float32'" >>> '{}({})'.format('MyClass', signature_string(posargs, optargs)) "MyClass(1, 'hello', None, dtype='float32')" Empty sequences and optargs values equal to default are omitted: >>> posargs = ['hello'] >>> optargs = [('size', 1, 1)] >>> signature_string(posargs, optargs) "'hello'" >>> posargs = [] >>> optargs = [('size', 2, 1)] >>> signature_string(posargs, optargs) 'size=2' >>> posargs = [] >>> optargs = [('size', 1, 1)] >>> signature_string(posargs, optargs) '' Using a different separator, globally or per argument "category": >>> posargs = [1, 'hello', None] >>> optargs = [('dtype', 'float32', 'float64'), ... ('order', 'F', 'C')] >>> signature_string(posargs, optargs) "1, 'hello', None, dtype='float32', order='F'" >>> signature_string(posargs, optargs, sep=(',', ',', ', ')) "1,'hello',None, dtype='float32',order='F'" Using format modifiers: >>> posargs = ['hello', 2.345] >>> optargs = [('extent', 1.442, 1.0), ('spacing', 0.0151, 1.0)] >>> signature_string(posargs, optargs) "'hello', 2.345, extent=1.442, spacing=0.0151" >>> # Print only two significant digits for all arguments. >>> # NOTE: this also affects the string! >>> mod = ':.2' >>> signature_string(posargs, optargs, mod=mod) 'he, 2.3, extent=1.4, spacing=0.015' >>> mod = [['', ''], [':.3', ':.2']] # one modifier per argument >>> signature_string(posargs, optargs, mod=mod) "'hello', 2.345, extent=1.44, spacing=0.015" Using callables for stringification: >>> posargs = ['arg1', np.ones(3)] >>> optargs = [] >>> signature_string(posargs, optargs, mod=[['', array_str], []]) "'arg1', [ 1., 1., 1.]" The number of printed digits in floating point numbers can be changed with `npy_printoptions`: >>> posargs = ['hello', 0.123456789012345] >>> optargs = [('extent', 1.234567890123456, 1.0)] >>> signature_string(posargs, optargs) # default is 8 digits "'hello', 0.12345679, extent=1.2345679" >>> with npy_printoptions(precision=2): ... sig_str = signature_string(posargs, optargs) >>> sig_str "'hello', 0.12, extent=1.2" """ # Define the separators for the two possible cases if is_string(sep): pos_sep = opt_sep = part_sep = sep else: pos_sep, opt_sep, part_sep = sep # Get the stringified parts posargs_conv, optargs_conv = signature_string_parts(posargs, optargs, mod) # Join the arguments using the separators parts = [] if posargs_conv: parts.append(pos_sep.join(argstr for argstr in posargs_conv)) if optargs_conv: parts.append(opt_sep.join(optargs_conv)) return part_sep.join(parts)
[docs]def signature_string_parts(posargs, optargs, mod='!r'): """Return stringified arguments as tuples. Parameters ---------- posargs : sequence Positional argument values, always included in the returned string tuple. optargs : sequence of 3-tuples Optional arguments with names and defaults, given in the form:: [(name1, value1, default1), (name2, value2, default2), ...] Only those parameters that are different from the given default are included as ``name=value`` keyword pairs. **Note:** The comparison is done by using ``if value == default:``, which is not valid for, e.g., NumPy arrays. mod : string or callable or sequence, optional Format modifier(s) for the argument strings. In its most general form, ``mod`` is a sequence of 2 sequences ``pos_mod, opt_mod`` with ``len(pos_mod) == len(posargs)`` and ``len(opt_mod) == len(optargs)``. Each entry ``m`` in those sequences can be a string, resulting in the following stringification of ``arg``:: arg_fmt = {{{}}}.format(m) arg_str = arg_fmt.format(arg) For a callable ``to_str``, the stringification is simply ``arg_str = to_str(arg)``. The entries ``pos_mod, opt_mod`` of ``mod`` can also be strings or callables instead of sequences, in which case the modifier applies to all corresponding arguments. Finally, if ``mod`` is a string or callable, it is applied to all arguments. The default behavior is to apply the "{!r}" (``repr``) conversion. For floating point scalars, the number of digits printed is determined by the ``precision`` value in NumPy's printing options, which can be temporarily modified with `npy_printoptions`. Returns ------- pos_strings : tuple of str The stringified positional arguments. opt_strings : tuple of str The stringified optional arguments, not including the ones equal to their respective defaults. """ # Convert modifiers to 2-sequence of sequence of strings if is_string(mod) or callable(mod): pos_mod = opt_mod = mod else: pos_mod, opt_mod = mod mods = [] for m, args in zip((pos_mod, opt_mod), (posargs, optargs)): if is_string(m) or callable(m): mods.append([m] * len(args)) else: if len(m) == 1: mods.append(m * len(args)) elif len(m) == len(args): mods.append(m) else: raise ValueError('sequence length mismatch: ' 'len({}) != len({})'.format(m, args)) pos_mod, opt_mod = mods precision = np.get_printoptions()['precision'] # Stringify values, treating strings specially posargs_conv = [] for arg, modifier in zip(posargs, pos_mod): if callable(modifier): posargs_conv.append(modifier(arg)) elif is_string(arg): # Preserve single quotes for strings by default if modifier: fmt = '{{{}}}'.format(modifier) else: fmt = "'{}'" posargs_conv.append(fmt.format(arg)) elif np.isscalar(arg) and str(arg) in ('inf', 'nan'): # Make sure the string quotes are added posargs_conv.append("'{}'".format(arg)) elif (np.isscalar(arg) and np.array(arg).real.astype('int64') != arg and modifier in ('', '!s', '!r')): # Floating point value, use numpy print option 'precision' fmt = '{{:.{}}}'.format(precision) posargs_conv.append(fmt.format(arg)) else: # All non-string types are passed through a format conversion fmt = '{{{}}}'.format(modifier) posargs_conv.append(fmt.format(arg)) # Build 'key=value' strings for values that are not equal to default optargs_conv = [] for (name, value, default), modifier in zip(optargs, opt_mod): if value == default: # Don't include continue # See above on str and repr if callable(modifier): optargs_conv.append('{}={}'.format(name, modifier(value))) elif is_string(value): if modifier: fmt = '{{{}}}'.format(modifier) else: fmt = "'{}'" value_str = fmt.format(value) optargs_conv.append('{}={}'.format(name, value_str)) elif np.isscalar(value) and str(value) in ('inf', 'nan'): # Make sure the string quotes are added optargs_conv.append("{}='{}'".format(name, value)) elif (np.isscalar(value) and np.array(value).real.astype('int64') != value and modifier in ('', '!s', '!r')): fmt = '{{:.{}}}'.format(precision) value_str = fmt.format(value) optargs_conv.append('{}={}'.format(name, value_str)) else: fmt = '{{{}}}'.format(modifier) value_str = fmt.format(value) optargs_conv.append('{}={}'.format(name, value_str)) return tuple(posargs_conv), tuple(optargs_conv)
def _separators(strings, linewidth): """Return separators that keep joined strings within the line width.""" if len(strings) <= 1: return () indent_len = 4 separators = [] cur_line_len = indent_len + len(strings[0]) + 1 if cur_line_len + 2 <= linewidth and '\n' not in strings[0]: # Next string might fit on same line separators.append(', ') cur_line_len += 1 # for the extra space else: # Use linebreak if string contains newline or doesn't fit separators.append(',\n') cur_line_len = indent_len for i, s in enumerate(strings[1:-1]): cur_line_len += len(s) + 1 if '\n' in s: # Use linebreak before and after if string contains newline separators[i] = ',\n' cur_line_len = indent_len separators.append(',\n') elif cur_line_len + 2 <= linewidth: # This string fits, next one might also fit on same line separators.append(', ') cur_line_len += 1 # for the extra space elif cur_line_len <= linewidth: # This string fits, but next one won't separators.append(',\n') cur_line_len = indent_len else: # This string doesn't fit but has no newlines in it separators[i] = ',\n' cur_line_len = indent_len + len(s) + 1 # Need to determine again what should come next if cur_line_len + 2 <= linewidth: # Next string might fit on same line separators.append(', ') else: separators.append(',\n') cur_line_len += len(strings[-1]) if cur_line_len + 1 > linewidth or '\n' in strings[-1]: # This string and a comma don't fit on this line separators[-1] = ',\n' return tuple(separators)
[docs]def repr_string(outer_string, inner_strings, allow_mixed_seps=True): r"""Return a pretty string for ``repr``. The returned string is formatted such that it does not extend beyond the line boundary if avoidable. The line width is taken from NumPy's printing options that can be retrieved with `numpy.get_printoptions`. They can be temporarily overridden using the `npy_printoptions` context manager. See Examples for details. Parameters ---------- outer_string : str Name of the class or function that should be printed outside the parentheses. inner_strings : sequence of sequence of str Stringifications of the positional and optional arguments. This is usually the return value of `signature_string_parts`. allow_mixed_seps : bool, optional If ``False`` and the string does not fit on one line, use ``',\n'`` to separate all strings. By default, a mixture of ``', '`` and ``',\n'`` is used to fit as much on one line as possible. In case some of the ``inner_strings`` span multiple lines, it is usually advisable to set ``allow_mixed_seps`` to ``False`` since the result tends to be more readable that way. Returns ------- repr_string : str Full string that can be returned by a class' ``__repr__`` method. Examples -------- Things that fit into one line are printed on one line: >>> outer_string = 'MyClass' >>> inner_strings = [('1', "'hello'", 'None'), ... ("dtype='float32'",)] >>> print(repr_string(outer_string, inner_strings)) MyClass(1, 'hello', None, dtype='float32') Otherwise, if a part of ``inner_strings`` fits on a line of its own, it is printed on one line, but separated from the other part with a line break: >>> outer_string = 'MyClass' >>> inner_strings = [('2.0', "'this_is_a_very_long_argument_string'"), ... ("long_opt_arg='another_quite_long_string'",)] >>> print(repr_string(outer_string, inner_strings)) MyClass( 2.0, 'this_is_a_very_long_argument_string', long_opt_arg='another_quite_long_string' ) If those parts are themselves too long, they are broken down into several lines: >>> outer_string = 'MyClass' >>> inner_strings = [("'this_is_a_very_long_argument_string'", ... "'another_very_long_argument_string'"), ... ("long_opt_arg='another_quite_long_string'", ... "long_opt2_arg='this_wont_fit_on_one_line_either'")] >>> print(repr_string(outer_string, inner_strings)) MyClass( 'this_is_a_very_long_argument_string', 'another_very_long_argument_string', long_opt_arg='another_quite_long_string', long_opt2_arg='this_wont_fit_on_one_line_either' ) The usage of mixed separators to optimally use horizontal space can be disabled by setting ``allow_mixed_seps=False``: >>> outer_string = 'MyClass' >>> inner_strings = [('2.0', "'this_is_a_very_long_argument_string'"), ... ("long_opt_arg='another_quite_long_string'",)] >>> print(repr_string(outer_string, inner_strings, allow_mixed_seps=False)) MyClass( 2.0, 'this_is_a_very_long_argument_string', long_opt_arg='another_quite_long_string' ) With the ``npy_printoptions`` context manager, the available line width can be changed: >>> outer_string = 'MyClass' >>> inner_strings = [('1', "'hello'", 'None'), ... ("dtype='float32'",)] >>> with npy_printoptions(linewidth=20): ... print(repr_string(outer_string, inner_strings)) MyClass( 1, 'hello', None, dtype='float32' ) """ linewidth = np.get_printoptions()['linewidth'] pos_strings, opt_strings = inner_strings # Length of the positional and optional argument parts of the signature, # including separators `', '` pos_sig_len = (sum(len(pstr) for pstr in pos_strings) + 2 * max((len(pos_strings) - 1), 0)) opt_sig_len = (sum(len(pstr) for pstr in opt_strings) + 2 * max((len(opt_strings) - 1), 0)) # Length of the one-line string, including 2 for the parentheses and # 2 for the joining ', ' repr_len = len(outer_string) + 2 + pos_sig_len + 2 + opt_sig_len if repr_len <= linewidth and not any('\n' in s for s in pos_strings + opt_strings): # Everything fits on one line fmt = '{}({})' pos_str = ', '.join(pos_strings) opt_str = ', '.join(opt_strings) parts_sep = ', ' else: # Need to split lines in some way fmt = '{}(\n{}\n)' if not allow_mixed_seps: pos_separators = [',\n'] * (len(pos_strings) - 1) else: pos_separators = _separators(pos_strings, linewidth) if len(pos_strings) == 0: pos_str = '' else: pos_str = pos_strings[0] for s, sep in zip(pos_strings[1:], pos_separators): pos_str = sep.join([pos_str, s]) if not allow_mixed_seps: opt_separators = [',\n'] * (len(opt_strings) - 1) else: opt_separators = _separators(opt_strings, linewidth) if len(opt_strings) == 0: opt_str = '' else: opt_str = opt_strings[0] for s, sep in zip(opt_strings[1:], opt_separators): opt_str = sep.join([opt_str, s]) # Check if we can put both parts on one line. This requires their # concatenation including 4 for indentation and 2 for ', ' to # be less than the line width. And they should contain no newline. if pos_str and opt_str: inner_len = 4 + len(pos_str) + 2 + len(opt_str) elif (pos_str and not opt_str) or (opt_str and not pos_str): inner_len = 4 + len(pos_str) + len(opt_str) else: inner_len = 0 if (not allow_mixed_seps or any('\n' in s for s in [pos_str, opt_str]) or inner_len > linewidth): parts_sep = ',\n' pos_str = indent(pos_str) opt_str = indent(opt_str) else: parts_sep = ', ' pos_str = indent(pos_str) # Don't indent `opt_str` parts = [s for s in [pos_str, opt_str] if s.strip()] # ignore empty inner_string = parts_sep.join(parts) return fmt.format(outer_string, inner_string)
[docs]def attribute_repr_string(inst_str, attr_str): """Return a repr string for an attribute that respects line width. Parameters ---------- inst_str : str Stringification of a class instance. attr_str : str Name of the attribute (not including the ``'.'``). Returns ------- attr_repr_str : str Concatenation of the two strings in a way that the line width is respected. Examples -------- >>> inst_str = 'rn((2, 3))' >>> attr_str = 'byaxis' >>> print(attribute_repr_string(inst_str, attr_str)) rn((2, 3)).byaxis >>> inst_str = 'MyClass()' >>> attr_str = 'attr_name' >>> print(attribute_repr_string(inst_str, attr_str)) MyClass().attr_name >>> inst_str = 'MyClass' >>> attr_str = 'class_attr' >>> print(attribute_repr_string(inst_str, attr_str)) MyClass.class_attr >>> long_inst_str = ( ... "MyClass('long string that will definitely trigger a line break')" ... ) >>> long_attr_str = 'long_attribute_name' >>> print(attribute_repr_string(long_inst_str, long_attr_str)) MyClass( 'long string that will definitely trigger a line break' ).long_attribute_name """ linewidth = np.get_printoptions()['linewidth'] if (len(inst_str) + 1 + len(attr_str) <= linewidth or '(' not in inst_str): # Instance string + dot + attribute string fit in one line or # no parentheses -> keep instance string as-is and append attr string parts = [inst_str, attr_str] else: # TODO(kohr-h): use `maxsplit=1` kwarg, not supported in Py 2 left, rest = inst_str.split('(', 1) right, middle = rest[::-1].split(')', 1) middle, right = middle[::-1], right[::-1] if middle.startswith('\n') and middle.endswith('\n'): # Already on multiple lines new_inst_str = inst_str else: init_parts = [left] if middle: init_parts.append(indent(middle)) new_inst_str = '(\n'.join(init_parts) + '\n)' + right parts = [new_inst_str, attr_str] return '.'.join(parts)
[docs]def method_repr_string(inst_str, meth_str, arg_strs=None, allow_mixed_seps=True): r"""Return a repr string for a method that respects line width. This function is useful to generate a ``repr`` string for a derived class that is created through a method, for instance :: functional.translated(x) as a better way of representing :: FunctionalTranslation(functional, x) Parameters ---------- inst_str : str Stringification of a class instance. meth_str : str Name of the method (not including the ``'.'``). arg_strs : sequence of str, optional Stringification of the arguments to the method. allow_mixed_seps : bool, optional If ``False`` and the argument strings do not fit on one line, use ``',\n'`` to separate all strings. By default, a mixture of ``', '`` and ``',\n'`` is used to fit as much on one line as possible. In case some of the ``arg_strs`` span multiple lines, it is usually advisable to set ``allow_mixed_seps`` to ``False`` since the result tends to be more readable that way. Returns ------- meth_repr_str : str Concatenation of all strings in a way that the line width is respected. Examples -------- >>> inst_str = 'MyClass' >>> meth_str = 'empty' >>> arg_strs = [] >>> print(method_repr_string(inst_str, meth_str, arg_strs)) MyClass.empty() >>> inst_str = 'MyClass' >>> meth_str = 'fromfile' >>> arg_strs = ["'tmpfile.txt'"] >>> print(method_repr_string(inst_str, meth_str, arg_strs)) MyClass.fromfile('tmpfile.txt') >>> inst_str = "MyClass('init string')" >>> meth_str = 'method' >>> arg_strs = ['2.0'] >>> print(method_repr_string(inst_str, meth_str, arg_strs)) MyClass('init string').method(2.0) >>> long_inst_str = ( ... "MyClass('long string that will definitely trigger a line break')" ... ) >>> meth_str = 'method' >>> long_arg1 = "'long argument string that should come on the next line'" >>> arg2 = 'param1=1' >>> arg3 = 'param2=2.0' >>> arg_strs = [long_arg1, arg2, arg3] >>> print(method_repr_string(long_inst_str, meth_str, arg_strs)) MyClass( 'long string that will definitely trigger a line break' ).method( 'long argument string that should come on the next line', param1=1, param2=2.0 ) >>> print(method_repr_string(long_inst_str, meth_str, arg_strs, ... allow_mixed_seps=False)) MyClass( 'long string that will definitely trigger a line break' ).method( 'long argument string that should come on the next line', param1=1, param2=2.0 ) """ linewidth = np.get_printoptions()['linewidth'] # Part up to the method name if (len(inst_str) + 1 + len(meth_str) + 1 <= linewidth or '(' not in inst_str): init_parts = [inst_str, meth_str] # Length of the line to the end of the method name meth_line_start_len = len(inst_str) + 1 + len(meth_str) else: # TODO(kohr-h): use `maxsplit=1` kwarg, not supported in Py 2 left, rest = inst_str.split('(', 1) right, middle = rest[::-1].split(')', 1) middle, right = middle[::-1], right[::-1] if middle.startswith('\n') and middle.endswith('\n'): # Already on multiple lines new_inst_str = inst_str else: new_inst_str = '(\n'.join([left, indent(middle)]) + '\n)' + right # Length of the line to the end of the method name, consisting of # ')' + '.' + <method name> meth_line_start_len = 1 + 1 + len(meth_str) init_parts = [new_inst_str, meth_str] # Method call part arg_str_oneline = ', '.join(arg_strs) if meth_line_start_len + 1 + len(arg_str_oneline) + 1 <= linewidth: meth_call_str = '(' + arg_str_oneline + ')' elif not arg_str_oneline: meth_call_str = '(\n)' else: if allow_mixed_seps: arg_seps = _separators(arg_strs, linewidth - 4) # indented else: arg_seps = [',\n'] * (len(arg_strs) - 1) full_arg_str = '' for arg_str, sep in zip_longest(arg_strs, arg_seps, fillvalue=''): full_arg_str += arg_str + sep meth_call_str = '(\n' + indent(full_arg_str) + '\n)' return '.'.join(init_parts) + meth_call_str
[docs]def run_from_ipython(): """If the process is run from IPython.""" return '__IPYTHON__' in globals()
[docs]def pkg_supports(feature, pkg_version, pkg_feat_dict): """Return bool indicating whether a package supports ``feature``. Parameters ---------- feature : str Name of a potential feature of a package. pkg_version : str Version of the package that should be checked for presence of the feature. pkg_feat_dict : dict Specification of features of a package. Each item has the following form:: feature_name: version_specification Here, ``feature_name`` is a string that is matched against ``feature``, and ``version_specification`` is a string or a sequence of strings that specifies version sets. These specifications are the same as for ``setuptools`` requirements, just without the package name. A ``None`` entry signals "no support in any version", i.e., always ``False``. If a sequence of requirements are given, they are OR-ed together. See ``Examples`` for details. Returns ------- supports : bool ``True`` if ``pkg_version`` of the package in question supports ``feature``, ``False`` otherwise. Examples -------- >>> feat_dict = { ... 'feat1': '==0.5.1', ... 'feat2': '>0.6, <=0.9', # both required simultaneously ... 'feat3': ['>0.6', '<=0.9'], # only one required, i.e. always True ... 'feat4': ['==0.5.1', '>0.6, <=0.9'], ... 'feat5': None ... } >>> pkg_supports('feat1', '0.5.1', feat_dict) True >>> pkg_supports('feat1', '0.4', feat_dict) False >>> pkg_supports('feat2', '0.5.1', feat_dict) False >>> pkg_supports('feat2', '0.6.1', feat_dict) True >>> pkg_supports('feat2', '0.9', feat_dict) True >>> pkg_supports('feat2', '1.0', feat_dict) False >>> pkg_supports('feat3', '0.4', feat_dict) True >>> pkg_supports('feat3', '1.0', feat_dict) True >>> pkg_supports('feat4', '0.5.1', feat_dict) True >>> pkg_supports('feat4', '0.6', feat_dict) False >>> pkg_supports('feat4', '0.6.1', feat_dict) True >>> pkg_supports('feat4', '1.0', feat_dict) False >>> pkg_supports('feat5', '0.6.1', feat_dict) False >>> pkg_supports('feat5', '1.0', feat_dict) False """ from pkg_resources import parse_requirements feature = str(feature) pkg_version = str(pkg_version) supp_versions = pkg_feat_dict.get(feature, None) if supp_versions is None: return False # Make sequence from single string if is_string(supp_versions): supp_versions = [supp_versions] # Make valid package requirements ver_specs = ['pkg' + supp_ver for supp_ver in supp_versions] # Each parse_requirements list contains only one entry since we specify # only one package ver_reqs = [list(parse_requirements(ver_spec))[0] for ver_spec in ver_specs] # If one of the requirements in the list is met, return True for req in ver_reqs: if req.specifier.contains(pkg_version, prereleases=True): return True # No match return False
[docs]@contextmanager def npy_random_seed(seed): """Context manager to temporarily set the NumPy random generator seed. Parameters ---------- seed : int or None Seed value for the random number generator. ``None`` is interpreted as keeping the current seed. Examples -------- Use this to make drawing pseudo-random numbers repeatable: >>> with npy_random_seed(42): ... rand_int = np.random.randint(10) >>> with npy_random_seed(42): ... same_rand_int = np.random.randint(10) >>> rand_int == same_rand_int True """ do_seed = seed is not None orig_rng_state = None try: if do_seed: orig_rng_state = np.random.get_state() np.random.seed(seed) yield finally: if do_seed and orig_rng_state is not None: np.random.set_state(orig_rng_state)
[docs]def unique(seq): """Return the unique values in a sequence. Parameters ---------- seq : sequence Sequence with (possibly duplicate) elements. Returns ------- unique : list Unique elements of ``seq``. Order is guaranteed to be the same as in seq. Examples -------- Determine unique elements in list >>> unique([1, 2, 3, 3]) [1, 2, 3] >>> unique((1, 'str', 'str')) [1, 'str'] The utility also works with unhashable types: >>> unique((1, [1], [1])) [1, [1]] """ # First check if all elements are hashable, if so O(n) can be done try: return list(OrderedDict.fromkeys(seq)) except TypeError: # Non-hashable, resort to O(n^2) unique_values = [] for i in seq: if i not in unique_values: unique_values.append(i) return unique_values
if __name__ == '__main__': from odl.util.testutils import run_doctests run_doctests()