You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
177 lines
6.6 KiB
177 lines
6.6 KiB
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import TYPE_CHECKING
|
|
|
|
from .base import HfQuantizer
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from ..modeling_utils import PreTrainedModel
|
|
|
|
from ..utils import (
|
|
ACCELERATE_MIN_VERSION,
|
|
BITSANDBYTES_MIN_VERSION,
|
|
is_accelerate_available,
|
|
is_bitsandbytes_available,
|
|
is_torch_available,
|
|
is_torch_hpu_available,
|
|
is_torch_npu_available,
|
|
is_torch_xpu_available,
|
|
logging,
|
|
)
|
|
from .quantizers_utils import get_module_from_name
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from ..core_model_loading import WeightConverter
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Bnb8BitHfQuantizer(HfQuantizer):
|
|
"""
|
|
8-bit quantization from bitsandbytes quantization method
|
|
"""
|
|
|
|
requires_calibration = False
|
|
|
|
def __init__(self, quantization_config, **kwargs):
|
|
super().__init__(quantization_config, **kwargs)
|
|
|
|
def validate_environment(self, *args, **kwargs):
|
|
if not is_accelerate_available():
|
|
raise ImportError(
|
|
f"Using `bitsandbytes` 8-bit quantization requires accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
|
)
|
|
if not is_bitsandbytes_available():
|
|
raise ImportError(
|
|
f"Using `bitsandbytes` 8-bit quantization requires bitsandbytes: `pip install -U bitsandbytes>={BITSANDBYTES_MIN_VERSION}`"
|
|
)
|
|
|
|
from ..integrations import validate_bnb_backend_availability
|
|
|
|
validate_bnb_backend_availability(raise_exception=True)
|
|
|
|
device_map = kwargs.get("device_map")
|
|
if not self.quantization_config.llm_int8_enable_fp32_cpu_offload and isinstance(device_map, dict):
|
|
values = set(device_map.values())
|
|
if values != {"cpu"} and ("cpu" in values or "disk" in values):
|
|
raise ValueError(
|
|
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
|
|
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
|
|
"in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to "
|
|
"`from_pretrained`. Check "
|
|
"https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
|
|
"for more details. "
|
|
)
|
|
|
|
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
|
|
# need more space for buffers that are created during quantization
|
|
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
|
return max_memory
|
|
|
|
def update_device_map(self, device_map):
|
|
if device_map is None:
|
|
if torch.cuda.is_available():
|
|
device_map = {"": torch.cuda.current_device()}
|
|
elif is_torch_npu_available():
|
|
device_map = {"": f"npu:{torch.npu.current_device()}"}
|
|
elif is_torch_hpu_available():
|
|
device_map = {"": f"hpu:{torch.hpu.current_device()}"}
|
|
elif is_torch_xpu_available():
|
|
device_map = {"": torch.xpu.current_device()}
|
|
else:
|
|
device_map = {"": "cpu"}
|
|
logger.info(
|
|
"The device_map was not initialized. "
|
|
f"Setting device_map to {device_map}. "
|
|
"If you want to use the model for inference, please set device_map ='auto' "
|
|
)
|
|
return device_map
|
|
|
|
def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
|
|
"Return the element size (in bytes) for `param_name`."
|
|
if self.param_needs_quantization(model, param_name):
|
|
# 8-bit
|
|
return 1
|
|
return super().param_element_size(model, param_name, param)
|
|
|
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
|
|
import bitsandbytes as bnb
|
|
|
|
module, name = get_module_from_name(model, param_name)
|
|
return isinstance(module, bnb.nn.Linear8bitLt) and name != "bias"
|
|
|
|
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
|
model.is_loaded_in_8bit = True
|
|
model.is_8bit_serializable = self.is_serializable()
|
|
return model
|
|
|
|
def _process_model_before_weight_loading(
|
|
self,
|
|
model: "PreTrainedModel",
|
|
device_map,
|
|
**kwargs,
|
|
):
|
|
from ..integrations import replace_with_bnb_linear
|
|
|
|
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
|
model, self.quantization_config.llm_int8_skip_modules, model._keep_in_fp32_modules
|
|
)
|
|
|
|
if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
|
|
if isinstance(device_map, dict):
|
|
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
|
|
self.modules_to_not_convert.extend(keys_on_cpu)
|
|
|
|
model = replace_with_bnb_linear(
|
|
model,
|
|
modules_to_not_convert=self.modules_to_not_convert,
|
|
quantization_config=self.quantization_config,
|
|
pre_quantized=self.pre_quantized,
|
|
)
|
|
|
|
def is_serializable(self):
|
|
return True
|
|
|
|
@property
|
|
def is_trainable(self) -> bool:
|
|
return True
|
|
|
|
def _dequantize(self, model, dtype=None):
|
|
from ..integrations import dequantize_and_replace
|
|
|
|
model = dequantize_and_replace(model, quantization_config=self.quantization_config, dtype=dtype)
|
|
return model
|
|
|
|
def get_quantize_ops(self):
|
|
from ..integrations.bitsandbytes import Bnb8bitQuantize
|
|
|
|
return Bnb8bitQuantize(self)
|
|
|
|
def get_weight_conversions(self):
|
|
from ..integrations.bitsandbytes import Bnb8bitDeserialize
|
|
|
|
if self.pre_quantized:
|
|
return [
|
|
WeightConverter(
|
|
source_patterns=["SCB", "weight_format", "weight"],
|
|
target_patterns="weight",
|
|
operations=[Bnb8bitDeserialize(self)],
|
|
)
|
|
]
|
|
return []
|