You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
2.4 KiB

from io import BytesIO
from typing import Any, Optional
import srsly
from ..compat import torch
from ..util import get_torch_default_device
from .pytorch import PyTorchShim
from .pytorch_grad_scaler import PyTorchGradScaler
class TorchScriptShim(PyTorchShim):
"""A Thinc shim that wraps a TorchScript module.
model:
The TorchScript module. A value of `None` is also possible to
construct a shim to deserialize into.
mixed_precision:
Enable mixed-precision. This changes whitelisted ops to run
in half precision for better performance and lower memory use.
grad_scaler:
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""
def __init__(
self,
model: Optional["torch.jit.ScriptModule"],
config=None,
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
):
if model is not None and not isinstance(model, torch.jit.ScriptModule):
raise ValueError(
"PyTorchScriptShim must be initialized with ScriptModule or None (for deserialization)"
)
super().__init__(model, config, optimizer, mixed_precision, grad_scaler, device)
def to_bytes(self):
filelike = BytesIO()
torch.jit.save(self._model, filelike)
filelike.seek(0)
model_bytes = filelike.getvalue()
msg = {"config": self.cfg, "model": model_bytes}
return srsly.msgpack_dumps(msg)
def from_bytes(self, bytes_data):
device = get_torch_default_device()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["model"])
filelike.seek(0)
# As of Torch 2.0.0, loading TorchScript models directly to
# an MPS device is not supported.
map_location = torch.device("cpu") if device.type == "mps" else device
self._model = torch.jit.load(filelike, map_location=map_location)
self._model.to(device)
self._grad_scaler.to_(device)
return self