from enum import Enum from typing import Optional import torch class EffectType(Enum): ORDERED = "Ordered" from torch._library.utils import RegistrationHandle # These classes do not have side effects as they just store quantization # params, so we dont need to mark them as ordered skip_classes = ( "__torch__.torch.classes.quantized.Conv2dPackedParamsBase", "__torch__.torch.classes.quantized.Conv3dPackedParamsBase", "__torch__.torch.classes.quantized.EmbeddingPackedParamsBase", "__torch__.torch.classes.quantized.LinearPackedParamsBase", "__torch__.torch.classes.xnnpack.Conv2dOpContext", "__torch__.torch.classes.xnnpack.LinearOpContext", "__torch__.torch.classes.xnnpack.TransposeConv2dOpContext", ) class EffectHolder: """A holder where one can register an effect impl to.""" def __init__(self, qualname: str): self.qualname: str = qualname self._set_default_effect() def _set_default_effect(self) -> None: self._effect: Optional[EffectType] = None # If the op contains a ScriptObject input, we want to mark it as having effects namespace, opname = torch._library.utils.parse_namespace(self.qualname) split = opname.split(".") if len(split) > 1: assert len(split) == 2, ( f"Tried to split {opname} based on '.' but found more than 1 '.'" ) opname, overload = split else: overload = "" if namespace == "higher_order": return opname = f"{namespace}::{opname}" if torch._C._get_operation_overload(opname, overload) is not None: # Since we call this when destroying the library, sometimes the # schema will be gone already at that time. schema = torch._C._get_schema(opname, overload) for arg in schema.arguments: if isinstance(arg.type, torch.ClassType): type_str = arg.type.str() # pyrefly: ignore[missing-attribute] if type_str in skip_classes: continue self._effect = EffectType.ORDERED return @property def effect(self) -> Optional[EffectType]: return self._effect @effect.setter def effect(self, _): raise RuntimeError("Unable to directly set kernel.") def register(self, effect: Optional[EffectType]) -> RegistrationHandle: """Register an effect Returns a RegistrationHandle that one can use to de-register this effect. """ self._effect = effect def deregister_effect(): self._set_default_effect() handle = RegistrationHandle(deregister_effect) return handle