# pylint: disable=useless-parent-delegation from __future__ import annotations from typing import Optional, Union from typing_extensions import Self import torch _POOL_HANDLE = tuple[int, int] def graph_pool_handle() -> _POOL_HANDLE: """ Return an opaque token representing the id of a graph memory pool. """ return torch._C._mtia_graphPoolHandle() class MTIAGraph(torch._C._MTIAGraph): """ Wrapper around a MTIA graph. """ def __new__(cls, keep_graph: bool = False) -> Self: return super().__new__(cls, keep_graph) def capture_begin(self, pool: _POOL_HANDLE) -> None: """ Begin capturing a MTIA graph. """ super().capture_begin(pool) def capture_end(self) -> None: """ End the capture of a MTIA graph. """ super().capture_end() def instantiate(self) -> None: """ Instantiate the captured MTIA graph. """ super().instantiate() def replay(self) -> None: """ Replay the captured MTIA graph. """ super().replay() def reset(self) -> None: """ Destroy the captured graph and reset the states. """ super().reset() def pool(self) -> _POOL_HANDLE: """ Return an opaque token representing the id of this graph's memory pool """ return super().pool() class graph: default_capture_stream: Optional[torch.mtia.Stream] = None def __init__( self, mtia_graph: MTIAGraph, pool: Optional[_POOL_HANDLE] = None, stream: Optional[torch.mtia.Stream] = None, ): if self.__class__.default_capture_stream is None: self.__class__.default_capture_stream = torch.mtia.current_stream() self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = ( () if pool is None else (pool,) ) self.capture_stream = ( stream if stream is not None else self.__class__.default_capture_stream ) assert self.capture_stream is not None self.stream_ctx = torch.mtia.stream(self.capture_stream) self.mtia_graph = mtia_graph def __enter__(self) -> None: torch.mtia.synchronize() torch.mtia.empty_cache() self.stream_ctx.__enter__() pool_arg = self.pool[0] if self.pool else (0, 0) self.mtia_graph.capture_begin(pool_arg) def __exit__(self, *args: object) -> None: self.mtia_graph.capture_end() self.stream_ctx.__exit__(*args) __all__ = [ "MTIAGraph", "graph", "graph_pool_handle", ]