import collections from collections.abc import Callable from typing import Any, Optional import torch from torch._dynamo.variables.dicts import ConstDictVariable from torch._dynamo.variables.lists import TupleVariable from torch.fx import has_side_effect, Proxy from .. import graph_break_hints from ..bytecode_transformation import create_call_function from ..exc import TYPE_CHECKING, unimplemented from ..graph_bytecode_inputs import ( get_external_object_by_index, register_graph_created_object, ) from ..source import CurrentStreamSource from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import FxTracebackAnnotateVariable from .lazy import LazyVariableTracker if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator from ..codegen import PyCodegen from torch._library.custom_ops import custom_op Tensor = torch.Tensor def new_event(*args: Any, **kwargs: Any) -> int: event = torch.Event(*args, **kwargs) return register_graph_created_object( event, EventVariable.make_construct_in_graph_event_fn( TupleVariable([]), ConstDictVariable({}) ), ) def new_stream(*args: tuple[Any], **kwargs: Any) -> int: stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload] return register_graph_created_object( stream, StreamVariable.make_construct_in_graph_stream_fn( TupleVariable([]), ConstDictVariable({}) ), ) def _codegen_current_stream(device: torch.device, cg: "PyCodegen") -> None: cg.add_push_null( lambda: cg.load_import_from( torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports] "stash_graph_created_object", ) ) cg(CurrentStreamSource(device)) cg.extend_output(create_call_function(1, False)) def get_current_stream(device: torch.device) -> int: stream = torch.accelerator.current_stream(device) return register_graph_created_object( stream, lambda _, cg: _codegen_current_stream(device, cg) ) def _get_stream_by_index(index: int) -> torch.Stream: stream = get_external_object_by_index(index) assert isinstance(stream, torch.Stream), ( f"Fork/join stream expected a stream object at index {index}" ) return stream def _get_event_by_index(index: int) -> torch.Event: event = get_external_object_by_index(index) assert isinstance(event, torch.Event), ( f"Record/wait event expected an event object at index {index}" ) return event @custom_op("streams::fork", mutates_args=()) def fork_stream( from_index: int, # kept to make stream transitions clearer to_index: int, ) -> None: torch.accelerator.set_stream(_get_stream_by_index(to_index)) @fork_stream.register_fake def _( from_index: int, # kept to make stream transitions clearer to_index: int, ) -> None: pass has_side_effect(torch.ops.streams.fork.default) @custom_op("streams::join", mutates_args=()) def join_stream(from_index: int, to_index: int) -> None: torch.accelerator.set_stream(_get_stream_by_index(to_index)) @join_stream.register_fake def _( from_index: int, to_index: int, ) -> None: pass has_side_effect(torch.ops.streams.join.default) @custom_op("streams::record_event", mutates_args=()) def record_event(event_index: int, stream_index: int) -> None: event = _get_event_by_index(event_index) stream = _get_stream_by_index(stream_index) stream.record_event(event) @record_event.register_fake def _( event_index: int, stream_index: int, ) -> None: pass has_side_effect(torch.ops.streams.record_event.default) @custom_op("streams::wait_event", mutates_args=()) def wait_event(event_index: int, stream_index: int) -> None: event = _get_event_by_index(event_index) stream = _get_stream_by_index(stream_index) stream.wait_event(event) @wait_event.register_fake def _( event_index: int, stream_index: int, ) -> None: pass has_side_effect(torch.ops.streams.wait_event.default) @custom_op("streams::wait_stream", mutates_args=()) def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None: waiting = _get_stream_by_index(waiting_stream_index) waited_on = _get_stream_by_index(waited_on_stream_index) waiting.wait_stream(waited_on) @wait_stream.register_fake def _( event_index: int, stream_index: int, ) -> None: pass has_side_effect(torch.ops.streams.wait_stream.default) @custom_op("streams::sync_dealloc", mutates_args=()) def sync_dealloc( wait_event_index: int, src_stream_index: int, to_dealloc: torch.Tensor ) -> None: """An op which waits on an event and moves the last usage of to_dealloc after the wait, so that after the sync occurs, the deallocation or subsequent reuse of the tensor's memory will be guaranteed to happen after a side stream is finished using it. See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream for more details""" torch.ops.streams.wait_event.default(wait_event_index, src_stream_index) has_side_effect(torch.ops.streams.sync_dealloc.default) @custom_op("streams::record_stream", mutates_args=()) def record_stream(tensor: torch.Tensor, stream_index: int) -> None: tensor.record_stream(_get_stream_by_index(stream_index)) @record_stream.register_fake def _( src_stream_index: int, wait_event_index: int, to_dealloc: torch.Tensor, ) -> None: pass class SymbolicStreamState: """Track the currently entered stream if any""" def __init__(self) -> None: from ..source import CurrentStreamSource cur_stack: list[StreamVariable] = [] if torch.accelerator.is_available(): stream_var = LazyVariableTracker.create( torch.accelerator.current_stream(), source=CurrentStreamSource(torch.accelerator.current_stream().device), ) cur_stack = [stream_var] # type: ignore[list-item] self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque( cur_stack ) def enter_stream(self, stream: "StreamVariable") -> None: self.cur_stream_stack.append(stream) def exit_stream(self) -> None: self.cur_stream_stack.pop() def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable": if device is not None: for stream in reversed(self.cur_stream_stack): if stream.device == device: return stream return self.cur_stream_stack[-1] def in_stream_context(self) -> bool: return len(self.cur_stream_stack) > 0 class StreamContextVariable(FxTracebackAnnotateVariable): """This represents torch.cuda.StreamContext""" @staticmethod def create( tx: "InstructionTranslator", stream_to_enter: "StreamVariable", **kwargs: dict[str, Any], ) -> "StreamContextVariable": return StreamContextVariable( stream_to_enter, **kwargs, ) def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None: self.stream = stream super().__init__( target_values={"stream": self.get_stream().user_object_index}, initial_values=None, **kwargs, ) def enter( self, tx: "InstructionTranslator", *args: VariableTracker ) -> VariableTracker: # to stream, from stream is the order of the arguments # we are entering the target, and leaving the initial stream tx.symbolic_stream_state.enter_stream(self.get_stream()) return super().enter(tx) def exit( self, tx: "InstructionTranslator", *args: VariableTracker ) -> VariableTracker: # to stream, from stream is the order of the arguments # we are leaving the target, and entering the initial stream tx.symbolic_stream_state.exit_stream() return super().exit(tx, *args) def supports_graph_breaks(self) -> bool: return True def get_stream(self) -> "StreamVariable": assert self.stream, "Stream context should have a separate stream" return self.stream class StreamVariable(StreamContextVariable): """Represents the device-agnostic torch.Stream class""" def __init__( self, proxy: Proxy, value: torch.Stream, user_object_index: Optional[int] = None, **kwargs: Any, ) -> None: # Index into the user object table # used to pass arbitrary objects to the graph if proxy is not None and "example_value" in proxy.node.meta: assert proxy.node.meta["example_value"] == value self.proxy = proxy self.value = value # pyrefly: ignore [read-only] self.device = value.device # pyrefly: ignore [read-only] self.user_object_index = user_object_index super().__init__(None, **kwargs) def python_type(self) -> type: return torch.Stream def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: assert hasattr(self.value, name), f"no stream method found named {name}" from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs from .builder import wrap_fx_proxy_cls if name in ("wait_stream", "synchronize", "wait_event"): tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ) return ConstantVariable(None) elif name == "query": return wrap_fx_proxy_cls( target_cls=ConstantVariable, tx=tx, proxy=tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ), ) elif name == "record_event": return wrap_fx_proxy_cls( target_cls=EventVariable, tx=tx, proxy=tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ), ) elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: from ..guards import GuardBuilder, install_guard if self.source: install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) # NB : Checking for mutation is necessary because we compare # constant values other = args[0] if not isinstance(other, StreamVariable): return ConstantVariable.create(NotImplemented) if other.source: assert self.source is not None install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) return ConstantVariable.create( cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type] ) return super().call_method(tx, name, args, kwargs) def as_proxy(self) -> Proxy: return self.proxy def module_name(self) -> str: return "torch._C" def fn_name(self) -> str: return "Stream" def reconstruct(self, codegen: "PyCodegen") -> None: # If we got here, this stream is fully subsumed by the graph - this means it is # not an input or global assert not self.source if self.user_object_index is not None: codegen.add_push_null( lambda: codegen.load_import_from( torch._dynamo.graph_bytecode_inputs.__name__, "get_external_object_by_index", ) ) codegen.append_output(codegen.create_load_const(self.user_object_index)) codegen.extend_output(create_call_function(1, False)) else: # This will support the legacy behavior prefix = f"_stream_{self.device}" name = codegen.tx.output.install_global_by_id(prefix, self.value) codegen.append_output(codegen.create_load_global(name, add=True)) def get_stream(self) -> "StreamVariable": return self @staticmethod def make_construct_in_graph_stream_fn( args: TupleVariable, kwargs: ConstDictVariable ) -> Callable[[int, "PyCodegen"], None]: def fn(index: int, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from( torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports] "stash_graph_created_object", ) ) codegen.add_push_null( lambda: codegen.load_import_from( torch._dynamo.utils.__name__, "build_stream" ) ) codegen(args) codegen(kwargs) codegen.extend_output(create_call_function(2, False)) codegen.extend_output(create_call_function(1, False)) return fn class EventVariable(VariableTracker): def __init__( self, proxy: Proxy, value: torch.Event, user_object_index: Optional[int], **kwargs: Any, ) -> None: if proxy is not None and "example_value" in proxy.node.meta: assert proxy.node.meta["example_value"] == value super().__init__(**kwargs) self.proxy = proxy self.value = value self.user_object_index = user_object_index def call_method( self, tx: "InstructionTranslator", name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: from ..utils import proxy_args_kwargs from .builder import wrap_fx_proxy_cls if name == "wait": tx.output.create_proxy( "call_function", torch.ops.streams.wait_event, ( self.user_object_index, EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, ), {}, ) return ConstantVariable(None) elif name == "record": tx.output.create_proxy( "call_function", torch.ops.streams.record_event, ( self.user_object_index, EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, ), {}, ) return ConstantVariable(None) elif name == "synchronize": tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ) return ConstantVariable(None) elif name == "query": return wrap_fx_proxy_cls( target_cls=ConstantVariable, tx=tx, proxy=tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ), ) else: method_name = ( f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}" ) unimplemented( gb_type="Unsupported event method", context=str(name), explanation=f"Dynamo doesn't support tracing the {method_name} method. " f"We currently support wait, record, synchronize, and query.", hints=[ *graph_break_hints.SUPPORTABLE, ], ) def as_proxy(self) -> Proxy: return self.proxy @staticmethod def _get_stream_arg( tx: "InstructionTranslator", args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> "StreamVariable": stream_arg = None if args: stream_arg = args[0] elif kwargs: stream_arg = kwargs.get("stream") if not stream_arg: stream_arg = tx.symbolic_stream_state.cur_stream() return stream_arg # type: ignore[return-value] @staticmethod def make_construct_in_graph_event_fn( args: TupleVariable, kwargs: ConstDictVariable ) -> Callable[[int, "PyCodegen"], None]: def fn(index: int, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from( torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports] "stash_graph_created_object", ) ) codegen.add_push_null( lambda: codegen.load_import_from( torch._dynamo.utils.__name__, "build_event" ) ) codegen(args) codegen(kwargs) codegen.extend_output(create_call_function(2, False)) codegen.extend_output(create_call_function(1, False)) return fn def reconstruct(self, codegen: "PyCodegen") -> None: # If we got here, this event is fully subsumed by the graph - this means it is # not an input or global assert not self.source # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there. prefix = "_event" name = codegen.tx.output.install_global_by_id(prefix, self.value) codegen.append_output(codegen.create_load_global(name, add=True))