You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
311 lines
17 KiB
311 lines
17 KiB
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from typing import Union
|
|
|
|
from transformers import TrainingArguments as TransformersTrainingArguments
|
|
from transformers.training_args import ParallelMode
|
|
from transformers.utils import ExplicitEnum
|
|
|
|
from sentence_transformers.sampler import DefaultBatchSampler, MultiDatasetDefaultBatchSampler
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BatchSamplers(ExplicitEnum):
|
|
"""
|
|
Stores the acceptable string identifiers for batch samplers.
|
|
|
|
The batch sampler is responsible for determining how samples are grouped into batches during training.
|
|
Valid options are:
|
|
|
|
- ``BatchSamplers.BATCH_SAMPLER``: **[default]** Uses :class:`~sentence_transformers.sampler.DefaultBatchSampler`, the default
|
|
PyTorch batch sampler.
|
|
- ``BatchSamplers.NO_DUPLICATES``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler`,
|
|
ensuring no duplicate samples in a batch. Recommended for losses that use in-batch negatives, such as:
|
|
|
|
- :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss`
|
|
- :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss`
|
|
- :class:`~sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss`
|
|
- :class:`~sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss`
|
|
- :class:`~sentence_transformers.losses.MegaBatchMarginLoss`
|
|
- :class:`~sentence_transformers.losses.GISTEmbedLoss`
|
|
- :class:`~sentence_transformers.losses.CachedGISTEmbedLoss`
|
|
- ``BatchSamplers.GROUP_BY_LABEL``: Uses :class:`~sentence_transformers.sampler.GroupByLabelBatchSampler`,
|
|
ensuring that each batch has 2+ samples from the same label. Recommended for losses that require multiple
|
|
samples from the same label, such as:
|
|
|
|
- :class:`~sentence_transformers.losses.BatchAllTripletLoss`
|
|
- :class:`~sentence_transformers.losses.BatchHardSoftMarginTripletLoss`
|
|
- :class:`~sentence_transformers.losses.BatchHardTripletLoss`
|
|
- :class:`~sentence_transformers.losses.BatchSemiHardTripletLoss`
|
|
|
|
If you want to use a custom batch sampler, then you can subclass
|
|
:class:`~sentence_transformers.sampler.DefaultBatchSampler` and pass the class (not an instance) to the
|
|
``batch_sampler`` argument in :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`
|
|
(or :class:`~sentence_transformers.cross_encoder.training_args.CrossEncoderTrainingArguments`, etc.).
|
|
Alternatively, you can pass a function that accepts ``dataset``, ``batch_size``, ``drop_last``,
|
|
``valid_label_columns``, ``generator``, and ``seed`` and returns a
|
|
:class:`~sentence_transformers.sampler.DefaultBatchSampler` instance.
|
|
|
|
Usage:
|
|
::
|
|
|
|
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
|
from sentence_transformers.training_args import BatchSamplers
|
|
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
|
from datasets import Dataset
|
|
|
|
model = SentenceTransformer("microsoft/mpnet-base")
|
|
train_dataset = Dataset.from_dict({
|
|
"anchor": ["It's nice weather outside today.", "He drove to work."],
|
|
"positive": ["It's so sunny.", "He took the car to the office."],
|
|
})
|
|
loss = MultipleNegativesRankingLoss(model)
|
|
args = SentenceTransformerTrainingArguments(
|
|
output_dir="checkpoints",
|
|
batch_sampler=BatchSamplers.NO_DUPLICATES,
|
|
)
|
|
trainer = SentenceTransformerTrainer(
|
|
model=model,
|
|
args=args,
|
|
train_dataset=train_dataset,
|
|
loss=loss,
|
|
)
|
|
trainer.train()
|
|
"""
|
|
|
|
BATCH_SAMPLER = "batch_sampler"
|
|
NO_DUPLICATES = "no_duplicates"
|
|
GROUP_BY_LABEL = "group_by_label"
|
|
|
|
|
|
class MultiDatasetBatchSamplers(ExplicitEnum):
|
|
"""
|
|
Stores the acceptable string identifiers for multi-dataset batch samplers.
|
|
|
|
The multi-dataset batch sampler is responsible for determining in what order batches are sampled from multiple
|
|
datasets during training. Valid options are:
|
|
|
|
- ``MultiDatasetBatchSamplers.ROUND_ROBIN``: Uses :class:`~sentence_transformers.sampler.RoundRobinBatchSampler`,
|
|
which uses round-robin sampling from each dataset until one is exhausted.
|
|
With this strategy, it's likely that not all samples from each dataset are used, but each dataset is sampled
|
|
from equally.
|
|
- ``MultiDatasetBatchSamplers.PROPORTIONAL``: **[default]** Uses :class:`~sentence_transformers.sampler.ProportionalBatchSampler`,
|
|
which samples from each dataset in proportion to its size.
|
|
With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.
|
|
|
|
If you want to use a custom multi-dataset batch sampler, then you can subclass
|
|
:class:`~sentence_transformers.sampler.MultiDatasetDefaultBatchSampler` and pass the class (not an instance) to the
|
|
``multi_dataset_batch_sampler`` argument in :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`.
|
|
(or :class:`~sentence_transformers.cross_encoder.training_args.CrossEncoderTrainingArguments`, etc.). Alternatively,
|
|
you can pass a function that accepts ``dataset`` (a :class:`~torch.utils.data.ConcatDataset`), ``batch_samplers``
|
|
(i.e. a list of batch sampler for each of the datasets in the :class:`~torch.utils.data.ConcatDataset`), ``generator``,
|
|
and ``seed`` and returns a :class:`~sentence_transformers.sampler.MultiDatasetDefaultBatchSampler` instance.
|
|
|
|
Usage:
|
|
::
|
|
|
|
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
|
from sentence_transformers.training_args import MultiDatasetBatchSamplers
|
|
from sentence_transformers.losses import CoSENTLoss
|
|
from datasets import Dataset, DatasetDict
|
|
|
|
model = SentenceTransformer("microsoft/mpnet-base")
|
|
train_general = Dataset.from_dict({
|
|
"sentence_A": ["It's nice weather outside today.", "He drove to work."],
|
|
"sentence_B": ["It's so sunny.", "He took the car to the bank."],
|
|
"score": [0.9, 0.4],
|
|
})
|
|
train_medical = Dataset.from_dict({
|
|
"sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."],
|
|
"sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."],
|
|
"score": [0.8, 0.6, 0.7],
|
|
})
|
|
train_legal = Dataset.from_dict({
|
|
"sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."],
|
|
"sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."],
|
|
"score": [0.7, 0.8],
|
|
})
|
|
train_dataset = DatasetDict({
|
|
"general": train_general,
|
|
"medical": train_medical,
|
|
"legal": train_legal,
|
|
})
|
|
|
|
loss = CoSENTLoss(model)
|
|
args = SentenceTransformerTrainingArguments(
|
|
output_dir="checkpoints",
|
|
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
|
|
)
|
|
trainer = SentenceTransformerTrainer(
|
|
model=model,
|
|
args=args,
|
|
train_dataset=train_dataset,
|
|
loss=loss,
|
|
)
|
|
trainer.train()
|
|
"""
|
|
|
|
ROUND_ROBIN = "round_robin" # Round-robin sampling from each dataset
|
|
PROPORTIONAL = "proportional" # Sample from each dataset in proportion to its size [default]
|
|
|
|
|
|
@dataclass
|
|
class SentenceTransformerTrainingArguments(TransformersTrainingArguments):
|
|
r"""
|
|
SentenceTransformerTrainingArguments extends :class:`~transformers.TrainingArguments` with additional arguments
|
|
specific to Sentence Transformers. See :class:`~transformers.TrainingArguments` for the complete list of
|
|
available arguments.
|
|
|
|
Args:
|
|
output_dir (`str`):
|
|
The output directory where the model checkpoints will be written.
|
|
prompts (`Union[Dict[str, Dict[str, str]], Dict[str, str], str]`, *optional*):
|
|
The prompts to use for each column in the training, evaluation and test datasets. Four formats are accepted:
|
|
|
|
1. `str`: A single prompt to use for all columns in the datasets, regardless of whether the training/evaluation/test
|
|
datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`.
|
|
2. `Dict[str, str]`: A dictionary mapping column names to prompts, regardless of whether the training/evaluation/test
|
|
datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`.
|
|
3. `Dict[str, str]`: A dictionary mapping dataset names to prompts. This should only be used if your training/evaluation/test
|
|
datasets are a :class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`.
|
|
4. `Dict[str, Dict[str, str]]`: A dictionary mapping dataset names to dictionaries mapping column names to
|
|
prompts. This should only be used if your training/evaluation/test datasets are a
|
|
:class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`.
|
|
|
|
batch_sampler (Union[:class:`~sentence_transformers.training_args.BatchSamplers`, `str`, :class:`~sentence_transformers.sampler.DefaultBatchSampler`, Callable[[...], :class:`~sentence_transformers.sampler.DefaultBatchSampler`]], *optional*):
|
|
The batch sampler to use. See :class:`~sentence_transformers.training_args.BatchSamplers` for valid options.
|
|
Defaults to ``BatchSamplers.BATCH_SAMPLER``.
|
|
multi_dataset_batch_sampler (Union[:class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers`, `str`, :class:`~sentence_transformers.sampler.MultiDatasetDefaultBatchSampler`, Callable[[...], :class:`~sentence_transformers.sampler.MultiDatasetDefaultBatchSampler`]], *optional*):
|
|
The multi-dataset batch sampler to use. See :class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers`
|
|
for valid options. Defaults to ``MultiDatasetBatchSamplers.PROPORTIONAL``.
|
|
router_mapping (`Dict[str, str] | Dict[str, Dict[str, str]]`, *optional*):
|
|
A mapping of dataset column names to Router routes, like "query" or "document". This is used to specify
|
|
which Router submodule to use for each dataset. Two formats are accepted:
|
|
|
|
1. `Dict[str, str]`: A mapping of column names to routes.
|
|
2. `Dict[str, Dict[str, str]]`: A mapping of dataset names to a mapping of column names to routes for
|
|
multi-dataset training/evaluation.
|
|
learning_rate_mapping (`Dict[str, float] | None`, *optional*):
|
|
A mapping of parameter name regular expressions to learning rates. This allows you to set different
|
|
learning rates for different parts of the model, e.g., `{'SparseStaticEmbedding\.*': 1e-3}` for the
|
|
SparseStaticEmbedding module. This is useful when you want to fine-tune specific parts of the model
|
|
with different learning rates.
|
|
"""
|
|
|
|
# Sometimes users will pass in a `str` repr of a dict in the CLI
|
|
# We need to track what fields those can be. Each time a new arg
|
|
# has a dict type, it must be added to this list.
|
|
# Important: These should be typed with Optional[Union[dict,str,...]]
|
|
_VALID_DICT_FIELDS = [
|
|
"accelerator_config",
|
|
"fsdp_config",
|
|
"deepspeed",
|
|
"gradient_checkpointing_kwargs",
|
|
"lr_scheduler_kwargs",
|
|
"prompts",
|
|
"router_mapping",
|
|
"learning_rate_mapping",
|
|
]
|
|
|
|
prompts: Union[str, None, dict[str, str], dict[str, dict[str, str]]] = field( # noqa: UP007
|
|
default=None,
|
|
metadata={
|
|
"help": "The prompts to use for each column in the datasets. "
|
|
"Either 1) a single string prompt, 2) a mapping of column names to prompts, 3) a mapping of dataset names "
|
|
"to prompts, or 4) a mapping of dataset names to a mapping of column names to prompts."
|
|
},
|
|
)
|
|
batch_sampler: Union[BatchSamplers, str, DefaultBatchSampler, Callable[..., DefaultBatchSampler]] = field( # noqa: UP007
|
|
default=BatchSamplers.BATCH_SAMPLER, metadata={"help": "The batch sampler to use."}
|
|
)
|
|
multi_dataset_batch_sampler: Union[ # noqa: UP007
|
|
MultiDatasetBatchSamplers, str, MultiDatasetDefaultBatchSampler, Callable[..., MultiDatasetDefaultBatchSampler]
|
|
] = field(
|
|
default=MultiDatasetBatchSamplers.PROPORTIONAL, metadata={"help": "The multi-dataset batch sampler to use."}
|
|
)
|
|
router_mapping: Union[str, None, dict[str, str], dict[str, dict[str, str]]] = field( # noqa: UP007
|
|
default_factory=dict,
|
|
metadata={
|
|
"help": 'A mapping of dataset column names to Router routes, like "query" or "document". '
|
|
"Either 1) a mapping of column names to routes or 2) a mapping of dataset names to a mapping "
|
|
"of column names to routes for multi-dataset training/evaluation. "
|
|
},
|
|
)
|
|
learning_rate_mapping: Union[str, None, dict[str, float]] = field( # noqa: UP007
|
|
default_factory=dict,
|
|
metadata={
|
|
"help": "A mapping of parameter name regular expressions to learning rates. "
|
|
"This allows you to set different learning rates for different parts of the model, e.g., "
|
|
r"{'SparseStaticEmbedding\.*': 1e-3} for the SparseStaticEmbedding module."
|
|
},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
super().__post_init__()
|
|
|
|
self.batch_sampler = (
|
|
BatchSamplers(self.batch_sampler) if isinstance(self.batch_sampler, str) else self.batch_sampler
|
|
)
|
|
self.multi_dataset_batch_sampler = (
|
|
MultiDatasetBatchSamplers(self.multi_dataset_batch_sampler)
|
|
if isinstance(self.multi_dataset_batch_sampler, str)
|
|
else self.multi_dataset_batch_sampler
|
|
)
|
|
|
|
self.router_mapping = self.router_mapping if self.router_mapping is not None else {}
|
|
if isinstance(self.router_mapping, str):
|
|
# Note that we allow a stringified dictionary for router_mapping, but then it should have been
|
|
# parsed by the superclass's `__post_init__` method already
|
|
raise ValueError(
|
|
"The `router_mapping` argument must be a dictionary mapping dataset column names to Router routes, "
|
|
"like 'query' or 'document'. A stringified dictionary also works."
|
|
)
|
|
|
|
self.learning_rate_mapping = self.learning_rate_mapping if self.learning_rate_mapping is not None else {}
|
|
if isinstance(self.learning_rate_mapping, str):
|
|
# Note that we allow a stringified dictionary for learning_rate_mapping, but then it should have been
|
|
# parsed by the superclass's `__post_init__` method already
|
|
raise ValueError(
|
|
"The `learning_rate_mapping` argument must be a dictionary mapping parameter name regular expressions "
|
|
"to learning rates. A stringified dictionary also works."
|
|
)
|
|
|
|
# The `compute_loss` method in `SentenceTransformerTrainer` is overridden to only compute the prediction loss,
|
|
# so we set `prediction_loss_only` to `True` here to avoid
|
|
self.prediction_loss_only = True
|
|
|
|
# Disable broadcasting of buffers to avoid `RuntimeError: one of the variables needed for gradient computation
|
|
# has been modified by an inplace operation.` when training with DDP & a BertModel-based model.
|
|
self.ddp_broadcast_buffers = False
|
|
|
|
if self.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
|
# If output_dir is "unused", then this instance is created to compare training arguments vs the defaults,
|
|
# so we don't have to warn.
|
|
if self.output_dir != "unused":
|
|
logger.warning(
|
|
"Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. "
|
|
"See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information."
|
|
)
|
|
|
|
elif self.parallel_mode == ParallelMode.DISTRIBUTED and not self.dataloader_drop_last:
|
|
# If output_dir is "unused", then this instance is created to compare training arguments vs the defaults,
|
|
# so we don't have to warn.
|
|
if self.output_dir != "unused":
|
|
logger.warning(
|
|
"When using DistributedDataParallel (DDP), it is recommended to set `dataloader_drop_last=True` to avoid hanging issues with an uneven last batch. "
|
|
"Setting `dataloader_drop_last=True`."
|
|
)
|
|
self.dataloader_drop_last = True
|
|
|
|
def to_dict(self):
|
|
training_args_dict = super().to_dict()
|
|
if callable(training_args_dict["batch_sampler"]):
|
|
del training_args_dict["batch_sampler"]
|
|
if callable(training_args_dict["multi_dataset_batch_sampler"]):
|
|
del training_args_dict["multi_dataset_batch_sampler"]
|
|
return training_args_dict
|