You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
645 lines
25 KiB
645 lines
25 KiB
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from ..utils import is_torch_available, logging
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
from torch import nn
|
|
from contextlib import contextmanager
|
|
|
|
from ..core_model_loading import ConversionOps
|
|
from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
FP4_VALUES = [
|
|
+0.0,
|
|
+0.5,
|
|
+1.0,
|
|
+1.5,
|
|
+2.0,
|
|
+3.0,
|
|
+4.0,
|
|
+6.0,
|
|
-0.0,
|
|
-0.5,
|
|
-1.0,
|
|
-1.5,
|
|
-2.0,
|
|
-3.0,
|
|
-4.0,
|
|
-6.0,
|
|
]
|
|
|
|
|
|
@contextmanager
|
|
def on_device(dev):
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
if isinstance(dev, torch.Tensor):
|
|
dev = dev.device
|
|
elif isinstance(dev, str):
|
|
dev = torch.device(dev)
|
|
dev_type = getattr(dev, "type", None)
|
|
if dev_type == "cuda":
|
|
with torch.cuda.device(dev):
|
|
yield
|
|
return
|
|
if dev_type == "xpu" and hasattr(torch, "xpu"):
|
|
with torch.xpu.device(dev):
|
|
yield
|
|
return
|
|
# other: CPU
|
|
yield
|
|
|
|
|
|
class Mxfp4Quantize(ConversionOps):
|
|
def __init__(self, hf_quantizer):
|
|
self.hf_quantizer = hf_quantizer
|
|
|
|
def convert(
|
|
self,
|
|
input_dict: dict[str, torch.Tensor],
|
|
model: torch.nn.Module | None = None,
|
|
missing_keys: list[str] | None = None,
|
|
full_layer_name: str | None = None,
|
|
**kwargs,
|
|
) -> dict[str, torch.Tensor]:
|
|
_, value = tuple(input_dict.items())[0]
|
|
value = value[0] if isinstance(value, list) else value
|
|
|
|
module, _ = get_module_from_name(model, full_layer_name)
|
|
|
|
with torch.device(value.device):
|
|
if isinstance(module, Mxfp4GptOssExperts):
|
|
triton_weight_tensor, weight_scale = quantize_to_mxfp4(value.transpose(-1, -2), triton_kernels_hub)
|
|
PrecisionConfig, FlexCtx, InFlexData = (
|
|
triton_kernels_hub.matmul_ogs.PrecisionConfig,
|
|
triton_kernels_hub.matmul_ogs.FlexCtx,
|
|
triton_kernels_hub.matmul_ogs.InFlexData,
|
|
)
|
|
triton_weight_tensor, weight_scale = swizzle_mxfp4(
|
|
triton_weight_tensor, weight_scale, triton_kernels_hub
|
|
)
|
|
|
|
proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
|
|
|
|
if proj in module._parameters:
|
|
# Remove the nn.Parameter registration so we can attach the Triton tensor
|
|
del module._parameters[proj]
|
|
|
|
setattr(module, proj, triton_weight_tensor)
|
|
setattr(
|
|
module,
|
|
f"{proj}_precision_config",
|
|
PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
|
|
)
|
|
|
|
missing_keys.discard(f"{full_layer_name}")
|
|
module._is_hf_initialized = True
|
|
|
|
return {}
|
|
|
|
|
|
class Mxfp4Dequantize(ConversionOps):
|
|
def __init__(self, hf_quantizer):
|
|
self.hf_quantizer = hf_quantizer
|
|
|
|
def convert(
|
|
self,
|
|
input_dict: dict[str, torch.Tensor],
|
|
model: torch.nn.Module | None = None,
|
|
full_layer_name: str | None = None,
|
|
missing_keys=None,
|
|
**kwargs,
|
|
) -> dict[str, torch.Tensor]:
|
|
if "_blocks" in input_dict.keys():
|
|
if isinstance(input_dict["_blocks"], list):
|
|
blocks = input_dict["_blocks"][0]
|
|
else:
|
|
blocks = input_dict["_blocks"]
|
|
if "_scales" in input_dict.keys():
|
|
if isinstance(input_dict["_scales"], list):
|
|
scales = input_dict["_scales"][0]
|
|
else:
|
|
scales = input_dict["_scales"]
|
|
|
|
# Here we are dequantizing the weights
|
|
dequantized = dequantize_convertops(blocks, scales)
|
|
return {full_layer_name: dequantized}
|
|
|
|
|
|
class Mxfp4Deserialize(ConversionOps):
|
|
def __init__(self, hf_quantizer):
|
|
self.hf_quantizer = hf_quantizer
|
|
|
|
def convert(
|
|
self,
|
|
input_dict: dict[str, torch.Tensor],
|
|
model: torch.nn.Module | None = None,
|
|
full_layer_name: str | None = None,
|
|
missing_keys: list[str] | None = None,
|
|
**kwargs,
|
|
) -> dict[str, torch.Tensor]:
|
|
param_data = {}
|
|
if "_blocks" in input_dict.keys():
|
|
if isinstance(input_dict["_blocks"], list):
|
|
param_data["_blocks"] = input_dict["_blocks"][0]
|
|
else:
|
|
param_data["_blocks"] = input_dict["_blocks"]
|
|
if "_scales" in input_dict.keys():
|
|
if isinstance(input_dict["_scales"], list):
|
|
param_data["_scales"] = input_dict["_scales"][0]
|
|
else:
|
|
param_data["_scales"] = input_dict["_scales"]
|
|
|
|
# Eagerly set tensors on the module and perform swizzle
|
|
module, _ = get_module_from_name(model, full_layer_name)
|
|
proj = "gate_up_proj" if "gate_up_proj" in full_layer_name else "down_proj"
|
|
swizzle_mxfp4_convertops(
|
|
param_data["_blocks"],
|
|
param_data["_scales"],
|
|
module,
|
|
proj,
|
|
param_data["_blocks"].device,
|
|
triton_kernels_hub,
|
|
)
|
|
missing_keys.discard(f"{full_layer_name}")
|
|
module._is_hf_initialized = True
|
|
# We return an empty mapping since the module was updated in-place. This prevents
|
|
# the loader from trying to materialize the original meta-parameter names again.
|
|
# We don't use set_param_for_module since it expects mainly a torch.nn.Parameter or a safetensors pointer
|
|
return {}
|
|
|
|
|
|
# Copied from GPT_OSS repo and vllm
|
|
def quantize_to_mxfp4(w, triton_kernels_hub):
|
|
downcast_to_mxfp_torch = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp_torch
|
|
w, w_scale = downcast_to_mxfp_torch(w.to(torch.bfloat16), torch.uint8, axis=1)
|
|
return w, w_scale
|
|
|
|
|
|
def swizzle_mxfp4(w, w_scale, triton_kernels_hub):
|
|
"""
|
|
Changes the layout of the tensors depending on the hardware
|
|
"""
|
|
FP4, convert_layout, wrap_torch_tensor = (
|
|
triton_kernels_hub.tensor.FP4,
|
|
triton_kernels_hub.tensor.convert_layout,
|
|
triton_kernels_hub.tensor.wrap_torch_tensor,
|
|
)
|
|
layout = triton_kernels_hub.tensor_details.layout
|
|
StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
|
|
|
|
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
|
|
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
|
|
w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
|
|
return w, w_scale
|
|
|
|
|
|
# Mostly copied from GPT_OSS repo
|
|
# TODO: Add absolute link when the repo is public
|
|
def _convert_moe_packed_tensors(
|
|
blocks,
|
|
scales,
|
|
*,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
rows_per_chunk: int = 32768 * 1024, # TODO these values are not here by mistake ;)
|
|
) -> torch.Tensor:
|
|
"""
|
|
Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
|
|
pass of GPT_OSS.
|
|
"""
|
|
import math
|
|
|
|
blocks = blocks.to(torch.uint8)
|
|
scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
|
|
|
|
assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}"
|
|
|
|
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
|
|
|
|
*prefix_shape, G, B = blocks.shape
|
|
rows_total = math.prod(prefix_shape) * G
|
|
|
|
blocks = blocks.reshape(rows_total, B)
|
|
scales = scales.reshape(rows_total, 1)
|
|
|
|
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
|
|
|
|
for r0 in range(0, rows_total, rows_per_chunk):
|
|
r1 = min(r0 + rows_per_chunk, rows_total)
|
|
|
|
blk = blocks[r0:r1]
|
|
exp = scales[r0:r1]
|
|
sub = out[r0:r1]
|
|
|
|
# This vector is only used to index into `lut`, but is hugeee in GPU memory so we delete it immediately
|
|
idx_lo = (blk & 0x0F).to(torch.int)
|
|
sub[:, 0::2] = lut[idx_lo]
|
|
del idx_lo
|
|
|
|
# This vector is only used to index into `lut`, but is hugeee in GPU memory so we delete it immediately
|
|
idx_hi = (blk >> 4).to(torch.int)
|
|
sub[:, 1::2] = lut[idx_hi]
|
|
del idx_hi
|
|
|
|
# Perform op
|
|
torch.ldexp(sub, exp, out=sub)
|
|
del blk, exp, sub
|
|
|
|
out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
|
|
|
|
return out.transpose(1, 2).contiguous()
|
|
|
|
|
|
def convert_moe_packed_tensors(
|
|
blocks,
|
|
scales,
|
|
*,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
rows_per_chunk: int = 32768 * 1024, # TODO these values are not here by mistake ;)
|
|
) -> torch.Tensor:
|
|
"""
|
|
Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
|
|
pass of GPT_OSS.
|
|
"""
|
|
# Since the intermediate ops requite A LOT of memory, in very constrained device_map="auto" settings
|
|
# it may OOM, hence this wrapper and move back to cpu if needed
|
|
# torch statistics are not accurate enough to estimate if we will have enough memory due to fragmentation and
|
|
# in-place operation on non-contiguous tensors (may sometimes require more temporary copies)
|
|
try:
|
|
return _convert_moe_packed_tensors(blocks, scales, dtype=dtype, rows_per_chunk=rows_per_chunk)
|
|
# In the case of OOM due to very tight device_map, we convert and return on cpu - it will then be put back on correct
|
|
# devide with the accelerate dispatch (doing it right away may still lead to OOM, but more memory is available later)
|
|
except torch.OutOfMemoryError:
|
|
blocks = blocks.to("cpu")
|
|
scales = scales.to("cpu")
|
|
return _convert_moe_packed_tensors(blocks, scales, dtype=dtype, rows_per_chunk=rows_per_chunk)
|
|
|
|
|
|
class Mxfp4GptOssExperts(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.num_experts = config.num_local_experts
|
|
self.intermediate_size = config.intermediate_size
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.gate_up_proj = nn.Parameter(
|
|
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
|
|
requires_grad=False,
|
|
)
|
|
|
|
self.gate_up_proj_bias = nn.Parameter(
|
|
torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
|
|
)
|
|
|
|
self.down_proj = nn.Parameter(
|
|
torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
|
|
requires_grad=False,
|
|
)
|
|
|
|
self.down_proj_bias = nn.Parameter(
|
|
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
|
|
)
|
|
self.alpha = 1.702
|
|
self.limit = getattr(config, "swiglu_limit", 7.0)
|
|
self.gate_up_proj_precision_config = None
|
|
self.down_proj_precision_config = None
|
|
self.limit = getattr(config, "swiglu_limit", 7.0)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
|
|
FnSpecs, FusedActivation, matmul_ogs = (
|
|
triton_kernels_hub.matmul_ogs.FnSpecs,
|
|
triton_kernels_hub.matmul_ogs.FusedActivation,
|
|
triton_kernels_hub.matmul_ogs.matmul_ogs,
|
|
)
|
|
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
|
|
|
|
with on_device(hidden_states.device):
|
|
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2)
|
|
|
|
intermediate_cache1 = matmul_ogs(
|
|
hidden_states,
|
|
self.gate_up_proj,
|
|
self.gate_up_proj_bias.to(torch.float32),
|
|
routing_data,
|
|
gather_indx=gather_idx,
|
|
precision_config=self.gate_up_proj_precision_config,
|
|
gammas=None,
|
|
fused_activation=act,
|
|
)
|
|
|
|
intermediate_cache3 = matmul_ogs(
|
|
intermediate_cache1,
|
|
self.down_proj,
|
|
self.down_proj_bias.to(torch.float32),
|
|
routing_data,
|
|
scatter_indx=scatter_idx,
|
|
precision_config=self.down_proj_precision_config,
|
|
gammas=routing_data.gate_scal,
|
|
)
|
|
return intermediate_cache3
|
|
|
|
|
|
# Adapted from GPT_OSS repo
|
|
# TODO: Add absolute link when the repo is public
|
|
def routing_torch_dist(
|
|
logits,
|
|
n_expts_act,
|
|
):
|
|
import os
|
|
|
|
GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
|
|
triton_kernels_hub.routing.GatherIndx,
|
|
triton_kernels_hub.routing.RoutingData,
|
|
triton_kernels_hub.routing.ScatterIndx,
|
|
triton_kernels_hub.routing.compute_expt_data_torch,
|
|
)
|
|
|
|
with on_device(logits.device):
|
|
world_size = torch.distributed.get_world_size()
|
|
rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
replace_value = -1
|
|
|
|
n_tokens = logits.shape[0]
|
|
n_expts_tot = logits.shape[1]
|
|
|
|
n_local_experts = n_expts_tot // world_size
|
|
local_expert_start = rank * n_local_experts
|
|
local_expert_end = (rank + 1) * n_local_experts
|
|
|
|
n_gates_pad = n_tokens * n_expts_act
|
|
|
|
def topk(vals, k):
|
|
tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
|
|
tk_indx = tk_indx.long()
|
|
tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
|
|
return tk_val, tk_indx.int()
|
|
|
|
expt_scal, expt_indx = topk(logits, n_expts_act)
|
|
expt_scal = torch.softmax(expt_scal, dim=-1)
|
|
expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
|
|
expt_scal = torch.gather(expt_scal, 1, sort_indices)
|
|
|
|
# Flatten and mask for local experts
|
|
expt_scal = expt_scal.reshape(-1)
|
|
|
|
hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start:local_expert_end]
|
|
|
|
expt_indx = expt_indx.view(-1).to(torch.int32)
|
|
|
|
# we use a large value to replace the indices that are not in the local expert range
|
|
var = 1000
|
|
expt_indx = torch.where(expt_indx < local_expert_start, var, expt_indx)
|
|
topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32)
|
|
gate_indx = torch.argsort(topk_indx).to(torch.int32)
|
|
expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value)
|
|
expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value)
|
|
|
|
gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx)
|
|
gate_scal = expt_scal[topk_indx]
|
|
|
|
topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx)
|
|
|
|
# # Routing metadata for local expert computation
|
|
gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
|
|
scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
|
|
|
|
expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad)
|
|
|
|
hit_experts = n_expts_act
|
|
return RoutingData(gate_scal, hist, n_local_experts, hit_experts, expt_data), gather_indx, scatter_indx
|
|
|
|
|
|
def mlp_forward(self, hidden_states):
|
|
import torch.distributed as dist
|
|
|
|
if dist.is_available() and dist.is_initialized() and hasattr(self, "_is_hooked"):
|
|
routing = routing_torch_dist
|
|
else:
|
|
routing = triton_kernels_hub.routing.routing
|
|
|
|
batch_size = hidden_states.shape[0]
|
|
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
|
|
router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
|
|
|
|
with on_device(router_logits.device):
|
|
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
|
|
|
|
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
|
|
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
|
|
return routed_out, router_logits
|
|
|
|
|
|
def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
|
|
from ..integrations.tensor_parallel import shard_and_distribute_module
|
|
|
|
model = kwargs.get("model")
|
|
empty_param = kwargs.get("empty_param")
|
|
casting_dtype = kwargs.get("casting_dtype")
|
|
to_contiguous = kwargs.get("to_contiguous")
|
|
rank = kwargs.get("rank")
|
|
device_mesh = kwargs.get("device_mesh")
|
|
|
|
for proj in ["gate_up_proj", "down_proj"]:
|
|
if proj in param_name:
|
|
if device_mesh is not None:
|
|
param_value = shard_and_distribute_module(
|
|
model,
|
|
param_value,
|
|
empty_param,
|
|
dq_param_name,
|
|
casting_dtype,
|
|
to_contiguous,
|
|
rank,
|
|
device_mesh,
|
|
)
|
|
blocks_attr = f"{proj}_blocks"
|
|
scales_attr = f"{proj}_scales"
|
|
setattr(module, param_name.rsplit(".", 1)[1], param_value)
|
|
if hasattr(module, blocks_attr) and hasattr(module, scales_attr):
|
|
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
|
|
setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
|
|
delattr(module, blocks_attr)
|
|
delattr(module, scales_attr)
|
|
|
|
|
|
def dequantize_convertops(blocks, scales):
|
|
dequantized = convert_moe_packed_tensors(blocks, scales)
|
|
return torch.nn.Parameter(dequantized)
|
|
|
|
|
|
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs):
|
|
"""
|
|
This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
|
|
"""
|
|
PrecisionConfig, FlexCtx, InFlexData = (
|
|
triton_kernels_hub.matmul_ogs.PrecisionConfig,
|
|
triton_kernels_hub.matmul_ogs.FlexCtx,
|
|
triton_kernels_hub.matmul_ogs.InFlexData,
|
|
)
|
|
from ..integrations.tensor_parallel import shard_and_distribute_module
|
|
|
|
model = kwargs.get("model")
|
|
empty_param = kwargs.get("empty_param")
|
|
casting_dtype = kwargs.get("casting_dtype")
|
|
to_contiguous = kwargs.get("to_contiguous")
|
|
rank = kwargs.get("rank")
|
|
device_mesh = kwargs.get("device_mesh")
|
|
if "blocks" in param_name:
|
|
proj = param_name.split(".")[-1].split("_blocks")[0]
|
|
if "scales" in param_name:
|
|
proj = param_name.split(".")[-1].split("_scales")[0]
|
|
if device_mesh is not None:
|
|
shard_and_distribute_module(
|
|
model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh
|
|
)
|
|
else:
|
|
setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False))
|
|
blocks_attr = f"{proj}_blocks"
|
|
scales_attr = f"{proj}_scales"
|
|
blocks = getattr(module, blocks_attr) # at this point values were loaded from ckpt
|
|
scales = getattr(module, scales_attr)
|
|
# Check if both blocks and scales both not on meta device
|
|
if blocks.device.type != "meta" and scales.device.type != "meta":
|
|
local_experts = blocks.size(0)
|
|
if proj == "gate_up_proj":
|
|
blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1)
|
|
else:
|
|
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
|
|
if getattr(target_device, "type", target_device) == "cpu":
|
|
target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
|
blocks = blocks.to(target_device).contiguous()
|
|
scales = scales.to(target_device).contiguous()
|
|
with on_device(target_device):
|
|
triton_weight_tensor, weight_scale = swizzle_mxfp4(
|
|
blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub
|
|
)
|
|
|
|
# need to overwrite the shapes for the kernels
|
|
if proj == "gate_up_proj":
|
|
triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2])
|
|
else:
|
|
triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size])
|
|
|
|
# triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor
|
|
setattr(module, proj, triton_weight_tensor)
|
|
setattr(
|
|
module,
|
|
f"{proj}_precision_config",
|
|
PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
|
|
)
|
|
|
|
# delete blocks and scales
|
|
delattr(module, scales_attr)
|
|
delattr(module, blocks_attr)
|
|
del blocks
|
|
|
|
|
|
def swizzle_mxfp4_convertops(blocks, scales, module, proj, target_device, triton_kernels_hub):
|
|
"""
|
|
This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
|
|
"""
|
|
PrecisionConfig, FlexCtx, InFlexData = (
|
|
triton_kernels_hub.matmul_ogs.PrecisionConfig,
|
|
triton_kernels_hub.matmul_ogs.FlexCtx,
|
|
triton_kernels_hub.matmul_ogs.InFlexData,
|
|
)
|
|
|
|
local_experts = blocks.size(0)
|
|
if getattr(target_device, "type", target_device) == "cpu":
|
|
target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
|
|
|
blocks = blocks.to(target_device).contiguous()
|
|
scales = scales.to(target_device).contiguous()
|
|
|
|
if proj == "gate_up_proj":
|
|
blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1)
|
|
else:
|
|
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
|
|
if getattr(target_device, "type", target_device) == "cpu":
|
|
target_device = "cuda"
|
|
|
|
with on_device(target_device):
|
|
triton_weight_tensor, weight_scale = swizzle_mxfp4(
|
|
blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub
|
|
)
|
|
# need to overwrite the shapes for the kernels
|
|
if proj == "gate_up_proj":
|
|
triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2])
|
|
else:
|
|
triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size])
|
|
|
|
# triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It's like a subtensor
|
|
# Since the Experts module registers gate_up_proj and down_proj as nn.Parameters, we need to remove them so we can attach the Triton tensor
|
|
if proj in module._parameters:
|
|
# Remove the nn.Parameter registration so we can attach the Triton tensor
|
|
del module._parameters[proj]
|
|
setattr(module, proj, triton_weight_tensor)
|
|
setattr(
|
|
module,
|
|
f"{proj}_precision_config",
|
|
PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
|
|
)
|
|
|
|
|
|
def replace_with_mxfp4_linear(model, quantization_config=None, modules_to_not_convert: list[str] | None = None):
|
|
"""
|
|
Public method that replaces the expert layers of the given model with mxfp4 quantized layers.
|
|
|
|
Args:
|
|
model (`torch.nn.Module`):
|
|
The model to convert, can be any `torch.nn.Module` instance.
|
|
quantization_config (`Mxfp4Config`, defaults to `None`):
|
|
The quantization config object that contains the quantization parameters.
|
|
modules_to_not_convert (`list`, *optional*, defaults to `None`):
|
|
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
|
|
converted.
|
|
"""
|
|
if quantization_config.dequantize:
|
|
return model
|
|
|
|
from .hub_kernels import get_kernel
|
|
|
|
global triton_kernels_hub
|
|
triton_kernels_hub = get_kernel("kernels-community/gpt-oss-triton-kernels")
|
|
|
|
has_been_replaced = False
|
|
for module_name, module in model.named_modules():
|
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
continue
|
|
if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
|
|
with torch.device("meta"):
|
|
model.set_submodule(module_name, Mxfp4GptOssExperts(model.config))
|
|
has_been_replaced = True
|
|
if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize:
|
|
from types import MethodType
|
|
|
|
module.forward = MethodType(mlp_forward, module)
|
|
|
|
if not has_been_replaced:
|
|
logger.warning(
|
|
"You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model."
|
|
" Please double check your model architecture, or submit an issue on github if you think this is"
|
|
" a bug."
|
|
)
|
|
|
|
return model
|