#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) #include #include #include #include #include #include #include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef USE_FBGEMM_GENAI #include #endif #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #endif using at::blas::ScalingType; using at::blas::SwizzleType; namespace at::cuda::scaled { static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) { #ifdef USE_ROCM static const std::vector archs = { "gfx942", #if ROCM_VERSION >= 60300 "gfx1200", "gfx1201", #endif #if ROCM_VERSION >= 60500 "gfx950" #endif }; return at::detail::getCUDAHooks().isGPUArch(archs); #else auto dprops = at::cuda::getCurrentDeviceProperties(); if (sm90_only || sm100_only) { return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10); } else { return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); } #endif } #ifdef USE_ROCM static bool _scaled_mm_is_fnuz() { return at::detail::getCUDAHooks().isGPUArch({"gfx942"}); } #endif /** * Track concrete implementations available */ enum class ScaledGemmImplementation { NONE = 0, TENSORWISE_TENSORWISE = 1, ROWWISE_ROWWISE = 2, BLOCK_128x128_1x128 = 3, BLOCK_1x128_128x128 = 4, BLOCK_1x128_1x128 = 5, MXFP8_MXFP8 = 6, NVFP4_NVFP4 = 7, NVFP4_NVFP4_SINGLE_SCALE = 8, MXFP4_MXFP4 = 9, }; /** * Convert passed int (enum) from python back into a * strictly-typed enum */ template std::vector convert_int_to_enum(ArrayType& v) { std::vector converted; converted.reserve(v.size()); for (auto vi : v) { converted.push_back(static_cast(vi)); } return converted; } bool check_tensorwise_recipe(c10::ScalarType, std::vector&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&); bool check_rowwise_recipe(c10::ScalarType, std::vector&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&); bool check_nvfp4_recipe(c10::ScalarType, std::vector&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&); bool check_nvfp4_recipe_single_scale (c10::ScalarType, std::vector&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&); bool check_deepseek_recipe(ScalingType, ScalingType, c10::ScalarType, std::vector&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&); bool check_mxfp8_recipe(c10::ScalarType, std::vector&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&); bool check_mxfp4_recipe(c10::ScalarType, std::vector&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&); } // namespace at::native::cuda::blas::scaled #else #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)