You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
908 lines
39 KiB
908 lines
39 KiB
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Some of the functions here are derived from the `accelerate` library, with some tweaks for better performances
|
|
and simplicity/ease of use.
|
|
"""
|
|
|
|
import copy
|
|
import inspect
|
|
import os
|
|
import re
|
|
from collections import OrderedDict, defaultdict
|
|
from typing import TYPE_CHECKING
|
|
|
|
from safetensors import safe_open
|
|
from safetensors.torch import save_file
|
|
|
|
from ..utils import (
|
|
is_accelerate_available,
|
|
is_torch_available,
|
|
is_torch_xpu_available,
|
|
logging,
|
|
)
|
|
from ..utils.quantization_config import QuantizationMethod
|
|
from .deepspeed import is_deepspeed_zero3_enabled
|
|
from .fsdp import is_fsdp_enabled
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
if is_accelerate_available():
|
|
from accelerate import dispatch_model
|
|
from accelerate.utils import get_max_memory
|
|
from accelerate.utils.modeling import clean_device_map, get_max_layer_size
|
|
|
|
if TYPE_CHECKING:
|
|
from ..modeling_utils import PreTrainedModel
|
|
from ..quantizers import HfQuantizer
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def get_module_size_with_ties(
|
|
tied_params,
|
|
module_size,
|
|
module_sizes,
|
|
modules_to_treat,
|
|
) -> tuple[int, list[str], list[nn.Module]]:
|
|
"""
|
|
Calculate the total size of a module, including its tied parameters.
|
|
|
|
Args:
|
|
tied_params (`List[str]`): The list of tied parameters.
|
|
module_size (`int`): The size of the module without tied parameters.
|
|
module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size.
|
|
modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat.
|
|
|
|
Returns:
|
|
`Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the
|
|
tied modules.
|
|
"""
|
|
if len(tied_params) < 1:
|
|
return module_size, [], []
|
|
tied_module_names = []
|
|
tied_modules = []
|
|
|
|
for tied_param in tied_params:
|
|
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + ".")][0]
|
|
tied_module_names.append(modules_to_treat[tied_module_index][0])
|
|
tied_modules.append(modules_to_treat[tied_module_index][1])
|
|
|
|
module_size_with_ties = module_size
|
|
for tied_param, tied_module_name in zip(tied_params, tied_module_names):
|
|
module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]
|
|
|
|
return module_size_with_ties, tied_module_names, tied_modules
|
|
|
|
|
|
def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None:
|
|
from ..modeling_utils import get_torch_context_manager_or_global_device
|
|
|
|
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
|
|
if device_map is None and not is_deepspeed_zero3_enabled():
|
|
device_in_context = get_torch_context_manager_or_global_device()
|
|
if device_in_context == torch.device("meta"):
|
|
raise RuntimeError(
|
|
"You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
|
|
"This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
|
|
"empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
|
|
)
|
|
device_map = device_in_context
|
|
|
|
# change device_map into a map if we passed an int, a str or a torch.device
|
|
if isinstance(device_map, torch.device):
|
|
device_map = {"": device_map}
|
|
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
|
try:
|
|
if device_map == "cuda":
|
|
# setting to the local rank
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
device_map = f"cuda:{local_rank}"
|
|
device_map = {"": torch.device(device_map)}
|
|
except RuntimeError:
|
|
raise ValueError(
|
|
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
|
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
|
)
|
|
elif isinstance(device_map, int):
|
|
if device_map < 0:
|
|
raise ValueError(
|
|
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
|
)
|
|
else:
|
|
device_map = {"": device_map}
|
|
|
|
if device_map is not None:
|
|
if is_deepspeed_zero3_enabled():
|
|
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
|
if not is_accelerate_available():
|
|
raise ValueError(
|
|
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
|
|
"requires `accelerate`. You can install it with `pip install accelerate`"
|
|
)
|
|
return device_map
|
|
|
|
|
|
def compute_module_sizes(
|
|
model: "PreTrainedModel",
|
|
hf_quantizer: "HfQuantizer | None" = None,
|
|
buffers_only: bool = False,
|
|
only_modules: bool = True,
|
|
) -> tuple[dict[str, int], dict[str, int]]:
|
|
"""
|
|
Compute the size of each submodule of a given model (in bytes).
|
|
Returns a tuple of 2 dicts, the fist one containing a mapping of all the modules and the corresponding size
|
|
in bytes, and the 2nd one containing a mapping from all leaf modules (modules containing parameters, the end of
|
|
the model graph) and the corresponding sizes.
|
|
If `only_modules` is set to False, the first mapping will not only contain the size of all modules, but also
|
|
the size of all parameters and buffers.
|
|
"""
|
|
all_module_sizes = defaultdict(int)
|
|
leaves_module_sizes = defaultdict(int)
|
|
|
|
if buffers_only:
|
|
iterator = model.named_buffers()
|
|
else:
|
|
# We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space
|
|
def all_tensors():
|
|
yield from model.named_parameters()
|
|
yield from model.named_buffers()
|
|
|
|
iterator = all_tensors()
|
|
|
|
tied_keys = getattr(model, "all_tied_weights_keys", {}).keys()
|
|
for name, param in iterator:
|
|
# Do not count tied keys (the model is usually not tied yet here, so they will appear in the iterator)
|
|
# If the model is already tied, then they simply do not appear in the iterator anyway (remove_duplicates=True by default)
|
|
if name in tied_keys:
|
|
continue
|
|
if hf_quantizer is not None:
|
|
dtype_size = hf_quantizer.param_element_size(model, name, param)
|
|
else:
|
|
dtype_size = param.element_size()
|
|
size = param.numel() * dtype_size
|
|
name_parts = name.split(".")
|
|
for idx in range(len(name_parts)):
|
|
all_module_sizes[".".join(name_parts[:idx])] += size
|
|
if "." in name:
|
|
leaves_module_sizes[name.rsplit(".", 1)[0]] += size
|
|
# If we want to also have the full leaves in `all_module_sizes`
|
|
if not only_modules:
|
|
all_module_sizes[name] += size
|
|
|
|
return all_module_sizes, leaves_module_sizes
|
|
|
|
|
|
def compute_module_total_buffer_size(model: nn.Module, hf_quantizer: "HfQuantizer | None" = None):
|
|
"""
|
|
Compute the total size of buffers in each submodule of a given model.
|
|
"""
|
|
module_sizes, _ = compute_module_sizes(model, hf_quantizer, buffers_only=True)
|
|
return module_sizes.get("", 0)
|
|
|
|
|
|
def get_balanced_memory(
|
|
model: "PreTrainedModel",
|
|
max_memory: dict[int | str, int | str] | None = None,
|
|
no_split_module_classes: set[str] | None = None,
|
|
hf_quantizer: "HfQuantizer | None" = None,
|
|
low_zero: bool = False,
|
|
):
|
|
"""
|
|
Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU.
|
|
|
|
<Tip>
|
|
|
|
All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
|
|
meta device (as it would if initialized within the `init_empty_weights` context manager).
|
|
|
|
</Tip>
|
|
|
|
Args:
|
|
model (`PreTrainedModel`):
|
|
The model to analyze.
|
|
max_memory (`Dict`, *optional*):
|
|
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
|
|
Example: `max_memory={0: "1GB"}`.
|
|
no_split_module_classes (`set[str]`, *optional*):
|
|
A set of layer class names that should never be split across device (for instance any layer that has a
|
|
residual connection).
|
|
hf_quantizer (`HfQuantizer`, *optional*):
|
|
A quantizer for the model.
|
|
low_zero (`bool`, *optional*):
|
|
Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the
|
|
Transformers generate function).
|
|
"""
|
|
# Get default / clean up max_memory
|
|
user_not_set_max_memory = max_memory is None
|
|
max_memory = get_max_memory(max_memory)
|
|
# Check the number of accelerators available
|
|
accelerator_max_memory = copy.deepcopy(max_memory)
|
|
_, _ = accelerator_max_memory.pop("cpu", None), accelerator_max_memory.pop("disk", None)
|
|
num_devices = len([d for d in accelerator_max_memory if accelerator_max_memory[d] > 0])
|
|
|
|
if num_devices == 0:
|
|
return max_memory
|
|
|
|
if num_devices == 1:
|
|
# We cannot do low_zero on just one GPU, but we will still reserve some memory for the buffer
|
|
low_zero = False
|
|
# If user just asked us to handle memory usage, we should avoid OOM
|
|
if user_not_set_max_memory:
|
|
for key in max_memory.keys():
|
|
if isinstance(key, int):
|
|
max_memory[key] *= 0.9 # 90% is a good compromise
|
|
logger.info(
|
|
f"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. "
|
|
"You can set `max_memory` in to a higher value to use more memory (at your own risk)."
|
|
)
|
|
break # only one device
|
|
|
|
module_sizes, leave_modules_sizes = compute_module_sizes(model, hf_quantizer)
|
|
per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices)
|
|
|
|
# We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get
|
|
# slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to
|
|
# add which is the biggest of:
|
|
# - the size of no split block (if applicable)
|
|
# - the mean of the layer sizes
|
|
if no_split_module_classes is None:
|
|
no_split_module_classes = []
|
|
elif not isinstance(no_split_module_classes, (list, tuple, set)):
|
|
no_split_module_classes = [no_split_module_classes]
|
|
|
|
# Identify the size of the no_split_block modules
|
|
buffer = 0
|
|
if len(no_split_module_classes) > 0:
|
|
no_split_children = {}
|
|
for name, size in module_sizes.items():
|
|
if name == "":
|
|
continue
|
|
submodule = model.get_submodule(name)
|
|
class_name = submodule.__class__.__name__
|
|
if class_name in no_split_module_classes and class_name not in no_split_children:
|
|
no_split_children[class_name] = size
|
|
|
|
if set(no_split_children.keys()) == set(no_split_module_classes):
|
|
break
|
|
buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0
|
|
|
|
mean_leaves = int(sum(leave_modules_sizes.values()) / max(len(leave_modules_sizes), 1))
|
|
buffer = int(1.25 * max(buffer, mean_leaves))
|
|
per_gpu += buffer
|
|
|
|
# Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them)
|
|
gpus_idx_list = sorted(
|
|
device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0
|
|
)
|
|
# The last device is left with max_memory just in case the buffer is not enough.
|
|
for idx in gpus_idx_list[:-1]:
|
|
max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx])
|
|
|
|
if low_zero:
|
|
min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)]))
|
|
max_memory[0] = min(min_zero, max_memory[0])
|
|
|
|
return max_memory
|
|
|
|
|
|
def _get_device_map(
|
|
model: "PreTrainedModel",
|
|
device_map: dict | str | None,
|
|
max_memory: dict | None,
|
|
hf_quantizer: "HfQuantizer | None",
|
|
) -> dict:
|
|
"""Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential'].
|
|
Otherwise, we check for any device inconsistencies in the device_map.
|
|
"""
|
|
if isinstance(device_map, str):
|
|
no_split_modules = model._no_split_modules
|
|
|
|
if device_map != "sequential":
|
|
inferred_max_memory = get_balanced_memory(
|
|
model,
|
|
max_memory=max_memory,
|
|
no_split_module_classes=no_split_modules,
|
|
hf_quantizer=hf_quantizer,
|
|
low_zero=(device_map == "balanced_low_0"),
|
|
)
|
|
else:
|
|
inferred_max_memory = get_max_memory(max_memory)
|
|
|
|
# If the user does not provide `max_memory`, accelerate sets the WHOLE cpu available memory as available.
|
|
# This is unwanted, as we don't want to set extremely tight bound and pressure for cpu if we are memory-constrained,
|
|
# especially if the model uses WeightConverter (because there will be some uncontrollable cpu memory spikes during
|
|
# the conversions before we resave the weights). In those cases, it's better to offload to disk a bit more
|
|
# if we were in-between, as otherwise we blow-up cpu memory
|
|
if max_memory is None and "cpu" in inferred_max_memory:
|
|
inferred_max_memory["cpu"] *= 0.90
|
|
|
|
if hf_quantizer is not None:
|
|
inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
|
|
|
|
# `inferred_max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU,
|
|
# which we can use to allocate parameters.
|
|
for device_name in inferred_max_memory:
|
|
if isinstance(device_name, int): # it's a GPU device
|
|
if is_torch_xpu_available():
|
|
unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name)
|
|
else:
|
|
unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
|
|
inferred_max_memory[device_name] += unused_memory
|
|
# respect the `max_memory` passed by the user
|
|
if max_memory is not None and device_name in max_memory:
|
|
inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name])
|
|
|
|
device_map = infer_auto_device_map(
|
|
model,
|
|
max_memory=inferred_max_memory,
|
|
no_split_module_classes=no_split_modules,
|
|
hf_quantizer=hf_quantizer,
|
|
)
|
|
|
|
if hf_quantizer is not None:
|
|
hf_quantizer.validate_environment(device_map=device_map)
|
|
|
|
return device_map
|
|
|
|
|
|
def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers):
|
|
device_map_kwargs = {
|
|
"device_map": device_map,
|
|
"offload_dir": offload_folder,
|
|
"offload_index": offload_index,
|
|
"offload_buffers": offload_buffers,
|
|
}
|
|
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
|
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
|
# For HQQ method we force-set the hooks for single GPU envs
|
|
if (
|
|
"force_hooks" in inspect.signature(dispatch_model).parameters
|
|
and hf_quantizer is not None
|
|
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
|
|
):
|
|
device_map_kwargs["force_hooks"] = True
|
|
if (
|
|
hf_quantizer is not None
|
|
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
|
|
and isinstance(device_map, dict)
|
|
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
|
):
|
|
device_map_kwargs["offload_buffers"] = True
|
|
|
|
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
|
|
dispatch_model(model, **device_map_kwargs)
|
|
|
|
|
|
def expand_device_map(device_map: dict | None, param_names: list[str]):
|
|
"""
|
|
Expand a device map to return the correspondence parameter name to device.
|
|
"""
|
|
if device_map is None:
|
|
return dict.fromkeys(param_names, "cpu")
|
|
|
|
# Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly
|
|
device_map_regex = re.compile(
|
|
"|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True))
|
|
)
|
|
new_device_map = {}
|
|
for param in param_names:
|
|
device_match = device_map_regex.match(param)
|
|
new_device_map[param] = device_map[device_match.group()] if device_match else device_map.get("", "cpu")
|
|
|
|
return new_device_map
|
|
|
|
|
|
def get_device(device_map: dict | None, param_name: str, valid_torch_device: bool = False) -> torch.device | str | int:
|
|
"""Return the device on which `param_name` should be according to the `device_map`. If `valid_torch_device` is `True`,
|
|
then if the device is `"disk"`, `"cpu"` will be returned instead."""
|
|
device = expand_device_map(device_map, [param_name])[param_name]
|
|
if valid_torch_device and device == "disk":
|
|
return "cpu"
|
|
return device
|
|
|
|
|
|
def accelerate_disk_offload(
|
|
model: "PreTrainedModel",
|
|
disk_offload_folder: str | None,
|
|
checkpoint_files: list[str] | None,
|
|
device_map: dict,
|
|
sharded_metadata: dict | None,
|
|
dtype: torch.dtype | None,
|
|
weight_mapping=None,
|
|
):
|
|
"""
|
|
Prepare the `disk_offload_index` that will be used for reading offloaded parameters. If reading from a safetensors
|
|
file, parameters which do not need any special WeightConverter operation during loading (i.e. they are used as-is, or only
|
|
renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside
|
|
`disk_offload_folder` during loading.
|
|
"""
|
|
from ..core_model_loading import WeightRenaming, rename_source_key
|
|
|
|
if disk_offload_folder is not None:
|
|
os.makedirs(disk_offload_folder, exist_ok=True)
|
|
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
|
|
|
|
renamings = []
|
|
if weight_mapping is not None:
|
|
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
|
|
|
|
# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
|
|
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
|
|
if is_offloaded_safetensors:
|
|
meta_state_dict = model.state_dict()
|
|
param_device_map = expand_device_map(device_map, meta_state_dict.keys())
|
|
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
|
if sharded_metadata is None:
|
|
weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0])
|
|
else:
|
|
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
|
weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()}
|
|
|
|
# Update the weight names according to the `weight_mapping`
|
|
weight_renaming_map = {
|
|
rename_source_key(k, renamings, [], model.base_model_prefix, meta_state_dict)[0]: k for k in weight_map
|
|
}
|
|
|
|
# Prepare the index using existing safetensors files
|
|
disk_offload_index = {
|
|
target_name: {
|
|
"safetensors_file": weight_map[source_name],
|
|
"weight_name": source_name,
|
|
"dtype": str_dtype,
|
|
}
|
|
for target_name, source_name in weight_renaming_map.items()
|
|
# Need to check if it's in the mapping in case of unexpected keys that would result in KeyError (we skip them)
|
|
if target_name in param_device_map and param_device_map[target_name] == "disk"
|
|
}
|
|
# In this case we will resave every offloaded weight
|
|
else:
|
|
disk_offload_index = {}
|
|
|
|
return disk_offload_index
|
|
|
|
|
|
def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | None, offload_index: dict) -> dict:
|
|
"""Write `weight` to disk inside `offload_folder`, and update `offload_index` accordingly. Everything is
|
|
saved in `safetensors` format."""
|
|
|
|
if offload_folder is None:
|
|
raise ValueError(
|
|
"The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is either "
|
|
"because the weights are not in `safetensors` format, or because the model uses an internal weight format "
|
|
"different than the one saved (i.e. most MoE models). Please provide an `offload_folder` for them in "
|
|
"`from_pretrained`."
|
|
)
|
|
# Write the weight to disk
|
|
safetensor_file = os.path.join(offload_folder, f"{weight_name}.safetensors")
|
|
save_file({weight_name: weight}, safetensor_file)
|
|
# Update the offloading index
|
|
str_dtype = str(weight.dtype).replace("torch.", "")
|
|
offload_index[weight_name] = {"safetensors_file": safetensor_file, "weight_name": weight_name, "dtype": str_dtype}
|
|
return offload_index
|
|
|
|
|
|
def load_offloaded_parameter(model: "PreTrainedModel", param_name: str) -> torch.Tensor:
|
|
"""Load `param_name` from disk, if it was offloaded due to the device_map, and thus lives as a meta parameter
|
|
inside `model`.
|
|
This is needed when resaving a model, when some parameters were offloaded (we need to load them from disk, to
|
|
then resave them to disk in the correct shard...)."""
|
|
# Start from the most inner module, and try to find the hook that was used for offloading the param
|
|
module_parts = param_name.split(".")
|
|
modules_to_check = [".".join(module_parts[:-idx]) for idx in range(1, len(module_parts))] + [""]
|
|
for parent_name in modules_to_check:
|
|
parent = model.get_submodule(parent_name)
|
|
if hasattr(parent, "_hf_hook"):
|
|
weights_map = parent._hf_hook.weights_map
|
|
truncated_param_name = param_name.replace(f"{parent_name}." if parent_name != "" else parent_name, "")
|
|
break
|
|
# If we did not break the loop, something is wrong
|
|
else:
|
|
raise ValueError(
|
|
f"{param_name} is on the meta device because it was offloaded, but we could not find "
|
|
"the corresponding hook for it"
|
|
)
|
|
|
|
# This call loads it from disk
|
|
tensor = weights_map[truncated_param_name]
|
|
return tensor
|
|
|
|
|
|
def _init_infer_auto_device_map(
|
|
model: nn.Module,
|
|
max_memory: dict[int | str, int | str] | None = None,
|
|
no_split_module_classes: set[str] | None = None,
|
|
tied_parameters: list[list[str]] | None = None,
|
|
hf_quantizer: "HfQuantizer | None" = None,
|
|
) -> tuple[
|
|
list[int | str],
|
|
dict[int | str, int | str],
|
|
list[int | str],
|
|
list[int],
|
|
dict[str, int],
|
|
list[list[str]],
|
|
list[str],
|
|
list[tuple[str, nn.Module]],
|
|
]:
|
|
"""
|
|
Initialize variables required for computing the device map for model allocation.
|
|
"""
|
|
max_memory = get_max_memory(max_memory)
|
|
if no_split_module_classes is None:
|
|
no_split_module_classes = []
|
|
elif not isinstance(no_split_module_classes, (list, tuple, set)):
|
|
no_split_module_classes = [no_split_module_classes]
|
|
|
|
devices = list(max_memory.keys())
|
|
if "disk" not in devices:
|
|
devices.append("disk")
|
|
gpus = [device for device in devices if device not in ["cpu", "disk"]]
|
|
|
|
# Devices that need to keep space for a potential offloaded layer.
|
|
if "mps" in gpus:
|
|
main_devices = ["mps"]
|
|
elif len(gpus) > 0:
|
|
main_devices = [gpus[0], "cpu"]
|
|
else:
|
|
main_devices = ["cpu"]
|
|
|
|
module_sizes, _ = compute_module_sizes(model, hf_quantizer, only_modules=False)
|
|
|
|
if tied_parameters is None:
|
|
if len(model.all_tied_weights_keys) > 0:
|
|
# create a list of list of tied params based on unique tied groups
|
|
groups = set(model.all_tied_weights_keys.values())
|
|
tied_parameters = [
|
|
sorted([k for k, v in model.all_tied_weights_keys.items() if v == target] + [target])
|
|
for target in groups
|
|
]
|
|
else:
|
|
tied_parameters = [[]]
|
|
|
|
# Direct submodules and parameters
|
|
modules_to_treat = (
|
|
list(model.named_parameters(recurse=False))
|
|
+ list(model.named_children())
|
|
+ list(model.named_buffers(recurse=False))
|
|
)
|
|
|
|
return (
|
|
devices,
|
|
max_memory,
|
|
main_devices,
|
|
gpus,
|
|
module_sizes,
|
|
tied_parameters,
|
|
no_split_module_classes,
|
|
modules_to_treat,
|
|
)
|
|
|
|
|
|
def infer_auto_device_map(
|
|
model: nn.Module,
|
|
max_memory: dict[int | str, int | str] | None = None,
|
|
no_split_module_classes: set[str] | None = None,
|
|
verbose: bool = False,
|
|
clean_result: bool = True,
|
|
offload_buffers: bool = False,
|
|
tied_parameters: list[list[str]] | None = None,
|
|
hf_quantizer: "HfQuantizer | None" = None,
|
|
):
|
|
"""
|
|
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
|
|
such that:
|
|
- we don't exceed the memory available of any of the GPU.
|
|
- if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that
|
|
has the largest size.
|
|
- if offload to the CPU is needed,we don't exceed the RAM available on the CPU.
|
|
- if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk
|
|
that has the largest size.
|
|
|
|
<Tip>
|
|
|
|
All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
|
|
meta device (as it would if initialized within the `init_empty_weights` context manager).
|
|
|
|
</Tip>
|
|
|
|
Args:
|
|
model (`torch.nn.Module`):
|
|
The model to analyze.
|
|
max_memory (`Dict`, *optional*):
|
|
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
|
|
Example: `max_memory={0: "1GB"}`.
|
|
no_split_module_classes (`set[str]`, *optional*):
|
|
A set of layer class names that should never be split across device (for instance any layer that has a
|
|
residual connection).
|
|
verbose (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to provide debugging statements as the function builds the device_map.
|
|
clean_result (`bool`, *optional*, defaults to `True`):
|
|
Clean the resulting device_map by grouping all submodules that go on the same device together.
|
|
offload_buffers (`bool`, *optional*, defaults to `False`):
|
|
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
|
|
well as the parameters.
|
|
"""
|
|
|
|
# Initialize the variables
|
|
(
|
|
devices,
|
|
max_memory,
|
|
main_devices,
|
|
gpus,
|
|
module_sizes,
|
|
tied_parameters,
|
|
no_split_module_classes,
|
|
modules_to_treat,
|
|
) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, tied_parameters, hf_quantizer)
|
|
|
|
device_map = OrderedDict()
|
|
current_device = 0
|
|
device_memory_used = dict.fromkeys(devices, 0)
|
|
device_buffer_sizes = {}
|
|
device_minimum_assignment_memory = {}
|
|
|
|
# Initialize maximum largest layer, to know which space to keep in memory
|
|
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
|
|
|
|
# Ready ? This is going to be a bit messy.
|
|
while len(modules_to_treat) > 0:
|
|
name, module = modules_to_treat.pop(0)
|
|
if verbose:
|
|
print(f"\nTreating module {name}.")
|
|
# Max size in the remaining layers may have changed since we took one, so we maybe update it.
|
|
max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")]
|
|
if len(max_layer_names) == 0:
|
|
max_layer_size, max_layer_names = get_max_layer_size(
|
|
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
|
|
module_sizes,
|
|
no_split_module_classes,
|
|
)
|
|
# Assess size needed
|
|
module_size = module_sizes[name]
|
|
|
|
# We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module
|
|
# and the other is not.
|
|
# Note: If we are currently processing the name `compute.weight`, an other parameter named
|
|
# e.g. `compute.weight_submodule.parameter`
|
|
# needs to be considered outside the current module, hence the check with additional dots.
|
|
tied_param_groups = [
|
|
tied_group
|
|
for tied_group in tied_parameters
|
|
if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
|
|
]
|
|
|
|
if verbose and len(tied_param_groups) > 0:
|
|
print(f" Found the relevant tied param groups {tied_param_groups}")
|
|
|
|
# Then we keep track of all the parameters that are tied to the current module, but not in the current module
|
|
tied_params = sum(
|
|
[[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], []
|
|
)
|
|
|
|
if verbose and len(tied_params) > 0:
|
|
print(f" So those parameters need to be taken into account {tied_params}")
|
|
|
|
device = devices[current_device]
|
|
current_max_size = max_memory[device] if device != "disk" else None
|
|
current_memory_reserved = 0
|
|
# Reduce max size available by the largest layer.
|
|
if devices[current_device] in main_devices:
|
|
current_max_size = current_max_size - max_layer_size
|
|
current_memory_reserved = max_layer_size
|
|
|
|
module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(
|
|
tied_params, module_size, module_sizes, modules_to_treat
|
|
)
|
|
|
|
# The module and its tied modules fit on the current device.
|
|
if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size:
|
|
if verbose:
|
|
output = f"Putting {name}"
|
|
|
|
if tied_module_names:
|
|
output += f" and {tied_module_names}"
|
|
else:
|
|
output += f" (size={module_size})"
|
|
|
|
if current_max_size is not None:
|
|
output += f" (available={current_max_size - device_memory_used[device]})"
|
|
|
|
output += f" on {device}."
|
|
print(output)
|
|
|
|
device_memory_used[device] += module_size_with_ties
|
|
|
|
# Assign the primary module to the device.
|
|
device_map[name] = device
|
|
|
|
# Assign tied modules if any.
|
|
for tied_module_name in tied_module_names:
|
|
if tied_module_name in [m[0] for m in modules_to_treat]:
|
|
# Find the index of the tied module in the list
|
|
tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name)
|
|
# Remove the tied module from the list to prevent reprocessing
|
|
modules_to_treat.pop(tied_module_index)
|
|
|
|
# Assign the tied module to the device
|
|
device_map[tied_module_name] = device
|
|
|
|
# Buffer Handling
|
|
if not offload_buffers and isinstance(module, nn.Module):
|
|
# Compute the total buffer size for the module
|
|
current_buffer_size = compute_module_total_buffer_size(module, hf_quantizer)
|
|
# Update the buffer size on the device
|
|
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size
|
|
|
|
continue
|
|
|
|
# The current module itself fits, so we try to split the tied modules.
|
|
if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size:
|
|
# can we split one of the tied modules to make it smaller or do we need to go on the next device?
|
|
if verbose:
|
|
print(
|
|
f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
|
|
f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})."
|
|
)
|
|
split_happened = False
|
|
for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
|
|
tied_module_children = list(tied_module.named_children())
|
|
if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
|
|
# can't break this one.
|
|
continue
|
|
|
|
if verbose:
|
|
print(f"Splitting {tied_module_name}.")
|
|
tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
|
|
tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
|
|
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]
|
|
|
|
modules_to_treat = (
|
|
[(name, module)]
|
|
+ modules_to_treat[:tied_module_index]
|
|
+ tied_module_children
|
|
+ modules_to_treat[tied_module_index + 1 :]
|
|
)
|
|
# Update the max layer size.
|
|
max_layer_size, max_layer_names = get_max_layer_size(
|
|
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
|
|
module_sizes,
|
|
no_split_module_classes,
|
|
)
|
|
split_happened = True
|
|
break
|
|
|
|
if split_happened:
|
|
continue
|
|
|
|
# If the tied module is not split, we go to the next device
|
|
if verbose:
|
|
print("None of the tied module can be split, going to the next device.")
|
|
|
|
# The current module itself doesn't fit, so we have to split it or go to the next device.
|
|
if device_memory_used[device] + module_size >= current_max_size:
|
|
# Split or not split?
|
|
modules_children = (
|
|
[]
|
|
if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
|
|
else list(module.named_children())
|
|
)
|
|
if verbose:
|
|
print(
|
|
f"Not enough space on {devices[current_device]} to put {name} (space available "
|
|
f"{current_max_size - device_memory_used[device]}, module size {module_size})."
|
|
)
|
|
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
|
|
# -> no split, we go to the next device
|
|
if verbose:
|
|
print("This module cannot be split, going to the next device.")
|
|
|
|
else:
|
|
# -> split, we replace the module studied by its children + parameters
|
|
if verbose:
|
|
print(f"Splitting {name}.")
|
|
modules_children = list(module.named_parameters(recurse=False)) + modules_children
|
|
modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
|
|
# Update the max layer size.
|
|
max_layer_size, max_layer_names = get_max_layer_size(
|
|
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
|
|
module_sizes,
|
|
no_split_module_classes,
|
|
)
|
|
continue
|
|
|
|
if device_memory_used[device] == 0:
|
|
device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved
|
|
|
|
# Neither the current module nor any tied modules can be split, so we move to the next device.
|
|
device_memory_used[device] = device_memory_used[device] + current_memory_reserved
|
|
current_device += 1
|
|
modules_to_treat = [(name, module)] + modules_to_treat
|
|
|
|
device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0}
|
|
|
|
if clean_result:
|
|
device_map = clean_device_map(device_map)
|
|
|
|
non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)
|
|
if non_gpu_buffer_size > 0 and not offload_buffers:
|
|
is_buffer_fit_any_gpu = False
|
|
for gpu_device, gpu_max_memory in max_memory.items():
|
|
if gpu_device == "cpu" or gpu_device == "disk":
|
|
continue
|
|
|
|
if not is_buffer_fit_any_gpu:
|
|
gpu_memory_used = device_memory_used.get(gpu_device, 0)
|
|
|
|
if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:
|
|
is_buffer_fit_any_gpu = True
|
|
|
|
if len(gpus) > 0 and not is_buffer_fit_any_gpu:
|
|
logger.warning(
|
|
f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does "
|
|
f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using "
|
|
f"offload_buffers=True."
|
|
)
|
|
|
|
if device_minimum_assignment_memory:
|
|
devices_info = "\n".join(
|
|
f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items()
|
|
)
|
|
logger.info(
|
|
f"Based on the current allocation process, no modules could be assigned to the following devices due to "
|
|
f"insufficient memory:\n"
|
|
f"{devices_info}\n"
|
|
f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing "
|
|
f"the available memory for these devices to at least the specified minimum, or adjusting the model config."
|
|
)
|
|
|
|
check_tied_parameters_on_same_device(tied_parameters, device_map)
|
|
return device_map
|
|
|
|
|
|
def _get_param_device(param, device_map):
|
|
if param in device_map:
|
|
return device_map[param]
|
|
parent_param = ".".join(param.split(".")[:-1])
|
|
if parent_param == param:
|
|
raise ValueError(f"The `device_map` does not contain the module {param}.")
|
|
else:
|
|
return _get_param_device(parent_param, device_map)
|
|
|
|
|
|
def check_tied_parameters_on_same_device(tied_params, device_map):
|
|
"""
|
|
Check if tied parameters are on the same device
|
|
|
|
Args:
|
|
tied_params (`List[List[str]]`):
|
|
A list of lists of parameter names being all tied together.
|
|
|
|
device_map (`Dict[str, Union[int, str, torch.device]]`):
|
|
A map that specifies where each submodule should go.
|
|
|
|
"""
|
|
for tie_param in tied_params:
|
|
tie_param_devices = {}
|
|
for param in tie_param:
|
|
tie_param_devices[param] = _get_param_device(param, device_map)
|
|
if len(set(tie_param_devices.values())) > 1:
|
|
logger.warning(
|
|
f"Tied parameters are on different devices: {tie_param_devices}. "
|
|
"Please modify your custom device map or set `device_map='auto'`. "
|
|
)
|