You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

272 lines
11 KiB

4 days ago
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import re
from typing import TYPE_CHECKING
from packaging import version
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name, should_convert_module
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from safetensors import safe_open
from ..utils import is_torch_available, is_torchao_available, logging
if is_torch_available():
from ..core_model_loading import WeightConverter
if is_torch_available():
import torch
if is_torchao_available():
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
)
logger = logging.get_logger(__name__)
def fuzzy_match_size(config_name: str) -> str | None:
"""
Extract the size digit from strings like "4weight", "8weight".
Returns the digit as an integer if found, otherwise None.
"""
config_name = config_name.lower()
str_match = re.search(r"(\d)weight", config_name)
if str_match:
return str_match.group(1)
return None
def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
if isinstance(weight, AffineQuantizedTensor):
return f"{weight.__class__.__name__}({weight._quantization_type()})"
if isinstance(weight, LinearActivationQuantizedTensor):
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
def _linear_extra_repr(self):
weight = _quantization_type(self.weight)
if weight is None:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
else:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
if is_torchao_available():
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
class TorchAoHfQuantizer(HfQuantizer):
"""
Quantizer for torchao: https://github.com/pytorch/ao/
"""
requires_calibration = False
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantized_param_size = None
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, str):
map_to_param_size = {
"int4_weight_only": 0.5,
"int8_weight_only": 1,
"int8_dynamic_activation_int8_weight": 1,
}
if quant_type in map_to_param_size:
self.quantized_param_size = map_to_param_size[quant_type]
else:
size_digit = fuzzy_match_size(quant_type.__class__.__name__)
self.quantized_param_size = 0.5 if size_digit == "4" else 1
def validate_environment(self, *args, **kwargs):
if not is_torchao_available():
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
self.offload = False
device_map = kwargs.get("device_map")
if isinstance(device_map, dict):
if ("disk" in device_map.values() or "cpu" in device_map.values()) and len(device_map) > 1:
self.offload = True
if self.pre_quantized and "disk" in device_map.values():
raise ValueError(
"You are attempting to perform disk offload with a pre-quantized torchao model "
"This is not supported yet . Please remove the disk device from the device_map."
)
if self.pre_quantized:
weights_only = kwargs.get("weights_only")
if weights_only:
torch_version = version.parse(importlib.metadata.version("torch"))
if torch_version < version.parse("2.5.0"):
raise RuntimeError(
f"In order to use torchao pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
f" You can also set with `weights_only=False` in `from_pretrained` if you don't want to update torch"
)
def update_dtype(self, dtype):
if self.quantization_config.quant_type == "int4_weight_only":
if dtype != torch.bfloat16:
logger.warning_once(
f"Setting dtype to {dtype} for int4_weight_only quantization, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
)
dtype = torch.bfloat16
return dtype
def get_state_dict_and_metadata(self, model):
"""
We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format.
"""
if version.parse("0.15.0") <= TORCHAO_VERSION:
return flatten_tensor_state_dict(model.state_dict())
else:
raise RuntimeError(
f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
)
def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
"Return the element size (in bytes) for `param_name`."
if self.param_needs_quantization(model, param_name) and self.quantized_param_size is not None:
return self.quantized_param_size
return super().param_element_size(model, param_name, param)
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
# need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
return max_memory
def _process_model_before_weight_loading(self, model: "PreTrainedModel", checkpoint_files=None, **kwargs):
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
)
if self.quantization_config.include_input_output_embeddings:
input_emb = model.get_input_embeddings()
input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
output_emb = model.get_output_embeddings()
output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
self.modules_to_not_convert = [
x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
]
if checkpoint_files is not None:
# Torchao needs access to all metadata later
self.set_metadata(checkpoint_files)
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
# check if the param_name is not in self.modules_to_not_convert
if not should_convert_module(param_name, self.modules_to_not_convert):
return False
# we only quantize the weight of nn.Linear and nn.Embedding
module, tensor_name = get_module_from_name(model, param_name)
_QUANTIZABLE = [torch.nn.Linear]
if self.quantization_config.include_input_output_embeddings:
_QUANTIZABLE.append(torch.nn.Embedding)
# Handle FqnToConfig, introduced in torchao 0.15.0+
if self.quantization_config._get_ao_version() >= version.parse("0.15.0"):
from torchao.quantization import FqnToConfig, fqn_matches_fqn_config
if isinstance(self.quantization_config.quant_type, FqnToConfig):
module_fqn, param_name_fqn = param_name.rsplit(".", 1)
if (
fqn_matches_fqn_config(module_fqn, self.quantization_config.quant_type)
or fqn_matches_fqn_config(param_name, self.quantization_config.quant_type)
or (
"_default" in self.quantization_config.quant_type.fqn_to_config
and isinstance(module, tuple(_QUANTIZABLE))
)
):
return True
return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
def _process_model_after_weight_loading(self, model, **kwargs):
return
def is_serializable(self) -> bool:
_is_torchao_serializable = version.parse("0.15.0") <= TORCHAO_VERSION
if not version.parse("0.15.0") <= TORCHAO_VERSION:
logger.warning(
"torchao quantized model only supports serialization for torchao version >= 0.15.0, please upgrade "
"your version to save the quantized model"
)
return _is_torchao_serializable
@property
def is_trainable(self) -> bool:
supported_quant_types_for_training = [
"int8_weight_only",
"int8_dynamic_activation_int8_weight",
]
return self.quantization_config.quant_type in supported_quant_types_for_training
@property
def is_compileable(self) -> bool:
return True
def set_metadata(self, checkpoint_files: list[str]):
if checkpoint_files[0].endswith(".safetensors"):
metadata = {}
for checkpoint in checkpoint_files:
with safe_open(checkpoint, framework="pt") as f:
metadata_ = f.metadata() or {}
metadata.update(metadata_)
# Save it
self.metadata = metadata
def get_quantize_ops(self):
from ..integrations.torchao import TorchAoQuantize
return TorchAoQuantize(self)
def get_weight_conversions(self):
from ..integrations.torchao import TorchAoDeserialize
if self.pre_quantized:
return [
WeightConverter(
# TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_"
# note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect)
# thus, the order of source_patterns is intentional
source_patterns=[
"_weight_qdata",
"_weight_scale_and_zero",
"_weight_scale",
"_weight_zero_point",
"_weight_act_pre_scale",
],
target_patterns="weight",
operations=[TorchAoDeserialize(self)],
),
]
return []