You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
320 lines
12 KiB
320 lines
12 KiB
from typing import Any, Callable, Dict, Optional, Tuple, cast
|
|
|
|
from ..compat import torch
|
|
from ..config import registry
|
|
from ..model import Model
|
|
from ..shims import PyTorchGradScaler, PyTorchShim
|
|
from ..types import ArgsKwargs, Floats3d, Padded
|
|
from ..util import (
|
|
convert_recursive,
|
|
is_torch_array,
|
|
is_xp_array,
|
|
partial,
|
|
torch2xp,
|
|
xp2torch,
|
|
)
|
|
|
|
|
|
@registry.layers("PyTorchRNNWrapper.v1")
|
|
def PyTorchRNNWrapper(
|
|
pytorch_model: Any,
|
|
convert_inputs: Optional[Callable] = None,
|
|
convert_outputs: Optional[Callable] = None,
|
|
) -> Model[Padded, Padded]:
|
|
"""Wrap a PyTorch RNN model for use in Thinc."""
|
|
if convert_inputs is None:
|
|
convert_inputs = convert_rnn_inputs
|
|
if convert_outputs is None:
|
|
convert_outputs = convert_rnn_outputs
|
|
return cast(
|
|
Model[Padded, Padded],
|
|
PyTorchWrapper(
|
|
pytorch_model,
|
|
convert_inputs=convert_inputs,
|
|
convert_outputs=convert_outputs,
|
|
),
|
|
)
|
|
|
|
|
|
@registry.layers("PyTorchWrapper.v1")
|
|
def PyTorchWrapper(
|
|
pytorch_model: Any,
|
|
convert_inputs: Optional[Callable] = None,
|
|
convert_outputs: Optional[Callable] = None,
|
|
) -> Model[Any, Any]:
|
|
"""Wrap a PyTorch model, so that it has the same API as Thinc models.
|
|
To optimize the model, you'll need to create a PyTorch optimizer and call
|
|
optimizer.step() after each batch. See examples/wrap_pytorch.py
|
|
|
|
Your PyTorch model's forward method can take arbitrary args and kwargs,
|
|
but must return either a single tensor as output or a tuple. You may find the
|
|
PyTorch register_forward_hook helpful if you need to adapt the output.
|
|
|
|
The convert functions are used to map inputs and outputs to and from your
|
|
PyTorch model. Each function should return the converted output, and a callback
|
|
to use during the backward pass. So:
|
|
|
|
Xtorch, get_dX = convert_inputs(X)
|
|
Ytorch, torch_backprop = model.shims[0](Xtorch, is_train)
|
|
Y, get_dYtorch = convert_outputs(Ytorch)
|
|
|
|
To allow maximum flexibility, the PyTorchShim expects ArgsKwargs objects
|
|
on the way into the forward and backward passed. The ArgsKwargs objects
|
|
will be passed straight into the model in the forward pass, and straight
|
|
into `torch.autograd.backward` during the backward pass.
|
|
"""
|
|
if convert_inputs is None:
|
|
convert_inputs = convert_pytorch_default_inputs
|
|
if convert_outputs is None:
|
|
convert_outputs = convert_pytorch_default_outputs
|
|
return Model(
|
|
"pytorch",
|
|
forward,
|
|
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
|
|
shims=[PyTorchShim(pytorch_model)],
|
|
dims={"nI": None, "nO": None},
|
|
)
|
|
|
|
|
|
@registry.layers("PyTorchWrapper.v2")
|
|
def PyTorchWrapper_v2(
|
|
pytorch_model: Any,
|
|
convert_inputs: Optional[Callable] = None,
|
|
convert_outputs: Optional[Callable] = None,
|
|
mixed_precision: bool = False,
|
|
grad_scaler: Optional[PyTorchGradScaler] = None,
|
|
device: Optional["torch.device"] = None,
|
|
) -> Model[Any, Any]:
|
|
"""Wrap a PyTorch model, so that it has the same API as Thinc models.
|
|
To optimize the model, you'll need to create a PyTorch optimizer and call
|
|
optimizer.step() after each batch. See examples/wrap_pytorch.py
|
|
|
|
Your PyTorch model's forward method can take arbitrary args and kwargs,
|
|
but must return either a single tensor as output or a tuple. You may find the
|
|
PyTorch register_forward_hook helpful if you need to adapt the output.
|
|
|
|
The convert functions are used to map inputs and outputs to and from your
|
|
PyTorch model. Each function should return the converted output, and a callback
|
|
to use during the backward pass. So:
|
|
|
|
Xtorch, get_dX = convert_inputs(X)
|
|
Ytorch, torch_backprop = model.shims[0](Xtorch, is_train)
|
|
Y, get_dYtorch = convert_outputs(Ytorch)
|
|
|
|
To allow maximum flexibility, the PyTorchShim expects ArgsKwargs objects
|
|
on the way into the forward and backward passed. The ArgsKwargs objects
|
|
will be passed straight into the model in the forward pass, and straight
|
|
into `torch.autograd.backward` during the backward pass.
|
|
|
|
mixed_precision:
|
|
Enable mixed-precision. This changes whitelisted ops to run
|
|
in half precision for better performance and lower memory use.
|
|
grad_scaler:
|
|
The gradient scaler to use for mixed-precision training. If this
|
|
argument is set to "None" and mixed precision is enabled, a gradient
|
|
scaler with the default configuration is used.
|
|
device:
|
|
The PyTorch device to run the model on. When this argument is
|
|
set to "None", the default device for the currently active Thinc
|
|
ops is used.
|
|
"""
|
|
if convert_inputs is None:
|
|
convert_inputs = convert_pytorch_default_inputs
|
|
if convert_outputs is None:
|
|
convert_outputs = convert_pytorch_default_outputs
|
|
return Model(
|
|
"pytorch",
|
|
forward,
|
|
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
|
|
shims=[
|
|
PyTorchShim(
|
|
pytorch_model,
|
|
mixed_precision=mixed_precision,
|
|
grad_scaler=grad_scaler,
|
|
device=device,
|
|
)
|
|
],
|
|
dims={"nI": None, "nO": None},
|
|
)
|
|
|
|
|
|
@registry.layers("PyTorchWrapper.v3")
|
|
def PyTorchWrapper_v3(
|
|
pytorch_model: "torch.nn.Module",
|
|
convert_inputs: Optional[Callable] = None,
|
|
convert_outputs: Optional[Callable] = None,
|
|
mixed_precision: bool = False,
|
|
grad_scaler: Optional[PyTorchGradScaler] = None,
|
|
device: Optional["torch.device"] = None,
|
|
serialize_model: Optional[Callable[[Any], bytes]] = None,
|
|
deserialize_model: Optional[Callable[[Any, bytes, "torch.device"], Any]] = None,
|
|
) -> Model[Any, Any]:
|
|
"""Wrap a PyTorch model, so that it has the same API as Thinc models.
|
|
To optimize the model, you'll need to create a PyTorch optimizer and call
|
|
optimizer.step() after each batch. See examples/wrap_pytorch.py
|
|
|
|
Your PyTorch model's forward method can take arbitrary args and kwargs,
|
|
but must return either a single tensor or a tuple. You may find the
|
|
PyTorch register_forward_hook helpful if you need to adapt the output.
|
|
|
|
The convert functions are used to map inputs and outputs to and from your
|
|
PyTorch model. Each function should return the converted output, and a callback
|
|
to use during the backward pass. So:
|
|
|
|
Xtorch, get_dX = convert_inputs(X)
|
|
Ytorch, torch_backprop = model.shims[0](Xtorch, is_train)
|
|
Y, get_dYtorch = convert_outputs(Ytorch)
|
|
|
|
To allow maximum flexibility, the PyTorchShim expects ArgsKwargs objects
|
|
on the way into the forward and backward passed. The ArgsKwargs objects
|
|
will be passed straight into the model in the forward pass, and straight
|
|
into `torch.autograd.backward` during the backward pass.
|
|
|
|
mixed_precision:
|
|
Enable mixed-precision. This changes whitelisted ops to run
|
|
in half precision for better performance and lower memory use.
|
|
grad_scaler:
|
|
The gradient scaler to use for mixed-precision training. If this
|
|
argument is set to "None" and mixed precision is enabled, a gradient
|
|
scaler with the default configuration is used.
|
|
device:
|
|
The PyTorch device to run the model on. When this argument is
|
|
set to "None", the default device for the currently active Thinc
|
|
ops is used.
|
|
serialize_model:
|
|
Callback that receives the wrapped PyTorch model as its argument and
|
|
returns a "bytes" representation of the same. The representation should
|
|
contain all the necessary information to fully deserialize the model.
|
|
When set to "None", the default serializer serializes the model's parameters.
|
|
deserialize_model:
|
|
Callback that receives the default PyTorch model (passed to the constructor), the
|
|
serialized "bytes" representation and a PyTorch device. It should return a
|
|
fully deserialized model on the target device as its result.
|
|
When set to "None", the default deserializer deserializes the model's parameters.
|
|
"""
|
|
if convert_inputs is None:
|
|
convert_inputs = convert_pytorch_default_inputs
|
|
if convert_outputs is None:
|
|
convert_outputs = convert_pytorch_default_outputs
|
|
return Model(
|
|
"pytorch",
|
|
forward,
|
|
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
|
|
shims=[
|
|
PyTorchShim(
|
|
pytorch_model,
|
|
mixed_precision=mixed_precision,
|
|
grad_scaler=grad_scaler,
|
|
device=device,
|
|
serialize_model=serialize_model,
|
|
deserialize_model=deserialize_model,
|
|
)
|
|
],
|
|
dims={"nI": None, "nO": None},
|
|
)
|
|
|
|
|
|
def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]:
|
|
"""Return the output of the wrapped PyTorch model for the given input,
|
|
along with a callback to handle the backward pass.
|
|
"""
|
|
convert_inputs = model.attrs["convert_inputs"]
|
|
convert_outputs = model.attrs["convert_outputs"]
|
|
|
|
Xtorch, get_dX = convert_inputs(model, X, is_train)
|
|
Ytorch, torch_backprop = model.shims[0](Xtorch, is_train)
|
|
Y, get_dYtorch = convert_outputs(model, (X, Ytorch), is_train)
|
|
|
|
def backprop(dY: Any) -> Any:
|
|
dYtorch = get_dYtorch(dY)
|
|
dXtorch = torch_backprop(dYtorch)
|
|
dX = get_dX(dXtorch)
|
|
return dX
|
|
|
|
return Y, backprop
|
|
|
|
|
|
# Default conversion functions
|
|
|
|
|
|
def convert_pytorch_default_inputs(
|
|
model: Model, X: Any, is_train: bool
|
|
) -> Tuple[ArgsKwargs, Callable[[ArgsKwargs], Any]]:
|
|
shim = cast(PyTorchShim, model.shims[0])
|
|
xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train, device=shim.device)
|
|
converted = convert_recursive(is_xp_array, xp2torch_, X)
|
|
if isinstance(converted, ArgsKwargs):
|
|
|
|
def reverse_conversion(dXtorch):
|
|
return convert_recursive(is_torch_array, torch2xp, dXtorch)
|
|
|
|
return converted, reverse_conversion
|
|
elif isinstance(converted, dict):
|
|
|
|
def reverse_conversion(dXtorch):
|
|
dX = convert_recursive(is_torch_array, torch2xp, dXtorch)
|
|
return dX.kwargs
|
|
|
|
return ArgsKwargs(args=tuple(), kwargs=converted), reverse_conversion
|
|
elif isinstance(converted, (tuple, list)):
|
|
|
|
def reverse_conversion(dXtorch):
|
|
dX = convert_recursive(is_torch_array, torch2xp, dXtorch)
|
|
return dX.args
|
|
|
|
return ArgsKwargs(args=tuple(converted), kwargs={}), reverse_conversion
|
|
else:
|
|
|
|
def reverse_conversion(dXtorch):
|
|
dX = convert_recursive(is_torch_array, torch2xp, dXtorch)
|
|
return dX.args[0]
|
|
|
|
return ArgsKwargs(args=(converted,), kwargs={}), reverse_conversion
|
|
|
|
|
|
def convert_pytorch_default_outputs(model: Model, X_Ytorch: Any, is_train: bool):
|
|
shim = cast(PyTorchShim, model.shims[0])
|
|
X, Ytorch = X_Ytorch
|
|
Y = convert_recursive(is_torch_array, torch2xp, Ytorch)
|
|
|
|
def reverse_conversion(dY: Any) -> ArgsKwargs:
|
|
dYtorch = convert_recursive(
|
|
is_xp_array, partial(xp2torch, device=shim.device), dY
|
|
)
|
|
return ArgsKwargs(args=((Ytorch,),), kwargs={"grad_tensors": dYtorch})
|
|
|
|
return Y, reverse_conversion
|
|
|
|
|
|
# BiLSTM conversion functions
|
|
|
|
|
|
def convert_rnn_inputs(model: Model, Xp: Padded, is_train: bool):
|
|
shim = cast(PyTorchShim, model.shims[0])
|
|
size_at_t = Xp.size_at_t
|
|
lengths = Xp.lengths
|
|
indices = Xp.indices
|
|
|
|
def convert_from_torch_backward(d_inputs: ArgsKwargs) -> Padded:
|
|
dX = torch2xp(d_inputs.args[0])
|
|
return Padded(dX, size_at_t, lengths, indices) # type: ignore
|
|
|
|
output = ArgsKwargs(
|
|
args=(xp2torch(Xp.data, requires_grad=True, device=shim.device), None),
|
|
kwargs={},
|
|
)
|
|
return output, convert_from_torch_backward
|
|
|
|
|
|
def convert_rnn_outputs(model: Model, inputs_outputs: Tuple, is_train):
|
|
shim = cast(PyTorchShim, model.shims[0])
|
|
Xp, (Ytorch, _) = inputs_outputs
|
|
|
|
def convert_for_torch_backward(dYp: Padded) -> ArgsKwargs:
|
|
dYtorch = xp2torch(dYp.data, requires_grad=True, device=shim.device)
|
|
return ArgsKwargs(args=(Ytorch,), kwargs={"grad_tensors": dYtorch})
|
|
|
|
Y = cast(Floats3d, torch2xp(Ytorch))
|
|
Yp = Padded(Y, Xp.size_at_t, Xp.lengths, Xp.indices)
|
|
return Yp, convert_for_torch_backward
|