Source code for odl.util.graphics

# Copyright 2014-2017 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Functions for graphical output."""

from __future__ import print_function, division, absolute_import
import numpy as np
import warnings

from odl.util.testutils import run_doctests
from odl.util.utility import is_real_dtype


__all__ = ('show_discrete_data',)


[docs]def warning_free_pause(): """Issue a matplotlib pause without the warning.""" import matplotlib.pyplot as plt with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="Using default event loop until " "function specific to this GUI is " "implemented") plt.pause(0.0001)
def _safe_minmax(values): """Calculate min and max of array with guards for nan and inf.""" # Nan and inf guarded min and max isfinite = np.isfinite(values) if np.any(isfinite): # Only use finite values values = values[isfinite] minval = np.min(values) maxval = np.max(values) return minval, maxval def _colorbar_ticks(minval, maxval): """Return the ticks (values show) in the colorbar.""" if not (np.isfinite(minval) and np.isfinite(maxval)): return [0, 0, 0] elif minval == maxval: return [minval] else: # Add eps to ensure values stay inside the range of the colorbar. # Otherwise they may occationally not display. eps = (maxval - minval) / 1e5 return [minval + eps, (maxval + minval) / 2., maxval - eps] def _digits(minval, maxval): """Digits needed to comforatbly display values in [minval, maxval]""" if minval == maxval: return 3 else: return min(10, max(2, int(1 + abs(np.log10(maxval - minval))))) def _colorbar_format(minval, maxval): """Return the format string for the colorbar.""" if not (np.isfinite(minval) and np.isfinite(maxval)): return str(maxval) else: return '%.{}f'.format(_digits(minval, maxval)) def _axes_info(grid, npoints=5): result = [] min_pt = grid.min() max_pt = grid.max() for axis in range(grid.ndim): xmin = min_pt[axis] xmax = max_pt[axis] points = np.linspace(xmin, xmax, npoints) indices = np.linspace(0, grid.shape[axis] - 1, npoints, dtype=int) tick_values = grid.coord_vectors[axis][indices] # Do not use corner point in case of a partition, use outer corner tick_values[[0, -1]] = xmin, xmax format_str = '{:.' + str(_digits(xmin, xmax)) + 'f}' tick_labels = [format_str.format(f) for f in tick_values] result += [(points, tick_labels)] return result
[docs]def show_discrete_data(values, grid, title=None, method='', force_show=False, fig=None, **kwargs): """Display a discrete 1d or 2d function. Parameters ---------- values : `numpy.ndarray` The values to visualize. grid : `RectGrid` or `RectPartition` Grid of the values. title : string, optional Set the title of the figure. method : string, optional 1d methods: 'plot' : graph plot 'scatter' : scattered 2d points (2nd axis <-> value) 2d methods: 'imshow' : image plot with coloring according to value, including a colorbar. 'scatter' : cloud of scattered 3d points (3rd axis <-> value) 'wireframe', 'plot_wireframe' : surface plot force_show : bool, optional Whether the plot should be forced to be shown now or deferred until later. Note that some backends always displays the plot, regardless of this value. fig : `matplotlib.figure.Figure`, optional The figure to show in. Expected to be of same "style", as the figure given by this function. The most common usecase is that fig is the return value from an earlier call to this function. Default: New figure interp : {'nearest', 'linear'}, optional Interpolation method to use. Default: 'nearest' axis_labels : string, optional Axis labels, default: ['x', 'y'] update_in_place : bool, optional Update the content of the figure in-place. Intended for faster real time plotting, typically ~5 times faster. This is only performed for ``method == 'imshow'`` with real data and ``fig != None``. Otherwise this parameter is treated as False. Default: False axis_fontsize : int, optional Fontsize for the axes. Default: 16 colorbar : bool, optional Argument relevant for 2d plots using ``method='imshow'``. If ``True``, include a colorbar in the plot. Default: True kwargs : {'figsize', 'saveto', ...}, optional Extra keyword arguments passed on to display method See the Matplotlib functions for documentation of extra options. Returns ------- fig : `matplotlib.figure.Figure` The resulting figure. It is also shown to the user. See Also -------- matplotlib.pyplot.plot : Show graph plot matplotlib.pyplot.imshow : Show data as image matplotlib.pyplot.scatter : Show scattered 3d points """ # Importing pyplot takes ~2 sec, only import when needed. import matplotlib.pyplot as plt args_re = [] args_im = [] dsp_kwargs = {} sub_kwargs = {} arrange_subplots = (121, 122) # horzontal arrangement # Create axis labels which remember their original meaning axis_labels = kwargs.pop('axis_labels', ['x', 'y']) values_are_complex = not is_real_dtype(values.dtype) figsize = kwargs.pop('figsize', None) saveto = kwargs.pop('saveto', None) interp = kwargs.pop('interp', 'nearest') axis_fontsize = kwargs.pop('axis_fontsize', 16) colorbar = kwargs.pop('colorbar', True) # Normalize input interp, interp_in = str(interp).lower(), interp method, method_in = str(method).lower(), method # Check if we should and can update the plot in-place update_in_place = kwargs.pop('update_in_place', False) if (update_in_place and (fig is None or values_are_complex or values.ndim != 2 or (values.ndim == 2 and method not in ('', 'imshow')))): update_in_place = False if values.ndim == 1: # TODO: maybe a plotter class would be better if not method: if interp == 'nearest': method = 'step' dsp_kwargs['where'] = 'mid' elif interp == 'linear': method = 'plot' else: raise ValueError('`interp` {!r} not supported' ''.format(interp_in)) if method == 'plot' or method == 'step' or method == 'scatter': args_re += [grid.coord_vectors[0], values.real] args_im += [grid.coord_vectors[0], values.imag] else: raise ValueError('`method` {!r} not supported' ''.format(method_in)) elif values.ndim == 2: if not method: method = 'imshow' if method == 'imshow': args_re = [np.rot90(values.real)] args_im = [np.rot90(values.imag)] if values_are_complex else [] extent = [grid.min()[0], grid.max()[0], grid.min()[1], grid.max()[1]] if interp == 'nearest': interpolation = 'nearest' elif interp == 'linear': interpolation = 'bilinear' else: raise ValueError('`interp` {!r} not supported' ''.format(interp_in)) dsp_kwargs.update({'interpolation': interpolation, 'cmap': 'bone', 'extent': extent, 'aspect': 'auto'}) elif method == 'scatter': pts = grid.points() args_re = [pts[:, 0], pts[:, 1], values.ravel().real] args_im = ([pts[:, 0], pts[:, 1], values.ravel().imag] if values_are_complex else []) sub_kwargs.update({'projection': '3d'}) elif method in ('wireframe', 'plot_wireframe'): method = 'plot_wireframe' x, y = grid.meshgrid args_re = [x, y, np.rot90(values.real)] args_im = ([x, y, np.rot90(values.imag)] if values_are_complex else []) sub_kwargs.update({'projection': '3d'}) else: raise ValueError('`method` {!r} not supported' ''.format(method_in)) else: raise NotImplementedError('no method for {}d display implemented' ''.format(values.ndim)) # Additional keyword args are passed on to the display method dsp_kwargs.update(**kwargs) if fig is not None: # Reuse figure if given as input if not isinstance(fig, plt.Figure): raise TypeError('`fig` {} not a matplotlib figure'.format(fig)) if not plt.fignum_exists(fig.number): # If figure does not exist, user either closed the figure or # is using IPython, in this case we need a new figure. fig = plt.figure(figsize=figsize) updatefig = False else: # Set current figure to given input fig = plt.figure(fig.number) updatefig = True if values.ndim > 1 and not update_in_place: # If the figure is larger than 1d, we can clear it since we # dont reuse anything. Keeping it causes performance problems. fig.clf() else: fig = plt.figure(figsize=figsize) updatefig = False if values_are_complex: # Real if len(fig.axes) == 0: # Create new axis if needed sub_re = plt.subplot(arrange_subplots[0], **sub_kwargs) sub_re.set_title('Real part') sub_re.set_xlabel(axis_labels[0], fontsize=axis_fontsize) if values.ndim == 2: sub_re.set_ylabel(axis_labels[1], fontsize=axis_fontsize) else: sub_re.set_ylabel('value') else: sub_re = fig.axes[0] display_re = getattr(sub_re, method) csub_re = display_re(*args_re, **dsp_kwargs) # Axis ticks if method == 'imshow' and not grid.is_uniform: (xpts, xlabels), (ypts, ylabels) = _axes_info(grid) plt.xticks(xpts, xlabels) plt.yticks(ypts, ylabels) if method == 'imshow' and len(fig.axes) < 2: # Create colorbar if none seems to exist # Use clim from kwargs if given if 'clim' not in kwargs: minval_re, maxval_re = _safe_minmax(values.real) else: minval_re, maxval_re = kwargs['clim'] ticks_re = _colorbar_ticks(minval_re, maxval_re) fmt_re = _colorbar_format(minval_re, maxval_re) plt.colorbar(csub_re, orientation='horizontal', ticks=ticks_re, format=fmt_re) # Imaginary if len(fig.axes) < 3: sub_im = plt.subplot(arrange_subplots[1], **sub_kwargs) sub_im.set_title('Imaginary part') sub_im.set_xlabel(axis_labels[0], fontsize=axis_fontsize) if values.ndim == 2: sub_im.set_ylabel(axis_labels[1], fontsize=axis_fontsize) else: sub_im.set_ylabel('value') else: sub_im = fig.axes[2] display_im = getattr(sub_im, method) csub_im = display_im(*args_im, **dsp_kwargs) # Axis ticks if method == 'imshow' and not grid.is_uniform: (xpts, xlabels), (ypts, ylabels) = _axes_info(grid) plt.xticks(xpts, xlabels) plt.yticks(ypts, ylabels) if method == 'imshow' and len(fig.axes) < 4: # Create colorbar if none seems to exist # Use clim from kwargs if given if 'clim' not in kwargs: minval_im, maxval_im = _safe_minmax(values.imag) else: minval_im, maxval_im = kwargs['clim'] ticks_im = _colorbar_ticks(minval_im, maxval_im) fmt_im = _colorbar_format(minval_im, maxval_im) plt.colorbar(csub_im, orientation='horizontal', ticks=ticks_im, format=fmt_im) else: if len(fig.axes) == 0: # Create new axis object if needed sub = plt.subplot(111, **sub_kwargs) sub.set_xlabel(axis_labels[0], fontsize=axis_fontsize) if values.ndim == 2: sub.set_ylabel(axis_labels[1], fontsize=axis_fontsize) else: sub.set_ylabel('value') try: # For 3d plots sub.set_zlabel('z') except AttributeError: pass else: sub = fig.axes[0] if update_in_place: import matplotlib as mpl imgs = [obj for obj in sub.get_children() if isinstance(obj, mpl.image.AxesImage)] if len(imgs) > 0 and updatefig: imgs[0].set_data(args_re[0]) csub = imgs[0] # Update min-max if 'clim' not in kwargs: minval, maxval = _safe_minmax(values) else: minval, maxval = kwargs['clim'] csub.set_clim(minval, maxval) else: display = getattr(sub, method) csub = display(*args_re, **dsp_kwargs) else: display = getattr(sub, method) csub = display(*args_re, **dsp_kwargs) # Axis ticks if method == 'imshow' and not grid.is_uniform: (xpts, xlabels), (ypts, ylabels) = _axes_info(grid) plt.xticks(xpts, xlabels) plt.yticks(ypts, ylabels) if method == 'imshow' and colorbar: # Add colorbar # Use clim from kwargs if given if 'clim' not in kwargs: minval, maxval = _safe_minmax(values) else: minval, maxval = kwargs['clim'] ticks = _colorbar_ticks(minval, maxval) fmt = _colorbar_format(minval, maxval) if len(fig.axes) < 2: # Create colorbar if none seems to exist plt.colorbar(mappable=csub, ticks=ticks, format=fmt) elif update_in_place: # If it exists and we should update it csub.colorbar.set_clim(minval, maxval) csub.colorbar.set_ticks(ticks) if '%' not in fmt: labels = [fmt] * len(ticks) else: labels = [fmt % t for t in ticks] csub.colorbar.set_ticklabels(labels) csub.colorbar.draw_all() # Set title of window if title is not None: if not values_are_complex: # Do not overwrite title for complex values plt.title(title) fig.canvas.manager.set_window_title(title) # Fixes overlapping stuff at the expense of potentially squashed subplots if not update_in_place: fig.tight_layout() if updatefig or plt.isinteractive(): # If we are running in interactive mode, we can always show the fig # This causes an artifact, where users of `CallbackShow` without # interactive mode only shows the figure after the second iteration. plt.show(block=False) if not update_in_place: plt.draw() warning_free_pause() else: try: sub.draw_artist(csub) fig.canvas.blit(fig.bbox) fig.canvas.update() fig.canvas.flush_events() except AttributeError: plt.draw() warning_free_pause() if force_show: plt.show() if saveto is not None: fig.savefig(saveto) return fig
if __name__ == '__main__': run_doctests()