You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
1.6 KiB

from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, cast
from ..config import registry
from ..model import Model
from ..types import ArrayXd, ListXd
ItemT = TypeVar("ItemT")
InT = Sequence[Sequence[ItemT]]
OutT = ListXd
InnerInT = Sequence[ItemT]
InnerOutT = ArrayXd
@registry.layers("with_flatten.v1")
def with_flatten(layer: Model[InnerInT[ItemT], InnerOutT]) -> Model[InT[ItemT], OutT]:
return Model(f"with_flatten({layer.name})", forward, layers=[layer], init=init)
def forward(
model: Model[InT, OutT], Xnest: InT, is_train: bool
) -> Tuple[OutT, Callable]:
layer: Model[InnerInT, InnerOutT] = model.layers[0]
Xflat = _flatten(Xnest)
Yflat, backprop_layer = layer(Xflat, is_train)
# Get the split points. We want n-1 splits for n items.
arr = layer.ops.asarray1i([len(x) for x in Xnest[:-1]])
splits = arr.cumsum()
Ynest = layer.ops.xp.split(Yflat, splits, axis=0)
def backprop(dYnest: OutT) -> InT:
dYflat = model.ops.flatten(dYnest) # type: ignore[arg-type, var-annotated]
# type ignore necessary for older versions of Mypy/Pydantic
dXflat = backprop_layer(dYflat)
dXnest = layer.ops.xp.split(dXflat, splits, axis=-1)
return dXnest
return Ynest, backprop
def _flatten(nested: InT) -> InnerInT:
flat: List = []
for item in nested:
flat.extend(item)
return cast(InT, flat)
def init(
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None
) -> None:
model.layers[0].initialize(
_flatten(X) if X is not None else None,
model.layers[0].ops.xp.hstack(Y) if Y is not None else None,
)