You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
7.6 KiB
227 lines
7.6 KiB
from __future__ import annotations
|
|
|
|
from typing import Any, overload
|
|
|
|
import numpy as np
|
|
import torch
|
|
from scipy.sparse import coo_matrix
|
|
from torch import Tensor, device
|
|
|
|
|
|
def _convert_to_tensor(a: list | np.ndarray | Tensor) -> Tensor:
|
|
"""
|
|
Converts the input `a` to a PyTorch tensor if it is not already a tensor.
|
|
Handles lists of sparse tensors by stacking them.
|
|
|
|
Args:
|
|
a (Union[list, np.ndarray, Tensor]): The input array or tensor.
|
|
|
|
Returns:
|
|
Tensor: The converted tensor.
|
|
"""
|
|
if isinstance(a, list):
|
|
# Check if list contains sparse tensors
|
|
if all(isinstance(x, Tensor) and x.is_sparse for x in a):
|
|
# Stack sparse tensors while preserving sparsity
|
|
return torch.stack([x.coalesce().to(dtype=torch.float32) for x in a])
|
|
else:
|
|
a = torch.tensor(a)
|
|
elif not isinstance(a, Tensor):
|
|
a = torch.tensor(a)
|
|
if a.is_sparse:
|
|
return a.to(dtype=torch.float32)
|
|
return a
|
|
|
|
|
|
def _convert_to_batch(a: Tensor) -> Tensor:
|
|
"""
|
|
If the tensor `a` is 1-dimensional, it is unsqueezed to add a batch dimension.
|
|
|
|
Args:
|
|
a (Tensor): The input tensor.
|
|
|
|
Returns:
|
|
Tensor: The tensor with a batch dimension.
|
|
"""
|
|
if a.dim() == 1:
|
|
a = a.unsqueeze(0)
|
|
return a
|
|
|
|
|
|
def _convert_to_batch_tensor(a: list | np.ndarray | Tensor) -> Tensor:
|
|
"""
|
|
Converts the input data to a tensor with a batch dimension.
|
|
Handles lists of sparse tensors by stacking them.
|
|
|
|
Args:
|
|
a (Union[list, np.ndarray, Tensor]): The input data to be converted.
|
|
|
|
Returns:
|
|
Tensor: The converted tensor with a batch dimension.
|
|
"""
|
|
a = _convert_to_tensor(a)
|
|
if a.dim() == 1:
|
|
a = a.unsqueeze(0)
|
|
return a
|
|
|
|
|
|
def normalize_embeddings(embeddings: Tensor) -> Tensor:
|
|
"""
|
|
Normalizes the embeddings matrix, so that each sentence embedding has unit length.
|
|
|
|
Args:
|
|
embeddings (Tensor): The input embeddings matrix.
|
|
|
|
Returns:
|
|
Tensor: The normalized embeddings matrix.
|
|
"""
|
|
if not embeddings.is_sparse:
|
|
return torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
|
|
embeddings = embeddings.coalesce()
|
|
indices, values = embeddings.indices(), embeddings.values()
|
|
|
|
# Compute row norms efficiently
|
|
row_norms = torch.zeros(embeddings.size(0), device=embeddings.device)
|
|
row_norms.index_add_(0, indices[0], values**2)
|
|
row_norms = torch.sqrt(row_norms).index_select(0, indices[0])
|
|
|
|
# Normalize values where norm > 0
|
|
mask = row_norms > 0
|
|
normalized_values = values.clone()
|
|
normalized_values[mask] /= row_norms[mask]
|
|
|
|
return torch.sparse_coo_tensor(indices, normalized_values, embeddings.size())
|
|
|
|
|
|
@overload
|
|
def truncate_embeddings(embeddings: np.ndarray, truncate_dim: int | None) -> np.ndarray: ...
|
|
|
|
|
|
@overload
|
|
def truncate_embeddings(embeddings: torch.Tensor, truncate_dim: int | None) -> torch.Tensor: ...
|
|
|
|
|
|
def truncate_embeddings(embeddings: np.ndarray | torch.Tensor, truncate_dim: int | None) -> np.ndarray | torch.Tensor:
|
|
"""
|
|
Truncates the embeddings matrix.
|
|
|
|
Args:
|
|
embeddings (Union[np.ndarray, torch.Tensor]): Embeddings to truncate.
|
|
truncate_dim (Optional[int]): The dimension to truncate sentence embeddings to. `None` does no truncation.
|
|
|
|
Example:
|
|
>>> from sentence_transformers import SentenceTransformer
|
|
>>> from sentence_transformers.util import truncate_embeddings
|
|
>>> model = SentenceTransformer("tomaarsen/mpnet-base-nli-matryoshka")
|
|
>>> embeddings = model.encode(["It's so nice outside!", "Today is a beautiful day.", "He drove to work earlier"])
|
|
>>> embeddings.shape
|
|
(3, 768)
|
|
>>> model.similarity(embeddings, embeddings)
|
|
tensor([[1.0000, 0.8100, 0.1426],
|
|
[0.8100, 1.0000, 0.2121],
|
|
[0.1426, 0.2121, 1.0000]])
|
|
>>> truncated_embeddings = truncate_embeddings(embeddings, 128)
|
|
>>> truncated_embeddings.shape
|
|
>>> model.similarity(truncated_embeddings, truncated_embeddings)
|
|
tensor([[1.0000, 0.8092, 0.1987],
|
|
[0.8092, 1.0000, 0.2716],
|
|
[0.1987, 0.2716, 1.0000]])
|
|
|
|
Returns:
|
|
Union[np.ndarray, torch.Tensor]: Truncated embeddings.
|
|
"""
|
|
return embeddings[..., :truncate_dim]
|
|
|
|
|
|
def select_max_active_dims(embeddings: np.ndarray | torch.Tensor, max_active_dims: int | None) -> torch.Tensor:
|
|
"""
|
|
Keeps only the top-k values (in absolute terms) for each embedding and creates a sparse tensor.
|
|
|
|
Args:
|
|
embeddings (Union[np.ndarray, torch.Tensor]): Embeddings to sparsify by keeping only top_k values.
|
|
max_active_dims (int): Number of values to keep as non-zeros per embedding.
|
|
|
|
Returns:
|
|
torch.Tensor: A sparse tensor containing only the top-k values per embedding.
|
|
"""
|
|
if max_active_dims is None:
|
|
return embeddings
|
|
# Convert to tensor if numpy array
|
|
if isinstance(embeddings, np.ndarray):
|
|
embeddings = torch.tensor(embeddings)
|
|
|
|
batch_size, dim = embeddings.shape
|
|
device = embeddings.device
|
|
|
|
# Get the top-k indices for each embedding (by absolute value)
|
|
_, top_indices = torch.topk(torch.abs(embeddings), k=min(max_active_dims, dim), dim=1)
|
|
|
|
# Create a mask of zeros, then set the top-k positions to 1
|
|
mask = torch.zeros_like(embeddings, dtype=torch.bool)
|
|
batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, min(max_active_dims, dim))
|
|
mask[batch_indices.flatten(), top_indices.flatten()] = True
|
|
|
|
# Create a sparse tensor with only the top values
|
|
embeddings[~mask] = 0
|
|
|
|
return embeddings
|
|
|
|
|
|
def batch_to_device(batch: dict[str, Any], target_device: device) -> dict[str, Any]:
|
|
"""
|
|
Send a PyTorch batch (i.e., a dictionary of string keys to Tensors) to a device (e.g. "cpu", "cuda", "mps").
|
|
|
|
Args:
|
|
batch (Dict[str, Tensor]): The batch to send to the device.
|
|
target_device (torch.device): The target device (e.g. "cpu", "cuda", "mps").
|
|
|
|
Returns:
|
|
Dict[str, Tensor]: The batch with tensors sent to the target device.
|
|
"""
|
|
for key in batch:
|
|
if isinstance(batch[key], Tensor):
|
|
batch[key] = batch[key].to(target_device)
|
|
return batch
|
|
|
|
|
|
def to_scipy_coo(x: Tensor) -> coo_matrix:
|
|
x = x.coalesce()
|
|
indices = x.indices().cpu().numpy()
|
|
values = x.values().cpu().numpy()
|
|
return coo_matrix((values, (indices[0], indices[1])), shape=x.shape)
|
|
|
|
|
|
def compute_count_vector(embeddings: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Compute count vector from sparse embeddings indicating how many samples have non-zero values in each dimension.
|
|
|
|
Args:
|
|
embeddings: Sparse tensor of shape (batch_size, vocab_size) or (vocab_size,)
|
|
|
|
Returns:
|
|
Count vector of shape (vocab_size,)
|
|
"""
|
|
if not embeddings.is_sparse:
|
|
embeddings = embeddings.to_sparse()
|
|
|
|
# Coalesce to ensure indices are sorted and unique
|
|
embeddings = embeddings.coalesce()
|
|
|
|
count_vector = torch.zeros(embeddings.size(-1), device=embeddings.device, dtype=torch.int32)
|
|
if embeddings.dim() == 1:
|
|
# Single embedding case
|
|
count_vector[embeddings.indices().squeeze()] = 1
|
|
return count_vector
|
|
elif embeddings.dim() == 2:
|
|
# Batch case
|
|
if embeddings.values().numel() > 0:
|
|
indices = embeddings.indices()
|
|
# Count how many samples have non-zero values in each dimension
|
|
unique_dims, counts = torch.unique(indices[1], return_counts=True)
|
|
count_vector[unique_dims] = counts.int()
|
|
|
|
return count_vector
|
|
else:
|
|
raise ValueError(f"Expected 1D or 2D tensor, got {embeddings.dim()}D")
|