You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

58 lines
1.6 KiB

from typing import Callable, Tuple, TypeVar, Union, cast
from ..config import registry
from ..model import Model
from ..types import Floats2d, Ragged
InT = TypeVar("InT", Floats2d, Ragged)
@registry.layers("expand_window.v1")
def expand_window(window_size: int = 1) -> Model[InT, InT]:
"""For each vector in an input, construct an output vector that contains the
input and a window of surrounding vectors. This is one step in a convolution.
"""
return Model("expand_window", forward, attrs={"window_size": window_size})
def forward(model: Model[InT, InT], X: InT, is_train: bool) -> Tuple[InT, Callable]:
if isinstance(X, Ragged):
return _expand_window_ragged(model, X)
else:
return _expand_window_floats(model, X)
def _expand_window_floats(
model: Model[InT, InT], X: Floats2d
) -> Tuple[Floats2d, Callable]:
nW = model.attrs["window_size"]
if len(X) > 0:
Y = model.ops.seq2col(X, nW)
else:
assert len(X) == 0
Y = model.ops.tile(X, (nW * 2) + 1)
def backprop(dY: Floats2d) -> Floats2d:
return model.ops.backprop_seq2col(dY, nW)
return Y, backprop
def _expand_window_ragged(
model: Model[InT, InT], Xr: Ragged
) -> Tuple[Ragged, Callable]:
nW = model.attrs["window_size"]
Y = Ragged(
model.ops.seq2col(cast(Floats2d, Xr.data), nW, lengths=Xr.lengths), Xr.lengths
)
def backprop(dYr: Ragged) -> Ragged:
return Ragged(
model.ops.backprop_seq2col(
cast(Floats2d, dYr.data), nW, lengths=Xr.lengths
),
Xr.lengths,
)
return Y, backprop