You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
23 lines
542 B
23 lines
542 B
from typing import Callable, Tuple
|
|
|
|
from ..config import registry
|
|
from ..model import Model
|
|
from ..types import Floats2d
|
|
|
|
InT = Floats2d
|
|
OutT = Floats2d
|
|
|
|
|
|
@registry.layers("softmax_activation.v1")
|
|
def softmax_activation() -> Model[InT, OutT]:
|
|
return Model("softmax_activation", forward)
|
|
|
|
|
|
def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
|
|
Y = model.ops.softmax(X, inplace=False)
|
|
|
|
def backprop(dY: OutT) -> InT:
|
|
return model.ops.backprop_softmax(Y, dY, axis=-1)
|
|
|
|
return Y, backprop
|