You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.6 KiB
54 lines
1.6 KiB
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, cast
|
|
|
|
from ..config import registry
|
|
from ..model import Model
|
|
from ..types import ArrayXd, ListXd
|
|
|
|
ItemT = TypeVar("ItemT")
|
|
InT = Sequence[Sequence[ItemT]]
|
|
OutT = ListXd
|
|
InnerInT = Sequence[ItemT]
|
|
InnerOutT = ArrayXd
|
|
|
|
|
|
@registry.layers("with_flatten.v1")
|
|
def with_flatten(layer: Model[InnerInT[ItemT], InnerOutT]) -> Model[InT[ItemT], OutT]:
|
|
return Model(f"with_flatten({layer.name})", forward, layers=[layer], init=init)
|
|
|
|
|
|
def forward(
|
|
model: Model[InT, OutT], Xnest: InT, is_train: bool
|
|
) -> Tuple[OutT, Callable]:
|
|
layer: Model[InnerInT, InnerOutT] = model.layers[0]
|
|
Xflat = _flatten(Xnest)
|
|
Yflat, backprop_layer = layer(Xflat, is_train)
|
|
# Get the split points. We want n-1 splits for n items.
|
|
arr = layer.ops.asarray1i([len(x) for x in Xnest[:-1]])
|
|
splits = arr.cumsum()
|
|
Ynest = layer.ops.xp.split(Yflat, splits, axis=0)
|
|
|
|
def backprop(dYnest: OutT) -> InT:
|
|
dYflat = model.ops.flatten(dYnest) # type: ignore[arg-type, var-annotated]
|
|
# type ignore necessary for older versions of Mypy/Pydantic
|
|
dXflat = backprop_layer(dYflat)
|
|
dXnest = layer.ops.xp.split(dXflat, splits, axis=-1)
|
|
return dXnest
|
|
|
|
return Ynest, backprop
|
|
|
|
|
|
def _flatten(nested: InT) -> InnerInT:
|
|
flat: List = []
|
|
for item in nested:
|
|
flat.extend(item)
|
|
return cast(InT, flat)
|
|
|
|
|
|
def init(
|
|
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
|
|
) -> None:
|
|
model.layers[0].initialize(
|
|
_flatten(X) if X is not None else None,
|
|
model.layers[0].ops.xp.hstack(Y) if Y is not None else None,
|
|
)
|