You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1091 lines
40 KiB

# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generic utilities
"""
import inspect
import json
import os
import warnings
from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Callable, Iterable, MutableMapping
from contextlib import AbstractContextManager, ExitStack, nullcontext
from dataclasses import dataclass, fields, is_dataclass
from enum import Enum
from functools import partial, wraps
from typing import Any, Optional, TypedDict
import numpy as np
from ..utils import logging
from .import_utils import is_mlx_available, is_torch_available, is_torch_fx_proxy, requires
_CAN_RECORD_REGISTRY = {}
logger = logging.get_logger(__name__)
_is_torch_available = False
if is_torch_available():
# required for @can_return_tuple decorator to work with torchdynamo
import torch
from torch.types import _dtype
from ..model_debugging_utils import model_addition_debugger_context
_is_torch_available = True
# required for @can_return_tuple decorator to work with torchdynamo
_is_mlx_available = False
if is_mlx_available():
_is_mlx_available = True
# vendored from distutils.util
def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
Raises ValueError if 'val' is anything else.
"""
val = val.lower()
if val in {"y", "yes", "t", "true", "on", "1"}:
return 1
if val in {"n", "no", "f", "false", "off", "0"}:
return 0
raise ValueError(f"invalid truth value {val!r}")
def infer_framework_from_repr(x):
"""
Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
frameworks in a smart order, without the need to import the frameworks).
"""
representation = str(type(x))
if representation.startswith("<class 'torch."):
return "pt"
elif representation.startswith("<class 'numpy."):
return "np"
elif representation.startswith("<class 'mlx."):
return "mlx"
def _get_frameworks_and_test_func(x):
"""
Returns an (ordered since we are in Python 3.7+) dictionary framework to test function, which places the framework
we can guess from the repr first, then Numpy, then the others.
"""
framework_to_test = {
"pt": is_torch_tensor,
"np": is_numpy_array,
"mlx": is_mlx_array,
}
preferred_framework = infer_framework_from_repr(x)
# We will test this one first, then numpy, then the others.
frameworks = [] if preferred_framework is None else [preferred_framework]
if preferred_framework != "np":
frameworks.append("np")
frameworks.extend([f for f in framework_to_test if f not in [preferred_framework, "np"]])
return {f: framework_to_test[f] for f in frameworks}
def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor`, `np.ndarray` or `mlx.array` in the order defined by `infer_framework_from_repr`
"""
# This gives us a smart order to test the frameworks with the corresponding tests.
framework_to_test_func = _get_frameworks_and_test_func(x)
for test_func in framework_to_test_func.values():
if test_func(x):
return True
# Tracers
if is_torch_fx_proxy(x):
return True
return False
def is_numpy_array(x):
"""
Tests if `x` is a numpy array or not.
"""
return isinstance(x, np.ndarray)
def is_torch_tensor(x):
"""
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
"""
return _is_torch_available and isinstance(x, torch.Tensor)
def is_torch_device(x):
"""
Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
"""
return _is_torch_available and isinstance(x, torch.device)
def is_torch_dtype(x):
"""
Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
"""
if not _is_torch_available:
return False
if isinstance(x, str):
if hasattr(torch, x):
x = getattr(torch, x)
else:
return False
return isinstance(x, torch.dtype)
def _is_tensor_or_array_like(value):
"""
Check if a value is array-like (includes ragged arrays)
"""
if is_numpy_array(value):
return True
if is_torch_tensor(value):
return True
if isinstance(value, (int, float, bool, np.number)):
return True
if isinstance(value, (list, tuple)):
if len(value) == 0:
# consider empty list or nested list as array-like
return True
return _is_tensor_or_array_like(value[0])
return False
def maybe_autocast(
device_type: str,
dtype: Optional["_dtype"] = None,
enabled: bool = True,
cache_enabled: bool | None = None,
):
"""
Context manager that only autocasts if:
- `autocast` is already enabled in this context
- Or this call to `maybe_autocast` has `enabled=True`
This prevents `autocast` being added to the graph when it is effectively a no-op.
Which makes graph splitting in `torch.compile` more flexible as it removes the
requirement that partition IDs be monotonically increasing.
"""
if torch.is_autocast_enabled(device_type) or enabled:
return torch.autocast(device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
else:
return nullcontext()
def _is_mlx(x):
import mlx.core as mx
return isinstance(x, mx.array)
def is_mlx_array(x):
"""
Tests if `x` is a mlx array or not. Safe to call even when mlx is not installed.
"""
return False if not _is_mlx_available else _is_mlx(x)
def is_flash_attention_requested(config=None, requested_attention_implementation: str | None = None):
"""
Checks whether some flavor of flash attention is requested or not.
This is checked against one of the two arguments, i.e. either the `config` or the directly passed value
`requested_attention_implementation`. Otherwise, an error will be raised (ambiguity).
The different versions of flash attention are usually
- Implementations based on the original flash attention repo: https://github.com/Dao-AILab/flash-attention
- Kernels implementations such as: https://huggingface.co/kernels-community/vllm-flash-attn3
"""
if config is not None and requested_attention_implementation is not None:
raise ValueError(
"Requested attention implementation is ambiguous: "
"Please pass either the config or the name of the attention implementation, not both."
)
if config is not None:
checked_attention_implementation = config._attn_implementation
else:
checked_attention_implementation = requested_attention_implementation
return "flash" in checked_attention_implementation
def to_py_obj(obj):
"""
Convert a PyTorch tensor, Numpy array or python list to a python list.
"""
if isinstance(obj, (int, float)):
return obj
elif isinstance(obj, (dict, UserDict)):
return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
# Only convert directly if all elements are numeric scalars
if all(isinstance(x, (int, float, np.number)) for x in obj):
return list(obj)
# Otherwise recurse element-wise
return [to_py_obj(o) for o in obj]
framework_to_py_obj = {
"pt": lambda obj: obj.tolist(),
"np": lambda obj: obj.tolist(),
}
# This gives us a smart order to test the frameworks with the corresponding tests.
framework_to_test_func = _get_frameworks_and_test_func(obj)
for framework, test_func in framework_to_test_func.items():
if test_func(obj):
return framework_to_py_obj[framework](obj)
# tolist also works on 0d np arrays
if isinstance(obj, np.number):
return obj.tolist()
else:
return obj
def to_numpy(obj):
"""
Convert a PyTorch tensor, Numpy array or python list to a Numpy array.
"""
framework_to_numpy = {
"pt": lambda obj: obj.detach().cpu().numpy(),
"np": lambda obj: obj,
}
if isinstance(obj, (dict, UserDict)):
return {k: to_numpy(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return np.array(obj)
# This gives us a smart order to test the frameworks with the corresponding tests.
framework_to_test_func = _get_frameworks_and_test_func(obj)
for framework, test_func in framework_to_test_func.items():
if test_func(obj):
return framework_to_numpy[framework](obj)
return obj
def safe_load_json_file(json_file: str):
"A helper to load safe config files and raise a proper error message if it wasn't serialized correctly"
try:
with open(json_file, encoding="utf-8") as reader:
text = reader.read()
config_dict = json.loads(text)
except json.JSONDecodeError:
raise OSError(f"It looks like the config file at '{json_file}' is not a valid JSON file.")
return config_dict
class ModelOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
python dictionary.
<Tip warning={true}>
You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
before.
</Tip>
"""
def __init_subclass__(cls) -> None:
"""Register subclasses as pytree nodes.
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if _is_torch_available:
from torch.utils._pytree import register_pytree_node
register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Subclasses of ModelOutput must use the @dataclass decorator
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
# Just need to check that the current class is not ModelOutput
is_modeloutput_subclass = self.__class__ != ModelOutput
if is_modeloutput_subclass and not is_dataclass(self):
raise TypeError(
f"{self.__module__}.{self.__class__.__name__} is not a dataclass."
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
)
def __post_init__(self):
"""Check the ModelOutput dataclass.
Only occurs if @dataclass decorator has been used.
"""
class_fields = fields(self)
# Safety and consistency checks
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not is_tensor(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
# reset first field to None
setattr(self, class_fields[0].name, None)
for idx, element in enumerate(iterator):
if not isinstance(element, (list, tuple)) or len(element) != 2 or not isinstance(element[0], str):
if idx == 0:
# If we do not have an iterator of key/values, set it as attribute
self[class_fields[0].name] = first_field
else:
# If we have a mixed iterator, raise an error
raise ValueError(
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
)
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
# Will raise a KeyException if needed
super().__setitem__(key, value)
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def __reduce__(self):
if not is_dataclass(self):
return super().__reduce__()
callable, _args, *remaining = super().__reduce__()
args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining
def to_tuple(self) -> tuple:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())
if _is_torch_available:
import torch.utils._pytree as _torch_pytree
def _model_output_flatten(output: ModelOutput) -> tuple[list[Any], "_torch_pytree.Context"]:
return list(output.values()), list(output.keys())
def _model_output_unflatten(
values: Iterable[Any],
context: "_torch_pytree.Context",
output_type=None,
) -> ModelOutput:
return output_type(**dict(zip(context, values)))
_torch_pytree.register_pytree_node(
ModelOutput,
_model_output_flatten,
partial(_model_output_unflatten, output_type=ModelOutput),
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
)
class ExplicitEnum(str, Enum):
"""
Enum with more explicit error message for missing values.
"""
@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
class PaddingStrategy(ExplicitEnum):
"""
Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
IDE.
"""
LONGEST = "longest"
MAX_LENGTH = "max_length"
DO_NOT_PAD = "do_not_pad"
class TensorType(ExplicitEnum):
"""
Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
tab-completion in an IDE.
"""
PYTORCH = "pt"
NUMPY = "np"
MLX = "mlx"
class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
in the `fastcore` library.
"""
def __init__(self, context_managers: list[AbstractContextManager]):
self.context_managers = context_managers
self.stack = ExitStack()
def __enter__(self):
for context_manager in self.context_managers:
self.stack.enter_context(context_manager)
def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)
def can_return_loss(model_class):
"""
Check if a given model can return loss.
Args:
model_class (`type`): The class of the model.
"""
signature = inspect.signature(model_class.forward)
for p in signature.parameters:
if p == "return_loss" and signature.parameters[p].default is True:
return True
return False
def find_labels(model_class):
"""
Find the labels used by a given model.
Args:
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
signature = inspect.signature(model_class.forward)
if "QuestionAnswering" in model_name:
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
else:
return [p for p in signature.parameters if "label" in p]
def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
"""Flatten a nested dict into a single level dict."""
def _flatten_dict(d, parent_key="", delimiter="."):
for k, v in d.items():
key = str(parent_key) + delimiter + str(k) if parent_key else k
if v and isinstance(v, MutableMapping):
yield from flatten_dict(v, key, delimiter=delimiter).items()
else:
yield key, v
return dict(_flatten_dict(d, parent_key, delimiter))
def transpose(array, axes=None):
"""
Framework-agnostic version of transpose operation.
"""
if is_numpy_array(array):
return np.transpose(array, axes=axes)
elif is_torch_tensor(array):
return array.T if axes is None else array.permute(*axes)
else:
raise ValueError(f"Type not supported for transpose: {type(array)}.")
def reshape(array, newshape):
"""
Framework-agnostic version of reshape operation.
"""
if is_numpy_array(array):
return np.reshape(array, newshape)
elif is_torch_tensor(array):
return array.reshape(*newshape)
else:
raise ValueError(f"Type not supported for reshape: {type(array)}.")
def squeeze(array, axis=None):
"""
Framework-agnostic version of squeeze operation.
"""
if is_numpy_array(array):
return np.squeeze(array, axis=axis)
elif is_torch_tensor(array):
return array.squeeze() if axis is None else array.squeeze(dim=axis)
else:
raise ValueError(f"Type not supported for squeeze: {type(array)}.")
def expand_dims(array, axis):
"""
Framework-agnostic version of expand_dims operation.
"""
if is_numpy_array(array):
return np.expand_dims(array, axis)
elif is_torch_tensor(array):
return array.unsqueeze(dim=axis)
else:
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
def tensor_size(array):
"""
Framework-agnostic version of size operation.
"""
if is_numpy_array(array):
return np.size(array)
elif is_torch_tensor(array):
return array.numel()
else:
raise ValueError(f"Type not supported for tensor_size: {type(array)}.")
def torch_int(x):
"""
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
"""
if not _is_torch_available:
return int(x)
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
def torch_float(x):
"""
Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float.
"""
if not _is_torch_available:
return int(x)
return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
def filter_out_non_signature_kwargs(extra: list | None = None):
"""
Decorator to filter out named arguments that are not in the function signature.
This decorator ensures that only the keyword arguments that match the function's signature, or are specified in the
`extra` list, are passed to the function. Any additional keyword arguments are filtered out and a warning is issued.
Parameters:
extra (`Optional[list]`, *optional*):
A list of extra keyword argument names that are allowed even if they are not in the function's signature.
Returns:
Callable:
A decorator that wraps the function and filters out invalid keyword arguments.
Example usage:
```python
@filter_out_non_signature_kwargs(extra=["allowed_extra_arg"])
def my_function(arg1, arg2, **kwargs):
print(arg1, arg2, kwargs)
my_function(arg1=1, arg2=2, allowed_extra_arg=3, invalid_arg=4)
# This will print: 1 2 {"allowed_extra_arg": 3}
# And issue a warning: "The following named arguments are not valid for `my_function` and were ignored: 'invalid_arg'"
```
"""
extra = extra or []
extra_params_to_pass = set(extra)
def decorator(func):
sig = inspect.signature(func)
function_named_args = set(sig.parameters.keys())
valid_kwargs_to_pass = function_named_args.union(extra_params_to_pass)
# Required for better warning message
is_instance_method = "self" in function_named_args
is_class_method = "cls" in function_named_args
# Mark function as decorated
func._filter_out_non_signature_kwargs = True
@wraps(func)
def wrapper(*args, **kwargs):
valid_kwargs = {}
invalid_kwargs = {}
for k, v in kwargs.items():
if k in valid_kwargs_to_pass:
valid_kwargs[k] = v
else:
invalid_kwargs[k] = v
if invalid_kwargs:
invalid_kwargs_names = [f"'{k}'" for k in invalid_kwargs]
invalid_kwargs_names = ", ".join(invalid_kwargs_names)
# Get the class name for better warning message
if is_instance_method:
cls_prefix = args[0].__class__.__name__ + "."
elif is_class_method:
cls_prefix = args[0].__name__ + "."
else:
cls_prefix = ""
warnings.warn(
f"The following named arguments are not valid for `{cls_prefix}{func.__name__}`"
f" and were ignored: {invalid_kwargs_names}",
UserWarning,
stacklevel=2,
)
return func(*args, **valid_kwargs)
return wrapper
return decorator
class TransformersKwargs(TypedDict, total=False):
"""
Keyword arguments to be passed to the forward pass of a `PreTrainedModel`.
Attributes:
num_items_in_batch (`Optional[torch.Tensor]`, *optional*):
Number of items in the batch. It is recommended to pass it when you are doing gradient accumulation.
output_hidden_states (`Optional[bool]`, *optional*):
Most of the models support outputting all hidden states computed during the forward pass.
output_attentions (`Optional[bool]`, *optional*):
Turn this on to return the intermediary attention scores.
output_router_logits (`Optional[bool]`, *optional*):
For MoE models, this allows returning the router logits to compute the loss.
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumulative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
position_ids (`torch.LongTensor`, *optional*)
Indices of positions of each input sequence tokens.
"""
num_items_in_batch: Optional["torch.Tensor"]
output_hidden_states: bool | None
output_attentions: bool | None
output_router_logits: bool | None
cu_seq_lens_q: Optional["torch.LongTensor"]
cu_seq_lens_k: Optional["torch.LongTensor"]
max_length_q: int | None
max_length_k: int | None
position_ids: Optional["torch.LongTensor"]
def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:
"""Checks whether a config dict is a timm config dict."""
return "pretrained_cfg" in config_dict
def is_timm_local_checkpoint(pretrained_model_path: str) -> bool:
"""
Checks whether a checkpoint is a timm model checkpoint.
"""
if pretrained_model_path is None:
return False
# in case it's Path, not str
pretrained_model_path = str(pretrained_model_path)
is_file = os.path.isfile(pretrained_model_path)
is_dir = os.path.isdir(pretrained_model_path)
# pretrained_model_path is a file
if is_file and pretrained_model_path.endswith(".json"):
with open(pretrained_model_path) as f:
config_dict = json.load(f)
return is_timm_config_dict(config_dict)
# pretrained_model_path is a directory with a config.json
if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")):
with open(os.path.join(pretrained_model_path, "config.json")) as f:
config_dict = json.load(f)
return is_timm_config_dict(config_dict)
return False
def set_attribute_for_modules(module: "torch.nn.Module", key: str, value: Any):
"""
Set a value to a module and all submodules.
"""
setattr(module, key, value)
for submodule in module.children():
set_attribute_for_modules(submodule, key, value)
def del_attribute_from_modules(module: "torch.nn.Module", key: str):
"""
Delete a value from a module and all submodules.
"""
# because we might remove it previously in case it's a shared module, e.g. activation function
if hasattr(module, key):
delattr(module, key)
for submodule in module.children():
del_attribute_from_modules(submodule, key)
def can_return_tuple(func):
"""
Decorator to wrap model method, to call output.to_tuple() if return_dict=False passed as a kwarg or
use_return_dict=False is set in the config.
Note:
output.to_tuple() convert output to tuple skipping all `None` values.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
return_dict = self.config.return_dict if hasattr(self, "config") else True
return_dict_passed = kwargs.pop("return_dict", return_dict)
if return_dict_passed is not None:
return_dict = return_dict_passed
output = func(self, *args, **kwargs)
if not return_dict and not isinstance(output, tuple):
output = output.to_tuple()
return output
return wrapper
@dataclass
@requires(backends=("torch",))
class OutputRecorder:
"""
Configuration for recording outputs from a model via hooks.
Attributes:
target_class (Type): The class (e.g., nn.Module) to which the hook will be attached.
index (Optional[int]): If the output is a tuple/list, optionally record only at a specific index.
layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn".
class_name (Optional[str]): Name of the class to which the hook will be attached. Could be the suffix of class name in some cases.
"""
target_class: "type[torch.nn.Module]"
index: int = 0
layer_name: str | None = None
class_name: str | None = None
def check_model_inputs(func=None, *, tie_last_hidden_states=True):
"""
Decorator to intercept specific layer outputs without using hooks.
Compatible with torch.compile (Dynamo tracing).
Args:
tie_last_hidden_states (`bool`, *optional*, defaults to `True`):
Whether to overwrite `out.hidden_states[-1]` with the `out.last_hidden_state`.
This is true for all language models and should be toggled off only if
`out.hidden_states[-1]` has to be the hidden state before last layer norm, which
is needed for some vision models (e.g. CLIP, SigLIP)
"""
def wrapped_fn(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
args_with_config_defaults = [
"use_cache",
"vision_feature_layer",
"vision_feature_select_strategy",
"vision_aspect_ratio",
]
for arg_name in args_with_config_defaults:
arg_index = None
if arg_name in func.__code__.co_varnames:
arg_index = func.__code__.co_varnames.index(arg_name) - 1 # -1 for self
if arg_index is not None and len(args) > arg_index and args[arg_index] is not None:
arg_value = args[arg_index]
elif kwargs.get(arg_name) is not None:
arg_value = kwargs[arg_name]
else:
arg_value = getattr(self.config, arg_name, None)
if arg_value is not None:
# Arg-specific handling
if arg_name == "use_cache":
if getattr(self, "gradient_checkpointing", False) and self.training and arg_value:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
arg_value = False
elif arg_name == "vision_feature_select_strategy":
valid_strategies = ["default", "full"]
if arg_value not in valid_strategies:
raise ValueError(
f"`Unexpected select feature strategy: {arg_value}. "
f"Please select from {valid_strategies}."
)
if arg_index is not None and len(args) > arg_index:
args = list(args)
args[arg_index] = arg_value
args = tuple(args)
else:
kwargs[arg_name] = arg_value
return_dict = kwargs.pop("return_dict", None)
if return_dict is None:
return_dict = getattr(self.config, "return_dict", True)
all_args = kwargs.copy()
if "kwargs" in all_args:
for k, v in all_args["kwargs"].items():
all_args[k] = v
# _can_record_outputs is None by default
capture_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__)) or {} # there is a weak ref for executorch
recordable_keys = {
f"output_{k}": all_args.get(
f"output_{k}",
getattr(
self.config,
f"output_{k}",
all_args.get("output_attentions", getattr(self.config, "output_attentions", False)),
),
)
for k in capture_flags
}
# We let cross attentions to be saved separately because some models add `cross-attn` layer
# when certain condtions are met. Let's output cross attention if attentions are requested (for BC)
if "output_attentions" in recordable_keys:
recordable_keys["output_cross_attentions"] = recordable_keys["output_attentions"]
collected_outputs = defaultdict(tuple)
monkey_patched_layers = []
def make_capture_wrapper(module, orig_forward, key, index):
@wraps(orig_forward)
def wrapped_forward(*args, **kwargs):
if key == "hidden_states" and len(collected_outputs[key]) == 0:
collected_outputs[key] += (args[0],)
output = orig_forward(*args, **kwargs)
if not isinstance(output, tuple):
collected_outputs[key] += (output,)
elif output[index] is not None:
if key not in collected_outputs:
collected_outputs[key] = (output[index],)
else:
collected_outputs[key] += (output[index],)
return output
return wrapped_forward
if any(recordable_keys.values()):
capture_tasks = []
for key, layer_specs in capture_flags.items():
if not recordable_keys.get(f"output_{key}", False):
continue
if not isinstance(layer_specs, list):
layer_specs = [layer_specs]
for specs in layer_specs:
if not isinstance(specs, OutputRecorder):
index = 0 if "hidden_states" in key else 1
class_name = None if not isinstance(specs, str) else specs
target_class = specs if not isinstance(specs, str) else None
specs = OutputRecorder(target_class=target_class, index=index, class_name=class_name)
capture_tasks.append((key, specs))
for name, module in self.named_modules():
for key, specs in capture_tasks:
# The second check is for multimodals where only backbone layer suffix is available
if (specs.target_class is not None and isinstance(module, specs.target_class)) or (
specs.class_name is not None and name.endswith(specs.class_name)
):
if specs.layer_name is not None and specs.layer_name not in name:
continue
# Monkey patch forward
original_forward = module.forward
module.forward = make_capture_wrapper(module, original_forward, key, specs.index)
monkey_patched_layers.append((module, original_forward))
try:
if kwargs.get("debug_io", False):
with model_addition_debugger_context(
self, kwargs.get("debug_io_dir", "model_debug"), kwargs.get("prune_layers")
):
outputs = func(self, *args, **kwargs)
else:
outputs = func(self, *args, **kwargs)
except TypeError as original_exception:
# If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.
# Get a TypeError even after removing the recordable kwargs -> re-raise the original exception
# Otherwise -> we're probably missing `**kwargs` in the decorated function
kwargs_without_recordable = {k: v for k, v in kwargs.items() if k not in recordable_keys}
try:
outputs = func(self, *args, **kwargs_without_recordable)
except TypeError:
raise original_exception
raise TypeError(
"Missing `**kwargs` in the signature of the `@check_model_inputs`-decorated function "
f"({func.__qualname__})"
)
# Restore original forward methods
for module, original_forward in monkey_patched_layers:
module.forward = original_forward
# Inject collected outputs into model output
for key in collected_outputs:
if key == "hidden_states":
if not tie_last_hidden_states:
pass
elif hasattr(outputs, "vision_hidden_states"):
collected_outputs[key] = collected_outputs[key][:-1]
collected_outputs[key] += (outputs.vision_hidden_states,)
elif hasattr(outputs, "last_hidden_state"):
collected_outputs[key] = collected_outputs[key][:-1]
collected_outputs[key] += (outputs.last_hidden_state,)
outputs[key] = collected_outputs[key]
elif key == "attentions":
if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2:
outputs[key] = collected_outputs[key][0::2]
outputs["cross_" + key] = collected_outputs[key][1::2]
else:
outputs[key] = collected_outputs[key]
else:
outputs[key] = collected_outputs[key]
if return_dict is False:
outputs = outputs.to_tuple()
return outputs
return wrapper
if func is not None:
return wrapped_fn(func)
return wrapped_fn
class GeneralInterface(MutableMapping):
"""
Dict-like object keeping track of a class-wide mapping, as well as a local one. Allows to have library-wide
modifications though the class mapping, as well as local modifications in a single file with the local mapping.
"""
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function)
_global_mapping = {}
def __init__(self):
self._local_mapping = {}
def __getitem__(self, key):
# First check if instance has a local override
if key in self._local_mapping:
return self._local_mapping[key]
return self._global_mapping[key]
def __setitem__(self, key, value):
# Allow local update of the default functions without impacting other instances
self._local_mapping.update({key: value})
def __delitem__(self, key):
del self._local_mapping[key]
def __iter__(self):
# Ensure we use all keys, with the overwritten ones on top
return iter({**self._global_mapping, **self._local_mapping})
def __len__(self):
return len(self._global_mapping.keys() | self._local_mapping.keys())
@classmethod
def register(cls, key: str, value: Callable):
cls._global_mapping.update({key: value})
def valid_keys(self) -> list[str]:
return list(self.keys())