You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
1.2 KiB
30 lines
1.2 KiB
|
6 days ago
|
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
||
|
|
#pragma once
|
||
|
|
|
||
|
|
#include <ATen/ATen.h>
|
||
|
|
|
||
|
|
namespace at::caching {
|
||
|
|
|
||
|
|
// Some systems (just cudagraphs currently) will persist a static tensor output
|
||
|
|
// whose TensorImpl does not change across iterations. For these tensors caching
|
||
|
|
// dtype conversions is invalid. Additionally, there will be an extra reference
|
||
|
|
// count to these cached tensors that would prevent buffer inplacing and other
|
||
|
|
// checks on tensor uniqueness. If we are not using these systems the enabled
|
||
|
|
// flag will be false and we will avoid the hash lookup.
|
||
|
|
|
||
|
|
TORCH_API bool is_cached_tensor(const at::Tensor& t);
|
||
|
|
TORCH_API void add_cached_tensor(const at::Tensor& t);
|
||
|
|
TORCH_API void remove_cached_tensor(const at::Tensor& t);
|
||
|
|
TORCH_API void set_cached_tensors_enabled(bool enable);
|
||
|
|
|
||
|
|
// For gradient buffer stealing we will adjust the use count of tensors
|
||
|
|
// which are persisted by cudagraphs, just as we need to adjust reference
|
||
|
|
// count of tensors with hooks.
|
||
|
|
TORCH_API size_t adjusted_use_count(const at::Tensor& t);
|
||
|
|
|
||
|
|
} // namespace at::caching
|
||
|
|
|
||
|
|
#else
|
||
|
|
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
||
|
|
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|