You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

345 lines
14 KiB

4 days ago
import inspect
from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
from ..utils import (
get_available_devices,
is_accelerate_available,
is_bitsandbytes_available,
is_torch_available,
logging,
)
if is_bitsandbytes_available():
import bitsandbytes as bnb
if is_torch_available():
import torch
import torch.nn as nn
from ..pytorch_utils import Conv1D
if is_accelerate_available():
import accelerate
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
logger = logging.get_logger(__name__)
class Bnb4bitQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
def convert(
self,
input_dict: dict[str, list[torch.Tensor]],
full_layer_name: str | None = None,
model: torch.nn.Module | None = None,
**kwargs,
) -> dict[str, torch.Tensor]:
"""
we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.
"""
value = list(input_dict.values())[0]
value = value[0]
# update param name to get the weights instead of the quantized stats
module, _ = get_module_from_name(model, full_layer_name)
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D):
value = value.T
old_value = model.get_parameter_or_buffer(full_layer_name)
new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device)
module._is_hf_initialized = True
return {full_layer_name: new_value}
class Bnb4bitDeserialize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
def convert(
self,
input_dict: dict[str, list[torch.Tensor]],
model: torch.nn.Module | None = None,
full_layer_name: str | None = None,
**kwargs,
) -> dict[str, torch.Tensor]:
"""
Deserialization of bnb keys. We need 6 keys to recreate the quantized weights
"""
if len(input_dict) == 1:
return input_dict
for key, value in input_dict.items():
if isinstance(value, list):
input_dict[key] = value[0]
key_weight = "weight"
weight = input_dict.pop(key_weight)
module, _ = get_module_from_name(model, full_layer_name)
new_value = bnb.nn.Params4bit.from_prequantized(
data=weight,
quantized_stats=input_dict,
requires_grad=False,
device=weight.device,
module=module,
)
module._is_hf_initialized = True
return {key_weight: new_value}
class Bnb8bitQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
def convert(
self,
input_dict: dict[str, list[torch.Tensor]],
model: torch.nn.Module | None = None,
full_layer_name: str | None = None,
**kwargs,
) -> dict[str, torch.Tensor]:
value = list(input_dict.values())[0]
value = value[0] if isinstance(value, list) else value
module, _ = get_module_from_name(model, full_layer_name)
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D):
value = value.T
value_device = value.device
kwargs = model.get_parameter_or_buffer(full_layer_name).__dict__
kwargs.pop("SCB", None)
new_value = bnb.nn.Int8Params(value.to("cpu"), requires_grad=False, **kwargs).to(value_device)
return {full_layer_name: new_value}
class Bnb8bitDeserialize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
def convert(
self,
input_dict: dict[str, list[torch.Tensor]],
model: torch.nn.Module | None = None,
full_layer_name: str | None = None,
**kwargs,
) -> dict[str, torch.Tensor]:
"""
Deserialization of bnb keys.
"""
if len(input_dict) == 1:
# special case when we only fetched the weight
# since we collected keys, we need to return it like that
return input_dict
for key, value in input_dict.items():
if isinstance(value, list):
input_dict[key] = value[0]
module, _ = get_module_from_name(model, full_layer_name)
key_weight = "weight"
weight = input_dict[key_weight]
kwargs = model.get_parameter_or_buffer(full_layer_name).__dict__
kwargs["SCB"] = input_dict["SCB"]
new_value = bnb.nn.Int8Params(weight, requires_grad=False, **kwargs).to(weight.device)
module._is_hf_initialized = True
return {key_weight: new_value}
def replace_with_bnb_linear(
model: torch.nn.Module,
modules_to_not_convert: list[str] | None = None,
quantization_config=None,
pre_quantized=False,
):
"""
A helper function to replace all `torch.nn.Linear` modules by bnb modules from the `bitsandbytes` library.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
modules_to_not_convert (`list[str]`, defaults to `None`):
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
converted.
quantization_config (`BitsAndBytesConfig`):
The quantization config object that contains the quantization parameters.
pre_quantized (`book`, defaults to `False`):
Whether the model is pre-quantized or not
"""
has_been_replaced = False
# we need this to correctly materialize the weights during quantization
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
new_module = None
with torch.device("meta"):
if isinstance(module, (nn.Linear, Conv1D)):
if isinstance(module, Conv1D):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features
if quantization_config.quantization_method() == "llm_int8":
new_module = bnb.nn.Linear8bitLt(
in_features,
out_features,
module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold,
)
if pre_quantized:
# this is kind of an edge case when supporting both loading and quantization ...
# we need to set the right dtype as we cast the checkpoint with the dtype of the meta model
new_module.weight.data = new_module.weight.data.to(dtype=torch.int8)
else:
new_module = bnb.nn.Linear4bit(
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
quant_storage=quantization_config.bnb_4bit_quant_storage,
)
if pre_quantized:
# same here
new_module.weight.data = new_module.weight.data.to(
dtype=quantization_config.bnb_4bit_quant_storage
)
if new_module is not None:
# Store the module class in case we need to transpose the weight later
new_module.source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
new_module.requires_grad_(False)
model.set_submodule(module_name, new_module)
has_been_replaced = True
if not has_been_replaced:
logger.warning(
"You are loading your model using eetq but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
"""
Helper function to dequantize 4bit or 8bit bnb weights.
If the weight is not a bnb quantized weight, it will be returned as is.
"""
if not isinstance(weight, torch.nn.Parameter):
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
cls_name = weight.__class__.__name__
if cls_name not in ("Params4bit", "Int8Params"):
return weight
if cls_name == "Params4bit":
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
return output_tensor
if state.SCB is None:
state.SCB = weight.SCB
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
# Use bitsandbytes API if available (requires v0.45.0+)
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
else:
# Multiply by (scale/127) to dequantize.
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
return dequantized
def _create_accelerate_new_hook(old_hook):
r"""
Creates a new hook based on the old hook. Use it only if you know what you are doing !
This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245
with some changes
"""
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
old_hook_attr = old_hook.__dict__
filtered_old_hook_attr = {}
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
for k in old_hook_attr:
if k in old_hook_init_signature.parameters:
filtered_old_hook_attr[k] = old_hook_attr[k]
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook
def dequantize_and_replace(model, quantization_config=None, dtype=None):
"""
Converts a quantized model into its dequantized original version. The newly converted model will have
some performance drop compared to the original model before quantization - use it only for specific usecases
such as QLoRA adapters merging.
Returns the converted model.
"""
quant_method = quantization_config.quantization_method()
target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
for module_name, module in model.named_modules():
if isinstance(module, target_cls):
with torch.device("meta"):
bias = getattr(module, "bias", None)
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
state = module.state if quant_method == "llm_int8" else None
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
weight = dequantize_bnb_weight(module.weight, state)
if dtype is None:
logger.warning_once(
f"The modules are dequantized in {weight.dtype}. If you want to change the dtype, please specify `dtype` in `dequantize`. "
)
else:
logger.warning_once(f"The modules are dequantized in {weight.dtype} and casted to {dtype}.")
weight = weight.to(dtype)
new_module.weight = torch.nn.Parameter(weight)
if bias is not None:
new_module.bias = bias
if hasattr(module, "_hf_hook"):
old_hook = module._hf_hook
new_hook = _create_accelerate_new_hook(old_hook)
remove_hook_from_module(module)
add_hook_to_module(new_module, new_hook)
new_module.to(module.weight.device)
model.set_submodule(module_name, new_module)
has_been_replaced = True
if not has_been_replaced:
logger.warning(
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
)
return model
def validate_bnb_backend_availability(raise_exception=False):
"""
Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not.
"""
bnb_supported_devices = getattr(bnb, "supported_torch_devices", set())
available_devices = set(get_available_devices())
if not available_devices.intersection(bnb_supported_devices):
if raise_exception:
err_msg = (
f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices}`. "
"Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation"
)
raise RuntimeError(err_msg)
logger.warning("No supported devices found for bitsandbytes")
return False
return True