You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
2.7 KiB

from typing import Callable, Optional, TypeVar
from ..config import registry
from ..model import Model
from ..types import Floats2d
InT = TypeVar("InT")
OutT = TypeVar("OutT")
@registry.layers("resizable.v1")
def resizable(layer, resize_layer: Callable) -> Model[InT, OutT]:
"""Container that holds one layer that can change dimensions."""
return Model(
f"resizable({layer.name})",
forward,
init=init,
layers=[layer],
attrs={"resize_layer": resize_layer},
dims={name: layer.maybe_get_dim(name) for name in layer.dim_names},
)
def forward(model: Model[InT, OutT], X: InT, is_train: bool):
layer = model.layers[0]
Y, callback = layer(X, is_train=is_train)
def backprop(dY: OutT) -> InT:
return callback(dY)
return Y, backprop
def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
layer = model.layers[0]
layer.initialize(X, Y)
def resize_model(model: Model[InT, OutT], new_nO):
old_layer = model.layers[0]
new_layer = model.attrs["resize_layer"](old_layer, new_nO)
model.layers[0] = new_layer
return model
def resize_linear_weighted(
layer: Model[Floats2d, Floats2d], new_nO, *, fill_defaults=None
) -> Model[Floats2d, Floats2d]:
"""Create a resized copy of a layer that has parameters W and b and dimensions nO and nI."""
assert not layer.layers
assert not layer.ref_names
assert not layer.shims
# return the original layer if it wasn't initialized or if nO didn't change
if layer.has_dim("nO") is None:
layer.set_dim("nO", new_nO)
return layer
elif new_nO == layer.get_dim("nO"):
return layer
elif layer.has_dim("nI") is None:
layer.set_dim("nO", new_nO, force=True)
return layer
dims = {name: layer.maybe_get_dim(name) for name in layer.dim_names}
dims["nO"] = new_nO
new_layer: Model[Floats2d, Floats2d] = Model(
layer.name,
layer._func,
dims=dims,
params={name: None for name in layer.param_names},
init=layer.init,
attrs=layer.attrs,
refs={},
ops=layer.ops,
)
new_layer.initialize()
for name in layer.param_names:
if layer.has_param(name):
filler = 0 if not fill_defaults else fill_defaults.get(name, 0)
_resize_parameter(name, layer, new_layer, filler=filler)
return new_layer
def _resize_parameter(name, layer, new_layer, filler=0):
larger = new_layer.get_param(name)
smaller = layer.get_param(name)
# copy the original weights
larger[: len(smaller)] = smaller
# set the new weights
larger[len(smaller) :] = filler
new_layer.set_param(name, larger)