# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Some of the functions here are derived from the `accelerate` library, with some tweaks for better performances and simplicity/ease of use. """ import copy import inspect import os import re from collections import OrderedDict, defaultdict from typing import TYPE_CHECKING from safetensors import safe_open from safetensors.torch import save_file from ..utils import ( is_accelerate_available, is_torch_available, is_torch_xpu_available, logging, ) from ..utils.quantization_config import QuantizationMethod from .deepspeed import is_deepspeed_zero3_enabled from .fsdp import is_fsdp_enabled if is_torch_available(): import torch import torch.nn as nn if is_accelerate_available(): from accelerate import dispatch_model from accelerate.utils import get_max_memory from accelerate.utils.modeling import clean_device_map, get_max_layer_size if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel from ..quantizers import HfQuantizer logger = logging.get_logger(__name__) def get_module_size_with_ties( tied_params, module_size, module_sizes, modules_to_treat, ) -> tuple[int, list[str], list[nn.Module]]: """ Calculate the total size of a module, including its tied parameters. Args: tied_params (`List[str]`): The list of tied parameters. module_size (`int`): The size of the module without tied parameters. module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size. modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat. Returns: `Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the tied modules. """ if len(tied_params) < 1: return module_size, [], [] tied_module_names = [] tied_modules = [] for tied_param in tied_params: tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + ".")][0] tied_module_names.append(modules_to_treat[tied_module_index][0]) tied_modules.append(modules_to_treat[tied_module_index][1]) module_size_with_ties = module_size for tied_param, tied_module_name in zip(tied_params, tied_module_names): module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] return module_size_with_ties, tied_module_names, tied_modules def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None: from ..modeling_utils import get_torch_context_manager_or_global_device # Potentially detect context manager or global device, and use it (only if no device_map was provided) if device_map is None and not is_deepspeed_zero3_enabled(): device_in_context = get_torch_context_manager_or_global_device() if device_in_context == torch.device("meta"): raise RuntimeError( "You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n" "This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an " "empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`" ) device_map = device_in_context # change device_map into a map if we passed an int, a str or a torch.device if isinstance(device_map, torch.device): device_map = {"": device_map} elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: try: if device_map == "cuda": # setting to the local rank local_rank = int(os.environ.get("LOCAL_RANK", 0)) device_map = f"cuda:{local_rank}" device_map = {"": torch.device(device_map)} except RuntimeError: raise ValueError( "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." ) elif isinstance(device_map, int): if device_map < 0: raise ValueError( "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " ) else: device_map = {"": device_map} if device_map is not None: if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.") if not is_accelerate_available(): raise ValueError( "Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` " "requires `accelerate`. You can install it with `pip install accelerate`" ) return device_map def compute_module_sizes( model: "PreTrainedModel", hf_quantizer: "HfQuantizer | None" = None, buffers_only: bool = False, only_modules: bool = True, ) -> tuple[dict[str, int], dict[str, int]]: """ Compute the size of each submodule of a given model (in bytes). Returns a tuple of 2 dicts, the fist one containing a mapping of all the modules and the corresponding size in bytes, and the 2nd one containing a mapping from all leaf modules (modules containing parameters, the end of the model graph) and the corresponding sizes. If `only_modules` is set to False, the first mapping will not only contain the size of all modules, but also the size of all parameters and buffers. """ all_module_sizes = defaultdict(int) leaves_module_sizes = defaultdict(int) if buffers_only: iterator = model.named_buffers() else: # We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space def all_tensors(): yield from model.named_parameters() yield from model.named_buffers() iterator = all_tensors() tied_keys = getattr(model, "all_tied_weights_keys", {}).keys() for name, param in iterator: # Do not count tied keys (the model is usually not tied yet here, so they will appear in the iterator) # If the model is already tied, then they simply do not appear in the iterator anyway (remove_duplicates=True by default) if name in tied_keys: continue if hf_quantizer is not None: dtype_size = hf_quantizer.param_element_size(model, name, param) else: dtype_size = param.element_size() size = param.numel() * dtype_size name_parts = name.split(".") for idx in range(len(name_parts)): all_module_sizes[".".join(name_parts[:idx])] += size if "." in name: leaves_module_sizes[name.rsplit(".", 1)[0]] += size # If we want to also have the full leaves in `all_module_sizes` if not only_modules: all_module_sizes[name] += size return all_module_sizes, leaves_module_sizes def compute_module_total_buffer_size(model: nn.Module, hf_quantizer: "HfQuantizer | None" = None): """ Compute the total size of buffers in each submodule of a given model. """ module_sizes, _ = compute_module_sizes(model, hf_quantizer, buffers_only=True) return module_sizes.get("", 0) def get_balanced_memory( model: "PreTrainedModel", max_memory: dict[int | str, int | str] | None = None, no_split_module_classes: set[str] | None = None, hf_quantizer: "HfQuantizer | None" = None, low_zero: bool = False, ): """ Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU. All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the meta device (as it would if initialized within the `init_empty_weights` context manager). Args: model (`PreTrainedModel`): The model to analyze. max_memory (`Dict`, *optional*): A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. Example: `max_memory={0: "1GB"}`. no_split_module_classes (`set[str]`, *optional*): A set of layer class names that should never be split across device (for instance any layer that has a residual connection). hf_quantizer (`HfQuantizer`, *optional*): A quantizer for the model. low_zero (`bool`, *optional*): Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the Transformers generate function). """ # Get default / clean up max_memory user_not_set_max_memory = max_memory is None max_memory = get_max_memory(max_memory) # Check the number of accelerators available accelerator_max_memory = copy.deepcopy(max_memory) _, _ = accelerator_max_memory.pop("cpu", None), accelerator_max_memory.pop("disk", None) num_devices = len([d for d in accelerator_max_memory if accelerator_max_memory[d] > 0]) if num_devices == 0: return max_memory if num_devices == 1: # We cannot do low_zero on just one GPU, but we will still reserve some memory for the buffer low_zero = False # If user just asked us to handle memory usage, we should avoid OOM if user_not_set_max_memory: for key in max_memory.keys(): if isinstance(key, int): max_memory[key] *= 0.9 # 90% is a good compromise logger.info( f"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. " "You can set `max_memory` in to a higher value to use more memory (at your own risk)." ) break # only one device module_sizes, leave_modules_sizes = compute_module_sizes(model, hf_quantizer) per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices) # We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get # slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to # add which is the biggest of: # - the size of no split block (if applicable) # - the mean of the layer sizes if no_split_module_classes is None: no_split_module_classes = [] elif not isinstance(no_split_module_classes, (list, tuple, set)): no_split_module_classes = [no_split_module_classes] # Identify the size of the no_split_block modules buffer = 0 if len(no_split_module_classes) > 0: no_split_children = {} for name, size in module_sizes.items(): if name == "": continue submodule = model.get_submodule(name) class_name = submodule.__class__.__name__ if class_name in no_split_module_classes and class_name not in no_split_children: no_split_children[class_name] = size if set(no_split_children.keys()) == set(no_split_module_classes): break buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0 mean_leaves = int(sum(leave_modules_sizes.values()) / max(len(leave_modules_sizes), 1)) buffer = int(1.25 * max(buffer, mean_leaves)) per_gpu += buffer # Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them) gpus_idx_list = sorted( device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0 ) # The last device is left with max_memory just in case the buffer is not enough. for idx in gpus_idx_list[:-1]: max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx]) if low_zero: min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)])) max_memory[0] = min(min_zero, max_memory[0]) return max_memory def _get_device_map( model: "PreTrainedModel", device_map: dict | str | None, max_memory: dict | None, hf_quantizer: "HfQuantizer | None", ) -> dict: """Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential']. Otherwise, we check for any device inconsistencies in the device_map. """ if isinstance(device_map, str): no_split_modules = model._no_split_modules if device_map != "sequential": inferred_max_memory = get_balanced_memory( model, max_memory=max_memory, no_split_module_classes=no_split_modules, hf_quantizer=hf_quantizer, low_zero=(device_map == "balanced_low_0"), ) else: inferred_max_memory = get_max_memory(max_memory) # If the user does not provide `max_memory`, accelerate sets the WHOLE cpu available memory as available. # This is unwanted, as we don't want to set extremely tight bound and pressure for cpu if we are memory-constrained, # especially if the model uses WeightConverter (because there will be some uncontrollable cpu memory spikes during # the conversions before we resave the weights). In those cases, it's better to offload to disk a bit more # if we were in-between, as otherwise we blow-up cpu memory if max_memory is None and "cpu" in inferred_max_memory: inferred_max_memory["cpu"] *= 0.90 if hf_quantizer is not None: inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory) # `inferred_max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU, # which we can use to allocate parameters. for device_name in inferred_max_memory: if isinstance(device_name, int): # it's a GPU device if is_torch_xpu_available(): unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name) else: unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name) inferred_max_memory[device_name] += unused_memory # respect the `max_memory` passed by the user if max_memory is not None and device_name in max_memory: inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name]) device_map = infer_auto_device_map( model, max_memory=inferred_max_memory, no_split_module_classes=no_split_modules, hf_quantizer=hf_quantizer, ) if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) return device_map def accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers): device_map_kwargs = { "device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index, "offload_buffers": offload_buffers, } if "skip_keys" in inspect.signature(dispatch_model).parameters: device_map_kwargs["skip_keys"] = model._skip_keys_device_placement # For HQQ method we force-set the hooks for single GPU envs if ( "force_hooks" in inspect.signature(dispatch_model).parameters and hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ ): device_map_kwargs["force_hooks"] = True if ( hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8 and isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()) ): device_map_kwargs["offload_buffers"] = True if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): dispatch_model(model, **device_map_kwargs) def expand_device_map(device_map: dict | None, param_names: list[str]): """ Expand a device map to return the correspondence parameter name to device. """ if device_map is None: return dict.fromkeys(param_names, "cpu") # Here, we first sort by number of submodules, then length of the full string, to make sure to match correctly device_map_regex = re.compile( "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: (x.count("."), len(x)), reverse=True)) ) new_device_map = {} for param in param_names: device_match = device_map_regex.match(param) new_device_map[param] = device_map[device_match.group()] if device_match else device_map.get("", "cpu") return new_device_map def get_device(device_map: dict | None, param_name: str, valid_torch_device: bool = False) -> torch.device | str | int: """Return the device on which `param_name` should be according to the `device_map`. If `valid_torch_device` is `True`, then if the device is `"disk"`, `"cpu"` will be returned instead.""" device = expand_device_map(device_map, [param_name])[param_name] if valid_torch_device and device == "disk": return "cpu" return device def accelerate_disk_offload( model: "PreTrainedModel", disk_offload_folder: str | None, checkpoint_files: list[str] | None, device_map: dict, sharded_metadata: dict | None, dtype: torch.dtype | None, weight_mapping=None, ): """ Prepare the `disk_offload_index` that will be used for reading offloaded parameters. If reading from a safetensors file, parameters which do not need any special WeightConverter operation during loading (i.e. they are used as-is, or only renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside `disk_offload_folder` during loading. """ from ..core_model_loading import WeightRenaming, rename_source_key if disk_offload_folder is not None: os.makedirs(disk_offload_folder, exist_ok=True) is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") renamings = [] if weight_mapping is not None: renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] # In this case, the offload index is simply the existing safetensors (except if using custom weight loading # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time) if is_offloaded_safetensors: meta_state_dict = model.state_dict() param_device_map = expand_device_map(device_map, meta_state_dict.keys()) str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" if sharded_metadata is None: weight_map = dict.fromkeys(safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0]) else: folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()} # Update the weight names according to the `weight_mapping` weight_renaming_map = { rename_source_key(k, renamings, [], model.base_model_prefix, meta_state_dict)[0]: k for k in weight_map } # Prepare the index using existing safetensors files disk_offload_index = { target_name: { "safetensors_file": weight_map[source_name], "weight_name": source_name, "dtype": str_dtype, } for target_name, source_name in weight_renaming_map.items() # Need to check if it's in the mapping in case of unexpected keys that would result in KeyError (we skip them) if target_name in param_device_map and param_device_map[target_name] == "disk" } # In this case we will resave every offloaded weight else: disk_offload_index = {} return disk_offload_index def offload_weight(weight: torch.Tensor, weight_name: str, offload_folder: str | None, offload_index: dict) -> dict: """Write `weight` to disk inside `offload_folder`, and update `offload_index` accordingly. Everything is saved in `safetensors` format.""" if offload_folder is None: raise ValueError( "The current `device_map` had weights offloaded to the disk, which needed to be re-saved. This is either " "because the weights are not in `safetensors` format, or because the model uses an internal weight format " "different than the one saved (i.e. most MoE models). Please provide an `offload_folder` for them in " "`from_pretrained`." ) # Write the weight to disk safetensor_file = os.path.join(offload_folder, f"{weight_name}.safetensors") save_file({weight_name: weight}, safetensor_file) # Update the offloading index str_dtype = str(weight.dtype).replace("torch.", "") offload_index[weight_name] = {"safetensors_file": safetensor_file, "weight_name": weight_name, "dtype": str_dtype} return offload_index def load_offloaded_parameter(model: "PreTrainedModel", param_name: str) -> torch.Tensor: """Load `param_name` from disk, if it was offloaded due to the device_map, and thus lives as a meta parameter inside `model`. This is needed when resaving a model, when some parameters were offloaded (we need to load them from disk, to then resave them to disk in the correct shard...).""" # Start from the most inner module, and try to find the hook that was used for offloading the param module_parts = param_name.split(".") modules_to_check = [".".join(module_parts[:-idx]) for idx in range(1, len(module_parts))] + [""] for parent_name in modules_to_check: parent = model.get_submodule(parent_name) if hasattr(parent, "_hf_hook"): weights_map = parent._hf_hook.weights_map truncated_param_name = param_name.replace(f"{parent_name}." if parent_name != "" else parent_name, "") break # If we did not break the loop, something is wrong else: raise ValueError( f"{param_name} is on the meta device because it was offloaded, but we could not find " "the corresponding hook for it" ) # This call loads it from disk tensor = weights_map[truncated_param_name] return tensor def _init_infer_auto_device_map( model: nn.Module, max_memory: dict[int | str, int | str] | None = None, no_split_module_classes: set[str] | None = None, tied_parameters: list[list[str]] | None = None, hf_quantizer: "HfQuantizer | None" = None, ) -> tuple[ list[int | str], dict[int | str, int | str], list[int | str], list[int], dict[str, int], list[list[str]], list[str], list[tuple[str, nn.Module]], ]: """ Initialize variables required for computing the device map for model allocation. """ max_memory = get_max_memory(max_memory) if no_split_module_classes is None: no_split_module_classes = [] elif not isinstance(no_split_module_classes, (list, tuple, set)): no_split_module_classes = [no_split_module_classes] devices = list(max_memory.keys()) if "disk" not in devices: devices.append("disk") gpus = [device for device in devices if device not in ["cpu", "disk"]] # Devices that need to keep space for a potential offloaded layer. if "mps" in gpus: main_devices = ["mps"] elif len(gpus) > 0: main_devices = [gpus[0], "cpu"] else: main_devices = ["cpu"] module_sizes, _ = compute_module_sizes(model, hf_quantizer, only_modules=False) if tied_parameters is None: if len(model.all_tied_weights_keys) > 0: # create a list of list of tied params based on unique tied groups groups = set(model.all_tied_weights_keys.values()) tied_parameters = [ sorted([k for k, v in model.all_tied_weights_keys.items() if v == target] + [target]) for target in groups ] else: tied_parameters = [[]] # Direct submodules and parameters modules_to_treat = ( list(model.named_parameters(recurse=False)) + list(model.named_children()) + list(model.named_buffers(recurse=False)) ) return ( devices, max_memory, main_devices, gpus, module_sizes, tied_parameters, no_split_module_classes, modules_to_treat, ) def infer_auto_device_map( model: nn.Module, max_memory: dict[int | str, int | str] | None = None, no_split_module_classes: set[str] | None = None, verbose: bool = False, clean_result: bool = True, offload_buffers: bool = False, tied_parameters: list[list[str]] | None = None, hf_quantizer: "HfQuantizer | None" = None, ): """ Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, such that: - we don't exceed the memory available of any of the GPU. - if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that has the largest size. - if offload to the CPU is needed,we don't exceed the RAM available on the CPU. - if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk that has the largest size. All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the meta device (as it would if initialized within the `init_empty_weights` context manager). Args: model (`torch.nn.Module`): The model to analyze. max_memory (`Dict`, *optional*): A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. Example: `max_memory={0: "1GB"}`. no_split_module_classes (`set[str]`, *optional*): A set of layer class names that should never be split across device (for instance any layer that has a residual connection). verbose (`bool`, *optional*, defaults to `False`): Whether or not to provide debugging statements as the function builds the device_map. clean_result (`bool`, *optional*, defaults to `True`): Clean the resulting device_map by grouping all submodules that go on the same device together. offload_buffers (`bool`, *optional*, defaults to `False`): In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as well as the parameters. """ # Initialize the variables ( devices, max_memory, main_devices, gpus, module_sizes, tied_parameters, no_split_module_classes, modules_to_treat, ) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, tied_parameters, hf_quantizer) device_map = OrderedDict() current_device = 0 device_memory_used = dict.fromkeys(devices, 0) device_buffer_sizes = {} device_minimum_assignment_memory = {} # Initialize maximum largest layer, to know which space to keep in memory max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) # Ready ? This is going to be a bit messy. while len(modules_to_treat) > 0: name, module = modules_to_treat.pop(0) if verbose: print(f"\nTreating module {name}.") # Max size in the remaining layers may have changed since we took one, so we maybe update it. max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")] if len(max_layer_names) == 0: max_layer_size, max_layer_names = get_max_layer_size( [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], module_sizes, no_split_module_classes, ) # Assess size needed module_size = module_sizes[name] # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module # and the other is not. # Note: If we are currently processing the name `compute.weight`, an other parameter named # e.g. `compute.weight_submodule.parameter` # needs to be considered outside the current module, hence the check with additional dots. tied_param_groups = [ tied_group for tied_group in tied_parameters if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) ] if verbose and len(tied_param_groups) > 0: print(f" Found the relevant tied param groups {tied_param_groups}") # Then we keep track of all the parameters that are tied to the current module, but not in the current module tied_params = sum( [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] ) if verbose and len(tied_params) > 0: print(f" So those parameters need to be taken into account {tied_params}") device = devices[current_device] current_max_size = max_memory[device] if device != "disk" else None current_memory_reserved = 0 # Reduce max size available by the largest layer. if devices[current_device] in main_devices: current_max_size = current_max_size - max_layer_size current_memory_reserved = max_layer_size module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties( tied_params, module_size, module_sizes, modules_to_treat ) # The module and its tied modules fit on the current device. if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: if verbose: output = f"Putting {name}" if tied_module_names: output += f" and {tied_module_names}" else: output += f" (size={module_size})" if current_max_size is not None: output += f" (available={current_max_size - device_memory_used[device]})" output += f" on {device}." print(output) device_memory_used[device] += module_size_with_ties # Assign the primary module to the device. device_map[name] = device # Assign tied modules if any. for tied_module_name in tied_module_names: if tied_module_name in [m[0] for m in modules_to_treat]: # Find the index of the tied module in the list tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name) # Remove the tied module from the list to prevent reprocessing modules_to_treat.pop(tied_module_index) # Assign the tied module to the device device_map[tied_module_name] = device # Buffer Handling if not offload_buffers and isinstance(module, nn.Module): # Compute the total buffer size for the module current_buffer_size = compute_module_total_buffer_size(module, hf_quantizer) # Update the buffer size on the device device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size continue # The current module itself fits, so we try to split the tied modules. if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size: # can we split one of the tied modules to make it smaller or do we need to go on the next device? if verbose: print( f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})." ) split_happened = False for tied_module_name, tied_module in zip(tied_module_names, tied_modules): tied_module_children = list(tied_module.named_children()) if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: # can't break this one. continue if verbose: print(f"Splitting {tied_module_name}.") tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] modules_to_treat = ( [(name, module)] + modules_to_treat[:tied_module_index] + tied_module_children + modules_to_treat[tied_module_index + 1 :] ) # Update the max layer size. max_layer_size, max_layer_names = get_max_layer_size( [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], module_sizes, no_split_module_classes, ) split_happened = True break if split_happened: continue # If the tied module is not split, we go to the next device if verbose: print("None of the tied module can be split, going to the next device.") # The current module itself doesn't fit, so we have to split it or go to the next device. if device_memory_used[device] + module_size >= current_max_size: # Split or not split? modules_children = ( [] if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor) else list(module.named_children()) ) if verbose: print( f"Not enough space on {devices[current_device]} to put {name} (space available " f"{current_max_size - device_memory_used[device]}, module size {module_size})." ) if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: # -> no split, we go to the next device if verbose: print("This module cannot be split, going to the next device.") else: # -> split, we replace the module studied by its children + parameters if verbose: print(f"Splitting {name}.") modules_children = list(module.named_parameters(recurse=False)) + modules_children modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat # Update the max layer size. max_layer_size, max_layer_names = get_max_layer_size( [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], module_sizes, no_split_module_classes, ) continue if device_memory_used[device] == 0: device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved # Neither the current module nor any tied modules can be split, so we move to the next device. device_memory_used[device] = device_memory_used[device] + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0} if clean_result: device_map = clean_device_map(device_map) non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0) if non_gpu_buffer_size > 0 and not offload_buffers: is_buffer_fit_any_gpu = False for gpu_device, gpu_max_memory in max_memory.items(): if gpu_device == "cpu" or gpu_device == "disk": continue if not is_buffer_fit_any_gpu: gpu_memory_used = device_memory_used.get(gpu_device, 0) if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used: is_buffer_fit_any_gpu = True if len(gpus) > 0 and not is_buffer_fit_any_gpu: logger.warning( f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does " f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using " f"offload_buffers=True." ) if device_minimum_assignment_memory: devices_info = "\n".join( f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items() ) logger.info( f"Based on the current allocation process, no modules could be assigned to the following devices due to " f"insufficient memory:\n" f"{devices_info}\n" f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing " f"the available memory for these devices to at least the specified minimum, or adjusting the model config." ) check_tied_parameters_on_same_device(tied_parameters, device_map) return device_map def _get_param_device(param, device_map): if param in device_map: return device_map[param] parent_param = ".".join(param.split(".")[:-1]) if parent_param == param: raise ValueError(f"The `device_map` does not contain the module {param}.") else: return _get_param_device(parent_param, device_map) def check_tied_parameters_on_same_device(tied_params, device_map): """ Check if tied parameters are on the same device Args: tied_params (`List[List[str]]`): A list of lists of parameter names being all tied together. device_map (`Dict[str, Union[int, str, torch.device]]`): A map that specifies where each submodule should go. """ for tie_param in tied_params: tie_param_devices = {} for param in tie_param: tie_param_devices[param] = _get_param_device(param, device_map) if len(set(tie_param_devices.values())) > 1: logger.warning( f"Tied parameters are on different devices: {tie_param_devices}. " "Please modify your custom device map or set `device_map='auto'`. " )