You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.2 KiB
51 lines
1.2 KiB
#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|
|
#pragma once
|
|
|
|
namespace at {
|
|
|
|
struct PhiloxXpuState {
|
|
PhiloxXpuState() = default;
|
|
PhiloxXpuState(uint64_t seed, uint64_t offset) {
|
|
seed_.val = seed;
|
|
offset_.val = offset;
|
|
}
|
|
// for graph capture
|
|
PhiloxXpuState(
|
|
int64_t* seed,
|
|
int64_t* offset_extragraph,
|
|
uint32_t offset_intragraph) {
|
|
seed_.ptr = seed;
|
|
offset_.ptr = offset_extragraph;
|
|
offset_intragraph_ = offset_intragraph;
|
|
captured_ = true;
|
|
}
|
|
|
|
union Payload {
|
|
uint64_t val;
|
|
int64_t* ptr;
|
|
};
|
|
|
|
Payload seed_{};
|
|
Payload offset_{};
|
|
uint32_t offset_intragraph_ = 0;
|
|
bool captured_ = false;
|
|
};
|
|
|
|
namespace xpu::philox {
|
|
inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxXpuState arg) {
|
|
if (arg.captured_) {
|
|
return std::make_tuple(
|
|
static_cast<uint64_t>(*arg.seed_.ptr),
|
|
static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
|
|
} else {
|
|
return std::make_tuple(arg.seed_.val, arg.offset_.val);
|
|
}
|
|
}
|
|
|
|
} // namespace xpu::philox
|
|
} // namespace at
|
|
|
|
#else
|
|
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
|
|
#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|