You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

44 lines
1.5 KiB

import torch
_GreenContext = object
SUPPORTED = False
if hasattr(torch._C, "_CUDAGreenContext"):
_GreenContext = torch._C._CUDAGreenContext # type: ignore[misc]
SUPPORTED = True
# Python shim helps Sphinx process docstrings more reliably.
# pyrefly: ignore [invalid-inheritance]
class GreenContext(_GreenContext):
r"""Wrapper around a CUDA green context.
.. warning::
This API is in beta and may change in future releases.
"""
@staticmethod
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
r"""Create a CUDA green context.
Arguments:
num_sms (int): The number of SMs to use in the green context.
device_id (int, optional): The device index of green context.
"""
if not SUPPORTED:
raise RuntimeError("PyTorch was not built with Green Context support!")
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
# Note that these functions are bypassed by we define them here
# for Sphinx documentation purposes
def set_context(self) -> None: # pylint: disable=useless-parent-delegation
r"""Make the green context the current context."""
return super().set_context() # type: ignore[misc]
def pop_context(self) -> None: # pylint: disable=useless-parent-delegation
r"""Assuming the green context is the current context, pop it from the
context stack and restore the previous context.
"""
return super().pop_context() # type: ignore[misc]