You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

356 lines
12 KiB

import math
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union, cast
from .backends import get_array_ops
from .config import registry
from .types import FloatsXd, Generator
KeyT = Tuple[int, str]
FloatOrSeq = Union[float, List[float], Generator]
IntOrSeq = Union[int, List[int], Generator]
SGD_DEFAULTS: Dict[str, Union[float, bool, int]] = {
"L2": 0.0,
"L2_is_weight_decay": True,
"grad_clip": 1.0,
}
ADAM_DEFAULTS: Dict[str, Union[float, bool, int]] = {
"learn_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"eps": 1e-08,
"L2": SGD_DEFAULTS["L2"],
"grad_clip": SGD_DEFAULTS["grad_clip"],
"L2_is_weight_decay": True,
}
@registry.optimizers("RAdam.v1")
def RAdam(
learn_rate: FloatOrSeq = ADAM_DEFAULTS["learn_rate"],
*,
beta1: FloatOrSeq = ADAM_DEFAULTS["beta1"],
beta2: FloatOrSeq = ADAM_DEFAULTS["beta2"],
eps: FloatOrSeq = ADAM_DEFAULTS["eps"],
L2: FloatOrSeq = ADAM_DEFAULTS["L2"],
L2_is_weight_decay: bool = cast(bool, ADAM_DEFAULTS["L2_is_weight_decay"]),
grad_clip: FloatOrSeq = ADAM_DEFAULTS["grad_clip"],
use_averages: bool = True,
):
return Optimizer(
learn_rate,
beta1=beta1,
beta2=beta2,
eps=eps,
grad_clip=grad_clip,
L2_is_weight_decay=L2_is_weight_decay,
L2=L2,
use_averages=use_averages,
use_radam=True,
)
@registry.optimizers("Adam.v1")
def Adam(
learn_rate: FloatOrSeq = ADAM_DEFAULTS["learn_rate"],
*,
L2: FloatOrSeq = ADAM_DEFAULTS["L2"],
beta1: FloatOrSeq = ADAM_DEFAULTS["beta1"],
beta2: FloatOrSeq = ADAM_DEFAULTS["beta2"],
eps: FloatOrSeq = ADAM_DEFAULTS["eps"],
grad_clip: FloatOrSeq = ADAM_DEFAULTS["grad_clip"],
L2_is_weight_decay: bool = cast(bool, ADAM_DEFAULTS["L2_is_weight_decay"]),
use_averages: bool = True,
):
return Optimizer(
learn_rate,
L2=L2,
beta1=beta1,
beta2=beta2,
eps=eps,
grad_clip=grad_clip,
L2_is_weight_decay=L2_is_weight_decay,
use_averages=use_averages,
use_radam=False,
)
@registry.optimizers("SGD.v1")
def SGD(
learn_rate: FloatOrSeq,
*,
L2: FloatOrSeq = SGD_DEFAULTS["L2"],
grad_clip: FloatOrSeq = SGD_DEFAULTS["grad_clip"],
L2_is_weight_decay: bool = cast(bool, SGD_DEFAULTS["L2_is_weight_decay"]),
use_averages: bool = True,
):
return Optimizer(
learn_rate,
L2=L2,
grad_clip=grad_clip,
L2_is_weight_decay=L2_is_weight_decay,
beta1=0.0,
beta2=0.0,
use_averages=use_averages,
)
class Optimizer(object):
"""Do various flavours of stochastic gradient descent, with first and
second order momentum. Currently support 'vanilla' SGD, Adam, and RAdam.
"""
mom1: Dict[KeyT, FloatsXd]
mom2: Dict[KeyT, FloatsXd]
averages: Optional[Dict[KeyT, FloatsXd]]
schedules: Dict[str, Generator]
nr_update: Dict[KeyT, int]
last_seen: Dict[KeyT, int]
grad_clip: float
learn_rate: float
b1: float
b2: float
eps: float
L2: float
use_radam: bool
L2_is_weight_decay: bool
_radam_buffer: List[List[Optional[FloatsXd]]]
# This "locks" the class, so we get an error if you try to assign to
# an unexpected variable.
__slots__ = [
"mom1",
"mom2",
"averages",
"schedules",
"nr_update",
"last_seen",
"grad_clip",
"learn_rate",
"b1",
"b2",
"eps",
"L2",
"use_radam",
"L2_is_weight_decay",
"_radam_buffer",
]
def __init__(
self,
learn_rate: FloatOrSeq,
*,
L2: FloatOrSeq = ADAM_DEFAULTS["L2"],
beta1: FloatOrSeq = ADAM_DEFAULTS["beta1"],
beta2: FloatOrSeq = ADAM_DEFAULTS["beta2"],
eps: FloatOrSeq = ADAM_DEFAULTS["eps"],
grad_clip: FloatOrSeq = ADAM_DEFAULTS["grad_clip"],
use_averages: bool = True,
use_radam: bool = False,
L2_is_weight_decay: bool = True,
):
"""
Initialize an optimizer.
learn_rate (float): The initial learning rate.
L2 (float): The L2 regularization term.
beta1 (float): First-order momentum.
beta2 (float): Second-order momentum.
eps (float): Epsilon term for Adam etc.
grad_clip (float): Gradient clipping.
use_averages (bool): Whether to track moving averages of the parameters.
use_radam (bool): Whether to use the RAdam optimizer.
L2_is_weight_decay (bool): Whether to interpret the L2 parameter as a
weight decay term, in the style of the AdamW optimizer.
"""
self.mom1 = {}
self.mom2 = {}
if use_averages:
self.averages = {}
else:
self.averages = None
self.schedules = {}
self.nr_update = defaultdict(int)
self.last_seen = defaultdict(int)
self._set_attr_or_schedule("grad_clip", grad_clip)
self._set_attr_or_schedule("learn_rate", learn_rate)
self._set_attr_or_schedule("b1", beta1)
self._set_attr_or_schedule("b2", beta2)
self._set_attr_or_schedule("eps", eps)
self._set_attr_or_schedule("L2", L2)
self.use_radam = use_radam
self.L2_is_weight_decay = L2_is_weight_decay
self._radam_buffer = [[None, None, None] for _ in range(10)]
def _set_attr_or_schedule(self, name, value):
if isinstance(value, (float, bool, int)):
setattr(self, name, value)
else:
if isinstance(value, list):
value = iter(value)
self.schedules[name] = value
try:
setattr(self, name, next(value))
except (StopIteration, TypeError) as e:
err = f"Invalid schedule for '{name}' ({type(value)})\n{e}"
raise ValueError(err)
def step_schedules(self):
for key, schedule in self.schedules.items():
try:
value = next(schedule)
except StopIteration: # schedule exhausted, use last value
value = getattr(self, key)
setattr(self, key, value)
def __call__(
self,
key: Tuple[int, str],
weights: FloatsXd,
gradient: FloatsXd,
*,
lr_scale: float = 1.0,
):
"""Call the optimizer with weights and a gradient. The key is the
identifier for the parameter, usually the node ID and parameter name.
"""
if len(gradient) < 1:
return weights, gradient
ops = get_array_ops(weights)
self.nr_update[key] += 1
nr_upd = self.nr_update[key]
if self.L2 != 0 and not self.L2_is_weight_decay:
gradient += self.L2 * weights
if self.grad_clip:
gradient = ops.clip_gradient(gradient, self.grad_clip)
if self.use_radam:
weights, gradient = self._radam(
ops, weights, gradient, lr_scale, key, nr_upd
)
elif self.b1 > 0.0 and self.b2 > 0.0:
weights, gradient = self._adam(
ops, weights, gradient, lr_scale, key, nr_upd
)
elif self.b2 > 0.0: # pragma: no cover
raise NotImplementedError # TODO: error message
else:
weights -= lr_scale * self.learn_rate * gradient
gradient *= 0
if self.L2 != 0 and self.L2_is_weight_decay:
weights -= lr_scale * self.learn_rate * self.L2 * weights
if self.averages is not None:
if key not in self.averages:
self.averages[key] = ops.alloc(weights.shape, dtype="float32")
ops.update_averages(self.averages[key], weights, nr_upd)
return weights, gradient
def _radam(self, ops, weights, grad, lr_scale, key, nr_upd):
if key not in self.mom1:
self.mom1[key] = ops.alloc1f(weights.size)
if key not in self.mom2:
self.mom2[key] = ops.alloc1f(weights.size)
weights_1D = ops.reshape1f(weights, weights.size)
gradient_1D = ops.reshape1f(grad, grad.size)
# While we port from the pytorch implementation, keep some of the same
# naming
state = {
"step": self.nr_update[key],
"exp_avg": self.mom1[key],
"exp_avg_sq": self.mom2[key],
}
group = {
"lr": self.learn_rate,
"betas": [self.b1, self.b2],
"eps": self.eps,
"weight_decay": 0.0,
"buffer": self._radam_buffer,
}
degenerated_to_sgd = True
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
# exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg_sq *= beta2
exp_avg_sq += (1 - beta2) * (gradient_1D**2)
# exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg *= beta1
exp_avg += (1 - beta1) * gradient_1D
state["step"] += 1
buffered = group["buffer"][int(state["step"] % 10)]
if state["step"] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state["step"]
beta2_t = beta2 ** state["step"]
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt(
(1 - beta2_t)
* (N_sma - 4)
/ (N_sma_max - 4)
* (N_sma - 2)
/ N_sma
* N_sma_max
/ (N_sma_max - 2)
) / (1 - beta1 ** state["step"])
elif degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state["step"])
else:
step_size = -1
buffered[2] = step_size
# more conservative since it's an approximated value
if N_sma >= 5:
if group["weight_decay"] != 0:
weights_1D += -group["weight_decay"] * group["lr"] * weights_1D
denom = ops.xp.sqrt(exp_avg_sq) + group["eps"]
weights_1D += -step_size * group["lr"] * (exp_avg / denom)
elif step_size > 0:
if group["weight_decay"] != 0:
weights_1D += -group["weight_decay"] * group["lr"] * weights_1D
weights_1D += -step_size * group["lr"] * exp_avg
return (
ops.reshape_f(weights_1D, weights.shape),
ops.reshape_f(gradient_1D, grad.shape),
)
def _adam(self, ops, weights, gradient, lr_scale, key, nr_upd):
weights_1D = ops.reshape1f(weights, weights.size)
gradient_1D = ops.reshape1f(gradient, gradient.size)
if key not in self.mom1:
self.mom1[key] = ops.alloc1f(weights.size)
if key not in self.mom2:
self.mom2[key] = ops.alloc1f(weights.size)
mom1 = self.mom1[key]
mom2 = self.mom2[key]
b1 = self.b1
b2 = self.b2
fix1 = 1.0 - (b1**nr_upd)
fix2 = 1.0 - (b2**nr_upd)
lr = self.learn_rate * fix2**0.5 / fix1
eps = self.eps
# needs to be 1D going into the adam function
weights_1D, gradient_1D, mom1, mom2 = ops.adam(
weights_1D, gradient_1D, mom1, mom2, b1, b2, eps, lr * lr_scale
)
self.mom1[key] = mom1
self.mom2[key] = mom2
return (
ops.reshape_f(weights_1D, weights.shape),
ops.reshape_f(gradient_1D, gradient.shape),
)
__all__ = ["Adam", "RAdam", "SGD", "Optimizer", "ADAM_DEFAULTS", "SGD_DEFAULTS"]