You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
130 lines
4.7 KiB
130 lines
4.7 KiB
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from enum import Enum
|
|
|
|
from numpy import ndarray
|
|
from torch import Tensor
|
|
|
|
from .util import (
|
|
cos_sim,
|
|
dot_score,
|
|
euclidean_sim,
|
|
manhattan_sim,
|
|
pairwise_cos_sim,
|
|
pairwise_dot_score,
|
|
pairwise_euclidean_sim,
|
|
pairwise_manhattan_sim,
|
|
)
|
|
|
|
|
|
class SimilarityFunction(Enum):
|
|
"""
|
|
Enum class for supported similarity functions. The following functions are supported:
|
|
|
|
- ``SimilarityFunction.COSINE`` (``"cosine"``): Cosine similarity
|
|
- ``SimilarityFunction.DOT_PRODUCT`` (``"dot"``, ``dot_product``): Dot product similarity
|
|
- ``SimilarityFunction.EUCLIDEAN`` (``"euclidean"``): Euclidean distance
|
|
- ``SimilarityFunction.MANHATTAN`` (``"manhattan"``): Manhattan distance
|
|
"""
|
|
|
|
COSINE = "cosine"
|
|
DOT_PRODUCT = "dot"
|
|
DOT = "dot" # Alias for DOT_PRODUCT
|
|
EUCLIDEAN = "euclidean"
|
|
MANHATTAN = "manhattan"
|
|
|
|
@staticmethod
|
|
def to_similarity_fn(
|
|
similarity_function: str | SimilarityFunction,
|
|
) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
|
|
"""
|
|
Converts a similarity function name or enum value to the corresponding similarity function.
|
|
|
|
Args:
|
|
similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function.
|
|
|
|
Returns:
|
|
Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The corresponding similarity function.
|
|
|
|
Raises:
|
|
ValueError: If the provided function is not supported.
|
|
|
|
Example:
|
|
>>> similarity_fn = SimilarityFunction.to_similarity_fn("cosine")
|
|
>>> similarity_scores = similarity_fn(embeddings1, embeddings2)
|
|
>>> similarity_scores
|
|
tensor([[0.3952, 0.0554],
|
|
[0.0992, 0.1570]])
|
|
"""
|
|
similarity_function = SimilarityFunction(similarity_function)
|
|
|
|
if similarity_function == SimilarityFunction.COSINE:
|
|
return cos_sim
|
|
if similarity_function == SimilarityFunction.DOT_PRODUCT:
|
|
return dot_score
|
|
if similarity_function == SimilarityFunction.MANHATTAN:
|
|
return manhattan_sim
|
|
if similarity_function == SimilarityFunction.EUCLIDEAN:
|
|
return euclidean_sim
|
|
|
|
raise ValueError(
|
|
f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}."
|
|
)
|
|
|
|
@staticmethod
|
|
def to_similarity_pairwise_fn(
|
|
similarity_function: str | SimilarityFunction,
|
|
) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
|
|
"""
|
|
Converts a similarity function into a pairwise similarity function.
|
|
|
|
The pairwise similarity function returns the diagonal vector from the similarity matrix, i.e. it only
|
|
computes the similarity(a[i], b[i]) for each i in the range of the input tensors, rather than
|
|
computing the similarity between all pairs of a and b.
|
|
|
|
Args:
|
|
similarity_function (Union[str, SimilarityFunction]): The name or enum value of the similarity function.
|
|
|
|
Returns:
|
|
Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]: The pairwise similarity function.
|
|
|
|
Raises:
|
|
ValueError: If the provided similarity function is not supported.
|
|
|
|
Example:
|
|
>>> pairwise_fn = SimilarityFunction.to_similarity_pairwise_fn("cosine")
|
|
>>> similarity_scores = pairwise_fn(embeddings1, embeddings2)
|
|
>>> similarity_scores
|
|
tensor([0.3952, 0.1570])
|
|
"""
|
|
similarity_function = SimilarityFunction(similarity_function)
|
|
|
|
if similarity_function == SimilarityFunction.COSINE:
|
|
return pairwise_cos_sim
|
|
if similarity_function == SimilarityFunction.DOT_PRODUCT:
|
|
return pairwise_dot_score
|
|
if similarity_function == SimilarityFunction.MANHATTAN:
|
|
return pairwise_manhattan_sim
|
|
if similarity_function == SimilarityFunction.EUCLIDEAN:
|
|
return pairwise_euclidean_sim
|
|
|
|
raise ValueError(
|
|
f"The provided function {similarity_function} is not supported. Use one of the supported values: {SimilarityFunction.possible_values()}."
|
|
)
|
|
|
|
@staticmethod
|
|
def possible_values() -> list[str]:
|
|
"""
|
|
Returns a list of possible values for the SimilarityFunction enum.
|
|
|
|
Returns:
|
|
list: A list of possible values for the SimilarityFunction enum.
|
|
|
|
Example:
|
|
>>> possible_values = SimilarityFunction.possible_values()
|
|
>>> possible_values
|
|
['cosine', 'dot', 'euclidean', 'manhattan']
|
|
"""
|
|
return [m.value for m in SimilarityFunction]
|