You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
2.0 KiB

from typing import Callable, List, Optional, Tuple, TypeVar
from ..config import registry
from ..model import Model
from ..types import Floats1d, Floats2d, Floats3d, Floats4d, FloatsXd, Padded, Ragged
# fmt: off
InT = TypeVar(
"InT", List[Floats1d], List[Floats2d], List[Floats3d], List[Floats4d],
Ragged, Padded, FloatsXd, Floats1d, Floats2d, Floats3d, Floats4d)
# fmt: on
@registry.layers("residual.v1")
def residual(layer: Model[InT, InT]) -> Model[InT, InT]:
return Model(
f"residual({layer.name})",
forward,
init=init,
layers=[layer],
dims={
"nO": layer.get_dim("nO") if layer.has_dim("nO") else None,
"nI": layer.get_dim("nI") if layer.has_dim("nI") else None,
},
)
def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
def backprop(d_output: InT) -> InT:
dX = backprop_layer(d_output)
if isinstance(d_output, list):
return [d_output[i] + dX[i] for i in range(len(d_output))]
elif isinstance(d_output, Ragged):
return Ragged(d_output.data + dX.data, dX.lengths)
elif isinstance(X, Padded):
dX.data += d_output.data
return dX
else:
return d_output + dX
Y, backprop_layer = model.layers[0](X, is_train)
if isinstance(X, list):
return [X[i] + Y[i] for i in range(len(X))], backprop
elif isinstance(X, Ragged):
return Ragged(X.data + Y.data, X.lengths), backprop
elif isinstance(X, Padded):
Y.data += X.data
return Y, backprop
else:
return X + Y, backprop
def init(
model: Model[InT, InT], X: Optional[InT] = None, Y: Optional[InT] = None
) -> None:
first_layer = model.layers[0]
if first_layer.has_dim("nO") is None:
first_layer.initialize(X=X, Y=Y)
else:
first_layer.initialize(X=X)
if first_layer.has_dim("nO"):
model.set_dim("nO", first_layer.get_dim("nO"))
if first_layer.has_dim("nI"):
model.set_dim("nI", first_layer.get_dim("nI"))