Source code for odl.solvers.nonsmooth.proximal_gradient_solvers

# Copyright 2014-2019 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""(Fast) Iterative shrinkage-thresholding algorithm."""

from __future__ import print_function, division, absolute_import
import numpy as np


__all__ = ('proximal_gradient', 'accelerated_proximal_gradient')


[docs]def proximal_gradient(x, f, g, gamma, niter, callback=None, **kwargs): r"""(Accelerated) proximal gradient algorithm for convex optimization. Also known as "Iterative Soft-Thresholding Algorithm" (ISTA). See `[Beck2009]`_ for more information. This solver solves the convex optimization problem:: min_{x in X} f(x) + g(x) where the proximal operator of ``f`` is known and ``g`` is differentiable. Parameters ---------- x : ``f.domain`` element Starting point of the iteration, updated in-place. f : `Functional` The function ``f`` in the problem definition. Needs to have ``f.proximal``. g : `Functional` The function ``g`` in the problem definition. Needs to have ``g.gradient``. gamma : positive float Step size parameter. niter : non-negative int, optional Number of iterations. callback : callable, optional Function called with the current iterate after each iteration. Other Parameters ---------------- lam : float or callable, optional Overrelaxation step size. If callable, it should take an index (starting at zero) and return the corresponding step size. Default: 1.0 Notes ----- The problem of interest is .. math:: \min_{x \in X} f(x) + g(x), where the formal conditions are that :math:`f : X \to \mathbb{R}` is proper, convex and lower-semicontinuous, and :math:`g : X \to \mathbb{R}` is differentiable and :math:`\nabla g` is :math:`1 / \beta`-Lipschitz continuous. Convergence is only guaranteed if the step length :math:`\gamma` satisfies .. math:: 0 < \gamma < 2 \beta and the parameter :math:`\lambda` (``lam``) satisfies .. math:: \sum_{k=0}^\infty \lambda_k (\delta - \lambda_k) = + \infty where :math:`\delta = \min \{1, \beta / \gamma\}`. References ---------- .. _[Beck2009]: http://epubs.siam.org/doi/abs/10.1137/080716542 """ # Get and validate input if x not in f.domain: raise TypeError('`x` {!r} is not in the domain of `f` {!r}' ''.format(x, f.domain)) if x not in g.domain: raise TypeError('`x` {!r} is not in the domain of `g` {!r}' ''.format(x, g.domain)) gamma, gamma_in = float(gamma), gamma if gamma <= 0: raise ValueError('`gamma` must be positive, got {}'.format(gamma_in)) if int(niter) != niter: raise ValueError('`niter` {} not understood'.format(niter)) lam_in = kwargs.pop('lam', 1.0) lam = lam_in if callable(lam_in) else lambda _: float(lam_in) # Get the proximal and gradient f_prox = f.proximal(gamma) g_grad = g.gradient # Create temporary tmp = x.space.element() for k in range(niter): lam_k = lam(k) # x - gamma grad_g (x) tmp.lincomb(1, x, -gamma, g_grad(x)) # Update x x.lincomb(1 - lam_k, x, lam_k, f_prox(tmp)) if callback is not None: callback(x)
[docs]def accelerated_proximal_gradient(x, f, g, gamma, niter, callback=None, **kwargs): r"""Accelerated proximal gradient algorithm for convex optimization. The method is known as "Fast Iterative Soft-Thresholding Algorithm" (FISTA). See `[Beck2009]`_ for more information. Solves the convex optimization problem:: min_{x in X} f(x) + g(x) where the proximal operator of ``f`` is known and ``g`` is differentiable. Parameters ---------- x : ``f.domain`` element Starting point of the iteration, updated in-place. f : `Functional` The function ``f`` in the problem definition. Needs to have ``f.proximal``. g : `Functional` The function ``g`` in the problem definition. Needs to have ``g.gradient``. gamma : positive float Step size parameter. niter : non-negative int, optional Number of iterations. callback : callable, optional Function called with the current iterate after each iteration. Notes ----- The problem of interest is .. math:: \min_{x \in X} f(x) + g(x), where the formal conditions are that :math:`f : X \to \mathbb{R}` is proper, convex and lower-semicontinuous, and :math:`g : X \to \mathbb{R}` is differentiable and :math:`\nabla g` is :math:`1 / \beta`-Lipschitz continuous. Convergence is only guaranteed if the step length :math:`\gamma` satisfies .. math:: 0 < \gamma < 2 \beta. References ---------- .. _[Beck2009]: http://epubs.siam.org/doi/abs/10.1137/080716542 """ # Get and validate input if x not in f.domain: raise TypeError('`x` {!r} is not in the domain of `f` {!r}' ''.format(x, f.domain)) if x not in g.domain: raise TypeError('`x` {!r} is not in the domain of `g` {!r}' ''.format(x, g.domain)) gamma, gamma_in = float(gamma), gamma if gamma <= 0: raise ValueError('`gamma` must be positive, got {}'.format(gamma_in)) if int(niter) != niter: raise ValueError('`niter` {} not understood'.format(niter)) # Get the proximal f_prox = f.proximal(gamma) g_grad = g.gradient # Create temporary tmp = x.space.element() y = x.copy() t = 1 for k in range(niter): # Update t t, t_old = (1 + np.sqrt(1 + 4 * t ** 2)) / 2, t alpha = (t_old - 1) / t # x - gamma grad_g (y) tmp.lincomb(1, y, -gamma, g_grad(y)) # Store old x value in y y.assign(x) # Update x f_prox(tmp, out=x) # Update y y.lincomb(1 + alpha, x, -alpha, y) if callback is not None: callback(x)
if __name__ == '__main__': from odl.util.testutils import run_doctests run_doctests()