You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
3.3 KiB
97 lines
3.3 KiB
from __future__ import annotations
|
|
|
|
import csv
|
|
import logging
|
|
import os
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
|
|
from sentence_transformers.util import batch_to_device
|
|
|
|
if TYPE_CHECKING:
|
|
from sentence_transformers.SentenceTransformer import SentenceTransformer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LabelAccuracyEvaluator(SentenceEvaluator):
|
|
"""
|
|
Evaluate a model based on its accuracy on a labeled dataset
|
|
|
|
This requires a model with LossFunction.SOFTMAX
|
|
|
|
The results are written in a CSV. If a CSV already exists, then values are appended.
|
|
"""
|
|
|
|
def __init__(self, dataloader: DataLoader, name: str = "", softmax_model=None, write_csv: bool = True):
|
|
"""
|
|
Constructs an evaluator for the given dataset
|
|
|
|
Args:
|
|
dataloader (DataLoader): the data for the evaluation
|
|
"""
|
|
super().__init__()
|
|
self.dataloader = dataloader
|
|
self.name = name
|
|
self.softmax_model = softmax_model
|
|
|
|
if name:
|
|
name = "_" + name
|
|
|
|
self.write_csv = write_csv
|
|
self.csv_file = "accuracy_evaluation" + name + "_results.csv"
|
|
self.csv_headers = ["epoch", "steps", "accuracy"]
|
|
self.primary_metric = "accuracy"
|
|
|
|
def __call__(
|
|
self, model: SentenceTransformer, output_path: str | None = None, epoch: int = -1, steps: int = -1
|
|
) -> dict[str, float]:
|
|
model.eval()
|
|
total = 0
|
|
correct = 0
|
|
|
|
if epoch != -1:
|
|
if steps == -1:
|
|
out_txt = f" after epoch {epoch}:"
|
|
else:
|
|
out_txt = f" in epoch {epoch} after {steps} steps:"
|
|
else:
|
|
out_txt = ":"
|
|
|
|
logger.info("Evaluation on the " + self.name + " dataset" + out_txt)
|
|
self.dataloader.collate_fn = model.smart_batching_collate
|
|
for step, batch in enumerate(self.dataloader):
|
|
features, label_ids = batch
|
|
for idx in range(len(features)):
|
|
features[idx] = batch_to_device(features[idx], model.device)
|
|
label_ids = label_ids.to(model.device)
|
|
with torch.no_grad():
|
|
_, prediction = self.softmax_model(features, labels=None)
|
|
|
|
total += prediction.size(0)
|
|
correct += torch.argmax(prediction, dim=1).eq(label_ids).sum().item()
|
|
accuracy = correct / total
|
|
|
|
logger.info(f"Accuracy: {accuracy:.4f} ({correct}/{total})\n")
|
|
|
|
if output_path is not None and self.write_csv:
|
|
os.makedirs(output_path, exist_ok=True)
|
|
csv_path = os.path.join(output_path, self.csv_file)
|
|
if not os.path.isfile(csv_path):
|
|
with open(csv_path, newline="", mode="w", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(self.csv_headers)
|
|
writer.writerow([epoch, steps, accuracy])
|
|
else:
|
|
with open(csv_path, newline="", mode="a", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow([epoch, steps, accuracy])
|
|
|
|
metrics = {"accuracy": accuracy}
|
|
metrics = self.prefix_name_to_metrics(metrics, self.name)
|
|
self.store_metrics_in_model_card_data(model, metrics, epoch, steps)
|
|
return metrics
|