You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
272 lines
11 KiB
272 lines
11 KiB
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import importlib
|
|
import re
|
|
from typing import TYPE_CHECKING
|
|
|
|
from packaging import version
|
|
|
|
from .base import HfQuantizer
|
|
from .quantizers_utils import get_module_from_name, should_convert_module
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from ..modeling_utils import PreTrainedModel
|
|
|
|
from safetensors import safe_open
|
|
|
|
from ..utils import is_torch_available, is_torchao_available, logging
|
|
|
|
|
|
if is_torch_available():
|
|
from ..core_model_loading import WeightConverter
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
if is_torchao_available():
|
|
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
|
|
from torchao.prototype.safetensors.safetensors_support import (
|
|
flatten_tensor_state_dict,
|
|
)
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def fuzzy_match_size(config_name: str) -> str | None:
|
|
"""
|
|
Extract the size digit from strings like "4weight", "8weight".
|
|
Returns the digit as an integer if found, otherwise None.
|
|
"""
|
|
config_name = config_name.lower()
|
|
|
|
str_match = re.search(r"(\d)weight", config_name)
|
|
|
|
if str_match:
|
|
return str_match.group(1)
|
|
|
|
return None
|
|
|
|
|
|
def _quantization_type(weight):
|
|
from torchao.dtypes import AffineQuantizedTensor
|
|
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
|
|
|
if isinstance(weight, AffineQuantizedTensor):
|
|
return f"{weight.__class__.__name__}({weight._quantization_type()})"
|
|
|
|
if isinstance(weight, LinearActivationQuantizedTensor):
|
|
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
|
|
|
|
|
|
def _linear_extra_repr(self):
|
|
weight = _quantization_type(self.weight)
|
|
if weight is None:
|
|
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
|
|
else:
|
|
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
|
|
|
|
|
|
if is_torchao_available():
|
|
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
|
|
|
|
|
|
class TorchAoHfQuantizer(HfQuantizer):
|
|
"""
|
|
Quantizer for torchao: https://github.com/pytorch/ao/
|
|
"""
|
|
|
|
requires_calibration = False
|
|
|
|
def __init__(self, quantization_config, **kwargs):
|
|
super().__init__(quantization_config, **kwargs)
|
|
|
|
self.quantized_param_size = None
|
|
quant_type = self.quantization_config.quant_type
|
|
if isinstance(quant_type, str):
|
|
map_to_param_size = {
|
|
"int4_weight_only": 0.5,
|
|
"int8_weight_only": 1,
|
|
"int8_dynamic_activation_int8_weight": 1,
|
|
}
|
|
if quant_type in map_to_param_size:
|
|
self.quantized_param_size = map_to_param_size[quant_type]
|
|
else:
|
|
size_digit = fuzzy_match_size(quant_type.__class__.__name__)
|
|
self.quantized_param_size = 0.5 if size_digit == "4" else 1
|
|
|
|
def validate_environment(self, *args, **kwargs):
|
|
if not is_torchao_available():
|
|
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
|
|
|
|
self.offload = False
|
|
device_map = kwargs.get("device_map")
|
|
if isinstance(device_map, dict):
|
|
if ("disk" in device_map.values() or "cpu" in device_map.values()) and len(device_map) > 1:
|
|
self.offload = True
|
|
if self.pre_quantized and "disk" in device_map.values():
|
|
raise ValueError(
|
|
"You are attempting to perform disk offload with a pre-quantized torchao model "
|
|
"This is not supported yet . Please remove the disk device from the device_map."
|
|
)
|
|
if self.pre_quantized:
|
|
weights_only = kwargs.get("weights_only")
|
|
if weights_only:
|
|
torch_version = version.parse(importlib.metadata.version("torch"))
|
|
if torch_version < version.parse("2.5.0"):
|
|
raise RuntimeError(
|
|
f"In order to use torchao pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
|
|
f" You can also set with `weights_only=False` in `from_pretrained` if you don't want to update torch"
|
|
)
|
|
|
|
def update_dtype(self, dtype):
|
|
if self.quantization_config.quant_type == "int4_weight_only":
|
|
if dtype != torch.bfloat16:
|
|
logger.warning_once(
|
|
f"Setting dtype to {dtype} for int4_weight_only quantization, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
|
|
)
|
|
dtype = torch.bfloat16
|
|
return dtype
|
|
|
|
def get_state_dict_and_metadata(self, model):
|
|
"""
|
|
We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format.
|
|
"""
|
|
if version.parse("0.15.0") <= TORCHAO_VERSION:
|
|
return flatten_tensor_state_dict(model.state_dict())
|
|
else:
|
|
raise RuntimeError(
|
|
f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
|
|
)
|
|
|
|
def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
|
|
"Return the element size (in bytes) for `param_name`."
|
|
if self.param_needs_quantization(model, param_name) and self.quantized_param_size is not None:
|
|
return self.quantized_param_size
|
|
|
|
return super().param_element_size(model, param_name, param)
|
|
|
|
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
|
|
# need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128
|
|
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
|
|
return max_memory
|
|
|
|
def _process_model_before_weight_loading(self, model: "PreTrainedModel", checkpoint_files=None, **kwargs):
|
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
|
|
)
|
|
if self.quantization_config.include_input_output_embeddings:
|
|
input_emb = model.get_input_embeddings()
|
|
input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
|
|
output_emb = model.get_output_embeddings()
|
|
output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
|
|
self.modules_to_not_convert = [
|
|
x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
|
|
]
|
|
if checkpoint_files is not None:
|
|
# Torchao needs access to all metadata later
|
|
self.set_metadata(checkpoint_files)
|
|
|
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
# check if the param_name is not in self.modules_to_not_convert
|
|
if not should_convert_module(param_name, self.modules_to_not_convert):
|
|
return False
|
|
|
|
# we only quantize the weight of nn.Linear and nn.Embedding
|
|
module, tensor_name = get_module_from_name(model, param_name)
|
|
_QUANTIZABLE = [torch.nn.Linear]
|
|
if self.quantization_config.include_input_output_embeddings:
|
|
_QUANTIZABLE.append(torch.nn.Embedding)
|
|
|
|
# Handle FqnToConfig, introduced in torchao 0.15.0+
|
|
if self.quantization_config._get_ao_version() >= version.parse("0.15.0"):
|
|
from torchao.quantization import FqnToConfig, fqn_matches_fqn_config
|
|
|
|
if isinstance(self.quantization_config.quant_type, FqnToConfig):
|
|
module_fqn, param_name_fqn = param_name.rsplit(".", 1)
|
|
if (
|
|
fqn_matches_fqn_config(module_fqn, self.quantization_config.quant_type)
|
|
or fqn_matches_fqn_config(param_name, self.quantization_config.quant_type)
|
|
or (
|
|
"_default" in self.quantization_config.quant_type.fqn_to_config
|
|
and isinstance(module, tuple(_QUANTIZABLE))
|
|
)
|
|
):
|
|
return True
|
|
|
|
return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
|
|
|
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
|
return
|
|
|
|
def is_serializable(self) -> bool:
|
|
_is_torchao_serializable = version.parse("0.15.0") <= TORCHAO_VERSION
|
|
if not version.parse("0.15.0") <= TORCHAO_VERSION:
|
|
logger.warning(
|
|
"torchao quantized model only supports serialization for torchao version >= 0.15.0, please upgrade "
|
|
"your version to save the quantized model"
|
|
)
|
|
return _is_torchao_serializable
|
|
|
|
@property
|
|
def is_trainable(self) -> bool:
|
|
supported_quant_types_for_training = [
|
|
"int8_weight_only",
|
|
"int8_dynamic_activation_int8_weight",
|
|
]
|
|
return self.quantization_config.quant_type in supported_quant_types_for_training
|
|
|
|
@property
|
|
def is_compileable(self) -> bool:
|
|
return True
|
|
|
|
def set_metadata(self, checkpoint_files: list[str]):
|
|
if checkpoint_files[0].endswith(".safetensors"):
|
|
metadata = {}
|
|
for checkpoint in checkpoint_files:
|
|
with safe_open(checkpoint, framework="pt") as f:
|
|
metadata_ = f.metadata() or {}
|
|
metadata.update(metadata_)
|
|
# Save it
|
|
self.metadata = metadata
|
|
|
|
def get_quantize_ops(self):
|
|
from ..integrations.torchao import TorchAoQuantize
|
|
|
|
return TorchAoQuantize(self)
|
|
|
|
def get_weight_conversions(self):
|
|
from ..integrations.torchao import TorchAoDeserialize
|
|
|
|
if self.pre_quantized:
|
|
return [
|
|
WeightConverter(
|
|
# TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_"
|
|
# note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect)
|
|
# thus, the order of source_patterns is intentional
|
|
source_patterns=[
|
|
"_weight_qdata",
|
|
"_weight_scale_and_zero",
|
|
"_weight_scale",
|
|
"_weight_zero_point",
|
|
"_weight_act_pre_scale",
|
|
],
|
|
target_patterns="weight",
|
|
operations=[TorchAoDeserialize(self)],
|
|
),
|
|
]
|
|
return []
|