You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
1.6 KiB

#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
#pragma once
#include <c10/core/Allocator.h>
#include <c10/cuda/CUDACachingAllocator.h>
namespace at::cuda {
// Keep BC only
using c10::CaptureId_t;
using c10::MempoolId_t;
// MemPool represents a pool of memory in a caching allocator. Currently,
// it's just the ID of the pool object maintained in the CUDACachingAllocator.
//
// An allocator pointer can be passed to the MemPool to define how the
// allocations should be done in the pool. For example: using a different
// system allocator such as ncclMemAlloc.
struct TORCH_CUDA_CPP_API MemPool {
MemPool(
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
bool is_user_created = true,
bool use_on_oom = false,
bool no_split = false);
MemPool(const MemPool&) = delete;
MemPool(MemPool&&) = default;
MemPool& operator=(const MemPool&) = delete;
MemPool& operator=(MemPool&&) = default;
~MemPool();
MempoolId_t id();
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator();
int use_count();
c10::DeviceIndex device();
static MempoolId_t graph_pool_handle(bool is_user_created = true);
private:
static std::atomic<CaptureId_t> uid_;
static std::atomic<CaptureId_t> uuid_;
c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator_;
bool is_user_created_;
MempoolId_t id_;
c10::DeviceIndex device_;
};
} // namespace at::cuda
#else
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)