You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1310 lines
59 KiB
1310 lines
59 KiB
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch OpenAI GPT-2 model."""
|
|
|
|
import math
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from ... import initialization as init
|
|
from ...activations import ACT2FN, get_activation
|
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
from ...generation import GenerationMixin
|
|
from ...masking_utils import create_causal_mask
|
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
from ...modeling_outputs import (
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
CausalLMOutputWithCrossAttentions,
|
|
QuestionAnsweringModelOutput,
|
|
SequenceClassifierOutputWithPast,
|
|
TokenClassifierOutput,
|
|
)
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
from ...pytorch_utils import Conv1D
|
|
from ...utils import (
|
|
ModelOutput,
|
|
auto_docstring,
|
|
logging,
|
|
)
|
|
from ...utils.generic import is_flash_attention_requested, maybe_autocast
|
|
from .configuration_gpt2 import GPT2Config
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def eager_attention_forward(module, query, key, value, attention_mask, **kwargs):
|
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
|
|
|
if module.scale_attn_weights:
|
|
attn_weights = attn_weights / torch.full(
|
|
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
|
)
|
|
|
|
# Layer-wise attention scaling
|
|
if module.scale_attn_by_inverse_layer_idx:
|
|
attn_weights = attn_weights / float(module.layer_idx + 1)
|
|
|
|
if not module.is_cross_attention:
|
|
# if only "normal" attention layer implements causal mask
|
|
query_length, key_length = query.size(-2), key.size(-2)
|
|
causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
|
|
mask_value = torch.finfo(attn_weights.dtype).min
|
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
|
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
|
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
|
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
|
|
|
if attention_mask is not None:
|
|
# Apply the attention mask
|
|
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
|
|
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
|
|
attn_weights = attn_weights.type(value.dtype)
|
|
attn_weights = module.attn_dropout(attn_weights)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class GPT2Attention(nn.Module):
|
|
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
|
super().__init__()
|
|
self.config = config
|
|
max_positions = config.max_position_embeddings
|
|
self.register_buffer(
|
|
"bias",
|
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
|
1, 1, max_positions, max_positions
|
|
),
|
|
persistent=False,
|
|
)
|
|
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
self.split_size = self.embed_dim
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
|
|
self.scale_attn_weights = config.scale_attn_weights
|
|
self.is_cross_attention = is_cross_attention
|
|
|
|
# Layer-wise attention scaling, reordering, and upcasting
|
|
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
|
|
self.layer_idx = layer_idx
|
|
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
|
|
|
|
if self.is_cross_attention:
|
|
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
|
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
|
else:
|
|
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
|
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
|
|
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
|
self.is_causal = not is_cross_attention
|
|
|
|
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None):
|
|
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
|
|
bsz, num_heads, q_seq_len, dk = query.size()
|
|
_, _, k_seq_len, _ = key.size()
|
|
|
|
# Preallocate attn_weights for `baddbmm`
|
|
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
|
|
|
|
# Compute Scale Factor
|
|
scale_factor = 1.0
|
|
if self.scale_attn_weights:
|
|
scale_factor /= float(value.size(-1)) ** 0.5
|
|
|
|
if self.scale_attn_by_inverse_layer_idx:
|
|
scale_factor /= float(self.layer_idx + 1)
|
|
|
|
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
|
with maybe_autocast(query.device.type, enabled=False):
|
|
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
|
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
|
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
|
|
|
if not self.is_cross_attention:
|
|
# if only "normal" attention layer implements causal mask
|
|
query_length, key_length = query.size(-2), key.size(-2)
|
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
|
mask_value = torch.finfo(attn_weights.dtype).min
|
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
|
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
|
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
|
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
|
|
|
if attention_mask is not None:
|
|
# Apply the attention mask
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
|
|
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
|
|
if attn_weights.dtype != torch.float32:
|
|
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
|
|
attn_weights = attn_weights.type(value.dtype)
|
|
attn_weights = self.attn_dropout(attn_weights)
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: tuple[torch.FloatTensor] | None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
encoder_attention_mask: torch.FloatTensor | None = None,
|
|
output_attentions: bool | None = False,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor | tuple[torch.Tensor], ...]:
|
|
is_cross_attention = encoder_hidden_states is not None
|
|
if past_key_values is not None:
|
|
if isinstance(past_key_values, EncoderDecoderCache):
|
|
is_updated = past_key_values.is_updated.get(self.layer_idx)
|
|
if is_cross_attention:
|
|
# after the first generated id, we can subsequently re-use all key/value_layer from cache
|
|
curr_past_key_values = past_key_values.cross_attention_cache
|
|
else:
|
|
curr_past_key_values = past_key_values.self_attention_cache
|
|
else:
|
|
curr_past_key_values = past_key_values
|
|
|
|
if is_cross_attention:
|
|
if not hasattr(self, "q_attn"):
|
|
raise ValueError(
|
|
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
|
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
|
)
|
|
query_states = self.q_attn(hidden_states)
|
|
attention_mask = encoder_attention_mask
|
|
|
|
# Try to get key/value states from cache if possible
|
|
if past_key_values is not None and is_updated:
|
|
key_states = curr_past_key_values.layers[self.layer_idx].keys
|
|
value_states = curr_past_key_values.layers[self.layer_idx].values
|
|
else:
|
|
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
|
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
|
|
key_states = key_states.view(shape_kv).transpose(1, 2)
|
|
value_states = value_states.view(shape_kv).transpose(1, 2)
|
|
else:
|
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
|
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
|
|
key_states = key_states.view(shape_kv).transpose(1, 2)
|
|
value_states = value_states.view(shape_kv).transpose(1, 2)
|
|
|
|
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
|
|
query_states = query_states.view(shape_q).transpose(1, 2)
|
|
|
|
if (past_key_values is not None and not is_cross_attention) or (
|
|
past_key_values is not None and is_cross_attention and not is_updated
|
|
):
|
|
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
|
cache_position = cache_position if not is_cross_attention else None
|
|
key_states, value_states = curr_past_key_values.update(
|
|
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
|
)
|
|
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
|
if is_cross_attention:
|
|
past_key_values.is_updated[self.layer_idx] = True
|
|
|
|
using_eager = self.config._attn_implementation == "eager"
|
|
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
self.config._attn_implementation, eager_attention_forward
|
|
)
|
|
|
|
if using_eager and self.reorder_and_upcast_attn:
|
|
attn_output, attn_weights = self._upcast_and_reordered_attn(
|
|
query_states, key_states, value_states, attention_mask
|
|
)
|
|
else:
|
|
attn_output, attn_weights = attention_interface(
|
|
self,
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
dropout=self.attn_dropout.p if self.training else 0.0,
|
|
**kwargs,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
|
|
attn_output = self.c_proj(attn_output)
|
|
attn_output = self.resid_dropout(attn_output)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class GPT2MLP(nn.Module):
|
|
def __init__(self, intermediate_size, config):
|
|
super().__init__()
|
|
embed_dim = config.hidden_size
|
|
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
|
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
|
self.act = ACT2FN[config.activation_function]
|
|
self.dropout = nn.Dropout(config.resid_pdrop)
|
|
|
|
def forward(self, hidden_states: tuple[torch.FloatTensor] | None) -> torch.FloatTensor:
|
|
hidden_states = self.c_fc(hidden_states)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states = self.c_proj(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class GPT2Block(GradientCheckpointingLayer):
|
|
def __init__(self, config, layer_idx=None):
|
|
super().__init__()
|
|
hidden_size = config.hidden_size
|
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
|
|
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
|
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
if config.add_cross_attention:
|
|
self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
|
|
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
self.mlp = GPT2MLP(inner_dim, config)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: tuple[torch.FloatTensor] | None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
encoder_attention_mask: torch.FloatTensor | None = None,
|
|
use_cache: bool | None = False,
|
|
output_attentions: bool | None = False,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor] | tuple[torch.Tensor, tuple[torch.FloatTensor, ...]] | None:
|
|
residual = hidden_states
|
|
hidden_states = self.ln_1(hidden_states)
|
|
attn_output, self_attn_weights = self.attn(
|
|
hidden_states,
|
|
past_key_values=past_key_values,
|
|
cache_position=cache_position,
|
|
attention_mask=attention_mask,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
**kwargs,
|
|
)
|
|
# residual connection
|
|
hidden_states = attn_output + residual
|
|
|
|
if encoder_hidden_states is not None:
|
|
# add one self-attention block for cross-attention
|
|
if not hasattr(self, "crossattention"):
|
|
raise ValueError(
|
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
|
"cross-attention layers by setting `config.add_cross_attention=True`"
|
|
)
|
|
residual = hidden_states
|
|
hidden_states = self.ln_cross_attn(hidden_states)
|
|
cross_attn_output, cross_attn_weights = self.crossattention(
|
|
hidden_states,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
# residual connection
|
|
hidden_states = residual + cross_attn_output
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.ln_2(hidden_states)
|
|
feed_forward_hidden_states = self.mlp(hidden_states)
|
|
# residual connection
|
|
hidden_states = residual + feed_forward_hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
if encoder_hidden_states is not None:
|
|
outputs += (cross_attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
|
|
class GPT2SequenceSummary(nn.Module):
|
|
r"""
|
|
Compute a single vector summary of a sequence hidden states.
|
|
|
|
Args:
|
|
config ([`GPT2Config`]):
|
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
|
|
config class of your model for the default values it uses):
|
|
|
|
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
|
|
|
|
- `"last"` -- Take the last token hidden state (like XLNet)
|
|
- `"first"` -- Take the first token hidden state (like Bert)
|
|
- `"mean"` -- Take the mean of all tokens hidden states
|
|
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
|
|
- `"attn"` -- Not implemented now, use multi-head attention
|
|
|
|
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
|
|
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
|
|
(otherwise to `config.hidden_size`).
|
|
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
|
|
another string or `None` will add no activation.
|
|
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
|
|
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
|
|
"""
|
|
|
|
def __init__(self, config: GPT2Config):
|
|
super().__init__()
|
|
|
|
self.summary_type = getattr(config, "summary_type", "last")
|
|
if self.summary_type == "attn":
|
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
|
raise NotImplementedError
|
|
|
|
self.summary = nn.Identity()
|
|
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
|
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
|
num_classes = config.num_labels
|
|
else:
|
|
num_classes = config.hidden_size
|
|
self.summary = nn.Linear(config.hidden_size, num_classes)
|
|
|
|
activation_string = getattr(config, "summary_activation", None)
|
|
self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
|
|
|
|
self.first_dropout = nn.Identity()
|
|
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
|
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
|
|
|
self.last_dropout = nn.Identity()
|
|
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
|
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
|
|
|
def forward(
|
|
self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
|
|
) -> torch.FloatTensor:
|
|
"""
|
|
Compute a single vector summary of a sequence hidden states.
|
|
|
|
Args:
|
|
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
|
|
The hidden states of the last layer.
|
|
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
|
|
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
|
|
|
|
Returns:
|
|
`torch.FloatTensor`: The summary of the sequence hidden states.
|
|
"""
|
|
if self.summary_type == "last":
|
|
output = hidden_states[:, -1]
|
|
elif self.summary_type == "first":
|
|
output = hidden_states[:, 0]
|
|
elif self.summary_type == "mean":
|
|
output = hidden_states.mean(dim=1)
|
|
elif self.summary_type == "cls_index":
|
|
if cls_index is None:
|
|
cls_index = torch.full_like(
|
|
hidden_states[..., :1, :],
|
|
hidden_states.shape[-2] - 1,
|
|
dtype=torch.long,
|
|
)
|
|
else:
|
|
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
|
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
|
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
|
elif self.summary_type == "attn":
|
|
raise NotImplementedError
|
|
|
|
output = self.first_dropout(output)
|
|
output = self.summary(output)
|
|
output = self.activation(output)
|
|
output = self.last_dropout(output)
|
|
|
|
return output
|
|
|
|
|
|
@auto_docstring
|
|
class GPT2PreTrainedModel(PreTrainedModel):
|
|
config: GPT2Config
|
|
base_model_prefix = "transformer"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["GPT2Block"]
|
|
_skip_keys_device_placement = "past_key_values"
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
_supports_attention_backend = True
|
|
_can_compile_fullgraph = True
|
|
|
|
@torch.no_grad()
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights."""
|
|
if isinstance(module, (nn.Linear, Conv1D)):
|
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
if module.bias is not None:
|
|
init.zeros_(module.bias)
|
|
elif isinstance(module, nn.Embedding):
|
|
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
|
# Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
|
|
if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
|
|
init.zeros_(module.weight[module.padding_idx])
|
|
elif isinstance(module, nn.LayerNorm):
|
|
init.zeros_(module.bias)
|
|
init.ones_(module.weight)
|
|
elif isinstance(module, GPT2Attention):
|
|
max_positions = module.config.max_position_embeddings
|
|
init.copy_(
|
|
module.bias,
|
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
|
1, 1, max_positions, max_positions
|
|
),
|
|
)
|
|
|
|
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
|
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
|
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
|
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
|
#
|
|
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
|
if isinstance(module, PreTrainedModel):
|
|
for name, p in module.named_parameters():
|
|
if name == "c_proj.weight":
|
|
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
|
init.normal_(p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
|
|
|
|
|
|
@dataclass
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
Base class for outputs of models predicting if two sentences are consecutive or not.
|
|
"""
|
|
)
|
|
class GPT2DoubleHeadsModelOutput(ModelOutput):
|
|
r"""
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
Language modeling loss.
|
|
mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
|
|
Multiple choice classification loss.
|
|
logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
|
|
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
|
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
|
|
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
"""
|
|
|
|
loss: torch.FloatTensor | None = None
|
|
mc_loss: torch.FloatTensor | None = None
|
|
logits: torch.FloatTensor | None = None
|
|
mc_logits: torch.FloatTensor | None = None
|
|
past_key_values: Cache | None = None
|
|
hidden_states: tuple[torch.FloatTensor] | None = None
|
|
attentions: tuple[torch.FloatTensor] | None = None
|
|
|
|
|
|
@auto_docstring
|
|
class GPT2Model(GPT2PreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.embed_dim = config.hidden_size
|
|
|
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
|
|
|
self.drop = nn.Dropout(config.embd_pdrop)
|
|
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
|
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
|
|
|
self.gradient_checkpointing = False
|
|
self._attn_implementation = config._attn_implementation
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.wte
|
|
|
|
def set_input_embeddings(self, new_embeddings):
|
|
self.wte = new_embeddings
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
token_type_ids: torch.LongTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
encoder_attention_mask: torch.FloatTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
return_dict: bool | None = None,
|
|
**kwargs,
|
|
) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
|
|
r"""
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
|
`past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
|
|
sequence tokens in the vocabulary.
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
|
`input_ids`.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
batch_size = input_ids.shape[0]
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
batch_size = inputs_embeds.shape[0]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
if token_type_ids is not None:
|
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
)
|
|
use_cache = False
|
|
|
|
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
|
|
if use_cache:
|
|
if past_key_values is None:
|
|
past_key_values = DynamicCache(config=self.config)
|
|
|
|
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
|
|
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config))
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.wte(input_ids)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
position_embeds = self.wpe(position_ids)
|
|
hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
|
|
|
|
# Attention mask.
|
|
# ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
|
|
if attention_mask is not None and attention_mask.ndim < 4:
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
|
|
causal_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
position_ids=position_ids,
|
|
)
|
|
|
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False
|
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
|
if encoder_attention_mask is None:
|
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
|
if _use_sdpa:
|
|
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
|
mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
)
|
|
elif not is_flash_attention_requested(requested_attention_implementation=self._attn_implementation):
|
|
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
else:
|
|
encoder_attention_mask = None
|
|
|
|
if token_type_ids is not None:
|
|
token_type_embeds = self.wte(token_type_ids)
|
|
hidden_states = hidden_states + token_type_embeds
|
|
|
|
hidden_states = self.drop(hidden_states)
|
|
|
|
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
|
|
|
all_self_attentions = () if output_attentions else None
|
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
|
all_hidden_states = () if output_hidden_states else None
|
|
for i, block in enumerate(self.h):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
outputs = block(
|
|
hidden_states,
|
|
past_key_values if not (self.gradient_checkpointing and self.training) else None,
|
|
cache_position,
|
|
causal_mask,
|
|
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (outputs[1],)
|
|
if self.config.add_cross_attention:
|
|
all_cross_attentions = all_cross_attentions + (outputs[2],)
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
hidden_states = hidden_states.view(output_shape)
|
|
# Add last hidden state
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
past_key_values = past_key_values if use_cache else None
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
|
|
if v is not None
|
|
)
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=past_key_values,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
|
embeddings).
|
|
"""
|
|
)
|
|
class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
|
_tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.transformer = GPT2Model(config)
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
token_type_ids: torch.LongTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
encoder_attention_mask: torch.FloatTensor | None = None,
|
|
labels: torch.LongTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
return_dict: bool | None = None,
|
|
logits_to_keep: int | torch.Tensor = 0,
|
|
**kwargs,
|
|
) -> tuple | CausalLMOutputWithCrossAttentions:
|
|
r"""
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
|
`past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
|
|
sequence tokens in the vocabulary.
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
|
`input_ids`.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
transformer_outputs = self.transformer(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
# Flatten the tokens
|
|
loss = self.loss_function(
|
|
logits,
|
|
labels,
|
|
vocab_size=self.config.vocab_size,
|
|
**kwargs,
|
|
)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + transformer_outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return CausalLMOutputWithCrossAttentions(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
attentions=transformer_outputs.attentions,
|
|
cross_attentions=transformer_outputs.cross_attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
|
|
RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
|
|
input embeddings, the classification head takes as input the input of a specified classification token index in the
|
|
input sequence).
|
|
"""
|
|
)
|
|
class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
|
|
_tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
config.num_labels = 1
|
|
self.transformer = GPT2Model(config)
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
self.multiple_choice_head = GPT2SequenceSummary(config)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
cache_position: torch.LongTensor | None = None,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
token_type_ids: torch.LongTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
mc_token_ids: torch.LongTensor | None = None,
|
|
labels: torch.LongTensor | None = None,
|
|
mc_labels: torch.LongTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
return_dict: bool | None = None,
|
|
**kwargs,
|
|
) -> tuple | GPT2DoubleHeadsModelOutput:
|
|
r"""
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
|
`past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
|
|
sequence tokens in the vocabulary.
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
|
`input_ids`.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
|
|
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
|
|
1]`.
|
|
labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
`labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
|
|
`-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
|
|
mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
|
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
|
|
where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> import torch
|
|
>>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
|
>>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
|
|
|
|
>>> # Add a [CLS] to the vocabulary (we should train it also!)
|
|
>>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
|
|
>>> # Update the model embeddings with the new vocabulary size
|
|
>>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
|
|
|
|
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
|
|
>>> encoded_choices = [tokenizer.encode(s) for s in choices]
|
|
>>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
|
|
|
|
>>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
|
|
>>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
|
|
|
|
>>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
|
|
>>> lm_logits = outputs.logits
|
|
>>> mc_logits = outputs.mc_logits
|
|
```"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
transformer_outputs = self.transformer(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
cache_position=cache_position,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = transformer_outputs[0]
|
|
|
|
lm_logits = self.lm_head(hidden_states)
|
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
|
|
|
mc_loss = None
|
|
if mc_labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
|
|
lm_loss = None
|
|
if labels is not None:
|
|
labels = labels.to(lm_logits.device)
|
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
loss_fct = CrossEntropyLoss()
|
|
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (lm_logits, mc_logits) + transformer_outputs[1:]
|
|
if mc_loss is not None:
|
|
output = (mc_loss,) + output
|
|
return ((lm_loss,) + output) if lm_loss is not None else output
|
|
|
|
return GPT2DoubleHeadsModelOutput(
|
|
loss=lm_loss,
|
|
mc_loss=mc_loss,
|
|
logits=lm_logits,
|
|
mc_logits=mc_logits,
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
attentions=transformer_outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring(
|
|
custom_intro="""
|
|
The GPT2 Model transformer with a sequence classification head on top (linear layer).
|
|
|
|
[`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
|
(e.g. GPT-1) do.
|
|
|
|
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
|
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
|
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
|
each row of the batch).
|
|
"""
|
|
)
|
|
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.transformer = GPT2Model(config)
|
|
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
token_type_ids: torch.LongTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
labels: torch.LongTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
return_dict: bool | None = None,
|
|
**kwargs,
|
|
) -> tuple | SequenceClassifierOutputWithPast:
|
|
r"""
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
|
`past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
|
|
sequence tokens in the vocabulary.
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
|
`input_ids`.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
transformer_outputs = self.transformer(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
hidden_states = transformer_outputs[0]
|
|
logits = self.score(hidden_states)
|
|
|
|
if input_ids is not None:
|
|
batch_size, sequence_length = input_ids.shape[:2]
|
|
else:
|
|
batch_size, sequence_length = inputs_embeds.shape[:2]
|
|
|
|
if self.config.pad_token_id is None and batch_size != 1:
|
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
|
if self.config.pad_token_id is None:
|
|
last_non_pad_token = -1
|
|
elif input_ids is not None:
|
|
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
|
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
|
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
|
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
|
else:
|
|
last_non_pad_token = -1
|
|
logger.warning_once(
|
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
|
)
|
|
|
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(pooled_logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
loss = loss_fct(pooled_logits, labels)
|
|
if not return_dict:
|
|
output = (pooled_logits,) + transformer_outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return SequenceClassifierOutputWithPast(
|
|
loss=loss,
|
|
logits=pooled_logits,
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
attentions=transformer_outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class GPT2ForTokenClassification(GPT2PreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = GPT2Model(config)
|
|
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
|
classifier_dropout = config.classifier_dropout
|
|
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
|
classifier_dropout = config.hidden_dropout
|
|
else:
|
|
classifier_dropout = 0.1
|
|
self.dropout = nn.Dropout(classifier_dropout)
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
token_type_ids: torch.LongTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
labels: torch.LongTensor | None = None,
|
|
use_cache: bool | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
return_dict: bool | None = None,
|
|
**kwargs,
|
|
) -> tuple | TokenClassifierOutput:
|
|
r"""
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
|
`past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
|
|
sequence tokens in the vocabulary.
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
|
`input_ids`.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
transformer_outputs = self.transformer(
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = transformer_outputs[0]
|
|
hidden_states = self.dropout(hidden_states)
|
|
logits = self.classifier(hidden_states)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
labels = labels.to(logits.device)
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits,) + transformer_outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TokenClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
attentions=transformer_outputs.attentions,
|
|
)
|
|
|
|
|
|
@auto_docstring
|
|
class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.transformer = GPT2Model(config)
|
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
attention_mask: torch.FloatTensor | None = None,
|
|
token_type_ids: torch.LongTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
start_positions: torch.LongTensor | None = None,
|
|
end_positions: torch.LongTensor | None = None,
|
|
output_attentions: bool | None = None,
|
|
output_hidden_states: bool | None = None,
|
|
return_dict: bool | None = None,
|
|
**kwargs,
|
|
) -> tuple | QuestionAnsweringModelOutput:
|
|
r"""
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
|
`past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
|
|
sequence tokens in the vocabulary.
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
|
`input_ids`.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.transformer(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.qa_outputs(sequence_output)
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1).contiguous()
|
|
end_logits = end_logits.squeeze(-1).contiguous()
|
|
|
|
total_loss = None
|
|
if start_positions is not None and end_positions is not None:
|
|
# If we are on multi-GPU, split add a dimension
|
|
if len(start_positions.size()) > 1:
|
|
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
|
if len(end_positions.size()) > 1:
|
|
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
|
ignored_index = start_logits.size(1)
|
|
start_positions = start_positions.clamp(0, ignored_index)
|
|
end_positions = end_positions.clamp(0, ignored_index)
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
|
start_loss = loss_fct(start_logits, start_positions)
|
|
end_loss = loss_fct(end_logits, end_positions)
|
|
total_loss = (start_loss + end_loss) / 2
|
|
|
|
if not return_dict:
|
|
output = (start_logits, end_logits) + outputs[2:]
|
|
return ((total_loss,) + output) if total_loss is not None else output
|
|
|
|
return QuestionAnsweringModelOutput(
|
|
loss=total_loss,
|
|
start_logits=start_logits,
|
|
end_logits=end_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"GPT2DoubleHeadsModel",
|
|
"GPT2ForQuestionAnswering",
|
|
"GPT2ForSequenceClassification",
|
|
"GPT2ForTokenClassification",
|
|
"GPT2LMHeadModel",
|
|
"GPT2Model",
|
|
"GPT2PreTrainedModel",
|
|
]
|