You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
106 lines
2.6 KiB
106 lines
2.6 KiB
# pylint: disable=useless-parent-delegation
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional, Union
|
|
from typing_extensions import Self
|
|
|
|
import torch
|
|
|
|
|
|
_POOL_HANDLE = tuple[int, int]
|
|
|
|
|
|
def graph_pool_handle() -> _POOL_HANDLE:
|
|
"""
|
|
Return an opaque token representing the id of a graph memory pool.
|
|
"""
|
|
return torch._C._mtia_graphPoolHandle()
|
|
|
|
|
|
class MTIAGraph(torch._C._MTIAGraph):
|
|
"""
|
|
Wrapper around a MTIA graph.
|
|
"""
|
|
|
|
def __new__(cls, keep_graph: bool = False) -> Self:
|
|
return super().__new__(cls, keep_graph)
|
|
|
|
def capture_begin(self, pool: _POOL_HANDLE) -> None:
|
|
"""
|
|
Begin capturing a MTIA graph.
|
|
"""
|
|
super().capture_begin(pool)
|
|
|
|
def capture_end(self) -> None:
|
|
"""
|
|
End the capture of a MTIA graph.
|
|
"""
|
|
super().capture_end()
|
|
|
|
def instantiate(self) -> None:
|
|
"""
|
|
Instantiate the captured MTIA graph.
|
|
"""
|
|
super().instantiate()
|
|
|
|
def replay(self) -> None:
|
|
"""
|
|
Replay the captured MTIA graph.
|
|
"""
|
|
super().replay()
|
|
|
|
def reset(self) -> None:
|
|
"""
|
|
Destroy the captured graph and reset the states.
|
|
"""
|
|
super().reset()
|
|
|
|
def pool(self) -> _POOL_HANDLE:
|
|
"""
|
|
Return an opaque token representing the id of this graph's memory pool
|
|
"""
|
|
return super().pool()
|
|
|
|
|
|
class graph:
|
|
default_capture_stream: Optional[torch.mtia.Stream] = None
|
|
|
|
def __init__(
|
|
self,
|
|
mtia_graph: MTIAGraph,
|
|
pool: Optional[_POOL_HANDLE] = None,
|
|
stream: Optional[torch.mtia.Stream] = None,
|
|
):
|
|
if self.__class__.default_capture_stream is None:
|
|
self.__class__.default_capture_stream = torch.mtia.current_stream()
|
|
|
|
self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = (
|
|
() if pool is None else (pool,)
|
|
)
|
|
self.capture_stream = (
|
|
stream if stream is not None else self.__class__.default_capture_stream
|
|
)
|
|
assert self.capture_stream is not None
|
|
self.stream_ctx = torch.mtia.stream(self.capture_stream)
|
|
self.mtia_graph = mtia_graph
|
|
|
|
def __enter__(self) -> None:
|
|
torch.mtia.synchronize()
|
|
torch.mtia.empty_cache()
|
|
|
|
self.stream_ctx.__enter__()
|
|
|
|
pool_arg = self.pool[0] if self.pool else (0, 0)
|
|
self.mtia_graph.capture_begin(pool_arg)
|
|
|
|
def __exit__(self, *args: object) -> None:
|
|
self.mtia_graph.capture_end()
|
|
self.stream_ctx.__exit__(*args)
|
|
|
|
|
|
__all__ = [
|
|
"MTIAGraph",
|
|
"graph",
|
|
"graph_pool_handle",
|
|
]
|