# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import sys from collections import defaultdict from contextlib import contextmanager import torch # Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch # in context managers TORCH_INIT_FUNCTIONS = { "uniform_": torch.nn.init.uniform_, "normal_": torch.nn.init.normal_, "constant_": torch.nn.init.constant_, "ones_": torch.nn.init.ones_, "zeros_": torch.nn.init.zeros_, "eye_": torch.nn.init.eye_, "dirac_": torch.nn.init.dirac_, "xavier_uniform_": torch.nn.init.xavier_uniform_, "xavier_normal_": torch.nn.init.xavier_normal_, "kaiming_uniform_": torch.nn.init.kaiming_uniform_, "kaiming_normal_": torch.nn.init.kaiming_normal_, "trunc_normal_": torch.nn.init.trunc_normal_, "orthogonal_": torch.nn.init.orthogonal_, "sparse_": torch.nn.init.sparse_, } def uniform_( tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None ) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator) return tensor def normal_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None ) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator) return tensor def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val) return tensor def ones_(tensor: torch.Tensor) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["ones_"](tensor) return tensor def zeros_(tensor: torch.Tensor) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["zeros_"](tensor) return tensor def eye_(tensor: torch.Tensor) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["eye_"](tensor) return tensor def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups) return tensor def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator) return tensor def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator) return tensor def kaiming_uniform_( tensor: torch.Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu", generator: torch.Generator | None = None, ) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["kaiming_uniform_"]( tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator ) return tensor def kaiming_normal_( tensor: torch.Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu", generator: torch.Generator | None = None, ) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["kaiming_normal_"]( tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator ) return tensor def trunc_normal_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0, generator: torch.Generator | None = None, ) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator) return tensor def orthogonal_( tensor: torch.Tensor, gain: float = 1, generator: torch.Generator | None = None, ) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator) return tensor def sparse_( tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None ) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator) return tensor def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor: if not getattr(tensor, "_is_hf_initialized", False): with torch.no_grad(): return tensor.copy_(other) return tensor def _variance_scaling(tensor, mode="fan_in", distribution="normal"): fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = 1.0 / denom if distribution == "truncated_normal": trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": normal_(tensor, std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) uniform_(tensor, -bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): if not getattr(tensor, "_is_hf_initialized", False): _variance_scaling(tensor, mode="fan_in", distribution="truncated_normal") return tensor def default_flax_embed_init_(tensor): if not getattr(tensor, "_is_hf_initialized", False): _variance_scaling(tensor, mode="fan_in", distribution="normal") return tensor # Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does # something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations, # where MultiHeadAttention lives), so the function name is binded at import time and just doing # `setattr(torch.nn.init, name, globals()[name])` is thus not enough # The following list should be enough for all torch versions we work with TORCH_MODULES_TO_PATCH = ( "torch.nn.init", "torch.nn.modules.activation", "torch.nn.modules.transformer", "torch.nn.modules.linear", "torch.nn.modules.loss", "torch.nn.modules.batchnorm", "torch.nn.modules.conv", "torch.nn.modules.normalization", "torch.nn.modules.rnn", "torch.nn.modules.sparse", ) @contextmanager def guard_torch_init_functions(): """ Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded. Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure and for remote code, we also use this context manager. """ originals = defaultdict(dict) try: # Replace all torch funcs by the ones in this file for module_name in TORCH_MODULES_TO_PATCH: if module_name in sys.modules: module = sys.modules[module_name] for func_name in TORCH_INIT_FUNCTIONS.keys(): if hasattr(module, func_name): originals[module][func_name] = getattr(module, func_name) setattr(module, func_name, globals()[func_name]) yield finally: # Set back the original functions on all modules for module, functions in originals.items(): for func_name, func in functions.items(): setattr(module, func_name, func) @contextmanager def no_init_weights(): """ Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`). This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards. """ from .modeling_utils import PreTrainedModel def empty_func(*args, **kwargs): pass originals = defaultdict(dict) try: # Replace all torch funcs by empty ones for module_name in TORCH_MODULES_TO_PATCH: if module_name in sys.modules: module = sys.modules[module_name] for func_name in TORCH_INIT_FUNCTIONS.keys(): if hasattr(module, func_name): originals[module][func_name] = getattr(module, func_name) setattr(module, func_name, empty_func) # Also patch our own `init_weights` original_init_weights = PreTrainedModel.init_weights PreTrainedModel.init_weights = empty_func yield finally: # Set back the original torch functions on all modules for module, functions in originals.items(): for func_name, func in functions.items(): setattr(module, func_name, func) # Set back `init_weights` PreTrainedModel.init_weights = original_init_weights @contextmanager def no_tie_weights(): """ Disable weight tying during loading with `from_pretrained`. This is needed as we want to have access to ALL weights in the state_dict during `from_pretrained`, and otherwise tying them would remove them from it, as it's called in `post_init` when instantiating. """ from .modeling_utils import PreTrainedModel def empty_func(*args, **kwargs): pass try: original_tie_weights = PreTrainedModel.tie_weights PreTrainedModel.tie_weights = empty_func yield finally: # Set back the original PreTrainedModel.tie_weights = original_tie_weights