You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
9.0 KiB
179 lines
9.0 KiB
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SentenceTransformerDataCollator:
|
|
"""Collator for a SentenceTransformers model.
|
|
This encodes the text columns to {column}_input_ids and {column}_attention_mask columns.
|
|
This works with the two text dataset that is used as the example in the training overview:
|
|
https://www.sbert.net/docs/sentence_transformer/training_overview.html
|
|
|
|
It is important that the columns are in the expected order. For example, if your dataset has columns
|
|
"answer", "question" in that order, then the MultipleNegativesRankingLoss will consider
|
|
"answer" as the anchor and "question" as the positive, and it will (unexpectedly) optimize for
|
|
"given the answer, what is the question?".
|
|
"""
|
|
|
|
tokenize_fn: Callable
|
|
valid_label_columns: list[str] = field(default_factory=lambda: ["label", "labels", "score", "scores"])
|
|
router_mapping: dict[str, str] | dict[str, dict[str, str]] | None = field(default_factory=dict, repr=False)
|
|
prompts: dict[str, str] | dict[str, dict[str, str]] | None = field(default_factory=dict, repr=False)
|
|
include_prompt_lengths: bool = field(default=False, repr=False)
|
|
all_special_ids: set[int] = field(default_factory=set, repr=False)
|
|
|
|
_prompt_length_mapping: dict[str, int] = field(default_factory=dict, init=False, repr=False)
|
|
_warned_columns: set[tuple[str]] = field(default_factory=set, init=False, repr=False)
|
|
|
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
|
|
column_names = list(features[0].keys())
|
|
|
|
# We should always be able to return a loss, label or not:
|
|
batch = {}
|
|
|
|
if "dataset_name" in column_names:
|
|
column_names.remove("dataset_name")
|
|
batch["dataset_name"] = features[0]["dataset_name"]
|
|
|
|
if tuple(column_names) not in self._warned_columns:
|
|
self.maybe_warn_about_column_order(column_names)
|
|
|
|
# Extract the label column if it exists
|
|
for label_column in self.valid_label_columns:
|
|
if label_column in column_names:
|
|
batch["label"] = torch.tensor([row[label_column] for row in features])
|
|
column_names.remove(label_column)
|
|
break
|
|
|
|
router_mapping = self.router_mapping
|
|
# If the router_mapping is a nested dict, then the outer keys are the column names, and we should
|
|
# grab the inner mapping for the specific dataset if it exists.
|
|
if (
|
|
router_mapping
|
|
and isinstance(router_mapping, dict)
|
|
and isinstance(next(iter(router_mapping.values())), dict)
|
|
):
|
|
if "dataset_name" in batch and batch["dataset_name"] in router_mapping:
|
|
# Use the mapping for the specific dataset
|
|
router_mapping = router_mapping[batch["dataset_name"]]
|
|
else:
|
|
router_mapping = {}
|
|
|
|
prompts = self.prompts
|
|
if prompts and isinstance(prompts, dict):
|
|
# If the prompts are a mapping, we should check if the outer keys are dataset names.
|
|
is_multi_dataset = "dataset_name" in batch
|
|
if is_multi_dataset and batch["dataset_name"] in prompts:
|
|
# Use the prompts for the specific dataset
|
|
prompts = prompts[batch["dataset_name"]]
|
|
elif isinstance(next(iter(prompts.values())), dict):
|
|
# If the prompts are a nested dictionary, but we are not in a multi-dataset setting,
|
|
# we should raise an error. If we are in a multi-dataset setting, but this dataset
|
|
# does not have prompts, we use an empty dictionary to denote no prompt.
|
|
if not is_multi_dataset:
|
|
raise ValueError(
|
|
"The prompts provided to the trainer are a nested dictionary. In this setting, the first "
|
|
"level of the dictionary should map to dataset names and the second level to column names. "
|
|
"However, as the provided dataset is a not a DatasetDict, no dataset names can be inferred. "
|
|
f"The keys to the provided prompts dictionary are {list(prompts.keys())!r}"
|
|
)
|
|
else:
|
|
prompts = {}
|
|
|
|
for column_name in column_names:
|
|
# Users can specify a router_mapping via the training arguments, which maps column names to "task types",
|
|
# useful for the Router module (among others). This has to be provided to the tokenization function.
|
|
task = router_mapping.get(column_name, None)
|
|
|
|
# Get the string prompt for the column, if it exists.
|
|
prompt = None
|
|
if isinstance(prompts, str):
|
|
prompt = prompts
|
|
elif isinstance(prompts, dict) and column_name in prompts:
|
|
prompt = prompts[column_name]
|
|
|
|
# If a prompt is provided, we prepend it to the column values. Some Pooling setups require removing the
|
|
# prompt tokens from the pooled embeddings, so we also store the prompt length which can be used for that.
|
|
if prompt:
|
|
if self.include_prompt_lengths:
|
|
prompt_length = self._get_prompt_length(prompt, task=task)
|
|
if prompt_length is not None:
|
|
batch[f"{column_name}_prompt_length"] = torch.tensor(
|
|
[prompt_length] * len(features), dtype=torch.int
|
|
)
|
|
inputs = [prompt + row[column_name] for row in features]
|
|
else:
|
|
inputs = [row[column_name] for row in features]
|
|
|
|
tokenized = self.tokenize_fn(inputs, task=task)
|
|
for key, value in tokenized.items():
|
|
batch[f"{column_name}_{key}"] = value
|
|
|
|
return batch
|
|
|
|
def _get_prompt_length(self, prompt: str, task: str | None = None) -> int:
|
|
if (prompt, task) in self._prompt_length_mapping:
|
|
return self._prompt_length_mapping[(prompt, task)]
|
|
|
|
tokenized_prompt = self.tokenize_fn([prompt], task=task)
|
|
if "input_ids" not in tokenized_prompt:
|
|
# If the tokenizer does not return input_ids, we cannot determine the prompt length.
|
|
# This can happen with some tokenizers that do not use input_ids.
|
|
return None
|
|
prompt_length = tokenized_prompt["input_ids"].shape[-1]
|
|
# If the tokenizer adds a special EOS token, we do not count it as part of the prompt length.
|
|
# This is to ensure that the prompt length does not include the EOS token.
|
|
last_token = tokenized_prompt["input_ids"][..., -1].item()
|
|
if last_token in self.all_special_ids:
|
|
prompt_length -= 1
|
|
|
|
self._prompt_length_mapping[(prompt, task)] = prompt_length
|
|
return prompt_length
|
|
|
|
def maybe_warn_about_column_order(self, column_names: list[str]) -> None:
|
|
"""Warn the user if the columns are likely not in the expected order."""
|
|
# A mapping from common column names to the expected index in the dataset
|
|
column_name_to_expected_idx = {
|
|
"anchor": 0,
|
|
"positive": 1,
|
|
"negative": 2,
|
|
"question": 0,
|
|
"answer": 1,
|
|
"query": 0,
|
|
"response": 1,
|
|
"hypothesis": 0,
|
|
"entailment": 1,
|
|
"contradiction": 2,
|
|
}
|
|
for column_name, expected_idx in column_name_to_expected_idx.items():
|
|
if column_name in column_names and column_names.index(column_name) != expected_idx:
|
|
if column_name in ("anchor", "positive", "negative"):
|
|
proposed_fix_columns = ["anchor", "positive", "negative"]
|
|
elif column_name in ("question", "answer"):
|
|
proposed_fix_columns = ["question", "answer"]
|
|
elif column_name in ("query", "response"):
|
|
proposed_fix_columns = ["query", "response"]
|
|
elif column_name in ("hypothesis", "entailment", "contradiction"):
|
|
proposed_fix_columns = ["hypothesis", "entailment", "contradiction"]
|
|
|
|
logger.warning(
|
|
f"Column {column_name!r} is at index {column_names.index(column_name)}, whereas "
|
|
f"a column with this name is usually expected at index {expected_idx}. Note that the column "
|
|
"order can be important for some losses, e.g. MultipleNegativesRankingLoss will always "
|
|
"consider the first column as the anchor and the second as the positive, regardless of "
|
|
"the dataset column names. Consider renaming the columns to match the expected order, e.g.:\n"
|
|
f"dataset = dataset.select_columns({proposed_fix_columns})"
|
|
)
|
|
# We only need to warn once per list of column names to prevent spamming the user
|
|
break
|
|
|
|
self._warned_columns.add(tuple(column_names))
|