BacktrackingLineSearch¶
- class odl.solvers.util.steplen.BacktrackingLineSearch(function, tau=0.5, discount=0.01, alpha=1.0, max_num_iter=None, estimate_step=False)[source]¶
Bases:
LineSearch
Backtracking line search for step length calculation.
This methods approximately finds the longest step length fulfilling the Armijo-Goldstein condition.
The line search algorithm is described in [BV2004], page 464 (book available online) and [GNS2009], pages 378--379. See also Backtracking_line_search.
References
[BV2004] Boyd, S, and Vandenberghe, L. Convex optimization. Cambridge university press, 2004.
[GNS2009] Griva, I, Nash, S G, and Sofer, A. Linear and nonlinear optimization. Siam, 2009.
Methods
__call__
(x, direction[, dir_derivative])Calculate the optimal step length along a line.
- __init__(function, tau=0.5, discount=0.01, alpha=1.0, max_num_iter=None, estimate_step=False)[source]¶
Initialize a new instance.
- Parameters:
- functioncallable
The cost function of the optimization problem to be solved. If
function
is not aFunctional
, calling this class later requires a value for thedir_derivative
argument.- taufloat, optional
The amount the step length is decreased in each iteration, as long as it does not fulfill the decrease condition. The step length is updated as
step_length *= tau
.- discountfloat, optional
The "discount factor" on
step length * direction derivative
, yielding the threshold under which the function value must lie to be accepted (see the references).- alphafloat, optional
The initial guess for the step length.
- max_num_iterint, optional
Maximum number of iterations allowed each time the line search method is called. If
None
, this number is calculated to allow a shortest step length of 10 times machine epsilon.- estimate_stepbool, optional
If the last step should be used as a estimate for the next step.
Examples
Create line search
>>> r3 = odl.rn(3) >>> func = odl.solvers.L2NormSquared(r3) >>> line_search = BacktrackingLineSearch(func)
Find step in point x and direction d that decreases the function value.
>>> x = r3.element([1, 2, 3]) >>> d = r3.element([-1, -1, -1]) >>> step_len = line_search(x, d) >>> step_len 1.0 >>> func(x + step_len * d) < func(x) True
Also works with non-functionals as arguments, but then the dir_derivative argument is mandatory
>>> r3 = odl.rn(3) >>> func = lambda x: x[0] ** 2 + x[1] ** 2 + x[2] ** 2 >>> line_search = BacktrackingLineSearch(func) >>> x = r3.element([1, 2, 3]) >>> d = r3.element([-1, -1, -1]) >>> dir_derivative = -12 >>> step_len = line_search(x, d, dir_derivative=dir_derivative) >>> step_len 1.0 >>> func(x + step_len * d) < func(x) True