You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
889 lines
34 KiB
889 lines
34 KiB
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from ..core_model_loading import ConversionOps
|
|
from ..quantizers.quantizers_utils import should_convert_module
|
|
from ..utils import is_kernels_available, is_torch_accelerator_available, is_torch_available, logging
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
import torch.nn as nn
|
|
import triton
|
|
import triton.language as tl
|
|
from torch.nn import functional as F
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
# Global for the CUTLASS quantization kernel (lazily loaded)
|
|
_quantization_kernel = None
|
|
|
|
|
|
def _get_quantization_kernel():
|
|
"""Lazily load the CUTLASS quantization kernel from HuggingFace Hub."""
|
|
global _quantization_kernel
|
|
if _quantization_kernel is None:
|
|
try:
|
|
from .hub_kernels import get_kernel
|
|
|
|
_quantization_kernel = get_kernel("RedHatAI/quantization")
|
|
except Exception as e:
|
|
logger.warning_once(f"Failed to load CUTLASS quantization kernel: {e}. Falling back to Triton.")
|
|
_quantization_kernel = False # Mark as unavailable
|
|
return _quantization_kernel if _quantization_kernel else None
|
|
|
|
|
|
def _supports_cutlass(block_size: list[int] | None, output_dtype: torch.dtype) -> bool:
|
|
"""
|
|
Check if CUTLASS blockwise FP8 matmul is supported for the given block size and output dtype.
|
|
|
|
CUTLASS blockwise kernels require:
|
|
- SM90+ (Hopper or newer)
|
|
- Block size [128, 128] for weights
|
|
- Block size [1, 128] for activations (handled implicitly)
|
|
- Output dtype bfloat16 or float16
|
|
"""
|
|
|
|
if not is_torch_available() or not torch.cuda.is_available() or not is_kernels_available():
|
|
return False
|
|
|
|
# CUTLASS only supports bfloat16/float16 output
|
|
if output_dtype not in (torch.bfloat16, torch.float16):
|
|
return False
|
|
|
|
# Check block size compatibility - CUTLASS only supports [128, 128]
|
|
if block_size is None:
|
|
return False
|
|
if len(block_size) != 2 or block_size[0] != 128 or block_size[1] != 128:
|
|
return False
|
|
|
|
# Check GPU capability (SM90+)
|
|
capability = torch.cuda.get_device_capability()
|
|
cuda_capability = capability[0] * 10 + capability[1]
|
|
|
|
# Try to load the kernel and check if blockwise FP8 is supported
|
|
kernel = _get_quantization_kernel()
|
|
if kernel is None:
|
|
return False
|
|
|
|
try:
|
|
return kernel.cutlass_scaled_mm_supports_block_fp8(cuda_capability)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
try:
|
|
_FP8_DTYPE = torch.float8_e4m3fn
|
|
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
|
|
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
|
|
except AttributeError:
|
|
_FP8_DTYPE = None
|
|
_FP8_MIN, _FP8_MAX = -448, 448
|
|
logger.warning_once("torch.float8_e4m3fn not available")
|
|
|
|
|
|
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
|
@triton.jit
|
|
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
|
s = tl.max(tl.abs(x)) / 448.0
|
|
y = x / s
|
|
y = y.to(y_ptr.dtype.element_ty)
|
|
tl.store(y_ptr + offs, y)
|
|
tl.store(s_ptr + pid, s)
|
|
|
|
|
|
def act_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
|
assert x.is_contiguous()
|
|
assert x.shape[-1] % block_size == 0
|
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
|
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
|
|
|
|
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
|
return y, s
|
|
|
|
|
|
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
|
|
@triton.jit
|
|
def _w8a8_block_fp8_matmul(
|
|
# Pointers to inputs and output
|
|
A,
|
|
B,
|
|
C,
|
|
As,
|
|
Bs,
|
|
# Shape for matmul
|
|
M,
|
|
N,
|
|
K,
|
|
# Block size for block-wise quantization
|
|
group_n,
|
|
group_k,
|
|
# Stride for inputs and output
|
|
stride_am,
|
|
stride_ak,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
stride_As_m,
|
|
stride_As_k,
|
|
stride_Bs_k,
|
|
stride_Bs_n,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
):
|
|
"""Triton-accelerated function used to perform linear operations (dot
|
|
product) on input tensors `A` and `B` with block-wise quantization, and
|
|
store the result in output tensor `C`.
|
|
"""
|
|
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
As_ptrs = As + offs_am * stride_As_m
|
|
offs_bsn = offs_bn // group_n
|
|
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
|
|
k_start = k * BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
|
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
if C.dtype.element_ty == tl.bfloat16:
|
|
c = accumulator.to(tl.bfloat16)
|
|
elif C.dtype.element_ty == tl.float16:
|
|
c = accumulator.to(tl.float16)
|
|
else:
|
|
c = accumulator.to(tl.float32)
|
|
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
|
|
@triton.jit
|
|
def _w8a8_block_fp8_matmul_per_tensor(
|
|
# Pointers to inputs and output
|
|
A,
|
|
B,
|
|
C,
|
|
As,
|
|
Bs,
|
|
# Shape for matmul
|
|
M,
|
|
N,
|
|
K,
|
|
# Block size for block-wise quantization
|
|
group_n,
|
|
group_k,
|
|
# Stride for inputs and output
|
|
stride_am,
|
|
stride_ak,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
):
|
|
"""Triton-accelerated function used to perform linear operations (dot
|
|
product) on input tensors `A` and `B` with per-tensor quantization, and
|
|
store the result in output tensor `C`.
|
|
"""
|
|
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
scale_a = tl.load(As)
|
|
scale_b = tl.load(Bs)
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
|
|
accumulator += tl.dot(a, b) * scale_a * scale_b
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
if C.dtype.element_ty == tl.bfloat16:
|
|
c = accumulator.to(tl.bfloat16)
|
|
elif C.dtype.element_ty == tl.float16:
|
|
c = accumulator.to(tl.float16)
|
|
else:
|
|
c = accumulator.to(tl.float32)
|
|
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
|
|
def w8a8_block_fp8_matmul_triton(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
block_size: list[int],
|
|
output_dtype: torch.dtype = torch.float32,
|
|
) -> torch.Tensor:
|
|
"""This function performs matrix multiplication with block-wise
|
|
quantization.
|
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
|
The output is returned in the specified `output_dtype`.
|
|
Args:
|
|
A: The input tensor, e.g., activation.
|
|
B: The input tensor, e.g., weight.
|
|
As: The per-token-group quantization scale for `A`.
|
|
Bs: The per-block quantization scale for `B`.
|
|
block_size: The block size for per-block quantization. It should
|
|
be 2-dim, e.g., [128, 128].
|
|
output_dytpe: The dtype of the returned tensor.
|
|
Returns:
|
|
torch.Tensor: The result of matmul.
|
|
"""
|
|
if block_size is None:
|
|
block_n, block_k = 128, 128
|
|
else:
|
|
assert len(block_size) == 2
|
|
block_n, block_k = block_size[0], block_size[1]
|
|
|
|
# if we have per-tensor quantization, we use 128x128 block size for tiled matmul multiplication
|
|
if block_n == B.shape[-2] and block_k == B.shape[-1]:
|
|
block_n = 128
|
|
block_k = 128
|
|
|
|
assert A.shape[-1] == B.shape[-1]
|
|
|
|
if As.numel() != 1:
|
|
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
|
|
|
M = A.numel() // A.shape[-1]
|
|
|
|
N, K = B.shape
|
|
assert B.ndim == 2 and B.is_contiguous()
|
|
if Bs.numel() != 1:
|
|
assert Bs.ndim == 2
|
|
assert triton.cdiv(N, block_n) == Bs.shape[0], f"{N}, {block_n}, {Bs.shape}"
|
|
assert triton.cdiv(K, block_k) == Bs.shape[1], f"{K}, {block_k}, {Bs.shape}"
|
|
|
|
C_shape = A.shape[:-1] + (N,)
|
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
|
|
|
BLOCK_SIZE_M = 128
|
|
if M < BLOCK_SIZE_M:
|
|
BLOCK_SIZE_M = triton.next_power_of_2(M)
|
|
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
|
|
BLOCK_SIZE_K = block_k
|
|
assert block_k % BLOCK_SIZE_K == 0
|
|
BLOCK_SIZE_N = block_n
|
|
|
|
def grid(META):
|
|
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
|
|
|
|
if As.numel() == 1 and Bs.numel() == 1:
|
|
_w8a8_block_fp8_matmul_per_tensor[grid](
|
|
A,
|
|
B,
|
|
C,
|
|
As,
|
|
Bs,
|
|
M,
|
|
N,
|
|
K,
|
|
block_n,
|
|
block_k,
|
|
A.stride(-2),
|
|
A.stride(-1),
|
|
B.stride(1),
|
|
B.stride(0),
|
|
C.stride(-2),
|
|
C.stride(-1),
|
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
GROUP_SIZE_M=8,
|
|
)
|
|
else:
|
|
_w8a8_block_fp8_matmul[grid](
|
|
A,
|
|
B,
|
|
C,
|
|
As,
|
|
Bs,
|
|
M,
|
|
N,
|
|
K,
|
|
block_n,
|
|
block_k,
|
|
A.stride(-2),
|
|
A.stride(-1),
|
|
B.stride(1),
|
|
B.stride(0),
|
|
C.stride(-2),
|
|
C.stride(-1),
|
|
As.stride(-2),
|
|
As.stride(-1),
|
|
Bs.stride(1),
|
|
Bs.stride(0),
|
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
GROUP_SIZE_M=8,
|
|
)
|
|
|
|
return C
|
|
|
|
|
|
def w8a8_block_fp8_matmul(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
block_size: list[int],
|
|
output_dtype: torch.dtype = torch.float32,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Dispatch to CUTLASS or Triton for block-wise FP8 matmul.
|
|
|
|
Uses CUTLASS when:
|
|
- Block size is [128, 128] (the only size CUTLASS supports)
|
|
- Running on SM90+ (Hopper or newer)
|
|
- The CUTLASS kernel is available
|
|
- Output dtype is bfloat16 or float16 (CUTLASS requirement)
|
|
- Tensor dimensions are compatible (divisible by 16)
|
|
|
|
Otherwise falls back to Triton.
|
|
"""
|
|
|
|
if _supports_cutlass(block_size, output_dtype):
|
|
kernel = _get_quantization_kernel()
|
|
if kernel is not None:
|
|
try:
|
|
# CUTLASS expects:
|
|
# - A: [M, K] row-major, float8_e4m3fn
|
|
# - B: [K, N] column-major, float8_e4m3fn
|
|
# - As: [M, K//128] M-major (activation scales)
|
|
# - Bs: [K//128, N//128] K-major (weight scales)
|
|
|
|
# Reshape A to 2D if needed
|
|
original_shape = A.shape
|
|
M = A.numel() // A.shape[-1]
|
|
K = A.shape[-1]
|
|
N = B.shape[0]
|
|
|
|
# CUTLASS requires dimensions divisible by 16
|
|
if K % 16 != 0 or N % 16 != 0:
|
|
raise ValueError(f"CUTLASS requires K ({K}) and N ({N}) divisible by 16")
|
|
|
|
A_2d = A.view(M, K).contiguous()
|
|
# B needs to be column-major for CUTLASS: [K, N] with stride(0)==1
|
|
# Our B is [N, K] row-major. Make it contiguous first, then transpose.
|
|
# B.contiguous() gives [N, K] with stride=(K,1)
|
|
# B.contiguous().t() gives [K, N] with stride=(1,K) which is column-major
|
|
# Do NOT call .contiguous() after .t() as it would make it row-major!
|
|
B_col_major = B.contiguous().t()
|
|
|
|
# Scales need proper layout for CUTLASS blockwise:
|
|
# As should be [M, K//128] with M-major layout (stride(0)==1)
|
|
# Bs should be [K//128, N//128] with K-major layout (stride(0)==1)
|
|
|
|
# As: reshape to [M, K//128], then make M-major via t().contiguous().t()
|
|
As_2d = As.view(M, -1).contiguous()
|
|
As_2d = As_2d.t().contiguous().t() # [M, K//128] with stride(0)==1
|
|
|
|
# Bs: our input is [N//128, K//128], need [K//128, N//128] with stride(0)==1
|
|
# Transpose to get [K//128, N//128], then make K-major via t().contiguous().t()
|
|
Bs_km = Bs.contiguous().t() # [K//128, N//128]
|
|
Bs_km = Bs_km.t().contiguous().t() # Make K-major (stride(0)==1)
|
|
|
|
# Call CUTLASS kernel - it returns the output tensor
|
|
# Signature: cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias=None) -> Tensor
|
|
C = kernel.cutlass_scaled_mm(A_2d, B_col_major, As_2d, Bs_km, output_dtype, None)
|
|
# Reshape output back
|
|
C_shape = original_shape[:-1] + (N,)
|
|
return C.view(C_shape)
|
|
except Exception as e:
|
|
logger.warning_once(f"CUTLASS kernel failed: {e}. Falling back to Triton.")
|
|
|
|
# Fall back to Triton
|
|
return w8a8_block_fp8_matmul_triton(A, B, As, Bs, block_size, output_dtype)
|
|
|
|
|
|
# Python version of the above triton function, it's much slower than the triton version, for testing
|
|
@torch.compile
|
|
def w8a8_block_fp8_matmul_compile(
|
|
input_q: torch.Tensor, # [batch, seq_len, hidden_dim]
|
|
weight_q: torch.Tensor, # [out_features, hidden_dim]
|
|
input_scale: torch.Tensor, # [batch * seq_len, num_input_groups]
|
|
weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n]
|
|
block_size: tuple[int, int] | None = None, # (M=128, N=128) for weights for example
|
|
output_dtype: torch.dtype = torch.float32,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Performs blocked matrix multiplication with FP8 quantized matrices.
|
|
|
|
Args:
|
|
input_q: Quantized input tensor with 1x128 block quantization
|
|
weight_q: Quantized weight tensor with 128x128 block quantization
|
|
input_scale: Scaling factors for input blocks
|
|
weight_scale: Scaling factors for weight blocks
|
|
block_size: Tuple of (M, N) for weight block dimensions
|
|
output_dtype: Desired output dtype
|
|
"""
|
|
batch_size, seq_len, hidden_dim = input_q.shape if input_q.ndim == 3 else (1, input_q.shape[0], input_q.shape[1])
|
|
out_features = weight_q.shape[0]
|
|
|
|
# Reshape input for batched matmul
|
|
input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim]
|
|
input_scale_reshaped = input_scale.view(input_scale.shape[0], -1) # [batch*seq_len, 1]
|
|
# Calculate number of blocks
|
|
num_weight_blocks_m = out_features // block_size[0]
|
|
num_weight_blocks_n = hidden_dim // block_size[1]
|
|
|
|
output = torch.zeros((batch_size * seq_len, out_features), dtype=torch.float32, device=input_q.device)
|
|
|
|
for i in range(num_weight_blocks_m):
|
|
m_start = i * block_size[0]
|
|
m_end = m_start + block_size[0]
|
|
|
|
for j in range(num_weight_blocks_n):
|
|
n_start = j * block_size[1]
|
|
n_end = n_start + block_size[1]
|
|
|
|
# Extract current blocks
|
|
input_block = input_reshaped[:, n_start:n_end]
|
|
weight_block = weight_q[m_start:m_end, n_start:n_end]
|
|
|
|
# Get corresponding scales
|
|
curr_input_scale = input_scale_reshaped[:, j : j + 1] # [batch*seq_len, 1]
|
|
curr_weight_scale = weight_scale[i, j] # scalar
|
|
|
|
block_result = (
|
|
torch._scaled_mm(
|
|
input_block,
|
|
weight_block.t(),
|
|
scale_a=torch.tensor(1, dtype=torch.float32, device=input_q.device),
|
|
scale_b=curr_weight_scale,
|
|
out_dtype=output_dtype,
|
|
)
|
|
* curr_input_scale
|
|
)
|
|
|
|
output[:, m_start:m_end] += block_result
|
|
|
|
output = output.view(batch_size, seq_len, out_features)
|
|
|
|
return output.to(output_dtype)
|
|
|
|
|
|
class FP8Linear(nn.Linear):
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
bias: bool = False,
|
|
dtype=torch.float8_e4m3fn,
|
|
block_size: tuple[int, int] | None = None,
|
|
activation_scheme="dynamic",
|
|
):
|
|
super().__init__(in_features, out_features)
|
|
|
|
# If block size is None, it means that we are doing per-tensor quantization
|
|
self.block_size = block_size
|
|
self.activation_scheme = activation_scheme
|
|
|
|
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
|
|
|
|
if self.block_size is None:
|
|
self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
|
|
else:
|
|
scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0]
|
|
scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1]
|
|
self.weight_scale_inv = nn.Parameter(
|
|
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
|
|
)
|
|
|
|
if self.activation_scheme == "static":
|
|
self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
|
|
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.empty(self.out_features))
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
if self.weight.element_size() > 1:
|
|
return F.linear(input, self.weight, self.bias)
|
|
else:
|
|
if isinstance(self.weight, torch.distributed.tensor.DTensor):
|
|
weight = self.weight._local_tensor.contiguous()
|
|
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
|
|
else:
|
|
weight = self.weight.contiguous()
|
|
scale_inv = self.weight_scale_inv.contiguous()
|
|
# Context manager used to switch among the available accelerators
|
|
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
|
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
|
with torch_accelerator_module.device(input.device):
|
|
if self.activation_scheme == "dynamic":
|
|
qinput, scale = act_quant(input, self.block_size[1])
|
|
elif self.activation_scheme == "static":
|
|
scale = self.activation_scale.to(torch.float32)
|
|
qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(torch.float8_e4m3fn)
|
|
|
|
else:
|
|
raise NotImplementedError("Not supported")
|
|
|
|
output = w8a8_block_fp8_matmul(
|
|
qinput,
|
|
weight,
|
|
scale,
|
|
scale_inv,
|
|
self.block_size,
|
|
output_dtype=input.dtype,
|
|
)
|
|
|
|
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
|
# preceding operations are ready before proceeding
|
|
torch_accelerator_module.synchronize()
|
|
if self.bias is not None:
|
|
output = output + self.bias
|
|
|
|
return output.to(dtype=input.dtype)
|
|
|
|
|
|
def _ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
|
|
class FP8Expert(nn.Module):
|
|
def __init__(self, config, block_size, dtype=torch.float8_e4m3fn):
|
|
super().__init__()
|
|
|
|
from ..activations import ACT2FN
|
|
|
|
self.block_size = block_size
|
|
# TODO we don't need exact expert count here but only in forward
|
|
self.num_experts = config.num_local_experts if hasattr(config, "num_local_experts") else config.num_experts
|
|
self.hidden_dim = config.hidden_size
|
|
self.intermediate_dim = (
|
|
config.moe_intermediate_size if hasattr(config, "moe_intermediate_size") else config.intermediate_size
|
|
)
|
|
|
|
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
|
|
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
|
|
|
|
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=dtype))
|
|
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=dtype))
|
|
|
|
bo, bi = self.block_size
|
|
|
|
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
|
|
gu_scale_o = _ceil_div(Wg_out, bo)
|
|
gu_scale_i = _ceil_div(Wg_in, bi)
|
|
self.gate_up_proj_scale_inv = nn.Parameter(
|
|
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32)
|
|
)
|
|
|
|
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
|
|
dp_scale_o = _ceil_div(Wd_out, bo)
|
|
dp_scale_i = _ceil_div(Wd_in, bi)
|
|
self.down_proj_scale_inv = nn.Parameter(
|
|
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32)
|
|
)
|
|
|
|
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
|
|
self.register_parameter("gate_up_bias", None)
|
|
self.register_parameter("down_bias", None)
|
|
|
|
# Activation used in the MLP (same as your config / ACT2FN)
|
|
# Keep a handle here; actual usage happens in forward of your MoE block
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
# We follow the mixtral "eager" moe implementation at
|
|
# https://github.com/huggingface/transformers/blob/457048fbfdba9a7dee8bd03328c62f49e57b95f9/src/transformers/models/mixtral/modular_mixtral.py#L148
|
|
# The core changes in this FP8 version should only relate to how we call the linear projections
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
top_k_index: torch.Tensor,
|
|
top_k_weights: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
final_hidden_states = torch.zeros_like(hidden_states)
|
|
with torch.no_grad():
|
|
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
|
|
expert_mask = expert_mask.permute(2, 1, 0)
|
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
|
|
|
for expert_idx in expert_hit:
|
|
expert_idx = expert_idx[0]
|
|
if expert_idx == len(self.gate_up_proj): # weights will load fine
|
|
continue
|
|
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
|
current_state = hidden_states[token_idx]
|
|
gate, up = self.linear(
|
|
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scale_inv[expert_idx]
|
|
).chunk(2, dim=-1)
|
|
current_hidden_states = self.act_fn(gate) * up
|
|
current_hidden_states = self.linear(
|
|
current_hidden_states, self.down_proj[expert_idx], self.down_proj_scale_inv[expert_idx]
|
|
)
|
|
|
|
routing_weights = top_k_weights[token_idx, top_k_pos, None]
|
|
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
|
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
|
|
|
return final_hidden_states
|
|
|
|
def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor:
|
|
if weight.element_size() > 1:
|
|
return F.linear(input, weight, None)
|
|
else:
|
|
# Context manager used to switch among the available accelerators
|
|
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
|
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
|
with torch_accelerator_module.device(input.device):
|
|
qinput, scale = act_quant(input, self.block_size[1])
|
|
output = w8a8_block_fp8_matmul(
|
|
qinput,
|
|
weight,
|
|
scale,
|
|
weight_scale_inv,
|
|
self.block_size,
|
|
output_dtype=input.dtype,
|
|
)
|
|
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
|
# preceding operations are ready before proceeding
|
|
torch_accelerator_module.synchronize()
|
|
return output.to(dtype=input.dtype)
|
|
|
|
|
|
def replace_with_fp8_linear(
|
|
model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False
|
|
):
|
|
"""
|
|
A helper function to replace all `torch.nn.Linear` modules by `FP8Linear` modules.
|
|
|
|
Parameters:
|
|
model (`torch.nn.Module`):
|
|
Input model or `torch.nn.Module` as the function is run recursively.
|
|
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
|
|
Names of the modules to not convert. In practice we keep the `lm_head` in full precision for numerical stability reasons.
|
|
quantization_config (`FbgemmFp8Config`):
|
|
The quantization config object that contains the quantization parameters.
|
|
pre_quantized (`book`, defaults to `False`):
|
|
Whether the model is pre-quantized or not
|
|
"""
|
|
|
|
if quantization_config.dequantize:
|
|
return model
|
|
|
|
has_been_replaced = False
|
|
for module_name, module in model.named_modules():
|
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
continue
|
|
# we need this to correctly materialize the weights during quantization
|
|
module_kwargs = {} if pre_quantized else {"dtype": None}
|
|
new_module = None
|
|
with torch.device("meta"):
|
|
if module_name.endswith(".experts"):
|
|
new_module = FP8Expert(
|
|
config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
|
|
)
|
|
elif isinstance(module, nn.Linear):
|
|
new_module = FP8Linear(
|
|
in_features=module.in_features,
|
|
out_features=module.out_features,
|
|
bias=module.bias is not None,
|
|
activation_scheme=quantization_config.activation_scheme,
|
|
block_size=quantization_config.weight_block_size,
|
|
**module_kwargs,
|
|
)
|
|
if new_module is not None:
|
|
model.set_submodule(module_name, new_module)
|
|
has_been_replaced = True
|
|
|
|
if not has_been_replaced:
|
|
logger.warning(
|
|
"You are loading your model using fp8 but no linear modules were found in your model."
|
|
" Please double check your model architecture."
|
|
)
|
|
return model
|
|
|
|
|
|
class Fp8Quantize(ConversionOps):
|
|
"""
|
|
A quantization operation that creates two tensors, weight and scale out of a weight.
|
|
"""
|
|
|
|
def __init__(self, hf_quantizer):
|
|
self.hf_quantizer = hf_quantizer
|
|
|
|
def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
|
|
# Unpack single key/value (value may be wrapped in a list)
|
|
target_keys, value = tuple(input_dict.items())[0]
|
|
value = value[0]
|
|
|
|
# Resolve block size (support dict-like or attr-like quant_config)
|
|
block_size = None
|
|
if self.hf_quantizer.quantization_config is not None:
|
|
if isinstance(self.hf_quantizer.quantization_config, dict):
|
|
block_size = self.hf_quantizer.quantization_config.get("weight_block_size")
|
|
else:
|
|
block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None)
|
|
if block_size is None:
|
|
block_size = (value.shape[-2], value.shape[-1])
|
|
|
|
block_m, block_n = block_size
|
|
rows, cols = value.shape[-2], value.shape[-1]
|
|
|
|
# Enforce exact tiling like your original
|
|
if rows % block_m != 0 or cols % block_n != 0:
|
|
raise ValueError(
|
|
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
|
|
)
|
|
|
|
# Leading dims can be empty (2D) or include num_experts/... (3D+)
|
|
leading_shape = value.shape[:-2]
|
|
rows_tiles = rows // block_m
|
|
cols_tiles = cols // block_n
|
|
|
|
original_shape = value.shape
|
|
value_fp32 = value.to(torch.float32)
|
|
|
|
# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
|
|
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
|
|
|
|
# Per-tile max-abs over the block dims
|
|
# dims: block_m is at -3, block_n is at -1 after the reshape
|
|
max_abs = reshaped.abs().amax(dim=(-3, -1))
|
|
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
|
|
|
|
# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
|
|
scales = _FP8_MAX / safe_max_abs
|
|
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
|
|
|
|
# Broadcast scales back over the block dims and quantize
|
|
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
|
|
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
|
|
scaled = reshaped * scales_broadcast
|
|
|
|
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
|
|
|
quantized = quantized.reshape(original_shape)
|
|
|
|
inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
|
|
if target_keys.endswith("weight"):
|
|
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
|
|
else:
|
|
scale_key = target_keys + "_scale_inv"
|
|
|
|
# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
|
|
return {
|
|
target_keys: quantized,
|
|
scale_key: inv_scales,
|
|
}
|
|
|
|
|
|
class Fp8Dequantize(ConversionOps):
|
|
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
|
|
|
|
def __init__(self, hf_quantizer):
|
|
self.hf_quantizer = hf_quantizer
|
|
|
|
def convert(
|
|
self,
|
|
input_dict: dict[str, torch.Tensor],
|
|
full_layer_name: str | None = None,
|
|
**kwargs,
|
|
) -> dict[str, torch.Tensor]:
|
|
if len(input_dict) < 2:
|
|
# case where we only got weights, need to check for "weight$"
|
|
return {full_layer_name: input_dict["weight$"]}
|
|
|
|
quantized = input_dict["weight$"][0]
|
|
scales = input_dict["weight_scale_inv"][0]
|
|
|
|
rows, cols = quantized.shape[-2:]
|
|
block_size = self.hf_quantizer.quantization_config.weight_block_size
|
|
if block_size is None:
|
|
block_size = (quantized.shape[-2], quantized.shape[-1])
|
|
|
|
block_m, block_n = block_size
|
|
|
|
if rows % block_m != 0 or cols % block_n != 0:
|
|
raise ValueError(
|
|
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
|
|
)
|
|
quantized = quantized.to(scales.dtype)
|
|
reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
|
|
expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n)
|
|
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
|
|
dequantized = reshaped * expanded_scales
|
|
|
|
return {
|
|
full_layer_name: dequantized.reshape(quantized.shape),
|
|
}
|