22from __future__
import annotations
32 "DEFAULT_ADAPROX_FACTOR",
35from copy
import deepcopy
36from typing
import Any, Callable, Sequence, cast
39import numpy.typing
as npt
44DEFAULT_ADAPROX_FACTOR = 1e-2
48 """Wrapper to make a numerical step into a step function
53 The step to take for a given array.
58 The step function that takes an array and returns the
65 """A parameter in a `Component`
70 The array of values that is being fit.
72 A dictionary of helper arrays that are used by an optimizer to
73 persist values like the gradient of `x`, the Hessian of `x`, etc.
75 A numerical step value or function to calculate the step for a
78 A function to calculate the gradient of `x`.
80 A function to take the proximal operator of `x`.
86 helpers: dict[str, np.ndarray],
87 step: Callable | float,
88 grad: Callable |
None =
None,
89 prox: Callable |
None =
None,
94 if isinstance(step, float):
105 """Calculate the step
110 The numerical step if no iteration is given.
116 """The shape of the array that is being fit."""
121 """The numpy dtype of the array that is being fit."""
125 """Create a shallow copy of this parameter.
130 A shallow copy of this parameter.
132 helpers = {k: v.copy()
for k, v
in self.
helpers.items()}
135 def __deepcopy__(self, memo: dict[int, Any] |
None =
None) -> Parameter:
136 """Create a deep copy of this parameter.
141 A memoization dictionary used by `copy.deepcopy`.
145 A deep copy of this parameter.
147 helpers = {k: deepcopy(v, memo)
for k, v
in self.
helpers.items()}
148 return Parameter(deepcopy(self.
x, memo), helpers, 0)
150 def copy(self, deep: bool =
False) -> Parameter:
151 """Copy this parameter, including all of the helper arrays.
156 If `True`, a deep copy is made.
157 If `False`, a shallow copy is made.
162 A copy of this parameter.
168 def update(self, it: int, input_grad: np.ndarray, *args):
169 """Update the parameter in one iteration.
171 This includes the gradient update, proximal update,
172 and any meta parameters that are stored as class
173 attributes to update the parameter.
178 The current iteration
180 The gradient from the full model, passed to the parameter.
182 raise NotImplementedError(
"Base Parameters cannot be updated")
184 def resize(self, old_box: Box, new_box: Box):
185 """Grow the parameter and all of the helper parameters
190 The old bounding box for the parameter.
192 The new bounding box for the parameter.
194 slices = new_box.overlapped_slices(old_box)
195 x = np.zeros(new_box.shape, dtype=self.
dtype)
196 x[slices[0]] = self.
x[slices[1]]
199 for name, value
in self.
helpers.items():
200 result = np.zeros(new_box.shape, dtype=self.
dtype)
201 result[slices[0]] = value[slices[1]]
206 """Convert a `np.ndarray` into a `Parameter`.
211 The array or parameter to convert into a `Parameter`.
216 `x`, converted into a `Parameter` if necessary.
218 if isinstance(x, Parameter):
224 """A `Parameter` that updates itself using the Beck-Teboulle 2009
225 FISTA proximal gradient method.
227 See https://www.ceremade.dauphine.fr/~carlier/FISTA
234 grad: Callable |
None =
None,
235 prox: Callable |
None =
None,
237 z0: np.ndarray |
None =
None,
251 def update(self, it: int, input_grad: np.ndarray, *args):
252 """Update the parameter and meta-parameters using the PGM
254 See `Parameter` for the full description.
259 step = self.
step / np.sum(args[0] * args[0])
263 y = _z - step * cast(Callable, self.
grad)(input_grad, _x, *args)
264 if self.
prox is not None:
268 t = 0.5 * (1 + np.sqrt(1 + 4 * self.
t**2))
269 omega = 1 + (self.
t - 1) / t
270 self.
helpers[
"z"] = _x + omega * (x - _x)
274 def __deepcopy__(self, memo: dict[int, Any] |
None =
None) -> FistaParameter:
275 """Create a deep copy of this parameter.
280 A memoization dictionary used by `copy.deepcopy`.
284 A deep copy of this parameter.
287 deepcopy(self.
x, memo),
292 deepcopy(self.
helpers[
"z"], memo),
296 """Create a shallow copy of this parameter.
301 A shallow copy of this parameter.
323 m[:] = (1 - b1[it]) * g + b1[it] * m
324 v[:] = (1 - b2) * (g**2) + b2 * v
328 phi = m / (1 - b1[it] ** t)
329 psi = np.sqrt(v / (1 - b2**t)) + eps
336 m[:] = (1 - b1[it]) * g + b1[it] * m
337 v[:] = (1 - b2) * (g**2) + b2 * v
341 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t)
342 psi = np.sqrt(v / (1 - b2**t)) + eps
349 m[:] = (1 - b1[it]) * g + b1[it] * m
350 v[:] = (1 - b2) * (g**2) + b2 * v
353 vhat[:] = np.maximum(vhat, v)
356 vhat = np.maximum(vhat, eps)
363 m[:] = (1 - b1[it]) * g + b1[it] * m
364 v[:] = (1 - b2) * (g**2) + b2 * v
367 vhat[:] = np.maximum(vhat, v)
370 vhat = np.maximum(vhat, eps)
378 m[:] = (1 - b1[it]) * g + b1[it] * m
379 v[:] = (1 - b2) * (g**2) + b2 * v
382 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2
383 vhat[:] = np.maximum(factor * vhat, v)
386 vhat = np.maximum(vhat, eps)
393 rho_inf = 2 / (1 - b2) - 1
396 m[:] = (1 - b1[it]) * g + b1[it] * m
397 v[:] = (1 - b2) * (g**2) + b2 * v
401 phi = m / (1 - b1[it] ** t)
402 rho = rho_inf - 2 * t * b2**t / (1 - b2**t)
405 psi = np.sqrt(v / (1 - b2**t))
406 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho)
409 psi = np.ones(g.shape, g.dtype)
412 psi = np.maximum(psi, np.sqrt(eps))
418 "adam": _adam_phi_psi,
419 "nadam": _nadam_phi_psi,
420 "amsgrad": _amsgrad_phi_psi,
421 "padam": _padam_phi_psi,
422 "adamx": _adamx_phi_psi,
423 "radam": _radam_phi_psi,
428 """Mock an array with only a single item"""
438 """Operator updated using te Proximal ADAM algorithm
440 Uses multiple variants of adaptive quasi-Newton gradient descent
441 * Adam (Kingma & Ba 2015)
443 * AMSGrad (Reddi, Kale & Kumar 2018)
444 * PAdam (Chen & Gu 2018)
445 * AdamX (Phuong & Phong 2019)
446 * RAdam (Liu et al. 2019)
447 See details of the algorithms in the respective papers.
453 step: Callable | float,
454 grad: Callable |
None =
None,
455 prox: Callable |
None =
None,
456 b1: float | SingleItemArray = 0.9,
460 m0: np.ndarray |
None =
None,
461 v0: np.ndarray |
None =
None,
462 vhat0: np.ndarray |
None =
None,
463 scheme: str =
"amsgrad",
464 prox_e_rel: float = 1e-6,
469 m0 = np.zeros(shape, dtype=dtype)
472 v0 = np.zeros(shape, dtype=dtype)
475 vhat0 = np.ones(shape, dtype=dtype) * -np.inf
489 if isinstance(b1, float):
503 def update(self, it: int, input_grad: np.ndarray, *args):
504 """Update the parameter and meta-parameters using the PGM
506 See `~Parameter` for more.
510 grad = cast(Callable, self.
grad)(input_grad, _x, *args)
526 _x += -step * phi / psi
531 _x += -step * phi / psi / 10
533 self.
x = cast(Callable, self.
prox)(_x)
535 def __deepcopy__(self, memo: dict[int, Any] |
None =
None) -> AdaproxParameter:
536 """Create a deep copy of this parameter.
541 A memoization dictionary used by `copy.deepcopy`.
545 A deep copy of this parameter.
548 deepcopy(self.
x, memo),
556 deepcopy(self.
helpers[
"m"], memo),
557 deepcopy(self.
helpers[
"v"], memo),
558 deepcopy(self.
helpers[
"vhat"], memo),
560 prox_e_rel=self.
e_rel,
564 """Create a shallow copy of this parameter.
569 A shallow copy of this parameter.
584 prox_e_rel=self.
e_rel,
589 """A parameter that is not updated"""
594 def update(self, it: int, input_grad: np.ndarray, *args):
598 """Create a shallow copy of this parameter.
603 A shallow copy of this parameter.
607 def __deepcopy__(self, memo: dict[int, Any] |
None =
None) -> FixedParameter:
608 """Create a deep copy of this parameter.
613 A memoization dictionary used by `copy.deepcopy`.
618 A deep copy of this parameter.
627 axis: int | Sequence[int] |
None =
None,
629 """Step size set at `factor` times the mean of `X` in direction `axis`"""
630 return np.maximum(minimum, factor * x.mean(axis=axis))
update(self, int it, np.ndarray input_grad, *args)
__init__(self, np.ndarray x, Callable|float step, Callable|None grad=None, Callable|None prox=None, float|SingleItemArray b1=0.9, float b2=0.999, float eps=1e-8, float p=0.25, np.ndarray|None m0=None, np.ndarray|None v0=None, np.ndarray|None vhat0=None, str scheme="amsgrad", float prox_e_rel=1e-6)
AdaproxParameter __deepcopy__(self, dict[int, Any]|None memo=None)
AdaproxParameter __copy__(self)
FistaParameter __deepcopy__(self, dict[int, Any]|None memo=None)
__init__(self, np.ndarray x, float step, Callable|None grad=None, Callable|None prox=None, float t0=1, np.ndarray|None z0=None)
update(self, int it, np.ndarray input_grad, *args)
FistaParameter __copy__(self)
FixedParameter __deepcopy__(self, dict[int, Any]|None memo=None)
__init__(self, np.ndarray x)
update(self, int it, np.ndarray input_grad, *args)
FixedParameter __copy__(self)
resize(self, Box old_box, Box new_box)
__init__(self, np.ndarray x, dict[str, np.ndarray] helpers, Callable|float step, Callable|None grad=None, Callable|None prox=None)
Parameter __deepcopy__(self, dict[int, Any]|None memo=None)
npt.DTypeLike dtype(self)
Parameter copy(self, bool deep=False)
tuple[int,...] shape(self)
update(self, int it, np.ndarray input_grad, *args)
_padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
Callable step_function_wrapper(float step)
_amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
_adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)
Parameter parameter(np.ndarray|Parameter x)
relative_step(np.ndarray x, float factor=0.1, float minimum=0, int|Sequence[int]|None axis=None)
_adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p)