You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

106 lines
2.6 KiB

# pylint: disable=useless-parent-delegation
from __future__ import annotations
from typing import Optional, Union
from typing_extensions import Self
import torch
_POOL_HANDLE = tuple[int, int]
def graph_pool_handle() -> _POOL_HANDLE:
"""
Return an opaque token representing the id of a graph memory pool.
"""
return torch._C._mtia_graphPoolHandle()
class MTIAGraph(torch._C._MTIAGraph):
"""
Wrapper around a MTIA graph.
"""
def __new__(cls, keep_graph: bool = False) -> Self:
return super().__new__(cls, keep_graph)
def capture_begin(self, pool: _POOL_HANDLE) -> None:
"""
Begin capturing a MTIA graph.
"""
super().capture_begin(pool)
def capture_end(self) -> None:
"""
End the capture of a MTIA graph.
"""
super().capture_end()
def instantiate(self) -> None:
"""
Instantiate the captured MTIA graph.
"""
super().instantiate()
def replay(self) -> None:
"""
Replay the captured MTIA graph.
"""
super().replay()
def reset(self) -> None:
"""
Destroy the captured graph and reset the states.
"""
super().reset()
def pool(self) -> _POOL_HANDLE:
"""
Return an opaque token representing the id of this graph's memory pool
"""
return super().pool()
class graph:
default_capture_stream: Optional[torch.mtia.Stream] = None
def __init__(
self,
mtia_graph: MTIAGraph,
pool: Optional[_POOL_HANDLE] = None,
stream: Optional[torch.mtia.Stream] = None,
):
if self.__class__.default_capture_stream is None:
self.__class__.default_capture_stream = torch.mtia.current_stream()
self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = (
() if pool is None else (pool,)
)
self.capture_stream = (
stream if stream is not None else self.__class__.default_capture_stream
)
assert self.capture_stream is not None
self.stream_ctx = torch.mtia.stream(self.capture_stream)
self.mtia_graph = mtia_graph
def __enter__(self) -> None:
torch.mtia.synchronize()
torch.mtia.empty_cache()
self.stream_ctx.__enter__()
pool_arg = self.pool[0] if self.pool else (0, 0)
self.mtia_graph.capture_begin(pool_arg)
def __exit__(self, *args: object) -> None:
self.mtia_graph.capture_end()
self.stream_ctx.__exit__(*args)
__all__ = [
"MTIAGraph",
"graph",
"graph_pool_handle",
]