# Copyright 2024 NetEase, Inc. and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..core_model_loading import ConversionOps from ..quantizers.quantizers_utils import should_convert_module from ..utils import is_torch_available, logging if is_torch_available(): import torch import torch.nn as nn logger = logging.get_logger(__name__) class EetqQuantize(ConversionOps): def __init__(self, hf_quantizer): self.hf_quantizer = hf_quantizer def convert( self, input_dict: dict[str, list[torch.Tensor]], full_layer_name: str | None = None, **kwargs ) -> dict[str, torch.Tensor]: _, value = tuple(input_dict.items())[0] value = value[0] value_device = value.device int8_weight = torch.t(value).contiguous().cpu() int8_weight, scales = eetq_kernels_hub.quant_weights(int8_weight, torch.int8, False) int8_weight = int8_weight.to(value_device) scales = scales.to(value_device) return {full_layer_name: int8_weight, f"{full_layer_name}_scales": scales} class EetqLinearMMFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, scales, bias=None): # The forward pass can use ctx. ctx.save_for_backward(x, weight, scales, bias) output = eetq_kernels_hub.w8_a16_gemm(x, weight, scales) output = output + bias if bias is not None else output return output @staticmethod def backward(ctx, grad_output): input, weight, scales, bias = ctx.saved_tensors identity = torch.eye(weight.shape[0]).to(weight.device).to(input.dtype) # Dequantize the weight weight = eetq_kernels_hub.w8_a16_gemm(identity, weight, scales) if ctx.needs_input_grad[0]: # 2D matrix multiplication, unsqueeze to 3D grad_input = grad_output.squeeze(0).matmul(weight.transpose(0, 1)).unsqueeze(0) return grad_input, None, None, None class EetqLinear(nn.Module): def __init__(self, in_features, out_features, dtype=torch.int8, bias=False): super().__init__() self.weight = nn.Parameter(torch.empty((in_features, out_features), dtype=dtype), requires_grad=False) self.weight_scales = nn.Parameter(torch.empty((out_features), dtype=torch.float16)) if bias: self.bias = nn.Parameter(torch.empty((out_features), dtype=torch.float16)) else: self.bias = None def forward(self, input): output = EetqLinearMMFunction.apply(input, self.weight, self.weight_scales, self.bias) return output def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = None, pre_quantized=False): """ A helper function to replace all `torch.nn.Linear` modules by `EetqLinear` modules. Parameters: model (`torch.nn.Module`): Input model or `torch.nn.Module` as the function is run recursively. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`): Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision for numerical stability reasons. """ from .hub_kernels import get_kernel global eetq_kernels_hub eetq_kernels_hub = get_kernel("kernels-community/quantization-eetq") has_been_replaced = False # we need this to correctly materialize the weights during quantization module_kwargs = {} if pre_quantized else {"dtype": None} for module_name, module in model.named_modules(): if not should_convert_module(module_name, modules_to_not_convert): continue with torch.device("meta"): if isinstance(module, nn.Linear): new_module = EetqLinear( module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs ) model.set_submodule(module_name, new_module) has_been_replaced = True if not has_been_replaced: logger.warning( "You are loading your model using eetq but no linear modules were found in your model." " Please double check your model architecture, or submit an issue on github if you think this is" " a bug." ) return model