You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
88 lines
2.8 KiB
88 lines
2.8 KiB
from typing import Callable, List, Optional, Tuple, TypeVar, Union, cast
|
|
|
|
from ..config import registry
|
|
from ..model import Model
|
|
from ..types import Array2d, Floats2d, Ints2d, List2d, Padded, Ragged
|
|
|
|
SeqT = TypeVar("SeqT", Padded, Ragged, List2d, List[Floats2d], List[Ints2d])
|
|
|
|
|
|
@registry.layers("with_list.v1")
|
|
def with_list(layer: Model[List2d, List2d]) -> Model[SeqT, SeqT]:
|
|
return Model(
|
|
f"with_list({layer.name})",
|
|
forward,
|
|
init=init,
|
|
layers=[layer],
|
|
dims={name: layer.maybe_get_dim(name) for name in layer.dim_names},
|
|
)
|
|
|
|
|
|
def forward(
|
|
model: Model[SeqT, SeqT], Xseq: SeqT, is_train: bool
|
|
) -> Tuple[SeqT, Callable]:
|
|
layer: Model[List2d, List2d] = model.layers[0]
|
|
if isinstance(Xseq, Padded):
|
|
return _padded_forward(layer, Xseq, is_train)
|
|
elif isinstance(Xseq, Ragged):
|
|
return _ragged_forward(layer, Xseq, is_train)
|
|
else:
|
|
return cast(Tuple[SeqT, Callable], layer(cast(List2d, Xseq), is_train))
|
|
|
|
|
|
def init(
|
|
model: Model[SeqT, SeqT], X: Optional[SeqT] = None, Y: Optional[SeqT] = None
|
|
) -> None:
|
|
model.layers[0].initialize(
|
|
X=_get_list(model, X) if X is not None else None,
|
|
Y=_get_list(model, Y) if Y is not None else None,
|
|
)
|
|
|
|
|
|
def _get_list(model, seq):
|
|
if isinstance(seq, Padded):
|
|
return model.ops.padded2list(seq)
|
|
elif isinstance(seq, Ragged):
|
|
return model.ops.unflatten(seq.data, seq.lengths)
|
|
else:
|
|
return seq
|
|
|
|
|
|
def _ragged_forward(
|
|
layer: Model[List2d, List2d], Xr: Ragged, is_train: bool
|
|
) -> Tuple[Ragged, Callable]:
|
|
# Assign these to locals, to keep code a bit shorter.
|
|
unflatten = layer.ops.unflatten
|
|
flatten = layer.ops.flatten
|
|
# It's worth being a bit careful about memory here, as the activations
|
|
# are potentially large on GPU. So we make nested function calls instead
|
|
# of assigning to temporaries where possible, so memory can be reclaimed
|
|
# sooner.
|
|
Ys, get_dXs = layer(unflatten(Xr.data, Xr.lengths), is_train)
|
|
|
|
def backprop(dYr: Ragged):
|
|
return Ragged(
|
|
flatten(get_dXs(unflatten(dYr.data, dYr.lengths))),
|
|
dYr.lengths,
|
|
)
|
|
|
|
return Ragged(flatten(Ys), Xr.lengths), backprop
|
|
|
|
|
|
def _padded_forward(
|
|
layer: Model[List2d, List2d], Xp: Padded, is_train: bool
|
|
) -> Tuple[Padded, Callable]:
|
|
# Assign these to locals, to keep code a bit shorter.
|
|
padded2list = layer.ops.padded2list
|
|
list2padded = layer.ops.list2padded
|
|
# It's worth being a bit careful about memory here, as the activations
|
|
# are potentially large on GPU. So we make nested function calls instead
|
|
# of assigning to temporaries where possible, so memory can be reclaimed
|
|
# sooner.
|
|
Ys, get_dXs = layer(padded2list(Xp), is_train)
|
|
|
|
def backprop(dYp):
|
|
return list2padded(get_dXs(padded2list(dYp)))
|
|
|
|
return list2padded(Ys), backprop
|