from typing import Callable, Tuple from ..config import registry from ..model import Model from ..types import Floats2d InT = Floats2d OutT = Floats2d @registry.layers("softmax_activation.v1") def softmax_activation() -> Model[InT, OutT]: return Model("softmax_activation", forward) def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]: Y = model.ops.softmax(X, inplace=False) def backprop(dY: OutT) -> InT: return model.ops.backprop_softmax(Y, dY, axis=-1) return Y, backprop