You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

183 lines
6.3 KiB

from functools import partial
from typing import Callable, Optional, Tuple, cast
from ..backends import Ops
from ..config import registry
from ..initializers import glorot_uniform_init, zero_init
from ..model import Model
from ..types import Floats1d, Floats2d, Floats4d, Padded, Ragged
from ..util import get_width
from .noop import noop
@registry.layers("LSTM.v1")
def LSTM(
nO: Optional[int] = None,
nI: Optional[int] = None,
*,
bi: bool = False,
depth: int = 1,
dropout: float = 0.0,
init_W: Optional[Callable] = None,
init_b: Optional[Callable] = None,
) -> Model[Padded, Padded]:
if depth == 0:
msg = "LSTM depth must be at least 1. Maybe we should make this a noop?"
raise ValueError(msg)
if init_W is None:
init_W = glorot_uniform_init
if init_b is None:
init_b = zero_init
model: Model[Padded, Padded] = Model(
"lstm",
forward,
dims={"nO": nO, "nI": nI, "depth": depth, "dirs": 1 + int(bi)},
attrs={"registry_name": "LSTM.v1", "dropout_rate": dropout},
params={"LSTM": None, "HC0": None},
init=partial(init, init_W, init_b),
)
return model
@registry.layers("PyTorchLSTM.v1")
def PyTorchLSTM(
nO: int, nI: int, *, bi: bool = False, depth: int = 1, dropout: float = 0.0
) -> Model[Padded, Padded]:
import torch.nn
from .pytorchwrapper import PyTorchRNNWrapper
from .with_padded import with_padded
if depth == 0:
return noop() # type: ignore[misc]
nH = nO
if bi:
nH = nO // 2
pytorch_rnn = PyTorchRNNWrapper(
torch.nn.LSTM(nI, nH, depth, bidirectional=bi, dropout=dropout)
)
pytorch_rnn.set_dim("nO", nO)
pytorch_rnn.set_dim("nI", nI)
return with_padded(pytorch_rnn)
def init(
init_W: Callable,
init_b: Callable,
model: Model,
X: Optional[Padded] = None,
Y: Optional[Padded] = None,
) -> None:
if X is not None:
model.set_dim("nI", get_width(X))
if Y is not None:
model.set_dim("nO", get_width(Y))
nH = int(model.get_dim("nO") / model.get_dim("dirs"))
nI = model.get_dim("nI")
depth = model.get_dim("depth")
dirs = model.get_dim("dirs")
# It's easiest to use the initializer if we alloc the weights separately
# and then stick them all together afterwards. The order matters here:
# we need to keep the same format that CuDNN expects.
params = []
# Convenience
init_W = partial(init_W, model.ops)
init_b = partial(init_b, model.ops)
layer_nI = nI
for i in range(depth):
for j in range(dirs):
# Input-to-gates weights and biases.
params.append(init_W((nH, layer_nI)))
params.append(init_W((nH, layer_nI)))
params.append(init_W((nH, layer_nI)))
params.append(init_W((nH, layer_nI)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
# Hidden-to-gates weights and biases
params.append(init_W((nH, nH)))
params.append(init_W((nH, nH)))
params.append(init_W((nH, nH)))
params.append(init_W((nH, nH)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
params.append(init_b((nH,)))
layer_nI = nH * dirs
model.set_param("LSTM", model.ops.xp.concatenate([p.ravel() for p in params]))
model.set_param("HC0", zero_init(model.ops, (2, depth, dirs, nH)))
size = model.get_param("LSTM").size
expected = 4 * dirs * nH * (nH + nI) + dirs * (8 * nH)
for _ in range(1, depth):
expected += 4 * dirs * (nH + nH * dirs) * nH + dirs * (8 * nH)
assert size == expected, (size, expected)
def forward(
model: Model[Padded, Padded], Xp: Padded, is_train: bool
) -> Tuple[Padded, Callable]:
dropout = model.attrs["dropout_rate"]
Xr = _padded_to_packed(model.ops, Xp)
LSTM = cast(Floats1d, model.get_param("LSTM"))
HC0 = cast(Floats4d, model.get_param("HC0"))
H0 = HC0[0]
C0 = HC0[1]
if is_train:
# Apply dropout over *weights*, not *activations*. RNN activations are
# heavily correlated, so dropout over the activations is less effective.
# This trick was explained in Yarin Gal's thesis, and popularised by
# Smerity in the AWD-LSTM. It also means we can do the dropout outside
# of the backend, improving compatibility.
mask = cast(Floats1d, model.ops.get_dropout_mask(LSTM.shape, dropout))
LSTM = LSTM * mask
Y, fwd_state = model.ops.lstm_forward_training(
LSTM, H0, C0, cast(Floats2d, Xr.data), Xr.lengths
)
else:
Y = model.ops.lstm_forward_inference(
LSTM, H0, C0, cast(Floats2d, Xr.data), Xr.lengths
)
fwd_state = tuple()
assert Y.shape == (Xr.data.shape[0], Y.shape[1]), (Xr.data.shape, Y.shape)
Yp = _packed_to_padded(model.ops, Ragged(Y, Xr.lengths), Xp)
def backprop(dYp: Padded) -> Padded:
assert fwd_state
dYr = _padded_to_packed(model.ops, dYp)
dX, dLSTM = model.ops.backprop_lstm(
cast(Floats2d, dYr.data), dYr.lengths, LSTM, fwd_state
)
dLSTM *= mask
model.inc_grad("LSTM", dLSTM)
return _packed_to_padded(model.ops, Ragged(dX, dYr.lengths), dYp)
return Yp, backprop
def _padded_to_packed(ops: Ops, Xp: Padded) -> Ragged:
"""Strip padding from a padded sequence."""
assert Xp.lengths.sum() == Xp.size_at_t.sum(), (
Xp.lengths.sum(),
Xp.size_at_t.sum(),
)
Y = ops.alloc2f(Xp.lengths.sum(), Xp.data.shape[2])
start = 0
for t in range(Xp.size_at_t.shape[0]):
batch_size = Xp.size_at_t[t]
Y[start : start + batch_size] = Xp.data[t, :batch_size] # type: ignore[assignment]
start += batch_size
return Ragged(Y, Xp.size_at_t)
def _packed_to_padded(ops: Ops, Xr: Ragged, Xp: Padded) -> Padded:
Y = ops.alloc3f(Xp.data.shape[0], Xp.data.shape[1], Xr.data.shape[1])
X = cast(Floats2d, Xr.data)
start = 0
for t in range(Xp.size_at_t.shape[0]):
batch_size = Xp.size_at_t[t]
Y[t, :batch_size] = X[start : start + batch_size]
start += batch_size
return Padded(Y, size_at_t=Xp.size_at_t, lengths=Xp.lengths, indices=Xp.indices)