You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.4 KiB

from typing import Any, Callable, Tuple
import numpy
from thinc.backends import Ops
from ..config import registry
from ..model import Model
@registry.layers("with_cpu.v1")
def with_cpu(layer: Model, ops: Ops) -> Model:
layer.to_cpu()
return Model(
f"with_cpu({layer.name})",
forward,
layers=[layer],
ops=ops,
init=init,
dims={name: layer.maybe_get_dim(name) for name in layer.dim_names},
)
def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]:
cpu_outputs, backprop = model.layers[0].begin_update(_to_cpu(X))
gpu_outputs = _to_device(model.ops, cpu_outputs)
def with_cpu_backprop(d_outputs):
cpu_d_outputs = _to_cpu(d_outputs)
return backprop(cpu_d_outputs)
return gpu_outputs, with_cpu_backprop
def init(model: Model, X: Any, Y: Any) -> None:
model.layers[0].initialize(X, Y)
def _to_cpu(X):
if isinstance(X, numpy.ndarray):
return X
elif isinstance(X, tuple):
return tuple([_to_cpu(x) for x in X])
elif isinstance(X, list):
return [_to_cpu(x) for x in X]
elif hasattr(X, "get"):
return X.get()
else:
return X
def _to_device(ops, X):
if isinstance(X, tuple):
return tuple([_to_device(ops, x) for x in X])
elif isinstance(X, list):
return [_to_device(ops, x) for x in X]
else:
return ops.asarray(X)