You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.7 KiB

from typing import Callable, cast
import numpy
from .backends import Ops
from .config import registry
from .types import FloatsXd, Shape
from .util import partial
# TODO: Harmonize naming with Keras, and fill in missing entries
# https://keras.io/initializers/ We should also have He normal/uniform
# and probably lecun normal/uniform.
# Initialize via numpy, before copying to ops. This makes it easier to work with
# the different backends, because the backend won't affect the randomization.
def lecun_normal_init(ops: Ops, shape: Shape) -> FloatsXd:
scale = numpy.sqrt(1.0 / shape[1])
return ops.asarray_f(cast(FloatsXd, numpy.random.normal(0, scale, shape)))
@registry.initializers("lecun_normal_init.v1")
def configure_lecun_normal_init() -> Callable[[Shape], FloatsXd]:
return partial(lecun_normal_init)
def he_normal_init(ops: Ops, shape: Shape) -> FloatsXd:
scale = numpy.sqrt(2.0 / shape[1])
return ops.asarray_f(cast(FloatsXd, numpy.random.normal(0, scale, shape)))
@registry.initializers("he_normal_init.v1")
def configure_he_normal_init() -> Callable[[Shape], FloatsXd]:
return partial(he_normal_init)
def glorot_normal_init(ops: Ops, shape: Shape) -> FloatsXd:
scale = numpy.sqrt(2.0 / (shape[1] + shape[0]))
return ops.asarray_f(cast(FloatsXd, numpy.random.normal(0, scale, shape)))
@registry.initializers("glorot_normal_init.v1")
def configure_glorot_normal_init() -> Callable[[Shape], FloatsXd]:
return partial(glorot_normal_init)
def he_uniform_init(ops: Ops, shape: Shape) -> FloatsXd:
scale = numpy.sqrt(6.0 / shape[1])
return ops.asarray_f(cast(FloatsXd, numpy.random.uniform(-scale, scale, shape)))
@registry.initializers("he_uniform_init.v1")
def configure_he_uniform_init() -> Callable[[Shape], FloatsXd]:
return partial(he_uniform_init)
def lecun_uniform_init(ops: Ops, shape: Shape) -> FloatsXd:
scale = numpy.sqrt(3.0 / shape[1])
return ops.asarray_f(cast(FloatsXd, numpy.random.uniform(-scale, scale, shape)))
@registry.initializers("lecun_uniform_init.v1")
def configure_lecun_uniform_init() -> Callable[[Shape], FloatsXd]:
return partial(lecun_uniform_init)
def glorot_uniform_init(ops: Ops, shape: Shape) -> FloatsXd:
scale = numpy.sqrt(6.0 / (shape[0] + shape[1]))
return ops.asarray_f(cast(FloatsXd, numpy.random.uniform(-scale, scale, shape)))
@registry.initializers("glorot_uniform_init.v1")
def configure_glorot_uniform_init() -> Callable[[Shape], FloatsXd]:
return partial(glorot_uniform_init)
def zero_init(ops: Ops, shape: Shape) -> FloatsXd:
return ops.alloc_f(shape)
@registry.initializers("zero_init.v1")
def configure_zero_init() -> Callable[[FloatsXd], FloatsXd]:
return partial(zero_init)
def uniform_init(
ops: Ops, shape: Shape, *, lo: float = -0.1, hi: float = 0.1
) -> FloatsXd:
values = numpy.random.uniform(lo, hi, shape)
return ops.asarray_f(cast(FloatsXd, values.astype("float32")))
@registry.initializers("uniform_init.v1")
def configure_uniform_init(
*, lo: float = -0.1, hi: float = 0.1
) -> Callable[[FloatsXd], FloatsXd]:
return partial(uniform_init, lo=lo, hi=hi)
def normal_init(ops: Ops, shape: Shape, *, mean: float = 0) -> FloatsXd:
size = int(ops.xp.prod(ops.xp.asarray(shape)))
inits = cast(FloatsXd, numpy.random.normal(scale=mean, size=size).astype("float32"))
inits = ops.reshape_f(inits, shape)
return ops.asarray_f(inits)
@registry.initializers("normal_init.v1")
def configure_normal_init(*, mean: float = 0) -> Callable[[FloatsXd], FloatsXd]:
return partial(normal_init, mean=mean)
__all__ = [
"normal_init",
"uniform_init",
"glorot_uniform_init",
"zero_init",
"lecun_uniform_init",
"he_uniform_init",
"glorot_normal_init",
"he_normal_init",
"lecun_normal_init",
]