""" Dictionary-related variable tracking classes for PyTorch Dynamo. This module implements variable tracking for different types of dictionary-like objects: - Regular Python dictionaries (dict) - Ordered dictionaries (collections.OrderedDict) - Default dictionaries (collections.defaultdict) - Dictionary views (keys and values) - Sets and frozensets (implemented internally using dictionaries) These classes are responsible for tracking dictionary operations during graph compilation, maintaining proper guards for dictionary mutations and key existence checks. They handle dictionary creation, modification, key/value access, and view operations while ensuring correct behavior in the compiled code through appropriate guard installation. The implementation uses a special _HashableTracker wrapper to handle dictionary keys while preserving proper aliasing semantics. Sets are implemented as dictionaries with None values for efficiency and code reuse. """ import collections import functools import operator import types from collections.abc import Sequence from typing import Any, Optional, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard from ..source import is_constant_source, is_from_local_source from ..utils import ( cmp_name_to_op_mapping, dict_items, dict_keys, dict_values, istype, raise_args_mismatch, specialize_symnode, ) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable from .lists import ListIteratorVariable if TYPE_CHECKING: from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator from .functions import UserFunctionVariable # [Adding a new supported class within the keys of ConstDictVariable] # - Implement is_python_hashable() method in the VariableTracker subclass # - Implement get_python_hash() and is_python_equal() methods for hashable types def was_instancecheck_override(obj: Any) -> bool: return type(obj).__dict__.get("__instancecheck__", False) def raise_unhashable( arg: VariableTracker, tx: Optional["InstructionTranslator"] = None ) -> None: if tx is None: from torch._dynamo.symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() try: arg_type = arg.python_type() except Exception: arg_type = type(arg) raise_observed_exception( TypeError, tx, args=[ ConstantVariable( f"unhashable type: {arg_type!r} and variable tracker = {type(arg.realize())}" ) ], ) def is_hashable(x: VariableTracker) -> bool: # NB - performing isinstance check on a LazVT realizes the VT, accidentally # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at # the underlying value without realizing the VT. Consider updating the # lazyVT `is_hashable` method if you see unnecessary guarding for a key VT. if ( isinstance(x, variables.LazyVariableTracker) and not x.is_realized() and x.is_hashable() ): return True return x.is_python_hashable() class ConstDictVariable(VariableTracker): CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS _nonvar_fields = { "user_cls", *VariableTracker._nonvar_fields, } class _HashableTracker: """ Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable This should not be seen or touched by anything outside of ConstDictVariable and its children Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing """ def __init__(self, vt: VariableTracker) -> None: # We specialize SymNodes vt = specialize_symnode(vt) # If Dynamo does not know the hashability of the vt, it will raise unsupported here if not is_hashable(vt): raise_unhashable(vt) self.vt = vt def __hash__(self) -> int: """ Computes the hash value for the wrapped VariableTracker. For unrealized LazyVariableTrackers, uses the hash of the original value to avoid realizing the tracker and inserting unnecessary guards. For all other cases, delegates to the VariableTracker's get_python_hash method. Returns: The hash value of the underlying variable tracker """ if ( isinstance(self.vt, variables.LazyVariableTracker) and not self.vt.is_realized() and self.vt.is_hashable() ): return hash(self.vt.original_value()) return self.vt.get_python_hash() def __eq__(self, other) -> bool: """ Checks equality between two _HashableTracker instances. Delegates to the VariableTracker's is_python_equal method to compare the underlying variable trackers for Python-level equality. Args: other: Another _HashableTracker instance to compare with Returns: True if the underlying variable trackers are Python-equal, False otherwise """ if self.vt is other.vt: return True return self.vt.is_python_equal(other.vt) def __init__( self, items: dict[VariableTracker, VariableTracker], user_cls: type = dict, **kwargs: Any, ) -> None: # .clone() pass these arguments in kwargs but they're recreated a few # lines below if "original_items" in kwargs: kwargs.pop("original_items") if "should_reconstruct_all" in kwargs: kwargs.pop("should_reconstruct_all") super().__init__(**kwargs) Hashable = ConstDictVariable._HashableTracker # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers assert all( isinstance(x, (VariableTracker, Hashable)) and isinstance(v, VariableTracker) for x, v in items.items() ) def make_hashable( key: Union[VariableTracker, "ConstDictVariable._HashableTracker"], ) -> "ConstDictVariable._HashableTracker": return key if isinstance(key, Hashable) else Hashable(key) dict_cls = self._get_dict_cls_from_user_cls(user_cls) self.items = dict_cls({make_hashable(x): v for x, v in items.items()}) # need to reconstruct everything if the dictionary is an intermediate value # or if a pop/delitem was executed self.should_reconstruct_all = ( not is_from_local_source(self.source) if self.source else True ) self.original_items = items.copy() self.user_cls = user_cls def _get_dict_cls_from_user_cls(self, user_cls: type) -> type: accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict) # avoid executing user code if user_cls is a dict subclass if user_cls in accepted_dict_types: dict_cls = user_cls else: # dict_cls = next( base for base in user_cls.__mro__ if base in accepted_dict_types ) assert dict_cls in accepted_dict_types, dict_cls # Use a dict instead as the call "defaultdict({make_hashable(x): v ..})" # would fail as defaultdict expects a callable as first argument if dict_cls is collections.defaultdict: dict_cls = dict return dict_cls def as_proxy(self) -> dict[Any, Any]: return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} def debug_repr(self) -> str: return ( "{" + ", ".join( f"{k.vt.debug_repr()}: {v.debug_repr()}" for k, v in self.items.items() ) + "}" ) def as_python_constant(self) -> dict[Any, Any]: return { k.vt.as_python_constant(): v.as_python_constant() for k, v in self.items.items() } def keys_as_python_constant(self) -> dict[Any, VariableTracker]: self.install_dict_keys_match_guard() return {k.vt.as_python_constant(): v for k, v in self.items.items()} def python_type(self) -> type: return self.user_cls def __contains__(self, vt: VariableTracker) -> bool: assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker return ( vt.is_python_hashable() and Hashable(vt) in self.items and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) ) def call_tree_map_branch( self, tx: "InstructionTranslator", tree_map_fn: "UserFunctionVariable", map_fn: VariableTracker, rest: Sequence[VariableTracker], tree_map_kwargs: dict[str, VariableTracker], ) -> VariableTracker: other_dicts: list[ConstDictVariable] = [] for candidate in rest: candidate = candidate.realize() if not isinstance(candidate, ConstDictVariable) or len( candidate.items ) != len(self.items): return self._tree_map_fallback( tx, tree_map_fn, map_fn, rest, tree_map_kwargs ) other_dicts.append(candidate) new_items_hashed = type(self.items)() for key_tracker, value in self.items.items(): sibling_leaves: list[VariableTracker] = [] for candidate in other_dicts: try: sibling_leaves.append(candidate.items[key_tracker]) except KeyError: return self._tree_map_fallback( tx, tree_map_fn, map_fn, rest, tree_map_kwargs ) new_items_hashed[key_tracker] = value.call_tree_map( tx, tree_map_fn, map_fn, sibling_leaves, tree_map_kwargs, ) updated_original_items = { key_tracker.vt: new_items_hashed[key_tracker] for key_tracker in new_items_hashed } return self.clone( items=new_items_hashed, original_items=updated_original_items, should_reconstruct_all=True, source=None, mutation_type=ValueMutationNew(), ) def len(self) -> int: return sum( not isinstance(x, variables.DeletedVariable) for x in self.items.values() ) def has_new_items(self) -> bool: return self.should_reconstruct_all or any( self.is_new_item(self.original_items.get(key.vt), value) for key, value in self.items.items() ) def is_new_item( self, value: Optional[VariableTracker], other: VariableTracker ) -> bool: # compare the id of the realized values if both values are not lazy VTs if value and value.is_realized() and other.is_realized(): return id(value.realize()) != id(other.realize()) return id(value) != id(other) def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None: # Build a dictionary that contains the keys and values. num_args = 0 for key, value in self.items.items(): # We can safely call realize() here as it won't introduce any new guards item = self.original_items.get(key.vt) if self.is_new_item(item, value) or self.should_reconstruct_all: codegen(key.vt) codegen(value) num_args += 1 codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) def reconstruct(self, codegen: "PyCodegen") -> None: if self.user_cls is collections.OrderedDict: # emit `OrderedDict(constructed_dict)` codegen.add_push_null( lambda: codegen.extend_output( [ codegen.create_load_python_module(collections), codegen.create_load_attr("OrderedDict"), ] ) ) self.reconstruct_kvs_into_new_dict(codegen) codegen.extend_output(create_call_function(1, False)) else: self.reconstruct_kvs_into_new_dict(codegen) def getitem_const_raise_exception_if_absent( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: try: error_message = ( f"Dict key lookup failed for {str(arg)}. " f"Debug representation of the key is {arg.debug_repr()!r}" ) except Exception: error_message = ConstantVariable.create( f"Dict key lookup failed for {str(arg)}" ) raise_observed_exception(KeyError, tx, args=[error_message]) return self.items[key] def getitem_const( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined] unimplemented( gb_type="key not found in dict", context=f"Key {arg.value}", # type: ignore[attr-defined] explanation=msg, hints=[ "Check if the key exists in the dictionary before accessing it.", *graph_break_hints.USER_ERROR, ], ) return self.items[key] def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: return None return self.items[key] def realize_key_vt(self, arg: VariableTracker) -> None: # Realize the LazyVT on a particular index assert arg in self key = ConstDictVariable._HashableTracker(arg) index = tuple(self.items.keys()).index(key) original_key_vt = tuple(self.original_items.keys())[index] if isinstance(original_key_vt, variables.LazyVariableTracker): original_key_vt.realize() def install_dict_keys_match_guard(self) -> None: if self.source: install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH)) def install_dict_contains_guard( self, tx: "InstructionTranslator", args: list[VariableTracker] ) -> None: # Key guarding - These are the cases to consider # 1) The dict has been mutated. In this case, we would have already # inserted a DICT_KEYS_MATCH guard, so we can skip. # # 2) args[0].source is None. This happens for const keys. Here, we # have to insert the DICT_CONTAINS guard. # # 3) args[0].source is not None. This can happen for non-const VTs. # 3a) contains=True. In this case, we can access the lazyVT from # original_items and selectively realize it. # 3b) contains=False. There is no easy way to selectively apply this # DICT_NOT_CONTAINS guard because our guard are represented via trees. # Be conservative and add DICT_KEYS_MATCH guard. if not self.source: return if tx.output.side_effects.is_modified(self): return contains = args[0] in self if args[0].source is None and args[0].is_python_constant(): install_guard( self.make_guard( functools.partial( type(self).CONTAINS_GUARD, key=args[0].as_python_constant(), invert=not contains, ) ) ) elif args[0].source: if contains: self.realize_key_vt(args[0]) else: self.install_dict_keys_match_guard() def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: # NB - Both key and value are LazyVariableTrackers in the beginning. So, # we have to insert guards when a dict method is accessed. For this to # be simple, we are conservative and overguard. We skip guard only for # get/__getitem__ because the key guard will be inserted by the # corresponding value VT. For __contains__, we add a DICT_CONTAINS # guard. But for all the other methods, we insert the DICT_KEYS_MATCH # guard to be conservative. from . import BuiltinVariable, ConstantVariable Hashable = ConstDictVariable._HashableTracker if name == "__init__": temp_dict_vt = variables.BuiltinVariable(dict).call_dict( tx, *args, **kwargs ) tx.output.side_effects.mutation(self) self.items.update(temp_dict_vt.items) # type: ignore[attr-defined] return ConstantVariable.create(None) elif name == "__getitem__": # Key guarding - Nothing to do. LazyVT for value will take care. if len(args) != 1: raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") return self.getitem_const_raise_exception_if_absent(tx, args[0]) elif name == "items": if args or kwargs: raise_args_mismatch( tx, name, "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) self.install_dict_keys_match_guard() if self.source: tx.output.guard_on_key_order.add(self.source) return DictItemsVariable(self) elif name == "keys": if len(args): raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") self.install_dict_keys_match_guard() if self.source: tx.output.guard_on_key_order.add(self.source) return DictKeysVariable(self) elif name == "values": if args or kwargs: raise_args_mismatch( tx, name, "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) self.install_dict_keys_match_guard() if self.source: tx.output.guard_on_key_order.add(self.source) if args or kwargs: raise_observed_exception(TypeError, tx) return DictValuesVariable(self) elif name == "copy": self.install_dict_keys_match_guard() if args or kwargs: raise_args_mismatch( tx, name, "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) return self.clone( items=self.items.copy(), mutation_type=ValueMutationNew(), source=None ) elif name == "__len__": if args or kwargs: raise_args_mismatch( tx, name, "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) self.install_dict_keys_match_guard() return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and self.is_mutable(): arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) self.install_dict_keys_match_guard() if kwargs or len(args) != 2: raise_args_mismatch( tx, name, "2 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) tx.output.side_effects.mutation(self) self.items[Hashable(args[0])] = args[1] return ConstantVariable.create(None) elif name == "__delitem__" and self.is_mutable(): arg_hashable = args and is_hashable(args[0]) if arg_hashable: self.install_dict_keys_match_guard() self.should_reconstruct_all = True tx.output.side_effects.mutation(self) self.items.__delitem__(Hashable(args[0])) return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) elif name == "get": if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) if args[0] not in self: self.install_dict_contains_guard(tx, args) if len(args) == 1: # if default is not given, return None return ConstantVariable.create(None) return args[1] # Key guarding - Nothing to do. return self.getitem_const(tx, args[0]) elif name == "pop" and self.is_mutable(): if len(args) not in (1, 2): raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) if args[0] not in self: # missing item, return the default value. Install no DICT_CONTAINS guard. self.install_dict_contains_guard(tx, args) if len(args) == 1: # if default is not given, raise KeyError raise_observed_exception(KeyError, tx) return args[1] self.should_reconstruct_all = True tx.output.side_effects.mutation(self) return self.items.pop(Hashable(args[0])) elif name == "popitem" and self.is_mutable(): if ( issubclass(self.user_cls, dict) and not issubclass(self.user_cls, collections.OrderedDict) and len(args) ): raise_args_mismatch(tx, name) if not self.items: msg = ConstantVariable.create("popitem(): dictionary is empty") raise_observed_exception(KeyError, tx, args=[msg]) if self.user_cls is collections.OrderedDict and ( len(args) == 1 or "last" in kwargs ): if len(args) == 1 and args[0].is_python_constant(): last = args[0].as_python_constant() elif (v := kwargs.get("last")) and v.is_python_constant(): last = v.as_python_constant() else: raise_args_mismatch(tx, name) k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined] else: k, v = self.items.popitem() self.should_reconstruct_all = True tx.output.side_effects.mutation(self) return variables.TupleVariable([k.vt, v]) elif name == "clear": if args or kwargs: raise_args_mismatch( tx, name, "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) self.should_reconstruct_all = True tx.output.side_effects.mutation(self) self.items.clear() return ConstantVariable.create(None) elif name == "update" and self.is_mutable(): # In general, this call looks like `a.update(b, x=1, y=2, ...)`. # Either `b` or the kwargs is omittable, but not both. self.install_dict_keys_match_guard() has_arg = len(args) == 1 has_kwargs = len(kwargs) > 0 if has_arg or has_kwargs: tx.output.side_effects.mutation(self) if has_arg: if isinstance(args[0], ConstDictVariable): # NB - Guard on all the keys of the other dict to ensure # correctness. args[0].install_dict_keys_match_guard() dict_vt: ConstDictVariable = args[0] else: dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment] self.items.update(dict_vt.items) # type: ignore[attr-defined] if has_kwargs: # Handle kwargs kwargs_hashable = { Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() } self.items.update(kwargs_hashable) return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) elif name == "__contains__": if not len(args): raise_args_mismatch( tx, name, "more than 1 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) self.install_dict_contains_guard(tx, args) contains = args[0] in self return ConstantVariable.create(contains) elif name == "setdefault" and self.is_mutable(): if len(args) not in (1, 2): raise_args_mismatch( tx, name, "1 or 2 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) arg_hashable = args and is_hashable(args[0]) if not arg_hashable: raise_unhashable(args[0], tx) self.install_dict_keys_match_guard() if kwargs or len(args) > 2: raise_args_mismatch( tx, name, "at most 2 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) value = self.maybe_getitem_const(args[0]) if value is not None: return value else: if len(args) == 1: x = ConstantVariable.create(None) else: x = args[1] tx.output.side_effects.mutation(self) self.items[Hashable(args[0])] = x return x elif name == "move_to_end": self.install_dict_keys_match_guard() tx.output.side_effects.mutation(self) if args[0] not in self: raise_observed_exception(KeyError, tx) last = True if len(args) == 2 and args[1].is_python_constant(): last = args[1].as_python_constant() if kwargs and "last" in kwargs and kwargs["last"].is_python_constant(): last = kwargs.get("last").as_python_constant() # type: ignore[union-attr] key = Hashable(args[0]) self.items.move_to_end(key, last=last) return ConstantVariable.create(None) elif name == "__eq__" and istype( self, ConstDictVariable ): # don't let Set use this function if len(args) != 1: raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") return variables.UserFunctionVariable(polyfills.dict___eq__).call_function( tx, [self, args[0]], {} ) elif name == "__ne__": return ConstantVariable.create( not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined] ) elif name == "__or__": if len(args) != 1: raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") other = args[0] # Method resolution for binops works as follow (using __or__ as example): # (1) dict.__or__(dict) => dict # (2) dict.__or__(subclass): return NotImplemented # (3) Check if subclass implements __ror__ => forward the call # to subclass.__ror__(dict) # Let's not forward the call to __ror__ yet because __ror__ can be # implemented in C (i.e. OrderedDict subclass) which Dynamo cannot # trace # if istype(other, variables.UserDefinedDictVariable): # if other.call_obj_hasattr(tx, "__ror__").value: # return other.call_method(tx, "__ror__", [self], kwargs) # The three dict types Dynamo can handle are dict, OrderedDict and # defaultdict. # TODO(guilhermeleobas): this check should be on builtin.py::call_or_ if not istype( other, (ConstDictVariable, variables.UserDefinedDictVariable) ): err_msg = ( f"unsupported operand type(s) for |: '{self.python_type().__name__}'" f"and '{other.python_type().__name__}'" ) raise_observed_exception(TypeError, tx, args=[err_msg]) # OrderedDict overloads __ror__ ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined] user_cls = ( collections.OrderedDict if any(issubclass(t, collections.OrderedDict) for t in ts) else dict ) self.install_dict_keys_match_guard() new_dict_vt = self.clone( items=self.items.copy(), mutation_type=ValueMutationNew(), source=None, user_cls=user_cls, ) # NB - Guard on all the keys of the other dict to ensure # correctness. args[0].install_dict_keys_match_guard() # type: ignore[attr-defined] new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined] return new_dict_vt elif name == "__ior__": self.call_method(tx, "update", args, kwargs) return self elif name == "__iter__": if self.source and not is_constant_source(self.source): tx.output.guard_on_key_order.add(self.source) return ListIteratorVariable( self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() ) else: return super().call_method(tx, name, args, kwargs) def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: self.install_dict_keys_match_guard() return [x.vt for x in self.items] def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> ConstantVariable: # dict not allow setting arbitrary attributes. OrderedDict and # defaultdict allow arbitrary setattr, but not deletion of default attrs if any( self.user_cls is t for t in (dict, collections.OrderedDict, collections.defaultdict) ): if hasattr(self.user_cls, name): return ConstantVariable.create(True) if self.user_cls is dict: return ConstantVariable.create(False) msg = f"hasattr on {self.user_cls} is not supported" unimplemented( gb_type="unsupported hasattr operation", context=f"Class {self.user_cls}", explanation=msg, hints=[ "Consider using a regular dictionary instead", *graph_break_hints.SUPPORTABLE, ], ) def clone(self, **kwargs: Any) -> VariableTracker: self.install_dict_keys_match_guard() return super().clone(**kwargs) def is_python_hashable(self): """ Dictionaries are mutable and therefore not hashable in Python. """ return False class MappingProxyVariable(VariableTracker): # proxies to the original dict_vt def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: super().__init__(**kwargs) assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict def python_type(self) -> type: return types.MappingProxyType def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return self.dv_dict.unpack_var_sequence(tx) def reconstruct(self, codegen: "PyCodegen") -> None: # load types.MappingProxyType if self.source: msg = ( f"Preexisting MappingProxyVariable (source: {self.source}) cannot be reconstructed " "because the connection to the original dict will be lost." ) unimplemented( gb_type="mapping proxy cannot be reconstructed", context=f"Source: {self.source}", explanation=msg, hints=[ "Use a mapping proxy constructed in the same `torch.compile` region.", *graph_break_hints.SUPPORTABLE, ], ) codegen.add_push_null( lambda: codegen.extend_output( [ codegen.create_load_python_module(types), codegen.create_load_attr("MappingProxyType"), ] ) ) codegen(self.dv_dict) codegen.extend_output(create_call_function(1, False)) def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: if self.source and tx.output.side_effects.has_existing_dict_mutation(): msg = ( "A dict has been modified while we have an existing mappingproxy object. " "A mapping proxy object, as the name suggest, proxies a mapping " "object (usually a dict). If the original dict object mutates, it " "is reflected in the proxy object as well. For an existing proxy " "object, we do not know the original dict it points to. Therefore, " "for correctness we graph break when there is dict mutation and we " "are trying to access a proxy object." ) unimplemented( gb_type="mapping proxy affected by dictionary mutation", context=f"Source: {self.source}, Dict mutation detected", explanation=msg, hints=[ "Avoid modifying dictionaries that might be referenced by mapping proxy objects", "Or avoid using the mapping proxy objects after modifying its underlying dictionary", ], ) return self.dv_dict.call_method(tx, name, args, kwargs) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> ConstantVariable: if self.python_type() is types.MappingProxyType: return ConstantVariable.create(name in types.MappingProxyType.__dict__) return super().call_obj_hasattr(tx, name) class NNModuleHooksDictVariable(ConstDictVariable): # Special class to avoid adding any guards on the nn module hook ids. def install_dict_keys_match_guard(self) -> None: pass def install_dict_contains_guard( self, tx: "InstructionTranslator", args: list[VariableTracker] ) -> None: pass class DefaultDictVariable(ConstDictVariable): def __init__( self, items: dict[VariableTracker, VariableTracker], user_cls: type, default_factory: Optional[VariableTracker] = None, **kwargs: Any, ) -> None: super().__init__(items, user_cls, **kwargs) assert user_cls is collections.defaultdict if default_factory is None: default_factory = ConstantVariable.create(None) self.default_factory = default_factory def is_python_constant(self) -> bool: # Return false for unsupported defaults. This ensures that a bad handler # path is not taken in BuiltinVariable for getitem. if self.default_factory not in [list, tuple, dict] and not self.items: return False return super().is_python_constant() def debug_repr(self) -> str: assert self.default_factory is not None return ( f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" ) @staticmethod def is_supported_arg(arg: VariableTracker) -> bool: if isinstance(arg, variables.BuiltinVariable): return arg.fn in (list, tuple, dict, set) else: return isinstance( arg, ( variables.functions.BaseUserFunctionVariable, variables.functions.PolyfilledFunctionVariable, ), ) def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: if name == "__getitem__": if len(args) != 1: raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") if args[0] in self: return self.getitem_const(tx, args[0]) else: if ( istype(self.default_factory, ConstantVariable) and self.default_factory.value is None ): raise_observed_exception(KeyError, tx, args=[args[0]]) else: default_var = self.default_factory.call_function(tx, [], {}) super().call_method( tx, "__setitem__", [args[0], default_var], kwargs ) return default_var else: return super().call_method(tx, name, args, kwargs) def reconstruct(self, codegen: "PyCodegen") -> None: # emit `defaultdict(default_factory, new_dict)` codegen.add_push_null( lambda: codegen.extend_output( [ codegen.create_load_python_module(collections), codegen.create_load_attr("defaultdict"), ] ) ) codegen(self.default_factory) self.reconstruct_kvs_into_new_dict(codegen) codegen.extend_output(create_call_function(2, False)) # TODO: Implementing this via inheritance rather than composition is a # footgun, because self method calls in dict will route back to the set # implementation, which is almost assuredly wrong class SetVariable(ConstDictVariable): """We model a sets as dictionary with None values""" CONTAINS_GUARD = GuardBuilder.SET_CONTAINS def __init__( self, items: list[VariableTracker], **kwargs: Any, ) -> None: # pyrefly: ignore[bad-assignment] items = dict.fromkeys(items, SetVariable._default_value()) # pyrefly: ignore[bad-argument-type] super().__init__(items, **kwargs) def debug_repr(self) -> str: if not self.items: return "set()" else: return "{" + ",".join(k.vt.debug_repr() for k in self.items) + "}" @property def set_items(self) -> set["ConstDictVariable._HashableTracker"]: return set(self.items.keys()) @staticmethod def _default_value() -> VariableTracker: # Variable to fill in he keys of the dictionary return ConstantVariable.create(None) def as_proxy(self) -> Any: return {k.vt.as_proxy() for k in self.set_items} def python_type(self) -> type: return set def as_python_constant(self) -> Any: return {k.vt.as_python_constant() for k in self.set_items} def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach([x.vt for x in self.set_items]) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) def _fast_set_method( self, tx: "InstructionTranslator", fn: Any, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: try: res = fn( *[x.as_python_constant() for x in [self, *args]], **{k: v.as_python_constant() for k, v in kwargs.items()}, ) except Exception as exc: raise_observed_exception( type(exc), tx, args=list(map(ConstantVariable.create, exc.args)) ) # pyrefly: ignore[unbound-name] return VariableTracker.build(tx, res) def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: # We forward the calls to the dictionary model from ..utils import check_constant_args if ( name in ( "isdisjoint", "union", "intersection", "difference", "symmetric_difference", ) and check_constant_args(args, kwargs) and self.python_type() is set ): py_type = self.python_type() return self._fast_set_method(tx, getattr(py_type, name), args, kwargs) if name == "__init__": temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs) tx.output.side_effects.mutation(self) self.items.clear() self.items.update(temp_set_vt.items) # type: ignore[attr-defined] return ConstantVariable.create(None) elif name == "add": if kwargs or len(args) != 1: raise_args_mismatch( tx, name, "1 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) name = "__setitem__" args = [args[0], SetVariable._default_value()] elif name == "pop": if kwargs or args: raise_args_mismatch( tx, name, "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) # Choose an item at random and pop it via the Dict.pop method try: result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment] except KeyError as e: raise_observed_exception( KeyError, tx, args=list(map(ConstantVariable.create, e.args)) ) # pyrefly: ignore[unbound-name] super().call_method(tx, name, [result], kwargs) # pyrefly: ignore[unbound-name] return result elif name == "isdisjoint": if kwargs or len(args) != 1: raise_args_mismatch( tx, name, "1 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) return variables.UserFunctionVariable( polyfills.set_isdisjoint ).call_function(tx, [self, args[0]], {}) elif name == "intersection": if kwargs: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") return variables.UserFunctionVariable( polyfills.set_intersection ).call_function(tx, [self, *args], {}) elif name == "intersection_update": if kwargs: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") return variables.UserFunctionVariable( polyfills.set_intersection_update ).call_function(tx, [self, *args], {}) elif name == "union": if kwargs: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") return variables.UserFunctionVariable(polyfills.set_union).call_function( tx, [self, *args], {} ) elif name == "difference": if kwargs: raise_args_mismatch( tx, name, f"Expect: 0 kwargs, Actual: {len(kwargs)} kwargs" ) return variables.UserFunctionVariable( polyfills.set_difference ).call_function(tx, [self, *args], {}) elif name == "difference_update": if kwargs: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") return variables.UserFunctionVariable( polyfills.set_difference_update ).call_function(tx, [self, *args], {}) elif name == "symmetric_difference": if kwargs or len(args) != 1: raise_args_mismatch( tx, name, "1 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) return variables.UserFunctionVariable( polyfills.set_symmetric_difference ).call_function(tx, [self, *args], {}) elif name == "symmetric_difference_update": if kwargs or len(args) != 1: raise_args_mismatch( tx, name, "1 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) return variables.UserFunctionVariable( polyfills.set_symmetric_difference_update ).call_function(tx, [self, *args], {}) elif name == "update" and self.is_mutable(): if kwargs: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") return variables.UserFunctionVariable(polyfills.set_update).call_function( tx, [self, *args], {} ) elif name == "remove": if kwargs or len(args) != 1: raise_args_mismatch( tx, name, "1 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) if args[0] not in self: raise_observed_exception(KeyError, tx, args=args) return super().call_method(tx, "pop", args, kwargs) elif name == "discard": if kwargs or len(args) != 1: raise_args_mismatch( tx, name, "1 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) if args[0] in self: return super().call_method(tx, "pop", args, kwargs) else: return ConstantVariable.create(value=None) elif name in ("issubset", "issuperset"): if len(args) != 1: raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") op = { "issubset": operator.le, "issuperset": operator.ge, } other = args[0].realize() if not istype(other, SetVariable): other = variables.BuiltinVariable(set).call_function(tx, [other], {}) return variables.BuiltinVariable(op.get(name)).call_function( tx, [self, other], {} ) elif name in ("__and__", "__or__", "__xor__", "__sub__"): m = { "__and__": "intersection", "__or__": "union", "__xor__": "symmetric_difference", "__sub__": "difference", }.get(name) if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): msg = ConstantVariable.create( f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" ) raise_observed_exception(TypeError, tx, args=[msg]) assert m is not None return self.call_method(tx, m, args, kwargs) elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"): if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): msg = ConstantVariable.create( f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" ) raise_observed_exception(TypeError, tx, args=[msg]) m = { "__iand__": "intersection_update", "__ior__": "update", "__ixor__": "symmetric_difference_update", "__isub__": "difference_update", }.get(name) assert m is not None self.call_method(tx, m, args, kwargs) return self elif name == "__eq__": if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): return ConstantVariable.create(False) r = self.call_method(tx, "symmetric_difference", args, kwargs) return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined] elif name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): return ConstantVariable.create(NotImplemented) return ConstantVariable.create( cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] ) return super().call_method(tx, name, args, kwargs) def getitem_const( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker: raise RuntimeError("Illegal to getitem on a set") def install_dict_keys_match_guard(self) -> None: # Already EQUALS_MATCH guarded pass class FrozensetVariable(SetVariable): def debug_repr(self) -> str: if not self.items: return "frozenset()" else: return "{" + ",".join(k.vt.debug_repr() for k in self.items) + "}" @property def set_items(self) -> set["ConstDictVariable._HashableTracker"]: return self.items.keys() def python_type(self) -> type: return frozenset def as_python_constant(self) -> Any: return frozenset({k.vt.as_python_constant() for k in self.set_items}) def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach([x.vt for x in self.set_items]) codegen.add_push_null( lambda: codegen.extend_output( [ codegen.create_load_global("frozenset"), ] ) ) codegen.extend_output(create_call_function(0, False)) def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: if name in ["add", "pop", "update", "remove", "discard", "clear"]: raise RuntimeError(f"Illegal call_method {name} on a frozenset") elif name == "__init__": # frozenset is immutable. Calling __init__ again shouldn't have any effect # In[1]: s = frozenset([1, 2]) # # In[2]: s.__init__([3, 4]) # # In[3]: s # frozenset({1, 2}) return ConstantVariable.create(None) elif name in ( "copy", "difference", "intersection", "symmetric_difference", ): r = super().call_method(tx, name, args, kwargs) return FrozensetVariable(r.items) # type: ignore[attr-defined] return super().call_method(tx, name, args, kwargs) def is_python_hashable(self): """ Frozensets are immutable and hashable in Python. """ return True def get_python_hash(self): return hash(self.as_python_constant()) def is_python_equal(self, other): return self.as_python_constant() == other.as_python_constant() class DictKeySetVariable(SetVariable): def debug_repr(self) -> str: if not self.items: return "dict_keys([])" else: return ( "dict_keys([" + ",".join(k.vt.debug_repr() for k in self.items) + "])" ) def install_dict_keys_match_guard(self) -> None: # Already EQUALS_MATCH guarded pass def install_dict_contains_guard( self, tx: "InstructionTranslator", args: list[VariableTracker] ) -> None: # Already EQUALS_MATCH guarded pass @property def set_items(self) -> Any: return self.items def python_type(self) -> type: return dict_keys def as_python_constant(self) -> Any: return dict.fromkeys( {k.vt.as_python_constant() for k in self.set_items}, None ).keys() def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: if name in ["add", "pop", "update", "remove", "discard", "clear"]: raise RuntimeError(f"Illegal call_method {name} on a dict_keys") return super().call_method(tx, name, args, kwargs) class DictViewVariable(VariableTracker): """ Models _PyDictViewObject This is an "abstract" class. Subclasses will override kv and the items method """ kv: Optional[str] = None def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: super().__init__(**kwargs) assert self.kv in ("keys", "values", "items") assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict @property def view_items(self) -> Any: assert self.kv is not None return getattr(self.dv_dict.items, self.kv)() @property def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items # Implement in the subclasses raise NotImplementedError def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return self.view_items_vt def reconstruct(self, codegen: "PyCodegen") -> None: assert self.kv is not None codegen(self.dv_dict) codegen.load_method(self.kv) codegen.call_method(0) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> ConstantVariable: assert self.kv is not None if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: if name == "__len__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name == "__iter__": return ListIteratorVariable( self.view_items_vt, mutation_type=ValueMutationNew() ) return super().call_method(tx, name, args, kwargs) class DictKeysVariable(DictViewVariable): kv = "keys" @property def set_items(self) -> set[VariableTracker]: return set(self.view_items) @property def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items return [x.vt for x in self.view_items] def python_type(self) -> type: return dict_keys def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: if name == "__contains__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name in ( "__and__", "__iand__", "__or__", "__ior__", "__sub__", "__isub__", "__xor__", "__ixor__", ): # These methods always returns a set m = getattr(self.set_items, name) r = m(args[0].set_items) # type: ignore[attr-defined] return SetVariable(r) if name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, DictKeysVariable)): return ConstantVariable.create(NotImplemented) return ConstantVariable.create( cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] ) return super().call_method(tx, name, args, kwargs) class DictValuesVariable(DictViewVariable): # DictValuesVariable is an iterable but cannot be compared. kv = "values" @property def view_items_vt(self) -> list[VariableTracker]: return list(self.view_items) def python_type(self) -> type: return dict_values class DictItemsVariable(DictViewVariable): kv = "items" @property def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items] def python_type(self) -> type: return dict_items def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: # TODO(guilhermeleobas): This should actually check if args[0] # implements the mapping protocol. if name == "__eq__": if len(args) != 1: raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") if isinstance(args[0], DictItemsVariable): return self.dv_dict.call_method(tx, "__eq__", [args[0].dv_dict], {}) return ConstantVariable.create(False) return super().call_method(tx, name, args, kwargs) def is_python_hashable(self): """ Dictionary item views are not hashable in Python. """ return False