You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
71 lines
2.8 KiB
71 lines
2.8 KiB
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from sentence_transformers.models.Module import Module
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WordWeights(Module):
|
|
"""This model can weight word embeddings, for example, with idf-values."""
|
|
|
|
config_keys: list[str] = ["vocab", "word_weights", "unknown_word_weight"]
|
|
|
|
def __init__(self, vocab: list[str], word_weights: dict[str, float], unknown_word_weight: float = 1):
|
|
"""
|
|
Initializes the WordWeights class.
|
|
|
|
Args:
|
|
vocab (List[str]): Vocabulary of the tokenizer.
|
|
word_weights (Dict[str, float]): Mapping of tokens to a float weight value. Word embeddings are multiplied
|
|
by this float value. Tokens in word_weights must not be equal to the vocab (can contain more or less values).
|
|
unknown_word_weight (float, optional): Weight for words in vocab that do not appear in the word_weights lookup.
|
|
These can be, for example, rare words in the vocab where no weight exists. Defaults to 1.
|
|
"""
|
|
super().__init__()
|
|
self.vocab = vocab
|
|
self.word_weights = word_weights
|
|
self.unknown_word_weight = unknown_word_weight
|
|
|
|
weights = []
|
|
num_unknown_words = 0
|
|
for word in vocab:
|
|
weight = unknown_word_weight
|
|
if word in word_weights:
|
|
weight = word_weights[word]
|
|
elif word.lower() in word_weights:
|
|
weight = word_weights[word.lower()]
|
|
else:
|
|
num_unknown_words += 1
|
|
weights.append(weight)
|
|
|
|
logger.info(
|
|
f"{num_unknown_words} of {len(vocab)} words without a weighting value. Set weight to {unknown_word_weight}"
|
|
)
|
|
|
|
self.emb_layer = nn.Embedding(len(vocab), 1)
|
|
self.emb_layer.load_state_dict({"weight": torch.FloatTensor(weights).unsqueeze(1)})
|
|
|
|
def forward(self, features: dict[str, Tensor]):
|
|
attention_mask = features["attention_mask"]
|
|
token_embeddings = features["token_embeddings"]
|
|
|
|
# Compute a weight value for each token
|
|
token_weights_raw = self.emb_layer(features["input_ids"]).squeeze(-1)
|
|
token_weights = token_weights_raw * attention_mask.float()
|
|
token_weights_sum = torch.sum(token_weights, 1)
|
|
|
|
# Multiply embedding by token weight value
|
|
token_weights_expanded = token_weights.unsqueeze(-1).expand(token_embeddings.size())
|
|
token_embeddings = token_embeddings * token_weights_expanded
|
|
|
|
features.update({"token_embeddings": token_embeddings, "token_weights_sum": token_weights_sum})
|
|
return features
|
|
|
|
def save(self, output_path: str, *args, safe_serialization: bool = True, **kwargs) -> None:
|
|
self.save_config(output_path)
|