You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
9.8 KiB
261 lines
9.8 KiB
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import inspect
|
|
from collections.abc import Callable
|
|
from functools import lru_cache, wraps
|
|
|
|
import torch
|
|
from safetensors.torch import storage_ptr, storage_size
|
|
from torch import nn
|
|
|
|
from .utils import (
|
|
is_torch_greater_or_equal,
|
|
is_torch_xla_available,
|
|
is_torchdynamo_compiling,
|
|
logging,
|
|
)
|
|
|
|
|
|
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
|
|
is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
|
|
|
# For backwards compatibility (e.g. some remote codes on Hub using those variables).
|
|
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
|
|
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
|
|
is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
|
|
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
|
|
is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
|
|
is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)
|
|
is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True)
|
|
|
|
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
|
_torch_distributed_available = torch.distributed.is_available()
|
|
|
|
|
|
def softmax_backward_data(parent, grad_output, output):
|
|
"""
|
|
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
|
|
to the torch version detected.
|
|
"""
|
|
|
|
from torch import _softmax_backward_data
|
|
|
|
return _softmax_backward_data(grad_output, output, parent.dim, output.dtype)
|
|
|
|
|
|
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
|
|
"""
|
|
Prune a linear layer to keep only entries in index.
|
|
|
|
Used to remove heads.
|
|
|
|
Args:
|
|
layer (`torch.nn.Linear`): The layer to prune.
|
|
index (`torch.LongTensor`): The indices to keep in the layer.
|
|
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
|
|
|
|
Returns:
|
|
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
|
|
"""
|
|
index = index.to(layer.weight.device)
|
|
W = layer.weight.index_select(dim, index).detach().clone()
|
|
if layer.bias is not None:
|
|
if dim == 1:
|
|
b = layer.bias.detach().clone()
|
|
else:
|
|
b = layer.bias[index].detach().clone()
|
|
new_size = list(layer.weight.size())
|
|
new_size[dim] = len(index)
|
|
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
|
new_layer.weight.requires_grad = False
|
|
new_layer.weight.copy_(W.contiguous())
|
|
new_layer.weight.requires_grad = True
|
|
if layer.bias is not None:
|
|
new_layer.bias.requires_grad = False
|
|
new_layer.bias.copy_(b.contiguous())
|
|
new_layer.bias.requires_grad = True
|
|
return new_layer
|
|
|
|
|
|
class Conv1D(nn.Module):
|
|
"""
|
|
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
|
|
|
|
Basically works like a linear layer but the weights are transposed.
|
|
|
|
Args:
|
|
nf (`int`): The number of output features.
|
|
nx (`int`): The number of input features.
|
|
"""
|
|
|
|
def __init__(self, nf, nx):
|
|
super().__init__()
|
|
self.nf = nf
|
|
self.nx = nx
|
|
self.weight = nn.Parameter(torch.empty(nx, nf))
|
|
self.bias = nn.Parameter(torch.zeros(nf))
|
|
nn.init.normal_(self.weight, std=0.02)
|
|
|
|
def __repr__(self) -> str:
|
|
return "Conv1D(nf={nf}, nx={nx})".format(**self.__dict__)
|
|
|
|
def forward(self, x):
|
|
size_out = x.size()[:-1] + (self.nf,)
|
|
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
|
x = x.view(size_out)
|
|
return x
|
|
|
|
|
|
def apply_chunking_to_forward(
|
|
forward_fn: Callable[..., torch.Tensor],
|
|
chunk_size: int,
|
|
chunk_dim: int,
|
|
*input_tensors,
|
|
) -> torch.Tensor:
|
|
"""
|
|
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
|
|
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
|
|
|
|
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
|
|
applying `forward_fn` to `input_tensors`.
|
|
|
|
Args:
|
|
forward_fn (`Callable[..., torch.Tensor]`):
|
|
The forward function of the model.
|
|
chunk_size (`int`):
|
|
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
|
|
chunk_dim (`int`):
|
|
The dimension over which the `input_tensors` should be chunked.
|
|
input_tensors (`tuple[torch.Tensor]`):
|
|
The input tensors of `forward_fn` which will be chunked
|
|
|
|
Returns:
|
|
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
|
|
|
|
|
|
Examples:
|
|
|
|
```python
|
|
# rename the usual forward() fn to forward_chunk()
|
|
def forward_chunk(self, hidden_states):
|
|
hidden_states = self.decoder(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# implement a chunked forward function
|
|
def forward(self, hidden_states):
|
|
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
|
|
```"""
|
|
|
|
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
|
|
|
|
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
|
|
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
|
if num_args_in_forward_chunk_fn != len(input_tensors):
|
|
raise ValueError(
|
|
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
|
|
"tensors are given"
|
|
)
|
|
|
|
if chunk_size > 0:
|
|
tensor_shape = input_tensors[0].shape[chunk_dim]
|
|
for input_tensor in input_tensors:
|
|
if input_tensor.shape[chunk_dim] != tensor_shape:
|
|
raise ValueError(
|
|
f"All input tenors have to be of the same shape: {tensor_shape}, "
|
|
f"found shape {input_tensor.shape[chunk_dim]}"
|
|
)
|
|
|
|
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
|
|
raise ValueError(
|
|
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
|
|
f"size {chunk_size}"
|
|
)
|
|
|
|
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
|
|
|
# chunk input tensor into tuples
|
|
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
|
# apply forward fn to every tuple
|
|
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
|
# concatenate output at same dimension
|
|
return torch.cat(output_chunks, dim=chunk_dim)
|
|
|
|
return forward_fn(*input_tensors)
|
|
|
|
|
|
def meshgrid(*tensors: torch.Tensor | list[torch.Tensor], indexing: str | None = None) -> tuple[torch.Tensor, ...]:
|
|
"""
|
|
Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument.
|
|
|
|
Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html
|
|
"""
|
|
return torch.meshgrid(*tensors, indexing=indexing)
|
|
|
|
|
|
def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
|
|
"""
|
|
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
|
|
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
|
|
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
|
|
non-overlapping lifetimes may have the same id.
|
|
"""
|
|
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
|
|
from torch.distributed.tensor import DTensor
|
|
|
|
if isinstance(tensor, DTensor):
|
|
local_tensor = tensor.to_local()
|
|
return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes
|
|
|
|
if tensor.device.type == "xla" and is_torch_xla_available():
|
|
# NOTE: xla tensors dont have storage
|
|
# use some other unique id to distinguish.
|
|
# this is a XLA tensor, it must be created using torch_xla's
|
|
# device. So the following import is safe:
|
|
import torch_xla
|
|
|
|
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
|
|
else:
|
|
unique_id = storage_ptr(tensor)
|
|
|
|
return tensor.device, unique_id, storage_size(tensor)
|
|
|
|
|
|
@wraps(lru_cache)
|
|
def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
|
|
"""
|
|
LRU cache decorator from standard functools library, but with a workaround to disable
|
|
caching when torchdynamo is compiling. Expected to work with class methods.
|
|
"""
|
|
|
|
def decorator(func):
|
|
func_with_cache = lru_cache(*lru_args, **lru_kwargs)(func)
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if is_torchdynamo_compiling():
|
|
return func(*args, **kwargs)
|
|
else:
|
|
return func_with_cache(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|