You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
124 lines
4.7 KiB
124 lines
4.7 KiB
# Copyright 2024 NetEase, Inc. and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from ..core_model_loading import ConversionOps
|
|
from ..quantizers.quantizers_utils import should_convert_module
|
|
from ..utils import is_torch_available, logging
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class EetqQuantize(ConversionOps):
|
|
def __init__(self, hf_quantizer):
|
|
self.hf_quantizer = hf_quantizer
|
|
|
|
def convert(
|
|
self, input_dict: dict[str, list[torch.Tensor]], full_layer_name: str | None = None, **kwargs
|
|
) -> dict[str, torch.Tensor]:
|
|
_, value = tuple(input_dict.items())[0]
|
|
value = value[0]
|
|
|
|
value_device = value.device
|
|
int8_weight = torch.t(value).contiguous().cpu()
|
|
int8_weight, scales = eetq_kernels_hub.quant_weights(int8_weight, torch.int8, False)
|
|
|
|
int8_weight = int8_weight.to(value_device)
|
|
scales = scales.to(value_device)
|
|
|
|
return {full_layer_name: int8_weight, f"{full_layer_name}_scales": scales}
|
|
|
|
|
|
class EetqLinearMMFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, weight, scales, bias=None):
|
|
# The forward pass can use ctx.
|
|
ctx.save_for_backward(x, weight, scales, bias)
|
|
output = eetq_kernels_hub.w8_a16_gemm(x, weight, scales)
|
|
output = output + bias if bias is not None else output
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, weight, scales, bias = ctx.saved_tensors
|
|
identity = torch.eye(weight.shape[0]).to(weight.device).to(input.dtype)
|
|
|
|
# Dequantize the weight
|
|
weight = eetq_kernels_hub.w8_a16_gemm(identity, weight, scales)
|
|
|
|
if ctx.needs_input_grad[0]:
|
|
# 2D matrix multiplication, unsqueeze to 3D
|
|
grad_input = grad_output.squeeze(0).matmul(weight.transpose(0, 1)).unsqueeze(0)
|
|
|
|
return grad_input, None, None, None
|
|
|
|
|
|
class EetqLinear(nn.Module):
|
|
def __init__(self, in_features, out_features, dtype=torch.int8, bias=False):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.empty((in_features, out_features), dtype=dtype), requires_grad=False)
|
|
self.weight_scales = nn.Parameter(torch.empty((out_features), dtype=torch.float16))
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.empty((out_features), dtype=torch.float16))
|
|
else:
|
|
self.bias = None
|
|
|
|
def forward(self, input):
|
|
output = EetqLinearMMFunction.apply(input, self.weight, self.weight_scales, self.bias)
|
|
return output
|
|
|
|
|
|
def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = None, pre_quantized=False):
|
|
"""
|
|
A helper function to replace all `torch.nn.Linear` modules by `EetqLinear` modules.
|
|
|
|
Parameters:
|
|
model (`torch.nn.Module`):
|
|
Input model or `torch.nn.Module` as the function is run recursively.
|
|
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
|
|
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
|
|
for numerical stability reasons.
|
|
"""
|
|
from .hub_kernels import get_kernel
|
|
|
|
global eetq_kernels_hub
|
|
eetq_kernels_hub = get_kernel("kernels-community/quantization-eetq")
|
|
|
|
has_been_replaced = False
|
|
# we need this to correctly materialize the weights during quantization
|
|
module_kwargs = {} if pre_quantized else {"dtype": None}
|
|
for module_name, module in model.named_modules():
|
|
if not should_convert_module(module_name, modules_to_not_convert):
|
|
continue
|
|
with torch.device("meta"):
|
|
if isinstance(module, nn.Linear):
|
|
new_module = EetqLinear(
|
|
module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs
|
|
)
|
|
model.set_submodule(module_name, new_module)
|
|
has_been_replaced = True
|
|
|
|
if not has_been_replaced:
|
|
logger.warning(
|
|
"You are loading your model using eetq but no linear modules were found in your model."
|
|
" Please double check your model architecture, or submit an issue on github if you think this is"
|
|
" a bug."
|
|
)
|
|
|
|
return model
|