You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
911 lines
38 KiB
911 lines
38 KiB
# Copyright 2022 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""PyTorch ERNIE model."""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from ... import initialization as init
|
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
|
from ...masking_utils import create_bidirectional_mask, create_causal_mask
|
|
from ...modeling_outputs import (
|
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
CausalLMOutputWithCrossAttentions,
|
|
MaskedLMOutput,
|
|
MultipleChoiceModelOutput,
|
|
NextSentencePredictorOutput,
|
|
QuestionAnsweringModelOutput,
|
|
SequenceClassifierOutput,
|
|
TokenClassifierOutput,
|
|
)
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...processing_utils import Unpack
|
|
from ...utils import TransformersKwargs, auto_docstring, logging
|
|
from ...utils.generic import can_return_tuple, check_model_inputs
|
|
from ..bert.modeling_bert import (
|
|
BertCrossAttention,
|
|
BertEmbeddings,
|
|
BertEncoder,
|
|
BertForMaskedLM,
|
|
BertForMultipleChoice,
|
|
BertForNextSentencePrediction,
|
|
BertForPreTraining,
|
|
BertForPreTrainingOutput,
|
|
BertForQuestionAnswering,
|
|
BertForSequenceClassification,
|
|
BertForTokenClassification,
|
|
BertLayer,
|
|
BertLMHeadModel,
|
|
BertLMPredictionHead,
|
|
BertModel,
|
|
BertPooler,
|
|
BertSelfAttention,
|
|
)
|
|
from .configuration_ernie import ErnieConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class ErnieEmbeddings(BertEmbeddings):
|
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.use_task_id = config.use_task_id
|
|
if config.use_task_id:
|
|
self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor | None = None,
|
|
token_type_ids: torch.LongTensor | None = None,
|
|
task_type_ids: torch.LongTensor | None = None,
|
|
position_ids: torch.LongTensor | None = None,
|
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
past_key_values_length: int = 0,
|
|
) -> torch.Tensor:
|
|
if input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
else:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
|
|
batch_size, seq_length = input_shape
|
|
|
|
if position_ids is None:
|
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
|
|
|
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
|
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
|
# issue #5664
|
|
if token_type_ids is None:
|
|
if hasattr(self, "token_type_ids"):
|
|
# NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
|
|
buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
|
|
buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
|
|
token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
|
|
else:
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
|
|
# .to is better than using _no_split_modules on ErnieEmbeddings as it's the first module and >1/2 the model size
|
|
inputs_embeds = inputs_embeds.to(token_type_embeddings.device)
|
|
embeddings = inputs_embeds + token_type_embeddings
|
|
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
embeddings = embeddings + position_embeddings
|
|
|
|
# add `task_type_id` for ERNIE model
|
|
if self.use_task_id:
|
|
if task_type_ids is None:
|
|
task_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
|
task_type_embeddings = self.task_type_embeddings(task_type_ids)
|
|
embeddings += task_type_embeddings
|
|
|
|
embeddings = self.LayerNorm(embeddings)
|
|
embeddings = self.dropout(embeddings)
|
|
return embeddings
|
|
|
|
|
|
class ErnieSelfAttention(BertSelfAttention):
|
|
pass
|
|
|
|
|
|
class ErnieCrossAttention(BertCrossAttention):
|
|
pass
|
|
|
|
|
|
class ErnieLayer(BertLayer):
|
|
pass
|
|
|
|
|
|
class ErniePooler(BertPooler):
|
|
pass
|
|
|
|
|
|
class ErnieLMPredictionHead(BertLMPredictionHead):
|
|
pass
|
|
|
|
|
|
class ErnieEncoder(BertEncoder):
|
|
pass
|
|
|
|
|
|
@auto_docstring
|
|
class ErniePreTrainedModel(PreTrainedModel):
|
|
config_class = ErnieConfig
|
|
base_model_prefix = "ernie"
|
|
supports_gradient_checkpointing = True
|
|
_supports_flash_attn = True
|
|
_supports_sdpa = True
|
|
_supports_flex_attn = True
|
|
_supports_attention_backend = True
|
|
_can_record_outputs = {
|
|
"hidden_states": ErnieLayer,
|
|
"attentions": ErnieSelfAttention,
|
|
"cross_attentions": ErnieCrossAttention,
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def _init_weights(self, module):
|
|
"""Initialize the weights"""
|
|
super()._init_weights(module)
|
|
if isinstance(module, ErnieLMPredictionHead):
|
|
init.zeros_(module.bias)
|
|
elif isinstance(module, ErnieEmbeddings):
|
|
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
|
init.zeros_(module.token_type_ids)
|
|
|
|
|
|
class ErnieModel(BertModel):
|
|
_no_split_modules = ["ErnieLayer"]
|
|
|
|
def __init__(self, config, add_pooling_layer=True):
|
|
super().__init__(self, config)
|
|
self.config = config
|
|
self.gradient_checkpointing = False
|
|
|
|
self.embeddings = ErnieEmbeddings(config)
|
|
self.encoder = ErnieEncoder(config)
|
|
|
|
self.pooler = ErniePooler(config) if add_pooling_layer else None
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@check_model_inputs
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
task_type_ids: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
encoder_attention_mask: torch.Tensor | None = None,
|
|
past_key_values: Cache | None = None,
|
|
use_cache: bool | None = None,
|
|
cache_position: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
|
|
r"""
|
|
task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Task type embedding is a special embedding to represent the characteristic of different tasks, such as
|
|
word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
|
|
assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
|
|
config.task_type_vocab_size-1]
|
|
"""
|
|
if self.config.is_decoder:
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
else:
|
|
use_cache = False
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
)
|
|
use_cache = False
|
|
|
|
if use_cache and past_key_values is None:
|
|
past_key_values = (
|
|
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
|
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
|
else DynamicCache(config=self.config)
|
|
)
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
input_shape = input_ids.size()
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
batch_size, seq_length = input_shape
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
if cache_position is None:
|
|
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
|
|
|
|
embedding_output = self.embeddings(
|
|
input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
token_type_ids=token_type_ids,
|
|
# specific to ernie
|
|
task_type_ids=task_type_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
past_key_values_length=past_key_values_length,
|
|
)
|
|
|
|
attention_mask, encoder_attention_mask = self._create_attention_masks(
|
|
attention_mask=attention_mask,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
embedding_output=embedding_output,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
)
|
|
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
position_ids=position_ids,
|
|
**kwargs,
|
|
)
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
past_key_values=encoder_outputs.past_key_values,
|
|
)
|
|
|
|
# Copied from transformers.models.bert.modeling_bert.BertModel._create_attention_masks
|
|
def _create_attention_masks(
|
|
self,
|
|
attention_mask,
|
|
encoder_attention_mask,
|
|
embedding_output,
|
|
encoder_hidden_states,
|
|
cache_position,
|
|
past_key_values,
|
|
):
|
|
if self.config.is_decoder:
|
|
attention_mask = create_causal_mask(
|
|
config=self.config,
|
|
input_embeds=embedding_output,
|
|
attention_mask=attention_mask,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
)
|
|
else:
|
|
attention_mask = create_bidirectional_mask(
|
|
config=self.config,
|
|
input_embeds=embedding_output,
|
|
attention_mask=attention_mask,
|
|
)
|
|
|
|
if encoder_attention_mask is not None:
|
|
encoder_attention_mask = create_bidirectional_mask(
|
|
config=self.config,
|
|
input_embeds=embedding_output,
|
|
attention_mask=encoder_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
|
|
return attention_mask, encoder_attention_mask
|
|
|
|
|
|
class ErnieForPreTrainingOutput(BertForPreTrainingOutput):
|
|
pass
|
|
|
|
|
|
class ErnieForPreTraining(BertForPreTraining):
|
|
_tied_weights_keys = {
|
|
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
|
"cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight",
|
|
}
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
task_type_ids: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
labels: torch.Tensor | None = None,
|
|
next_sentence_label: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor] | ErnieForPreTrainingOutput:
|
|
r"""
|
|
task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Task type embedding is a special embedding to represent the characteristic of different tasks, such as
|
|
word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
|
|
assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
|
|
config.task_type_vocab_size-1]
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
|
|
the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
|
next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
|
|
pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
|
|
|
|
- 0 indicates sequence B is a continuation of sequence A,
|
|
- 1 indicates sequence B is a random sequence.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, ErnieForPreTraining
|
|
>>> import torch
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
|
|
>>> model = ErnieForPreTraining.from_pretrained("nghuyong/ernie-1.0-base-zh")
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
|
|
>>> prediction_logits = outputs.prediction_logits
|
|
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
|
```
|
|
"""
|
|
outputs = self.ernie(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
task_type_ids=task_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
|
|
sequence_output, pooled_output = outputs[:2]
|
|
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
|
|
|
total_loss = None
|
|
if labels is not None and next_sentence_label is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
|
total_loss = masked_lm_loss + next_sentence_loss
|
|
|
|
return ErnieForPreTrainingOutput(
|
|
loss=total_loss,
|
|
prediction_logits=prediction_scores,
|
|
seq_relationship_logits=seq_relationship_score,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class ErnieForCausalLM(BertLMHeadModel):
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
task_type_ids: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
encoder_attention_mask: torch.Tensor | None = None,
|
|
labels: torch.Tensor | None = None,
|
|
past_key_values: list[torch.Tensor] | None = None,
|
|
use_cache: bool | None = None,
|
|
cache_position: torch.Tensor | None = None,
|
|
logits_to_keep: int | torch.Tensor = 0,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions:
|
|
r"""
|
|
task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Task type embedding is a special embedding to represent the characteristic of different tasks, such as
|
|
word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
|
|
assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
|
|
config.task_type_vocab_size-1]
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
|
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
|
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
|
|
"""
|
|
if labels is not None:
|
|
use_cache = False
|
|
|
|
outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.ernie(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
task_type_ids=task_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs.last_hidden_state
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.cls(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
|
|
|
return CausalLMOutputWithCrossAttentions(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
cross_attentions=outputs.cross_attentions,
|
|
)
|
|
|
|
|
|
class ErnieForMaskedLM(BertForMaskedLM):
|
|
_tied_weights_keys = {
|
|
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
|
"cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight",
|
|
}
|
|
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
task_type_ids: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
encoder_hidden_states: torch.Tensor | None = None,
|
|
encoder_attention_mask: torch.Tensor | None = None,
|
|
labels: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor] | MaskedLMOutput:
|
|
r"""
|
|
task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Task type embedding is a special embedding to represent the characteristic of different tasks, such as
|
|
word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
|
|
assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
|
|
config.task_type_vocab_size-1]
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
|
"""
|
|
outputs = self.ernie(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
task_type_ids=task_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.cls(sequence_output)
|
|
|
|
masked_lm_loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
|
|
|
return MaskedLMOutput(
|
|
loss=masked_lm_loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class ErnieForNextSentencePrediction(BertForNextSentencePrediction):
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
task_type_ids: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
labels: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor] | NextSentencePredictorOutput:
|
|
r"""
|
|
task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Task type embedding is a special embedding to represent the characteristic of different tasks, such as
|
|
word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
|
|
assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
|
|
config.task_type_vocab_size-1]
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
|
(see `input_ids` docstring). Indices should be in `[0, 1]`:
|
|
|
|
- 0 indicates sequence B is a continuation of sequence A,
|
|
- 1 indicates sequence B is a random sequence.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, ErnieForNextSentencePrediction
|
|
>>> import torch
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
|
|
>>> model = ErnieForNextSentencePrediction.from_pretrained("nghuyong/ernie-1.0-base-zh")
|
|
|
|
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
|
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
|
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
|
|
|
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
|
>>> logits = outputs.logits
|
|
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
|
```
|
|
"""
|
|
|
|
outputs = self.ernie(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
task_type_ids=task_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
seq_relationship_scores = self.cls(pooled_output)
|
|
|
|
next_sentence_loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
|
|
|
|
return NextSentencePredictorOutput(
|
|
loss=next_sentence_loss,
|
|
logits=seq_relationship_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class ErnieForSequenceClassification(BertForSequenceClassification):
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
task_type_ids: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
labels: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor] | SequenceClassifierOutput:
|
|
r"""
|
|
task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Task type embedding is a special embedding to represent the characteristic of different tasks, such as
|
|
word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
|
|
assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
|
|
config.task_type_vocab_size-1]
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
outputs = self.ernie(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
task_type_ids=task_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
loss = loss_fct(logits, labels)
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class ErnieForMultipleChoice(BertForMultipleChoice):
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
task_type_ids: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
labels: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor] | MultipleChoiceModelOutput:
|
|
r"""
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
|
1]`:
|
|
|
|
- 0 corresponds to a *sentence A* token,
|
|
- 1 corresponds to a *sentence B* token.
|
|
|
|
[What are token type IDs?](../glossary#token-type-ids)
|
|
task_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
|
Task type embedding is a special embedding to represent the characteristic of different tasks, such as
|
|
word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
|
|
assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
|
|
config.task_type_vocab_size-1]
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.max_position_embeddings - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
model's internal embedding lookup matrix.
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
|
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
|
`input_ids` above)
|
|
"""
|
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
|
|
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
|
inputs_embeds = (
|
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
|
if inputs_embeds is not None
|
|
else None
|
|
)
|
|
|
|
outputs = self.ernie(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
task_type_ids=task_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
reshaped_logits = logits.view(-1, num_choices)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(reshaped_logits, labels)
|
|
|
|
return MultipleChoiceModelOutput(
|
|
loss=loss,
|
|
logits=reshaped_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class ErnieForTokenClassification(BertForTokenClassification):
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
task_type_ids: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
labels: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor] | TokenClassifierOutput:
|
|
r"""
|
|
task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Task type embedding is a special embedding to represent the characteristic of different tasks, such as
|
|
word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
|
|
assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
|
|
config.task_type_vocab_size-1]
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
|
"""
|
|
outputs = self.ernie(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
task_type_ids=task_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
sequence_output = self.dropout(sequence_output)
|
|
logits = self.classifier(sequence_output)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
return TokenClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class ErnieForQuestionAnswering(BertForQuestionAnswering):
|
|
@can_return_tuple
|
|
@auto_docstring
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
task_type_ids: torch.Tensor | None = None,
|
|
position_ids: torch.Tensor | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
start_positions: torch.Tensor | None = None,
|
|
end_positions: torch.Tensor | None = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
|
|
r"""
|
|
task_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Task type embedding is a special embedding to represent the characteristic of different tasks, such as
|
|
word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
|
|
assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
|
|
config.task_type_vocab_size-1]
|
|
"""
|
|
outputs = self.ernie(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
task_type_ids=task_type_ids,
|
|
position_ids=position_ids,
|
|
inputs_embeds=inputs_embeds,
|
|
return_dict=True,
|
|
**kwargs,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.qa_outputs(sequence_output)
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1).contiguous()
|
|
end_logits = end_logits.squeeze(-1).contiguous()
|
|
|
|
total_loss = None
|
|
if start_positions is not None and end_positions is not None:
|
|
# If we are on multi-GPU, split add a dimension
|
|
if len(start_positions.size()) > 1:
|
|
start_positions = start_positions.squeeze(-1)
|
|
if len(end_positions.size()) > 1:
|
|
end_positions = end_positions.squeeze(-1)
|
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
|
ignored_index = start_logits.size(1)
|
|
start_positions = start_positions.clamp(0, ignored_index)
|
|
end_positions = end_positions.clamp(0, ignored_index)
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
|
start_loss = loss_fct(start_logits, start_positions)
|
|
end_loss = loss_fct(end_logits, end_positions)
|
|
total_loss = (start_loss + end_loss) / 2
|
|
|
|
return QuestionAnsweringModelOutput(
|
|
loss=total_loss,
|
|
start_logits=start_logits,
|
|
end_logits=end_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"ErnieForCausalLM",
|
|
"ErnieForMaskedLM",
|
|
"ErnieForMultipleChoice",
|
|
"ErnieForNextSentencePrediction",
|
|
"ErnieForPreTraining",
|
|
"ErnieForQuestionAnswering",
|
|
"ErnieForSequenceClassification",
|
|
"ErnieForTokenClassification",
|
|
"ErnieModel",
|
|
"ErniePreTrainedModel",
|
|
]
|