You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
82 lines
2.9 KiB
82 lines
2.9 KiB
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
|
#pragma once
|
|
|
|
#include <c10/cuda/CUDAStream.h>
|
|
#include <iostream>
|
|
#include <utility>
|
|
|
|
// CUDA Graphs utils used by c10 and aten.
|
|
// aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only.
|
|
|
|
namespace c10::cuda {
|
|
|
|
// RAII guard for "cudaStreamCaptureMode", a thread-local value
|
|
// that controls the error-checking strictness of a capture.
|
|
struct C10_CUDA_API CUDAStreamCaptureModeGuard {
|
|
CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired)
|
|
: strictness_(desired) {
|
|
C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_));
|
|
}
|
|
CUDAStreamCaptureModeGuard(const CUDAStreamCaptureModeGuard&) = delete;
|
|
CUDAStreamCaptureModeGuard(CUDAStreamCaptureModeGuard&&) = delete;
|
|
CUDAStreamCaptureModeGuard& operator=(const CUDAStreamCaptureModeGuard&) =
|
|
delete;
|
|
CUDAStreamCaptureModeGuard& operator=(CUDAStreamCaptureModeGuard&&) = delete;
|
|
~CUDAStreamCaptureModeGuard() {
|
|
C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
|
|
}
|
|
|
|
private:
|
|
cudaStreamCaptureMode strictness_;
|
|
};
|
|
|
|
// Protects against enum cudaStreamCaptureStatus implementation changes.
|
|
// Some compilers seem not to like static_assert without the messages.
|
|
static_assert(
|
|
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
|
|
"unexpected int(cudaStreamCaptureStatusNone) value");
|
|
static_assert(
|
|
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
|
|
"unexpected int(cudaStreamCaptureStatusActive) value");
|
|
static_assert(
|
|
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
|
|
"unexpected int(cudaStreamCaptureStatusInvalidated) value");
|
|
|
|
enum class CaptureStatus : int {
|
|
None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
|
|
Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
|
|
Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
|
|
};
|
|
|
|
inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
|
|
switch (status) {
|
|
case CaptureStatus::None:
|
|
os << "cudaStreamCaptureStatusNone";
|
|
break;
|
|
case CaptureStatus::Active:
|
|
os << "cudaStreamCaptureStatusActive";
|
|
break;
|
|
case CaptureStatus::Invalidated:
|
|
os << "cudaStreamCaptureStatusInvalidated";
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Unknown CUDA graph CaptureStatus", int(status));
|
|
}
|
|
return os;
|
|
}
|
|
|
|
// Use this version where you're sure a CUDA context exists already.
|
|
inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
|
|
cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone};
|
|
C10_CUDA_CHECK(
|
|
cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
|
|
return CaptureStatus(is_capturing);
|
|
}
|
|
|
|
} // namespace c10::cuda
|
|
|
|
#else
|
|
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
|
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|