You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
303 lines
11 KiB
303 lines
11 KiB
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
import sys
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
|
|
import torch
|
|
|
|
|
|
# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
|
|
# in context managers
|
|
TORCH_INIT_FUNCTIONS = {
|
|
"uniform_": torch.nn.init.uniform_,
|
|
"normal_": torch.nn.init.normal_,
|
|
"constant_": torch.nn.init.constant_,
|
|
"ones_": torch.nn.init.ones_,
|
|
"zeros_": torch.nn.init.zeros_,
|
|
"eye_": torch.nn.init.eye_,
|
|
"dirac_": torch.nn.init.dirac_,
|
|
"xavier_uniform_": torch.nn.init.xavier_uniform_,
|
|
"xavier_normal_": torch.nn.init.xavier_normal_,
|
|
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
|
|
"kaiming_normal_": torch.nn.init.kaiming_normal_,
|
|
"trunc_normal_": torch.nn.init.trunc_normal_,
|
|
"orthogonal_": torch.nn.init.orthogonal_,
|
|
"sparse_": torch.nn.init.sparse_,
|
|
}
|
|
|
|
|
|
def uniform_(
|
|
tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
|
|
) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
|
|
return tensor
|
|
|
|
|
|
def normal_(
|
|
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
|
|
) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
|
|
return tensor
|
|
|
|
|
|
def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
|
|
return tensor
|
|
|
|
|
|
def ones_(tensor: torch.Tensor) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["ones_"](tensor)
|
|
return tensor
|
|
|
|
|
|
def zeros_(tensor: torch.Tensor) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
|
|
return tensor
|
|
|
|
|
|
def eye_(tensor: torch.Tensor) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["eye_"](tensor)
|
|
return tensor
|
|
|
|
|
|
def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
|
|
return tensor
|
|
|
|
|
|
def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
|
|
return tensor
|
|
|
|
|
|
def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
|
|
return tensor
|
|
|
|
|
|
def kaiming_uniform_(
|
|
tensor: torch.Tensor,
|
|
a: float = 0,
|
|
mode: str = "fan_in",
|
|
nonlinearity: str = "leaky_relu",
|
|
generator: torch.Generator | None = None,
|
|
) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
|
|
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
|
|
)
|
|
return tensor
|
|
|
|
|
|
def kaiming_normal_(
|
|
tensor: torch.Tensor,
|
|
a: float = 0,
|
|
mode: str = "fan_in",
|
|
nonlinearity: str = "leaky_relu",
|
|
generator: torch.Generator | None = None,
|
|
) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
|
|
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
|
|
)
|
|
return tensor
|
|
|
|
|
|
def trunc_normal_(
|
|
tensor: torch.Tensor,
|
|
mean: float = 0.0,
|
|
std: float = 1.0,
|
|
a: float = -2.0,
|
|
b: float = 2.0,
|
|
generator: torch.Generator | None = None,
|
|
) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
|
|
return tensor
|
|
|
|
|
|
def orthogonal_(
|
|
tensor: torch.Tensor,
|
|
gain: float = 1,
|
|
generator: torch.Generator | None = None,
|
|
) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
|
|
return tensor
|
|
|
|
|
|
def sparse_(
|
|
tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
|
|
) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
|
|
return tensor
|
|
|
|
|
|
def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
with torch.no_grad():
|
|
return tensor.copy_(other)
|
|
return tensor
|
|
|
|
|
|
def _variance_scaling(tensor, mode="fan_in", distribution="normal"):
|
|
fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(tensor)
|
|
if mode == "fan_in":
|
|
denom = fan_in
|
|
elif mode == "fan_out":
|
|
denom = fan_out
|
|
elif mode == "fan_avg":
|
|
denom = (fan_in + fan_out) / 2
|
|
|
|
variance = 1.0 / denom
|
|
|
|
if distribution == "truncated_normal":
|
|
trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
|
elif distribution == "normal":
|
|
normal_(tensor, std=math.sqrt(variance))
|
|
elif distribution == "uniform":
|
|
bound = math.sqrt(3 * variance)
|
|
uniform_(tensor, -bound, bound)
|
|
else:
|
|
raise ValueError(f"invalid distribution {distribution}")
|
|
|
|
|
|
def lecun_normal_(tensor):
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
_variance_scaling(tensor, mode="fan_in", distribution="truncated_normal")
|
|
return tensor
|
|
|
|
|
|
def default_flax_embed_init_(tensor):
|
|
if not getattr(tensor, "_is_hf_initialized", False):
|
|
_variance_scaling(tensor, mode="fan_in", distribution="normal")
|
|
return tensor
|
|
|
|
|
|
# Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does
|
|
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations,
|
|
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
|
|
# `setattr(torch.nn.init, name, globals()[name])` is thus not enough
|
|
# The following list should be enough for all torch versions we work with
|
|
TORCH_MODULES_TO_PATCH = (
|
|
"torch.nn.init",
|
|
"torch.nn.modules.activation",
|
|
"torch.nn.modules.transformer",
|
|
"torch.nn.modules.linear",
|
|
"torch.nn.modules.loss",
|
|
"torch.nn.modules.batchnorm",
|
|
"torch.nn.modules.conv",
|
|
"torch.nn.modules.normalization",
|
|
"torch.nn.modules.rnn",
|
|
"torch.nn.modules.sparse",
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def guard_torch_init_functions():
|
|
"""
|
|
Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
|
|
protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.
|
|
|
|
Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
|
|
and for remote code, we also use this context manager.
|
|
"""
|
|
originals = defaultdict(dict)
|
|
try:
|
|
# Replace all torch funcs by the ones in this file
|
|
for module_name in TORCH_MODULES_TO_PATCH:
|
|
if module_name in sys.modules:
|
|
module = sys.modules[module_name]
|
|
for func_name in TORCH_INIT_FUNCTIONS.keys():
|
|
if hasattr(module, func_name):
|
|
originals[module][func_name] = getattr(module, func_name)
|
|
setattr(module, func_name, globals()[func_name])
|
|
yield
|
|
finally:
|
|
# Set back the original functions on all modules
|
|
for module, functions in originals.items():
|
|
for func_name, func in functions.items():
|
|
setattr(module, func_name, func)
|
|
|
|
|
|
@contextmanager
|
|
def no_init_weights():
|
|
"""
|
|
Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`).
|
|
This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device
|
|
with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards.
|
|
"""
|
|
from .modeling_utils import PreTrainedModel
|
|
|
|
def empty_func(*args, **kwargs):
|
|
pass
|
|
|
|
originals = defaultdict(dict)
|
|
try:
|
|
# Replace all torch funcs by empty ones
|
|
for module_name in TORCH_MODULES_TO_PATCH:
|
|
if module_name in sys.modules:
|
|
module = sys.modules[module_name]
|
|
for func_name in TORCH_INIT_FUNCTIONS.keys():
|
|
if hasattr(module, func_name):
|
|
originals[module][func_name] = getattr(module, func_name)
|
|
setattr(module, func_name, empty_func)
|
|
|
|
# Also patch our own `init_weights`
|
|
original_init_weights = PreTrainedModel.init_weights
|
|
PreTrainedModel.init_weights = empty_func
|
|
|
|
yield
|
|
finally:
|
|
# Set back the original torch functions on all modules
|
|
for module, functions in originals.items():
|
|
for func_name, func in functions.items():
|
|
setattr(module, func_name, func)
|
|
# Set back `init_weights`
|
|
PreTrainedModel.init_weights = original_init_weights
|
|
|
|
|
|
@contextmanager
|
|
def no_tie_weights():
|
|
"""
|
|
Disable weight tying during loading with `from_pretrained`. This is needed as we want to have access to ALL
|
|
weights in the state_dict during `from_pretrained`, and otherwise tying them would remove them from it, as it's
|
|
called in `post_init` when instantiating.
|
|
"""
|
|
from .modeling_utils import PreTrainedModel
|
|
|
|
def empty_func(*args, **kwargs):
|
|
pass
|
|
|
|
try:
|
|
original_tie_weights = PreTrainedModel.tie_weights
|
|
PreTrainedModel.tie_weights = empty_func
|
|
|
|
yield
|
|
finally:
|
|
# Set back the original
|
|
PreTrainedModel.tie_weights = original_tie_weights
|