import torch _GreenContext = object SUPPORTED = False if hasattr(torch._C, "_CUDAGreenContext"): _GreenContext = torch._C._CUDAGreenContext # type: ignore[misc] SUPPORTED = True # Python shim helps Sphinx process docstrings more reliably. # pyrefly: ignore [invalid-inheritance] class GreenContext(_GreenContext): r"""Wrapper around a CUDA green context. .. warning:: This API is in beta and may change in future releases. """ @staticmethod def create(num_sms: int, device_id: int = 0) -> _GreenContext: r"""Create a CUDA green context. Arguments: num_sms (int): The number of SMs to use in the green context. device_id (int, optional): The device index of green context. """ if not SUPPORTED: raise RuntimeError("PyTorch was not built with Green Context support!") return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined] # Note that these functions are bypassed by we define them here # for Sphinx documentation purposes def set_context(self) -> None: # pylint: disable=useless-parent-delegation r"""Make the green context the current context.""" return super().set_context() # type: ignore[misc] def pop_context(self) -> None: # pylint: disable=useless-parent-delegation r"""Assuming the green context is the current context, pop it from the context stack and restore the previous context. """ return super().pop_context() # type: ignore[misc]