You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

82 lines
3.1 KiB

#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
#pragma once
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <ATen/dlpack.h>
// this converter will:
// 1) take a Tensor object and wrap it in the DLPack tensor
// 2) take a dlpack tensor and convert it to the ATen Tensor
namespace at {
TORCH_API ScalarType toScalarType(const DLDataType& dtype);
TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
TORCH_API struct DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src);
TORCH_API void toDLPackNonOwning(const Tensor& src, DLTensor* out);
TORCH_API Tensor
fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter = {});
TORCH_API Tensor fromDLPackVersioned(
DLManagedTensorVersioned* src,
std::function<void(void*)> deleter = {});
TORCH_API DLDataType getDLDataType(const Tensor& t);
TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
// Copies the Tensor if there's a device mismatch or copy is forced.
// This should be used before actually creating the DLPack capsule.
TORCH_API Tensor maybeCopyTensor(
const Tensor& data,
std::optional<DLDevice> optional_dl_device,
std::optional<bool> copy);
// Converts the given at::Device into a DLDevice.
TORCH_API DLDevice torchDeviceToDLDevice(at::Device device);
// Converts the DLDevice to an ATen device.
TORCH_API Device dlDeviceToTorchDevice(
DLDeviceType type,
c10::DeviceIndex index,
void* data = nullptr);
// This trait class is used for retrieving different attributes, such as the
// PyCapsule names and conversion functions for both DLPack tensor classes:
// `DLManagedTensor` and `DLManagedTensorVersioned`.
//
// Each specialization should contain the following 2 traits:
// - `capsule`: actual name of the capsule
// - `used`: name of the capsule after using it
// - `toDLPack`: function for converting a tensor into a DLPack capsule
// - `fromDLPack`: function for creating a tensor from a DLPack capsule
//
// While `toDLPack` is the directly exposed to Python, `fromDLPack` is not.
// Although it contains the core implementation, it lacks the required book
// keeping logic contained in its caller `tensor_fromDLPack`.
//
// That said, `fromDLPack` is used directly in a few DLPack tests that live
// inside ATen (no Python available).
template <class T>
struct DLPackTraits {};
template <>
struct DLPackTraits<DLManagedTensor> {
inline static constexpr const char* capsule = "dltensor";
inline static constexpr const char* used = "used_dltensor";
inline static auto toDLPack = at::toDLPack;
inline static auto fromDLPack = at::fromDLPack;
};
template <>
struct DLPackTraits<DLManagedTensorVersioned> {
inline static constexpr const char* capsule = "dltensor_versioned";
inline static constexpr const char* used = "used_dltensor_versioned";
inline static auto toDLPack = at::toDLPackVersioned;
inline static auto fromDLPack = at::fromDLPackVersioned;
};
} // namespace at
#else
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)