You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
44 lines
1.5 KiB
44 lines
1.5 KiB
import torch
|
|
|
|
|
|
_GreenContext = object
|
|
SUPPORTED = False
|
|
|
|
if hasattr(torch._C, "_CUDAGreenContext"):
|
|
_GreenContext = torch._C._CUDAGreenContext # type: ignore[misc]
|
|
SUPPORTED = True
|
|
|
|
|
|
# Python shim helps Sphinx process docstrings more reliably.
|
|
# pyrefly: ignore [invalid-inheritance]
|
|
class GreenContext(_GreenContext):
|
|
r"""Wrapper around a CUDA green context.
|
|
|
|
.. warning::
|
|
This API is in beta and may change in future releases.
|
|
"""
|
|
|
|
@staticmethod
|
|
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
|
|
r"""Create a CUDA green context.
|
|
|
|
Arguments:
|
|
num_sms (int): The number of SMs to use in the green context.
|
|
device_id (int, optional): The device index of green context.
|
|
"""
|
|
if not SUPPORTED:
|
|
raise RuntimeError("PyTorch was not built with Green Context support!")
|
|
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
|
|
|
|
# Note that these functions are bypassed by we define them here
|
|
# for Sphinx documentation purposes
|
|
def set_context(self) -> None: # pylint: disable=useless-parent-delegation
|
|
r"""Make the green context the current context."""
|
|
return super().set_context() # type: ignore[misc]
|
|
|
|
def pop_context(self) -> None: # pylint: disable=useless-parent-delegation
|
|
r"""Assuming the green context is the current context, pop it from the
|
|
context stack and restore the previous context.
|
|
"""
|
|
return super().pop_context() # type: ignore[misc]
|