You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.0 KiB

from typing import Callable, Optional, Tuple
import numpy
from ..config import registry
from ..model import Model
from ..types import Floats2d, Ints2d
InT = Ints2d
OutT = Floats2d
@registry.layers("uniqued.v1")
def uniqued(layer: Model, *, column: int = 0) -> Model[InT, OutT]:
"""Group inputs to a layer, so that the layer only has to compute for the
unique values. The data is transformed back before output, and the same
transformation is applied for the gradient. Effectively, this is a cache
local to each minibatch.
"""
return Model(
f"uniqued({layer.name})",
forward,
init=init,
layers=[layer],
dims={"nO": None, "nI": None},
attrs={"column": column},
)
def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
column: int = model.attrs["column"]
layer = model.layers[0]
if X.size < 2:
return layer(X, is_train)
keys = X[:, column]
if not isinstance(keys, numpy.ndarray):
keys = keys.get() # pragma: no cover
uniq_keys, ind, inv, counts = layer.ops.xp.unique(
keys, return_index=True, return_inverse=True, return_counts=True
)
counts = model.ops.reshape2i(counts, -1, 1)
X_uniq = X[ind]
Y_uniq, bp_Y_uniq = layer(X_uniq, is_train)
Y = Y_uniq[inv].reshape((X.shape[0],) + Y_uniq.shape[1:])
uniq_shape = tuple(Y_uniq.shape)
def backprop(dY: OutT) -> InT:
dY_uniq = layer.ops.alloc2f(*uniq_shape)
layer.ops.scatter_add(dY_uniq, layer.ops.asarray_i(inv), dY)
d_uniques = bp_Y_uniq(dY_uniq)
# This confusing bit of indexing "ununiques"
return (d_uniques / counts)[inv]
return Y, backprop
def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
layer = model.layers[0]
layer.initialize(X=X, Y=Y)
if layer.has_dim("nI"):
model.set_dim("nI", layer.get_dim("nI")) # pragma: no cover
if layer.has_dim("nO"):
model.set_dim("nO", layer.get_dim("nO"))