You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
272 lines
11 KiB
272 lines
11 KiB
from __future__ import annotations
|
|
|
|
import csv
|
|
import logging
|
|
import os
|
|
from typing import TYPE_CHECKING, Literal
|
|
|
|
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
|
|
from sentence_transformers.readers import InputExample
|
|
from sentence_transformers.similarity_functions import SimilarityFunction
|
|
from sentence_transformers.util import (
|
|
pairwise_cos_sim,
|
|
pairwise_dot_score,
|
|
pairwise_euclidean_sim,
|
|
pairwise_manhattan_sim,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
import numpy as np
|
|
|
|
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TripletEvaluator(SentenceEvaluator):
|
|
"""
|
|
Evaluate a model based on a triplet: (sentence, positive_example, negative_example).
|
|
Checks if ``similarity(sentence, positive_example) > similarity(sentence, negative_example) + margin``.
|
|
|
|
Args:
|
|
anchors (List[str]): Sentences to check similarity to. (e.g. a query)
|
|
positives (List[str]): List of positive sentences
|
|
negatives (List[str]): List of negative sentences
|
|
main_similarity_function (Union[str, SimilarityFunction], optional):
|
|
The similarity function to use. If not specified, use cosine similarity,
|
|
dot product, Euclidean, and Manhattan similarity. Defaults to None.
|
|
margin (Union[float, Dict[str, float]], optional): Margins for various similarity metrics.
|
|
If a float is provided, it will be used as the margin for all similarity metrics.
|
|
If a dictionary is provided, the keys should be 'cosine', 'dot', 'manhattan', and 'euclidean'.
|
|
The value specifies the minimum margin by which the negative sample should be further from
|
|
the anchor than the positive sample. Defaults to None.
|
|
name (str): Name for the output. Defaults to "".
|
|
batch_size (int): Batch size used to compute embeddings. Defaults to 16.
|
|
show_progress_bar (bool): If true, prints a progress bar. Defaults to False.
|
|
write_csv (bool): Write results to a CSV file. Defaults to True.
|
|
truncate_dim (int, optional): The dimension to truncate sentence embeddings to.
|
|
`None` uses the model's current truncation dimension. Defaults to None.
|
|
similarity_fn_names (List[str], optional): List of similarity function names to evaluate.
|
|
If not specified, evaluate using the ``model.similarity_fn_name``.
|
|
Defaults to None.
|
|
|
|
Example:
|
|
::
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
from sentence_transformers.evaluation import TripletEvaluator
|
|
from datasets import load_dataset
|
|
|
|
# Load a model
|
|
model = SentenceTransformer('all-mpnet-base-v2')
|
|
|
|
# Load a dataset with (anchor, positive, negative) triplets
|
|
dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
|
|
|
|
# Initialize the TripletEvaluator using anchors, positives, and negatives
|
|
triplet_evaluator = TripletEvaluator(
|
|
anchors=dataset[:1000]["anchor"],
|
|
positives=dataset[:1000]["positive"],
|
|
negatives=dataset[:1000]["negative"],
|
|
name="all_nli_dev",
|
|
)
|
|
results = triplet_evaluator(model)
|
|
'''
|
|
TripletEvaluator: Evaluating the model on the all-nli-dev dataset:
|
|
Accuracy Cosine Similarity: 95.60%
|
|
'''
|
|
print(triplet_evaluator.primary_metric)
|
|
# => "all_nli_dev_cosine_accuracy"
|
|
print(results[triplet_evaluator.primary_metric])
|
|
# => 0.956
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
anchors: list[str],
|
|
positives: list[str],
|
|
negatives: list[str],
|
|
main_similarity_function: str | SimilarityFunction | None = None,
|
|
margin: float | dict[str, float] | None = None,
|
|
name: str = "",
|
|
batch_size: int = 16,
|
|
show_progress_bar: bool = False,
|
|
write_csv: bool = True,
|
|
truncate_dim: int | None = None,
|
|
similarity_fn_names: list[Literal["cosine", "dot", "euclidean", "manhattan"]] | None = None,
|
|
main_distance_function: str | SimilarityFunction | None = "deprecated",
|
|
):
|
|
super().__init__()
|
|
self.anchors = anchors
|
|
self.positives = positives
|
|
self.negatives = negatives
|
|
self.name = name
|
|
self.truncate_dim = truncate_dim
|
|
|
|
assert len(self.anchors) == len(self.positives)
|
|
assert len(self.anchors) == len(self.negatives)
|
|
|
|
if main_distance_function != "deprecated" and main_similarity_function is None:
|
|
main_similarity_function = main_distance_function
|
|
logger.warning(
|
|
"The 'main_distance_function' parameter is deprecated. Please use 'main_similarity_function' instead. "
|
|
"'main_distance_function' will be removed in a future release."
|
|
)
|
|
|
|
self.main_similarity_function = (
|
|
SimilarityFunction(main_similarity_function) if main_similarity_function else None
|
|
)
|
|
self.similarity_fn_names = similarity_fn_names or []
|
|
|
|
if margin is None:
|
|
self.margin = {"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0}
|
|
elif isinstance(margin, (float, int)):
|
|
self.margin = {"cosine": margin, "dot": margin, "manhattan": margin, "euclidean": margin}
|
|
elif isinstance(margin, dict):
|
|
self.margin = {
|
|
**{"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0},
|
|
**margin,
|
|
}
|
|
else:
|
|
raise ValueError(
|
|
"`margin` should be a float or a dictionary with keys 'cosine', 'dot', 'manhattan', and 'euclidean'"
|
|
)
|
|
|
|
self.batch_size = batch_size
|
|
if show_progress_bar is None:
|
|
show_progress_bar = (
|
|
logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG
|
|
)
|
|
self.show_progress_bar = show_progress_bar
|
|
|
|
self.csv_file: str = "triplet_evaluation" + ("_" + name if name else "") + "_results.csv"
|
|
self.csv_headers = ["epoch", "steps"]
|
|
self.write_csv = write_csv
|
|
|
|
self._append_csv_headers(self.similarity_fn_names)
|
|
|
|
def _append_csv_headers(self, similarity_fn_names):
|
|
for fn_name in similarity_fn_names:
|
|
self.csv_headers.append(f"accuracy_{fn_name}")
|
|
|
|
@classmethod
|
|
def from_input_examples(cls, examples: list[InputExample], **kwargs):
|
|
anchors = []
|
|
positives = []
|
|
negatives = []
|
|
|
|
for example in examples:
|
|
anchors.append(example.texts[0])
|
|
positives.append(example.texts[1])
|
|
negatives.append(example.texts[2])
|
|
return cls(anchors, positives, negatives, **kwargs)
|
|
|
|
def __call__(
|
|
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
|
|
) -> dict[str, float]:
|
|
if epoch != -1:
|
|
if steps == -1:
|
|
out_txt = f" after epoch {epoch}"
|
|
else:
|
|
out_txt = f" in epoch {epoch} after {steps} steps"
|
|
else:
|
|
out_txt = ""
|
|
if self.truncate_dim is not None:
|
|
out_txt += f" (truncated to {self.truncate_dim})"
|
|
|
|
logger.info(f"TripletEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:")
|
|
|
|
embeddings_anchors = self.embed_inputs(model, self.anchors)
|
|
embeddings_positives = self.embed_inputs(model, self.positives)
|
|
embeddings_negatives = self.embed_inputs(model, self.negatives)
|
|
|
|
if not self.similarity_fn_names:
|
|
self.similarity_fn_names = [model.similarity_fn_name]
|
|
self._append_csv_headers(self.similarity_fn_names)
|
|
|
|
similarity_functions = {
|
|
"cosine": lambda anchors, positives, negatives: (
|
|
pairwise_cos_sim(anchors, positives),
|
|
pairwise_cos_sim(anchors, negatives),
|
|
),
|
|
"dot": lambda anchors, positives, negatives: (
|
|
pairwise_dot_score(anchors, positives),
|
|
pairwise_dot_score(anchors, negatives),
|
|
),
|
|
"manhattan": lambda anchors, positives, negatives: (
|
|
pairwise_manhattan_sim(anchors, positives),
|
|
pairwise_manhattan_sim(anchors, negatives),
|
|
),
|
|
"euclidean": lambda anchors, positives, negatives: (
|
|
pairwise_euclidean_sim(anchors, positives),
|
|
pairwise_euclidean_sim(anchors, negatives),
|
|
),
|
|
}
|
|
|
|
metrics = {}
|
|
for fn_name in self.similarity_fn_names:
|
|
if fn_name in similarity_functions:
|
|
positive_scores, negative_scores = similarity_functions[fn_name](
|
|
embeddings_anchors, embeddings_positives, embeddings_negatives
|
|
)
|
|
accuracy = (positive_scores > negative_scores + self.margin[fn_name]).float().mean().item()
|
|
metrics[f"{fn_name}_accuracy"] = accuracy
|
|
logger.info(f"Accuracy {fn_name.capitalize()} Similarity:\t{accuracy:.2%}")
|
|
|
|
if output_path is not None and self.write_csv:
|
|
os.makedirs(output_path, exist_ok=True)
|
|
csv_path = os.path.join(output_path, self.csv_file)
|
|
if not os.path.isfile(csv_path):
|
|
with open(csv_path, newline="", mode="w", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(self.csv_headers)
|
|
writer.writerow([epoch, steps] + list(metrics.values()))
|
|
|
|
else:
|
|
with open(csv_path, newline="", mode="a", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow([epoch, steps] + list(metrics.values()))
|
|
|
|
if len(self.similarity_fn_names) > 1:
|
|
metrics["max_accuracy"] = max(metrics.values())
|
|
|
|
if self.main_similarity_function:
|
|
self.primary_metric = {
|
|
SimilarityFunction.COSINE: "cosine_accuracy",
|
|
SimilarityFunction.DOT_PRODUCT: "dot_accuracy",
|
|
SimilarityFunction.EUCLIDEAN: "euclidean_accuracy",
|
|
SimilarityFunction.MANHATTAN: "manhattan_accuracy",
|
|
}.get(self.main_similarity_function)
|
|
else:
|
|
if len(self.similarity_fn_names) > 1:
|
|
self.primary_metric = "max_accuracy"
|
|
else:
|
|
self.primary_metric = f"{self.similarity_fn_names[0]}_accuracy"
|
|
|
|
metrics = self.prefix_name_to_metrics(metrics, self.name)
|
|
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
|
|
return metrics
|
|
|
|
def embed_inputs(
|
|
self,
|
|
model: SentenceTransformer,
|
|
sentences: str | list[str] | np.ndarray,
|
|
**kwargs,
|
|
) -> np.ndarray:
|
|
return model.encode(
|
|
sentences,
|
|
batch_size=self.batch_size,
|
|
show_progress_bar=self.show_progress_bar,
|
|
convert_to_numpy=True,
|
|
truncate_dim=self.truncate_dim,
|
|
**kwargs,
|
|
)
|
|
|
|
def get_config_dict(self):
|
|
config_dict = {}
|
|
if self.margin != {"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0}:
|
|
config_dict["margin"] = self.margin
|
|
if self.truncate_dim is not None:
|
|
config_dict["truncate_dim"] = self.truncate_dim
|
|
return config_dict
|