#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) #include #include namespace at::native { enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; using fused_adam_fn = void (*)( const at::Tensor& param, const at::Tensor& grad, const at::Tensor& exp_avg, const at::Tensor& exp_avg_sq, const at::Tensor& max_exp_avg_sq, const at::Tensor& state_step, const double lr, const double beta1, const double beta2, const double weight_decay, const double eps, const bool amsgrad, const bool maximize, const float* grad_scale_ptr, const ADAM_MODE); DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub) } // namespace at::native #else #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)