You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
3.2 KiB

from typing import Any, Callable, Optional
from ..compat import torch
from ..model import Model
from ..shims import PyTorchGradScaler, PyTorchShim, TorchScriptShim
from .pytorchwrapper import (
convert_pytorch_default_inputs,
convert_pytorch_default_outputs,
forward,
)
def TorchScriptWrapper_v1(
torchscript_model: Optional["torch.jit.ScriptModule"] = None,
convert_inputs: Optional[Callable] = None,
convert_outputs: Optional[Callable] = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
) -> Model[Any, Any]:
"""Wrap a TorchScript model, so that it has the same API as Thinc models.
torchscript_model:
The TorchScript module. A value of `None` is also possible to
construct a shim to deserialize into.
convert_inputs:
Function that converts inputs and gradients that should be passed
to the model to Torch tensors.
convert_outputs:
Function that converts model outputs and gradients from Torch tensors
Thinc arrays.
mixed_precision:
Enable mixed-precision. This changes whitelisted ops to run
in half precision for better performance and lower memory use.
grad_scaler:
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""
if convert_inputs is None:
convert_inputs = convert_pytorch_default_inputs
if convert_outputs is None:
convert_outputs = convert_pytorch_default_outputs
return Model(
"pytorch_script",
forward,
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
shims=[
TorchScriptShim(
model=torchscript_model,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)
],
dims={"nI": None, "nO": None},
)
def pytorch_to_torchscript_wrapper(model: Model):
"""Convert a PyTorch wrapper to a TorchScript wrapper. The embedded PyTorch
`Module` is converted to `ScriptModule`.
"""
shim = model.shims[0]
if not isinstance(shim, PyTorchShim):
raise ValueError("Expected PyTorchShim when converting a PyTorch wrapper")
convert_inputs = model.attrs["convert_inputs"]
convert_outputs = model.attrs["convert_outputs"]
pytorch_model = shim._model
if not isinstance(pytorch_model, torch.nn.Module):
raise ValueError("PyTorchShim does not wrap a PyTorch module")
torchscript_model = torch.jit.script(pytorch_model)
grad_scaler = shim._grad_scaler
mixed_precision = shim._mixed_precision
device = shim.device
return TorchScriptWrapper_v1(
torchscript_model,
convert_inputs=convert_inputs,
convert_outputs=convert_outputs,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)