You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
550 lines
17 KiB
550 lines
17 KiB
import collections
|
|
from collections.abc import Callable
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from torch._dynamo.variables.dicts import ConstDictVariable
|
|
from torch._dynamo.variables.lists import TupleVariable
|
|
from torch.fx import has_side_effect, Proxy
|
|
|
|
from .. import graph_break_hints
|
|
from ..bytecode_transformation import create_call_function
|
|
from ..exc import TYPE_CHECKING, unimplemented
|
|
from ..graph_bytecode_inputs import (
|
|
get_external_object_by_index,
|
|
register_graph_created_object,
|
|
)
|
|
from ..source import CurrentStreamSource
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .ctx_manager import FxTracebackAnnotateVariable
|
|
from .lazy import LazyVariableTracker
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
|
|
|
from ..codegen import PyCodegen
|
|
|
|
from torch._library.custom_ops import custom_op
|
|
|
|
|
|
Tensor = torch.Tensor
|
|
|
|
|
|
def new_event(*args: Any, **kwargs: Any) -> int:
|
|
event = torch.Event(*args, **kwargs)
|
|
return register_graph_created_object(
|
|
event,
|
|
EventVariable.make_construct_in_graph_event_fn(
|
|
TupleVariable([]), ConstDictVariable({})
|
|
),
|
|
)
|
|
|
|
|
|
def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
|
|
stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
|
|
return register_graph_created_object(
|
|
stream,
|
|
StreamVariable.make_construct_in_graph_stream_fn(
|
|
TupleVariable([]), ConstDictVariable({})
|
|
),
|
|
)
|
|
|
|
|
|
def _codegen_current_stream(device: torch.device, cg: "PyCodegen") -> None:
|
|
cg.add_push_null(
|
|
lambda: cg.load_import_from(
|
|
torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
|
|
"stash_graph_created_object",
|
|
)
|
|
)
|
|
cg(CurrentStreamSource(device))
|
|
cg.extend_output(create_call_function(1, False))
|
|
|
|
|
|
def get_current_stream(device: torch.device) -> int:
|
|
stream = torch.accelerator.current_stream(device)
|
|
return register_graph_created_object(
|
|
stream, lambda _, cg: _codegen_current_stream(device, cg)
|
|
)
|
|
|
|
|
|
def _get_stream_by_index(index: int) -> torch.Stream:
|
|
stream = get_external_object_by_index(index)
|
|
assert isinstance(stream, torch.Stream), (
|
|
f"Fork/join stream expected a stream object at index {index}"
|
|
)
|
|
return stream
|
|
|
|
|
|
def _get_event_by_index(index: int) -> torch.Event:
|
|
event = get_external_object_by_index(index)
|
|
assert isinstance(event, torch.Event), (
|
|
f"Record/wait event expected an event object at index {index}"
|
|
)
|
|
return event
|
|
|
|
|
|
@custom_op("streams::fork", mutates_args=())
|
|
def fork_stream(
|
|
from_index: int, # kept to make stream transitions clearer
|
|
to_index: int,
|
|
) -> None:
|
|
torch.accelerator.set_stream(_get_stream_by_index(to_index))
|
|
|
|
|
|
@fork_stream.register_fake
|
|
def _(
|
|
from_index: int, # kept to make stream transitions clearer
|
|
to_index: int,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
has_side_effect(torch.ops.streams.fork.default)
|
|
|
|
|
|
@custom_op("streams::join", mutates_args=())
|
|
def join_stream(from_index: int, to_index: int) -> None:
|
|
torch.accelerator.set_stream(_get_stream_by_index(to_index))
|
|
|
|
|
|
@join_stream.register_fake
|
|
def _(
|
|
from_index: int,
|
|
to_index: int,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
has_side_effect(torch.ops.streams.join.default)
|
|
|
|
|
|
@custom_op("streams::record_event", mutates_args=())
|
|
def record_event(event_index: int, stream_index: int) -> None:
|
|
event = _get_event_by_index(event_index)
|
|
stream = _get_stream_by_index(stream_index)
|
|
stream.record_event(event)
|
|
|
|
|
|
@record_event.register_fake
|
|
def _(
|
|
event_index: int,
|
|
stream_index: int,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
has_side_effect(torch.ops.streams.record_event.default)
|
|
|
|
|
|
@custom_op("streams::wait_event", mutates_args=())
|
|
def wait_event(event_index: int, stream_index: int) -> None:
|
|
event = _get_event_by_index(event_index)
|
|
stream = _get_stream_by_index(stream_index)
|
|
stream.wait_event(event)
|
|
|
|
|
|
@wait_event.register_fake
|
|
def _(
|
|
event_index: int,
|
|
stream_index: int,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
has_side_effect(torch.ops.streams.wait_event.default)
|
|
|
|
|
|
@custom_op("streams::wait_stream", mutates_args=())
|
|
def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
|
|
waiting = _get_stream_by_index(waiting_stream_index)
|
|
waited_on = _get_stream_by_index(waited_on_stream_index)
|
|
waiting.wait_stream(waited_on)
|
|
|
|
|
|
@wait_stream.register_fake
|
|
def _(
|
|
event_index: int,
|
|
stream_index: int,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
has_side_effect(torch.ops.streams.wait_stream.default)
|
|
|
|
|
|
@custom_op("streams::sync_dealloc", mutates_args=())
|
|
def sync_dealloc(
|
|
wait_event_index: int, src_stream_index: int, to_dealloc: torch.Tensor
|
|
) -> None:
|
|
"""An op which waits on an event and moves the last usage of to_dealloc
|
|
after the wait, so that after the sync occurs, the deallocation or
|
|
subsequent reuse of the tensor's memory will be guaranteed to happen
|
|
after a side stream is finished using it.
|
|
See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream
|
|
for more details"""
|
|
torch.ops.streams.wait_event.default(wait_event_index, src_stream_index)
|
|
|
|
|
|
has_side_effect(torch.ops.streams.sync_dealloc.default)
|
|
|
|
|
|
@custom_op("streams::record_stream", mutates_args=())
|
|
def record_stream(tensor: torch.Tensor, stream_index: int) -> None:
|
|
tensor.record_stream(_get_stream_by_index(stream_index))
|
|
|
|
|
|
@record_stream.register_fake
|
|
def _(
|
|
src_stream_index: int,
|
|
wait_event_index: int,
|
|
to_dealloc: torch.Tensor,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
class SymbolicStreamState:
|
|
"""Track the currently entered stream if any"""
|
|
|
|
def __init__(self) -> None:
|
|
from ..source import CurrentStreamSource
|
|
|
|
cur_stack: list[StreamVariable] = []
|
|
if torch.accelerator.is_available():
|
|
stream_var = LazyVariableTracker.create(
|
|
torch.accelerator.current_stream(),
|
|
source=CurrentStreamSource(torch.accelerator.current_stream().device),
|
|
)
|
|
cur_stack = [stream_var] # type: ignore[list-item]
|
|
|
|
self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque(
|
|
cur_stack
|
|
)
|
|
|
|
def enter_stream(self, stream: "StreamVariable") -> None:
|
|
self.cur_stream_stack.append(stream)
|
|
|
|
def exit_stream(self) -> None:
|
|
self.cur_stream_stack.pop()
|
|
|
|
def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable":
|
|
if device is not None:
|
|
for stream in reversed(self.cur_stream_stack):
|
|
if stream.device == device:
|
|
return stream
|
|
|
|
return self.cur_stream_stack[-1]
|
|
|
|
def in_stream_context(self) -> bool:
|
|
return len(self.cur_stream_stack) > 0
|
|
|
|
|
|
class StreamContextVariable(FxTracebackAnnotateVariable):
|
|
"""This represents torch.cuda.StreamContext"""
|
|
|
|
@staticmethod
|
|
def create(
|
|
tx: "InstructionTranslator",
|
|
stream_to_enter: "StreamVariable",
|
|
**kwargs: dict[str, Any],
|
|
) -> "StreamContextVariable":
|
|
return StreamContextVariable(
|
|
stream_to_enter,
|
|
**kwargs,
|
|
)
|
|
|
|
def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None:
|
|
self.stream = stream
|
|
super().__init__(
|
|
target_values={"stream": self.get_stream().user_object_index},
|
|
initial_values=None,
|
|
**kwargs,
|
|
)
|
|
|
|
def enter(
|
|
self, tx: "InstructionTranslator", *args: VariableTracker
|
|
) -> VariableTracker:
|
|
# to stream, from stream is the order of the arguments
|
|
# we are entering the target, and leaving the initial stream
|
|
tx.symbolic_stream_state.enter_stream(self.get_stream())
|
|
return super().enter(tx)
|
|
|
|
def exit(
|
|
self, tx: "InstructionTranslator", *args: VariableTracker
|
|
) -> VariableTracker:
|
|
# to stream, from stream is the order of the arguments
|
|
# we are leaving the target, and entering the initial stream
|
|
tx.symbolic_stream_state.exit_stream()
|
|
return super().exit(tx, *args)
|
|
|
|
def supports_graph_breaks(self) -> bool:
|
|
return True
|
|
|
|
def get_stream(self) -> "StreamVariable":
|
|
assert self.stream, "Stream context should have a separate stream"
|
|
return self.stream
|
|
|
|
|
|
class StreamVariable(StreamContextVariable):
|
|
"""Represents the device-agnostic torch.Stream class"""
|
|
|
|
def __init__(
|
|
self,
|
|
proxy: Proxy,
|
|
value: torch.Stream,
|
|
user_object_index: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
# Index into the user object table
|
|
# used to pass arbitrary objects to the graph
|
|
if proxy is not None and "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == value
|
|
|
|
self.proxy = proxy
|
|
self.value = value
|
|
# pyrefly: ignore [read-only]
|
|
self.device = value.device
|
|
# pyrefly: ignore [read-only]
|
|
self.user_object_index = user_object_index
|
|
super().__init__(None, **kwargs)
|
|
|
|
def python_type(self) -> type:
|
|
return torch.Stream
|
|
|
|
def call_method(
|
|
self,
|
|
tx: "InstructionTranslator",
|
|
name: str,
|
|
args: list[VariableTracker],
|
|
kwargs: dict[str, VariableTracker],
|
|
) -> VariableTracker:
|
|
assert hasattr(self.value, name), f"no stream method found named {name}"
|
|
|
|
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
if name in ("wait_stream", "synchronize", "wait_event"):
|
|
tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
)
|
|
return ConstantVariable(None)
|
|
elif name == "query":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=ConstantVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
elif name == "record_event":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=EventVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
|
|
from ..guards import GuardBuilder, install_guard
|
|
|
|
if self.source:
|
|
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
|
|
|
# NB : Checking for mutation is necessary because we compare
|
|
# constant values
|
|
other = args[0]
|
|
if not isinstance(other, StreamVariable):
|
|
return ConstantVariable.create(NotImplemented)
|
|
|
|
if other.source:
|
|
assert self.source is not None
|
|
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
|
|
return ConstantVariable.create(
|
|
cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
|
|
)
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def as_proxy(self) -> Proxy:
|
|
return self.proxy
|
|
|
|
def module_name(self) -> str:
|
|
return "torch._C"
|
|
|
|
def fn_name(self) -> str:
|
|
return "Stream"
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
# If we got here, this stream is fully subsumed by the graph - this means it is
|
|
# not an input or global
|
|
assert not self.source
|
|
if self.user_object_index is not None:
|
|
codegen.add_push_null(
|
|
lambda: codegen.load_import_from(
|
|
torch._dynamo.graph_bytecode_inputs.__name__,
|
|
"get_external_object_by_index",
|
|
)
|
|
)
|
|
codegen.append_output(codegen.create_load_const(self.user_object_index))
|
|
codegen.extend_output(create_call_function(1, False))
|
|
else:
|
|
# This will support the legacy behavior
|
|
prefix = f"_stream_{self.device}"
|
|
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
|
codegen.append_output(codegen.create_load_global(name, add=True))
|
|
|
|
def get_stream(self) -> "StreamVariable":
|
|
return self
|
|
|
|
@staticmethod
|
|
def make_construct_in_graph_stream_fn(
|
|
args: TupleVariable, kwargs: ConstDictVariable
|
|
) -> Callable[[int, "PyCodegen"], None]:
|
|
def fn(index: int, codegen: "PyCodegen") -> None:
|
|
codegen.add_push_null(
|
|
lambda: codegen.load_import_from(
|
|
torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
|
|
"stash_graph_created_object",
|
|
)
|
|
)
|
|
codegen.add_push_null(
|
|
lambda: codegen.load_import_from(
|
|
torch._dynamo.utils.__name__, "build_stream"
|
|
)
|
|
)
|
|
codegen(args)
|
|
codegen(kwargs)
|
|
codegen.extend_output(create_call_function(2, False))
|
|
codegen.extend_output(create_call_function(1, False))
|
|
|
|
return fn
|
|
|
|
|
|
class EventVariable(VariableTracker):
|
|
def __init__(
|
|
self,
|
|
proxy: Proxy,
|
|
value: torch.Event,
|
|
user_object_index: Optional[int],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
if proxy is not None and "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == value
|
|
super().__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.value = value
|
|
self.user_object_index = user_object_index
|
|
|
|
def call_method(
|
|
self,
|
|
tx: "InstructionTranslator",
|
|
name: str,
|
|
args: list[VariableTracker],
|
|
kwargs: dict[str, VariableTracker],
|
|
) -> VariableTracker:
|
|
from ..utils import proxy_args_kwargs
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
if name == "wait":
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
torch.ops.streams.wait_event,
|
|
(
|
|
self.user_object_index,
|
|
EventVariable._get_stream_arg(tx, args, kwargs).user_object_index,
|
|
),
|
|
{},
|
|
)
|
|
return ConstantVariable(None)
|
|
elif name == "record":
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
torch.ops.streams.record_event,
|
|
(
|
|
self.user_object_index,
|
|
EventVariable._get_stream_arg(tx, args, kwargs).user_object_index,
|
|
),
|
|
{},
|
|
)
|
|
return ConstantVariable(None)
|
|
elif name == "synchronize":
|
|
tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
)
|
|
return ConstantVariable(None)
|
|
elif name == "query":
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=ConstantVariable,
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
|
|
),
|
|
)
|
|
else:
|
|
method_name = (
|
|
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
|
|
)
|
|
unimplemented(
|
|
gb_type="Unsupported event method",
|
|
context=str(name),
|
|
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
|
|
f"We currently support wait, record, synchronize, and query.",
|
|
hints=[
|
|
*graph_break_hints.SUPPORTABLE,
|
|
],
|
|
)
|
|
|
|
def as_proxy(self) -> Proxy:
|
|
return self.proxy
|
|
|
|
@staticmethod
|
|
def _get_stream_arg(
|
|
tx: "InstructionTranslator",
|
|
args: list[VariableTracker],
|
|
kwargs: dict[str, VariableTracker],
|
|
) -> "StreamVariable":
|
|
stream_arg = None
|
|
if args:
|
|
stream_arg = args[0]
|
|
elif kwargs:
|
|
stream_arg = kwargs.get("stream")
|
|
|
|
if not stream_arg:
|
|
stream_arg = tx.symbolic_stream_state.cur_stream()
|
|
|
|
return stream_arg # type: ignore[return-value]
|
|
|
|
@staticmethod
|
|
def make_construct_in_graph_event_fn(
|
|
args: TupleVariable, kwargs: ConstDictVariable
|
|
) -> Callable[[int, "PyCodegen"], None]:
|
|
def fn(index: int, codegen: "PyCodegen") -> None:
|
|
codegen.add_push_null(
|
|
lambda: codegen.load_import_from(
|
|
torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
|
|
"stash_graph_created_object",
|
|
)
|
|
)
|
|
codegen.add_push_null(
|
|
lambda: codegen.load_import_from(
|
|
torch._dynamo.utils.__name__, "build_event"
|
|
)
|
|
)
|
|
codegen(args)
|
|
codegen(kwargs)
|
|
codegen.extend_output(create_call_function(2, False))
|
|
codegen.extend_output(create_call_function(1, False))
|
|
|
|
return fn
|
|
|
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
|
# If we got here, this event is fully subsumed by the graph - this means it is
|
|
# not an input or global
|
|
assert not self.source
|
|
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
|
|
prefix = "_event"
|
|
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
|
codegen.append_output(codegen.create_load_global(name, add=True))
|