You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
2.7 KiB

from enum import Enum
from typing import Optional
import torch
class EffectType(Enum):
ORDERED = "Ordered"
from torch._library.utils import RegistrationHandle
# These classes do not have side effects as they just store quantization
# params, so we dont need to mark them as ordered
skip_classes = (
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase",
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase",
"__torch__.torch.classes.quantized.EmbeddingPackedParamsBase",
"__torch__.torch.classes.quantized.LinearPackedParamsBase",
"__torch__.torch.classes.xnnpack.Conv2dOpContext",
"__torch__.torch.classes.xnnpack.LinearOpContext",
"__torch__.torch.classes.xnnpack.TransposeConv2dOpContext",
)
class EffectHolder:
"""A holder where one can register an effect impl to."""
def __init__(self, qualname: str):
self.qualname: str = qualname
self._set_default_effect()
def _set_default_effect(self) -> None:
self._effect: Optional[EffectType] = None
# If the op contains a ScriptObject input, we want to mark it as having effects
namespace, opname = torch._library.utils.parse_namespace(self.qualname)
split = opname.split(".")
if len(split) > 1:
assert len(split) == 2, (
f"Tried to split {opname} based on '.' but found more than 1 '.'"
)
opname, overload = split
else:
overload = ""
if namespace == "higher_order":
return
opname = f"{namespace}::{opname}"
if torch._C._get_operation_overload(opname, overload) is not None:
# Since we call this when destroying the library, sometimes the
# schema will be gone already at that time.
schema = torch._C._get_schema(opname, overload)
for arg in schema.arguments:
if isinstance(arg.type, torch.ClassType):
type_str = arg.type.str() # pyrefly: ignore[missing-attribute]
if type_str in skip_classes:
continue
self._effect = EffectType.ORDERED
return
@property
def effect(self) -> Optional[EffectType]:
return self._effect
@effect.setter
def effect(self, _):
raise RuntimeError("Unable to directly set kernel.")
def register(self, effect: Optional[EffectType]) -> RegistrationHandle:
"""Register an effect
Returns a RegistrationHandle that one can use to de-register this
effect.
"""
self._effect = effect
def deregister_effect():
self._set_default_effect()
handle = RegistrationHandle(deregister_effect)
return handle