You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

45 lines
1.3 KiB

from typing import Any, Callable, Optional, Tuple, TypeVar
from ..model import Model
_ModelT = TypeVar("_ModelT", bound=Model)
do_nothing = lambda *args, **kwargs: None
def with_debug(
layer: _ModelT,
name: Optional[str] = None,
*,
on_init: Callable[[Model, Any, Any], None] = do_nothing,
on_forward: Callable[[Model, Any, bool], None] = do_nothing,
on_backprop: Callable[[Any], None] = do_nothing,
) -> _ModelT:
"""Debugging layer that wraps any layer and allows executing callbacks
during the forward pass, backward pass and initialization. The callbacks
will receive the same arguments as the functions they're called in.
"""
name = layer.name if name is None else name
orig_forward = layer._func
orig_init = layer.init
def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]:
on_forward(model, X, is_train)
layer_Y, layer_callback = orig_forward(layer, X, is_train=is_train)
def backprop(dY: Any) -> Any:
on_backprop(dY)
return layer_callback(dY)
return layer_Y, backprop
def init(model: Model, X: Any, Y: Any) -> None:
on_init(model, X, Y)
if orig_init is not None:
orig_init(layer, X, Y)
layer.replace_callbacks(forward, init=init)
return layer