# Copyright 2014-2020 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""Callback objects for per-iterate actions in iterative methods."""
from __future__ import absolute_import, division, print_function
import contextlib
import copy
import os
import time
import warnings
from builtins import object
import numpy as np
from odl.util import signature_string
__all__ = ('Callback', 'CallbackStore', 'CallbackApply', 'CallbackPrintTiming',
'CallbackPrintIteration', 'CallbackPrint', 'CallbackPrintNorm',
'CallbackShow', 'CallbackSaveToDisk', 'CallbackSleep',
'CallbackShowConvergence', 'CallbackPrintHardwareUsage',
'CallbackProgressBar', 'save_animation')
[docs]class Callback(object):
"""Abstract base class for handling iterates of solvers."""
[docs] def __call__(self, iterate):
"""Apply the callback object to result.
Parameters
----------
iterate : `LinearSpaceElement`
Partial result after n iterations.
Returns
-------
None
"""
def __and__(self, other):
"""Return ``self & other``.
Compose callbacks, calls both in sequence.
Parameters
----------
other : callable
The other callback to compose with.
Returns
-------
result : `Callback`
A callback whose `__call__` method calls both constituents
`__call__`.
Examples
--------
>>> store = CallbackStore()
>>> iter = CallbackPrintIteration()
>>> store & iter
CallbackStore() & CallbackPrintIteration()
"""
return _CallbackAnd(self, other)
def __mul__(self, other):
"""Return ``self * other``.
Compose callback with operator, calls the callback after calling the
operator.
Parameters
----------
other : `Operator`
The operator to compose with.
Returns
-------
result : `Callback`
A callback whose `__call__` method calls first the operator, and
then applies the callback to the result.
Examples
--------
>>> r3 = odl.rn(3)
>>> callback = odl.solvers.CallbackPrint()
>>> operator = odl.ScalingOperator(r3, 2.0)
>>> composed_callback = callback * operator
>>> composed_callback([1, 2, 3])
rn(3).element([ 2., 4., 6.])
"""
return _CallbackCompose(self, other)
[docs] def reset(self):
"""Reset the callback to its initial state.
Should be overridden by subclasses.
"""
pass
def __repr__(self):
"""Return ``repr(self)``."""
return '{}()'.format(self.__class__.__name__)
class _CallbackAnd(Callback):
"""Callback used for combining several callbacks."""
def __init__(self, *callbacks):
"""Initialize a new instance.
Parameters
----------
callback1, ..., callbackN : callable
Callables to be called in sequence as listed.
"""
callbacks = [c if isinstance(c, Callback) else CallbackApply(c)
for c in callbacks]
self.callbacks = callbacks
def __call__(self, result):
"""Apply all callbacks to result."""
for p in self.callbacks:
p(result)
def reset(self):
"""Reset all callbacks to their initial state."""
for callback in self.callbacks:
callback.reset()
def __repr__(self):
"""Return ``repr(self)``."""
return ' & '.join('{!r}'.format(p) for p in self.callbacks)
class _CallbackCompose(Callback):
"""Callback used for the composition of a callback with an operator."""
def __init__(self, callback, operator):
"""Initialize a new instance.
Parameters
----------
callback : callable
The callback to call.
operator : `Operator`
Operator to apply before calling the callback.
"""
self.callback = callback
self.operator = operator
def __call__(self, result):
"""Apply the callback."""
self.callback(self.operator(result))
def reset(self):
"""Reset the internal callback to its initial state."""
self.callback.reset()
def __repr__(self):
"""Return ``repr(self)``.
Examples
--------
>>> r3 = odl.rn(3)
>>> callback = odl.solvers.CallbackPrint()
>>> operator = odl.ScalingOperator(r3, 2.0)
>>> callback * operator
CallbackPrint() * ScalingOperator(rn(3), 2.0)
"""
return '{!r} * {!r}'.format(self.callback, self.operator)
[docs]class CallbackStore(Callback):
"""Callback for storing all iterates of a solver.
Can optionally apply a function, for example the norm or calculating the
residual.
By default, calls the ``copy()`` method on the iterates before storing.
"""
[docs] def __init__(self, results=None, function=None, step=1):
"""Initialize a new instance.
Parameters
----------
results : list, optional
List in which to store the iterates.
Default: new list (``[]``)
function : callable, optional
Deprecated, use composition instead. See examples.
Function to be called on all incoming results before storage.
Default: copy
step : int, optional
Number of iterates between storing iterates.
Examples
--------
Store results as-is:
>>> callback = CallbackStore()
Provide list to store iterates in:
>>> results = []
>>> callback = CallbackStore(results=results)
Store the norm of the results:
>>> norm_function = lambda x: x.norm()
>>> callback = CallbackStore() * norm_function
"""
self.results = [] if results is None else results
self.function = function
if function is not None:
warnings.warn('`function` argument is deprecated and will be '
'removed in a future release. Use composition '
'instead. '
'See Examples in the documentation.',
DeprecationWarning)
self.step = int(step)
self.iter = 0
[docs] def __call__(self, result):
"""Append result to results list."""
if self.iter % self.step == 0:
if self.function:
self.results.append(self.function(result))
else:
self.results.append(copy.copy(result))
self.iter += 1
[docs] def reset(self):
"""Clear the results list."""
self.results = []
self.iter = 0
def __iter__(self):
"""Allow iteration over the results."""
return iter(self.results)
[docs] def __getitem__(self, index):
"""Return ``self[index]``.
Get iterates by index.
"""
return self.results[index]
def __len__(self):
"""Number of results stored."""
return len(self.results)
def __repr__(self):
"""Return ``repr(self)``."""
optargs = [('results', self.results, []),
('function', self.function, None),
('step', self.step, 1)]
inner_str = signature_string([], optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)
[docs]class CallbackApply(Callback):
"""Callback for applying a custom function to iterates."""
[docs] def __init__(self, function, step=1):
"""Initialize a new instance.
Parameters
----------
function : callable
Function to call on the current iterate.
step : int, optional
Number of iterates between applications of ``function``.
Examples
--------
By default, the function is called on each iterate:
>>> def func(x):
... print(np.max(x))
>>> callback = CallbackApply(func)
>>> x = odl.rn(3).element([1, 2, 3])
>>> callback(x)
3.0
>>> callback(x)
3.0
To apply only to each n-th iterate, supply ``step=n``:
>>> callback = CallbackApply(func, step=2)
>>> callback(x)
3.0
>>> callback(x) # no output
>>> callback(x) # next output
3.0
"""
assert callable(function)
self.function = function
self.step = int(step)
self.iter = 0
[docs] def __call__(self, result):
"""Apply function to result."""
if self.iter % self.step == 0:
self.function(result)
self.iter += 1
[docs] def reset(self):
"""Set `iter` to 0."""
self.iter = 0
def __str__(self):
"""Return ``str(self)``."""
return repr(self)
def __repr__(self):
"""Return ``repr(self)``."""
posargs = [self.function]
optargs = [('step', self.step, 1)]
inner_str = signature_string(posargs, optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)
[docs]class CallbackPrintIteration(Callback):
"""Callback for printing the iteration count."""
[docs] def __init__(self, fmt='iter = {}', step=1, **kwargs):
"""Initialize a new instance.
Parameters
----------
fmt : string, optional
Format string for the text to be printed. The text is printed as::
print(fmt.format(cur_iter_num))
where ``cur_iter_num`` is the current iteration number.
step : positive int, optional
Number of iterations between output.
Other Parameters
----------------
kwargs :
Key word arguments passed to the print function.
Examples
--------
Create simple callback that prints iteration count:
>>> callback = CallbackPrintIteration()
>>> callback(None)
iter = 0
>>> callback(None)
iter = 1
Create callback that every 2nd iterate prints iteration count with
a custom string:
>>> callback = CallbackPrintIteration(fmt='Current iter is {}.',
... step=2)
>>> callback(None)
Current iter is 0.
>>> callback(None) # prints nothing
>>> callback(None)
Current iter is 2.
"""
self.fmt = str(fmt)
self.step = int(step)
self.iter = 0
self.kwargs = kwargs
[docs] def __call__(self, _):
"""Print the current iteration."""
if self.iter % self.step == 0:
print(self.fmt.format(self.iter), **self.kwargs)
self.iter += 1
[docs] def reset(self):
"""Set `iter` to 0."""
self.iter = 0
def __repr__(self):
"""Return ``repr(self)``.
Examples
--------
>>> CallbackPrintIteration(fmt='Current iter is {}.', step=2)
CallbackPrintIteration(fmt='Current iter is {}.', step=2)
"""
optargs = [('fmt', self.fmt, 'iter = {}'),
('step', self.step, 1)]
inner_str = signature_string([], optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)
[docs]class CallbackPrintTiming(Callback):
"""Callback for printing the time elapsed since the previous iteration."""
[docs] def __init__(self, fmt='Time elapsed = {:<5.03f} s', step=1,
cumulative=False, **kwargs):
"""Initialize a new instance.
Parameters
----------
fmt : string, optional
Formating that should be applied. The time is printed as ::
print(fmt.format(runtime))
where ``runtime`` is the runtime since the last iterate.
step : positive int, optional
Number of iterations between prints.
cumulative : boolean, optional
Print the time since the initialization instead of the last call.
Other Parameters
----------------
kwargs :
Key word arguments passed to the print function.
"""
self.fmt = str(fmt)
self.step = int(step)
self.iter = 0
self.cumulative = cumulative
self.start_time = time.time()
self.kwargs = kwargs
[docs] def __call__(self, _):
"""Print time elapsed from the previous iteration."""
if self.iter % self.step == 0:
current_time = time.time()
print(self.fmt.format(current_time - self.start_time),
**self.kwargs)
if not self.cumulative:
self.start_time = current_time
self.iter += 1
[docs] def reset(self):
"""Set `time` to the current time."""
self.start_time = time.time()
self.iter = 0
def __repr__(self):
"""Return ``repr(self)``."""
optargs = [('fmt', self.fmt, 'Time elapsed = {:<5.03f} s'),
('step', self.step, 1),
('cumulative', self.cumulative, False)]
inner_str = signature_string([], optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)
[docs]class CallbackPrint(Callback):
"""Callback for printing the current value."""
[docs] def __init__(self, func=None, fmt='{!r}', step=1, **kwargs):
"""Initialize a new instance.
Parameters
----------
func : callable, optional
Deprecated, use composition instead. See examples.
Functional that should be called on the current iterate before
printing. Default: print current iterate.
fmt : string, optional
Formating that should be applied. Will be used as ::
print(fmt.format(x))
where ``x`` is the input to the callback.
step : positive int, optional
Number of iterations between prints.
Other Parameters
----------------
kwargs :
Key word arguments passed to the print function.
Examples
--------
Callback for simply printing the current iterate:
>>> callback = CallbackPrint()
>>> callback([1, 2])
[1, 2]
Apply function before printing via composition:
>>> callback = CallbackPrint() * np.sum
>>> callback([1, 2])
3
Format to two decimal points:
>>> callback = CallbackPrint(fmt='{0:.2f}') * np.sum
>>> callback([1, 2])
3.00
"""
self.func = func
if func is not None:
warnings.warn('`func` argument is deprecated and will be removed '
'in a future release. Use composition instead. '
'See Examples in the documentation.',
DeprecationWarning)
if func is not None and not callable(func):
raise TypeError('`func` must be `callable` or `None`')
self.fmt = str(fmt)
self.step = int(step)
self.iter = 0
self.kwargs = kwargs
[docs] def __call__(self, result):
"""Print the current value."""
if self.iter % self.step == 0:
if self.func is not None:
result = self.func(result)
print(self.fmt.format(result), **self.kwargs)
self.iter += 1
[docs] def reset(self):
"""Set `iter` to 0."""
self.iter = 0
def __repr__(self):
"""Return ``repr(self)``."""
optargs = [('func', self.func, None),
('fmt', self.fmt, '{!r}'),
('step', self.step, 1)]
inner_str = signature_string([], optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)
[docs]class CallbackPrintNorm(Callback):
"""Callback for printing the current norm."""
[docs] def __call__(self, result):
"""Print the current norm."""
print("norm = {}".format(result.norm()))
def __repr__(self):
"""Return ``repr(self)``."""
return '{}()'.format(self.__class__.__name__)
[docs]class CallbackShow(Callback):
"""Callback for showing iterates.
See Also
--------
odl.discr.discr_space.DiscretizedSpaceElement.show
odl.space.base_tensors.Tensor.show
"""
[docs] def __init__(self, title=None, step=1, saveto=None, **kwargs):
"""Initialize a new instance.
Additional parameters are passed through to the ``show`` method.
Parameters
----------
title : str, optional
Format string for the title of the displayed figure.
The title name is generated as ::
title = title.format(cur_iter_num)
where ``cur_iter_num`` is the current iteration number.
For the default ``None``, the title format ``'Iterate {}'``
is used.
step : positive int, optional
Number of iterations between plots.
saveto : str or callable, optional
Format string for the name of the file(s) where
iterates are saved.
If ``saveto`` is a string, the file name is generated as ::
filename = saveto.format(cur_iter_num)
where ``cur_iter_num`` is the current iteration number.
If ``saveto`` is a callable, the file name is generated as ::
filename = saveto(cur_iter_num)
If the directory name does not exist, a ``ValueError`` is raised.
If ``saveto is None``, the figures are not saved.
Other Parameters
----------------
kwargs :
Optional keyword arguments passed on to ``x.show``.
Examples
--------
Show the result of each iterate:
>>> callback = CallbackShow()
Show and save every fifth iterate in ``png`` format, overwriting the
previous one:
>>> callback = CallbackShow(step=5,
... saveto='my_path/my_iterate.png')
Show and save each fifth iterate in ``png`` format, indexing the files
with the iteration number:
>>> callback = CallbackShow(step=5,
... saveto='my_path/my_iterate_{}.png')
Pass additional arguments to ``show``:
>>> callback = CallbackShow(step=5, clim=[0, 1])
"""
if title is None:
self.title = 'Iterate {}'
else:
self.title = str(title)
self.title_formatter = self.title.format
self.saveto = saveto
self.saveto_formatter = getattr(self.saveto, 'format', self.saveto)
self.step = step
self.fig = kwargs.pop('fig', None)
self.iter = 0
self.space_of_last_x = None
self.kwargs = kwargs
[docs] def __call__(self, x):
"""Show the current iterate."""
# Check if we should update the figure in-place
x_space = x.space
update_in_place = (self.space_of_last_x == x_space)
self.space_of_last_x = x_space
if self.iter % self.step == 0:
title = self.title_formatter(self.iter)
if self.saveto is None:
self.fig = x.show(title, fig=self.fig,
update_in_place=update_in_place,
**self.kwargs)
else:
saveto = self.saveto_formatter(self.iter)
self.fig = x.show(title, fig=self.fig,
update_in_place=update_in_place,
saveto=saveto, **self.kwargs)
self.iter += 1
[docs] def reset(self):
"""Set `iter` to 0 and create a new figure."""
self.iter = 0
self.fig = None
self.space_of_last_x = None
def __repr__(self):
"""Return ``repr(self)``."""
posargs = []
if self.title != 'Iterate {}':
posargs.append(self.title)
optargs = [('step', self.step, 1),
('saveto', self.saveto, None)]
for kwarg, value in self.kwargs.items():
optargs.append((kwarg, value, None))
inner_str = signature_string(posargs, optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)
[docs]class CallbackSaveToDisk(Callback):
"""Callback for saving iterates to disk."""
[docs] def __init__(self, saveto, step=1, impl='pickle', **kwargs):
"""Initialize a new instance.
Parameters
----------
saveto : string
Format string for the name of the file(s) where
iterates are saved. The file name is generated as
filename = saveto.format(cur_iter_num)
where ``cur_iter_num`` is the current iteration number.
step : positive int, optional
Number of iterations between saves.
impl : {'pickle', 'numpy', 'numpy_txt'}, optional
The format to store the iterates in. Numpy formats are only usable
if the data can be converted to an array via `numpy.asarray`.
Other Parameters
----------------
kwargs :
Optional arguments passed to the save function.
Examples
--------
Store each iterate:
>>> callback = CallbackSaveToDisk('my_path/my_iterate')
Save every fifth overwriting the previous one:
>>> callback = CallbackSaveToDisk(saveto='my_path/my_iterate',
... step=5)
Save each fifth iterate in ``numpy`` format, indexing the files with
the iteration number:
>>> callback = CallbackSaveToDisk(saveto='my_path/my_iterate_{}',
... step=5, impl='numpy')
"""
self.saveto = saveto
try:
self.saveto_formatter = self.saveto.format
except AttributeError:
self.saveto_formatter = self.saveto
self.step = int(step)
self.impl = str(impl).lower()
self.kwargs = kwargs
self.iter = 0
[docs] def __call__(self, x):
"""Save the current iterate."""
if self.iter % self.step == 0:
file_path = self.saveto_formatter(self.iter)
folder_path = os.path.dirname(os.path.realpath(file_path))
if not os.path.exists(folder_path):
os.makedirs(folder_path)
if self.impl == 'pickle':
import pickle
with open(file_path, 'wb+') as f:
pickle.dump(x, f, **self.kwargs)
elif self.impl == 'numpy':
np.save(file_path, np.asarray(x), **self.kwargs)
elif self.impl == 'numpy_txt':
np.savetxt(file_path, np.asarray(x), **self.kwargs)
else:
raise RuntimeError('unknown `impl` {}'.format(self.impl))
self.iter += 1
[docs] def reset(self):
"""Set `iter` to 0."""
self.iter = 0
def __repr__(self):
"""Return ``repr(self)``."""
posargs = [self.saveto]
optargs = [('step', self.step, 1),
('impl', self.impl, 'pickle')]
for kwarg, value in self.kwargs.items():
optargs.append((kwarg, value, None))
inner_str = signature_string(posargs, optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)
[docs]class CallbackSleep(Callback):
"""Callback for sleeping for a specific time span."""
[docs] def __init__(self, seconds=1.0):
"""Initialize a new instance.
Parameters
----------
seconds : float, optional
Number of seconds to sleep, can be float for subsecond precision.
Examples
--------
Sleep 1 second between consecutive iterates:
>>> callback = CallbackSleep(seconds=1)
Sleep 10 ms between consecutive iterate:
>>> callback = CallbackSleep(seconds=0.01)
"""
self.seconds = float(seconds)
[docs] def __call__(self, x):
"""Sleep for a specified time."""
time.sleep(self.seconds)
def __repr__(self):
"""Return ``repr(self)``."""
optargs = [('seconds', self.seconds, 1.0)]
inner_str = signature_string([], optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)
[docs]class CallbackShowConvergence(Callback):
"""Displays a convergence plot."""
[docs] def __init__(self, functional, title='convergence', logx=False, logy=False,
**kwargs):
"""Initialize a new instance.
Parameters
----------
functional : callable
Function that is called with the current iterate and returns the
function value.
title : str, optional
Title of the plot.
logx : bool, optional
If true, the x axis is logarithmic.
logx : bool, optional
If true, the y axis is logarithmic.
Other Parameters
----------------
kwargs :
Additional parameters passed to the scatter-plotting function.
"""
self.functional = functional
self.title = title
self.logx = logx
self.logy = logy
self.kwargs = kwargs
self.iter = 0
import matplotlib.pyplot as plt
self.fig = plt.figure(title)
self.ax = self.fig.add_subplot(111)
self.ax.set_xlabel('iteration')
self.ax.set_ylabel('function value')
self.ax.set_title(title)
if logx:
self.ax.set_xscale("log", nonposx='clip')
if logy:
self.ax.set_yscale("log", nonposy='clip')
[docs] def __call__(self, x):
"""Implement ``self(x)``."""
if self.logx:
it = self.iter + 1
else:
it = self.iter
self.ax.scatter(it, self.functional(x), **self.kwargs)
self.iter += 1
[docs] def reset(self):
"""Set `iter` to 0."""
self.iter = 0
def __repr__(self):
"""Return ``repr(self)``."""
return '{}(functional={}, title={}, logx={}, logy={})'.format(
self.__class__.__name__,
self.functional,
self.title,
self.logx,
self.logy)
[docs]class CallbackPrintHardwareUsage(Callback):
"""Callback for printing memory and CPU usage.
This callback requires the ``psutil`` package.
"""
[docs] def __init__(self, step=1, fmt_cpu='CPU usage (% each core): {}',
fmt_mem='RAM usage: {}', fmt_swap='SWAP usage: {}', **kwargs):
"""Initialize a new instance.
Parameters
----------
step : positive int, optional
Number of iterations between output.
fmt_cpu : string, optional
Formating that should be applied. The CPU usage is printed as ::
print(fmt_cpu.format(cpu))
where ``cpu`` is a vector with the percentage of current CPU usaged
for each core. An empty format string disables printing of CPU
usage.
fmt_mem : string, optional
Formating that should be applied. The RAM usage is printed as ::
print(fmt_mem.format(mem))
where ``mem`` is the current RAM memory usaged. An empty format
string disables printing of RAM memory usage.
fmt_swap : string, optional
Formating that should be applied. The SWAP usage is printed as ::
print(fmt_swap.format(swap))
where ``swap`` is the current SWAP memory usaged. An empty format
string disables printing of SWAP memory usage.
Other Parameters
----------------
kwargs :
Key word arguments passed to the print function.
Examples
--------
Print memory and CPU usage
>>> callback = CallbackPrintHardwareUsage()
Only print every tenth step
>>> callback = CallbackPrintHardwareUsage(step=10)
Only print the RAM memory usage in every step, and with a non-default
formatting
>>> callback = CallbackPrintHardwareUsage(step=1, fmt_cpu='',
... fmt_mem='RAM {}',
... fmt_swap='')
"""
self.step = int(step)
self.fmt_cpu = str(fmt_cpu)
self.fmt_mem = str(fmt_mem)
self.fmt_swap = str(fmt_swap)
self.iter = 0
[docs] def __call__(self, _):
"""Print the memory and CPU usage"""
import psutil
if self.iter % self.step == 0:
if self.fmt_cpu:
print(self.fmt_cpu.format(psutil.cpu_percent(percpu=True)),
**self.kwargs)
if self.fmt_mem:
print(self.fmt_mem.format(psutil.virtual_memory()),
**self.kwargs)
if self.fmt_swap:
print(self.fmt_swap.format(psutil.swap_memory()),
**self.kwargs)
self.iter += 1
[docs] def reset(self):
"""Set `iter` to 0."""
self.iter = 0
def __repr__(self):
"""Return ``repr(self)``."""
optargs = [('step', self.step, 1),
('fmt_cpu', self.fmt_cpu, 'CPU usage (% each core): {}'),
('fmt_mem', self.fmt_mem, 'RAM usage: {}'),
('fmt_swap', self.fmt_swap, 'SWAP usage: {}')]
inner_str = signature_string([], optargs)
return '{}({})'.format(self.__class__.__name__, inner_str)
[docs]class CallbackProgressBar(Callback):
"""Callback for displaying a progress bar.
This callback requires the ``tqdm`` package.
"""
[docs] def __init__(self, niter, step=1, **kwargs):
"""Initialize a new instance.
Parameters
----------
niter : positive int, optional
Total number of iterations.
step : positive int, optional
Number of iterations between output.
Other Parameters
----------------
kwargs :
Further parameters passed to ``tqdm.tqdm``.
"""
self.niter = int(niter)
self.step = int(step)
self.kwargs = kwargs
self.reset()
[docs] def __call__(self, _):
"""Update the progressbar."""
self.iter += 1
if self.iter % self.step == 0:
self.pbar.update(self.step)
[docs] def reset(self):
"""Set `iter` to 0."""
import tqdm
self.iter = 0
self.pbar = tqdm.tqdm(total=self.niter, **self.kwargs)
def __repr__(self):
"""Return ``repr(self)``."""
posargs = [self.niter]
optargs = [('step', self.step, 1)]
inner_str = signature_string(posargs, optargs)
if self.kwargs:
return '{}({}, **{})'.format(self.__class__.__name__,
inner_str, self.kwargs)
else:
return '{}({})'.format(self.__class__.__name__,
inner_str)
[docs]@contextlib.contextmanager
def save_animation(filename,
writer=None,
writer_kwargs=None,
dpi=None,
saving_kwargs=None,
fig=None,
step=1):
"""Context manager for creating animations from a series of plots.
The context manager uses `matplotlib.animation` to generate the
animation.
Parameters
----------
filename : str
Name of the generated output file.
writer : str
Back-end for generating the movie file. Available writers can be
checked with the command ``matplotlib.animation.writers.list()``.
See the `matplotlib animation writers doc`_ for details.
For the default ``None``, the first writer from the list of available
ones is chosen.
writer_kwargs : dict
Keyword arguments passed to the writer class constructor.
See the `matplotlib animation writers doc`_ for details.
dpi : float, optional
Resolution of the saved frames in DPI. For ``None``, the figure
resolution ``fig.dpi`` is used, which is the default resolution if
``fig is None``.
saving_kwargs : dict
Keyword arguments passed to the ``saving`` method of the writer
instance. See the `matplotlib animation writers doc`_ for details.
fig : matplotlib.figure.Figure, optional
Matplotlib figure used for plotting. For the default ``None``, a new
figure is created.
step : positive int, optional
Number of iterations between frames.
.. _matplotlib animation writers doc:
https://matplotlib.org/api/animation_api.html#writer-classes
"""
import matplotlib.animation
import matplotlib.pyplot as plt
if writer_kwargs is None:
writer_kwargs = {}
if saving_kwargs is None:
saving_kwargs = {}
if writer is None:
try:
writer = matplotlib.animation.writers.list()[0]
except IndexError:
raise RuntimeError('no animation writer available')
writer_cls = matplotlib.animation.writers[writer]
moviewriter = writer_cls(**writer_kwargs)
if fig is None:
fig, ax = plt.subplots(dpi=dpi)
else:
ax = fig.axes[-1]
dpi = fig.dpi if dpi is None else float(dpi)
it = 0
class CallbackAppendMovieFrame(Callback):
"""Callback for appending frames to an animation."""
@staticmethod
def _update_plot_2d(x):
"""Helper for updating a 2D plot (``imshow``)."""
try:
im = ax.get_images()[-1]
except IndexError:
im = ax.imshow(x)
im.set_array(x)
def __call__(self, x):
"""Implement ``self(x)``."""
if not hasattr(x, 'ndim') or x.ndim != 2:
raise TypeError(
'input must be `ndim` and `shape` attributes, but got '
'input of type {}'.format(type(x).__name__)
)
if x.ndim == 2:
update_plot = self._update_plot_2d
else:
raise NotImplementedError(
'currently only 2D plots (`imshow`) are supported'
)
if it % step == 0:
update_plot(x)
moviewriter.grab_frame()
with moviewriter.saving(fig, filename, dpi=dpi, **saving_kwargs):
yield CallbackAppendMovieFrame()
if __name__ == '__main__':
from odl.util.testutils import run_doctests
run_doctests()