You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
6.8 KiB
205 lines
6.8 KiB
"""
|
|
Note [Opaque Objects]
|
|
|
|
Opaque objects are the way we allow custom operators to accept a user-defined
|
|
"black box" object as an input.
|
|
|
|
There are two kinds of opaque types: VALUE type and REFERENCE type.
|
|
The distinction determines how torch.compile handles the object.
|
|
|
|
REFERENCE TYPES (default):
|
|
|
|
Reference-typed opaque objects represent mutable stateful objects and are
|
|
treated as black boxes. In torch.compile, since torch.compile cannot optimize
|
|
the anything (including tensors) within the object, the object must be an
|
|
input to the graph.
|
|
|
|
You can register a custom class as being a reference-based opaque object class
|
|
through `register_opaque_type(MyClass, typ="reference")`.
|
|
|
|
VALUE TYPES:
|
|
|
|
Value-typed opaque objects represent constant values.
|
|
In torch.compile, the graph specializes on the object like how other constants
|
|
are. Therefore there are a couple of methods on the class that must be
|
|
implemented before registering it as a value-typed opaque object class:
|
|
- __eq__: torch.compile will create guards based on the equality of this
|
|
object, meaning that a recompilation will happen if __eq__ returns False.
|
|
- __hash__: This must be implemented for Fake Tensor caching
|
|
- __repr__: This must be implemented as it will be used in the FX graph's
|
|
codegen to reconstruct the object. The string representation must be able to
|
|
construct the object again through its __init__ method.
|
|
|
|
You can register a custom class as being a reference-based opaque object class
|
|
through `register_opaque_type(MyClass, typ="value")`.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Literal, NewType
|
|
from weakref import WeakKeyDictionary
|
|
|
|
import torch
|
|
|
|
from .fake_class_registry import register_fake_class
|
|
|
|
|
|
@register_fake_class("aten::OpaqueObject")
|
|
class FakeOpaqueObject:
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None:
|
|
raise RuntimeError(
|
|
"FakeOpaqueObject should not be created through __obj_unflatten__ "
|
|
"and should be special handled. Please file an issue to Github."
|
|
)
|
|
|
|
|
|
OpaqueTypeStr = "__torch__.torch.classes.aten.OpaqueObject"
|
|
|
|
OpaqueType = NewType("OpaqueType", torch._C.ScriptObject)
|
|
|
|
|
|
@dataclass
|
|
class _OpaqueTypeInfo:
|
|
class_name: str
|
|
opaque_typ: Literal["reference", "value"]
|
|
|
|
|
|
# Mapping of type -> (string name, reference/value type)
|
|
_OPAQUE_TYPES: WeakKeyDictionary[Any, _OpaqueTypeInfo] = WeakKeyDictionary()
|
|
# Mapping of class_name -> (type, reference/value type)
|
|
_OPAQUE_TYPES_BY_NAME: dict[str, _OpaqueTypeInfo] = {}
|
|
|
|
|
|
def get_opaque_type_name(cls: Any) -> str:
|
|
"""
|
|
Gets the registered opaque type name for a given class.
|
|
|
|
Args:
|
|
cls (type): The class to get the type name for.
|
|
|
|
Returns:
|
|
str: The registered type name for the class.
|
|
|
|
Raises:
|
|
ValueError: If the class is not registered as an opaque type.
|
|
"""
|
|
if cls not in _OPAQUE_TYPES:
|
|
raise ValueError(
|
|
f"Class {cls} is not registered as an opaque type. "
|
|
f"Call register_opaque_type({cls.__name__}) first."
|
|
)
|
|
return _OPAQUE_TYPES[cls].class_name
|
|
|
|
|
|
def register_opaque_type(cls: Any, *, typ: str) -> None:
|
|
"""
|
|
Registers the given type as an opaque type which allows this to be consumed
|
|
by a custom operator.
|
|
|
|
The type name will be automatically generated from the class's fully
|
|
qualified name (ex. my_module.MyClass).
|
|
|
|
Args:
|
|
cls (type): The class to register as an opaque type.
|
|
typ (str): Either "reference" or "value". See Note [Opaque Objects] for
|
|
more details.
|
|
"""
|
|
import torch.utils._pytree as pytree
|
|
|
|
# Prevent registration of built-in types (int, str, list, dict, etc.) and torch.Tensor
|
|
if cls.__module__ == "builtins" or cls is torch.Tensor:
|
|
raise ValueError(
|
|
f"Unable to register built-in type {cls} as an opaque type. "
|
|
"Please wrap it in a custom class and register the custom class as opaque."
|
|
)
|
|
|
|
if cls in pytree.SUPPORTED_NODES:
|
|
raise ValueError(
|
|
f"{cls} cannot be registered as an opaque object as it has been "
|
|
"registered as a pytree. Opaque objects must be pytree leaves."
|
|
)
|
|
|
|
assert typ in ["reference", "value"], (
|
|
"Opaque type must be either 'reference' or 'value'"
|
|
)
|
|
|
|
if typ == "value":
|
|
if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap]
|
|
raise TypeError(
|
|
f"Value-type opaque object of type {cls} is "
|
|
"expected to have a non-default `__eq__` "
|
|
"implementation as we will use this in torch.compile "
|
|
"to guard on the equality of objects."
|
|
)
|
|
|
|
# Class with a custom `__eq__` without `__hash__` won't inherit the default
|
|
# `__hash__` from object; see https://stackoverflow.com/a/1608907.
|
|
if cls.__hash__ is None: # type: ignore[comparison-overlap]
|
|
raise TypeError(
|
|
f"Value-type opaque object of type {cls} is "
|
|
"expected to have a non-default `__hash__` "
|
|
"implementation as we will use this in torch.compile "
|
|
"for FakeTensor caching."
|
|
)
|
|
|
|
if cls.__repr__ is object.__repr__: # type: ignore[comparison-overlap]
|
|
raise TypeError(
|
|
f"Value-type opaque object of type {cls} is "
|
|
"expected to have a non-default `__repr__` "
|
|
"implementation as we will use this to reconstruct "
|
|
"the object in the FX codegen."
|
|
)
|
|
|
|
# Generate a fully qualified name by combining module and qualname
|
|
name = f"{cls.__module__}.{cls.__qualname__}"
|
|
|
|
type_info = _OpaqueTypeInfo(name, typ)
|
|
_OPAQUE_TYPES[cls] = type_info
|
|
_OPAQUE_TYPES_BY_NAME[name] = type_info
|
|
|
|
torch._C._register_opaque_type(name)
|
|
|
|
|
|
def is_opaque_type(cls: Any) -> bool:
|
|
"""
|
|
Checks if the given type is an opaque type.
|
|
"""
|
|
if isinstance(cls, str):
|
|
return torch._C._is_opaque_type_registered(cls)
|
|
|
|
if cls not in _OPAQUE_TYPES:
|
|
return False
|
|
|
|
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls].class_name)
|
|
|
|
|
|
def is_opaque_value_type(cls: Any) -> bool:
|
|
"""
|
|
Checks if the given type is an opaque **value** type.
|
|
See Note [Opaque Objects] for more information.
|
|
"""
|
|
if not is_opaque_type(cls):
|
|
return False
|
|
|
|
if isinstance(cls, str):
|
|
return _OPAQUE_TYPES_BY_NAME[cls].opaque_typ == "value"
|
|
|
|
return _OPAQUE_TYPES[cls].opaque_typ == "value"
|
|
|
|
|
|
def is_opaque_reference_type(cls: Any) -> bool:
|
|
"""
|
|
Checks if the given type is an opaque **reference** type.
|
|
See Note [Opaque Objects] for more information.
|
|
"""
|
|
if not is_opaque_type(cls):
|
|
return False
|
|
|
|
if isinstance(cls, str):
|
|
return _OPAQUE_TYPES_BY_NAME[cls].opaque_typ == "reference"
|
|
|
|
return _OPAQUE_TYPES[cls].opaque_typ == "reference"
|