You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
785 B
30 lines
785 B
from __future__ import annotations
|
|
|
|
try:
|
|
from typing import Self
|
|
except ImportError:
|
|
from typing_extensions import Self
|
|
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
from sentence_transformers.models.Module import Module
|
|
|
|
|
|
class Normalize(Module):
|
|
"""This layer normalizes embeddings to unit length"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
features.update({"sentence_embedding": F.normalize(features["sentence_embedding"], p=2, dim=1)})
|
|
return features
|
|
|
|
def save(self, output_path: str, *args, safe_serialization: bool = True, **kwargs) -> None:
|
|
return
|
|
|
|
@classmethod
|
|
def load(cls, *args, **kwargs) -> Self:
|
|
return cls()
|