You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
4.6 KiB

4 days ago
from __future__ import annotations
from abc import abstractmethod
from typing import Any
import torch
from tokenizers import Tokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from sentence_transformers.models.Module import Module
class InputModule(Module):
"""
Subclass of :class:`sentence_transformers.models.Module`, base class for all input modules in the Sentence
Transformers library, i.e. modules that are used to process inputs and optionally also perform processing
in the forward pass.
This class provides a common interface for all input modules, including methods for loading and saving the module's
configuration and weights, as well as input processing. It also provides a method for performing the forward pass
of the module.
Three abstract methods are defined in this class, which must be implemented by subclasses:
- :meth:`sentence_transformers.models.Module.forward`: The forward pass of the module.
- :meth:`sentence_transformers.models.Module.save`: Save the module to disk.
- :meth:`sentence_transformers.models.InputModule.tokenize`: Tokenize the input texts and return a dictionary of tokenized features.
Optionally, you may also have to override:
- :meth:`sentence_transformers.models.Module.load`: Load the module from disk.
To assist with loading and saving the module, several utility methods are provided:
- :meth:`sentence_transformers.models.Module.load_config`: Load the module's configuration from a JSON file.
- :meth:`sentence_transformers.models.Module.load_file_path`: Load a file from the module's directory, regardless of whether the module is saved locally or on Hugging Face.
- :meth:`sentence_transformers.models.Module.load_dir_path`: Load a directory from the module's directory, regardless of whether the module is saved locally or on Hugging Face.
- :meth:`sentence_transformers.models.Module.load_torch_weights`: Load the PyTorch weights of the module, regardless of whether the module is saved locally or on Hugging Face.
- :meth:`sentence_transformers.models.Module.save_config`: Save the module's configuration to a JSON file.
- :meth:`sentence_transformers.models.Module.save_torch_weights`: Save the PyTorch weights of the module.
- :meth:`sentence_transformers.models.InputModule.save_tokenizer`: Save the tokenizer used by the module.
- :meth:`sentence_transformers.models.Module.get_config_dict`: Get the module's configuration as a dictionary.
And several class variables are defined to assist with loading and saving the module:
- :attr:`sentence_transformers.models.Module.config_file_name`: The name of the configuration file used to save the module's configuration.
- :attr:`sentence_transformers.models.Module.config_keys`: A list of keys used to save the module's configuration.
- :attr:`sentence_transformers.models.InputModule.save_in_root`: Whether to save the module's configuration in the root directory of the model or in a subdirectory named after the module.
- :attr:`sentence_transformers.models.InputModule.tokenizer`: The tokenizer used by the module.
"""
save_in_root: bool = True
tokenizer: PreTrainedTokenizerBase | Tokenizer
"""
The tokenizer used for tokenizing the input texts. It can be either a
:class:`transformers.PreTrainedTokenizerBase` subclass or a Tokenizer from the
``tokenizers`` library.
"""
@abstractmethod
def tokenize(self, texts: list[str], **kwargs) -> dict[str, torch.Tensor | Any]:
"""
Tokenizes the input texts and returns a dictionary of tokenized features.
Args:
texts (list[str]): List of input texts to tokenize.
**kwargs: Additional keyword arguments for tokenization, e.g. ``task``.
Returns:
dict[str, torch.Tensor | Any]: Dictionary containing tokenized features, e.g.
``{"input_ids": ..., "attention_mask": ...}``
"""
def save_tokenizer(self, output_path: str, **kwargs) -> None:
"""
Saves the tokenizer to the specified output path.
Args:
output_path (str): Path to save the tokenizer.
**kwargs: Additional keyword arguments for saving the tokenizer.
Returns:
None
"""
if not hasattr(self, "tokenizer"):
return
if isinstance(self.tokenizer, PreTrainedTokenizerBase):
self.tokenizer.save_pretrained(output_path, **kwargs)
elif isinstance(self.tokenizer, Tokenizer):
self.tokenizer.save(output_path, **kwargs)
return