You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
2.1 KiB

from typing import Callable, Optional, Tuple, cast
from ..backends import Ops
from ..config import registry
from ..model import Model
from ..types import Padded
InT = Padded
OutT = Padded
@registry.layers("bidirectional.v1")
def bidirectional(
l2r: Model[InT, OutT], r2l: Optional[Model[InT, OutT]] = None
) -> Model[InT, OutT]:
"""Stitch two RNN models into a bidirectional layer. Expects squared sequences."""
if r2l is None:
r2l = l2r.copy()
return Model(f"bi{l2r.name}", forward, layers=[l2r, r2l], init=init)
def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
l2r, r2l = model.layers
X_rev = _reverse(model.ops, X)
l2r_Z, bp_l2r_Z = l2r(X, is_train)
r2l_Z, bp_r2l_Z = r2l(X_rev, is_train)
Z = _concatenate(model.ops, l2r_Z, r2l_Z)
def backprop(dZ: OutT) -> InT:
d_l2r_Z, d_r2l_Z = _split(model.ops, dZ)
dX_l2r = bp_l2r_Z(d_l2r_Z)
dX_r2l = bp_r2l_Z(d_r2l_Z)
return _sum(dX_l2r, dX_r2l)
return Z, backprop
def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
(Y1, Y2) = _split(model.ops, Y) if Y is not None else (None, None)
model.layers[0].initialize(X=X, Y=Y1)
model.layers[1].initialize(X=X, Y=Y2)
def _reverse(ops: Ops, Xp: Padded) -> Padded:
return Padded(Xp.data[::1], Xp.size_at_t, Xp.lengths, Xp.indices)
def _concatenate(ops: Ops, l2r: Padded, r2l: Padded) -> Padded:
return Padded(
ops.xp.concatenate((l2r.data, r2l.data), axis=-1),
l2r.size_at_t,
l2r.lengths,
l2r.indices,
)
def _split(ops: Ops, Xp: Padded) -> Tuple[Padded, Padded]:
half = Xp.data.shape[-1] // 2
# I don't know how to write these ellipsis in the overloads :(
X_l2r = Xp.data[cast(Tuple[slice, slice], (..., slice(None, half)))]
X_r2l = Xp.data[cast(Tuple[slice, slice], (..., slice(half)))]
return (
Padded(X_l2r, Xp.size_at_t, Xp.lengths, Xp.indices),
Padded(X_r2l, Xp.size_at_t, Xp.lengths, Xp.indices),
)
def _sum(Xp: Padded, Yp: Padded) -> Padded:
return Padded(Xp.data + Yp.data, Xp.size_at_t, Xp.lengths, Xp.indices)