#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) #include #include #include #include #include #include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #include #ifdef USE_FBGEMM_GENAI #include #endif #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #endif using at::blas::ScalingType; namespace at::native::onednn::scaled { /** * Track concrete implementations available */ enum class ScaledGemmImplementation { NONE = 0, TENSORWISE_TENSORWISE = 1, ROWWISE_ROWWISE = 2, }; /** * Convert passed int (enum) from python back into a * strictly-typed enum */ template std::vector convert_int_to_enum(ArrayType& v) { std::vector converted; converted.reserve(v.size()); for (auto vi : v) { converted.push_back(static_cast(vi)); } return converted; } bool check_tensorwise_recipe( c10::ScalarType, std::vector&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&); bool check_rowwise_recipe( c10::ScalarType, std::vector&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&); } // namespace at::native::onednn::scaled #else #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)