You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
3.6 KiB
125 lines
3.6 KiB
from __future__ import annotations
|
|
|
|
import csv
|
|
import importlib
|
|
import logging
|
|
from contextlib import contextmanager
|
|
|
|
|
|
def fullname(o) -> str:
|
|
"""
|
|
Gives a full name (package_name.class_name) for a class / object in Python. Will
|
|
be used to load the correct classes from JSON files
|
|
|
|
Args:
|
|
o: The object for which to get the full name.
|
|
|
|
Returns:
|
|
str: The full name of the object.
|
|
|
|
Example:
|
|
>>> from sentence_transformers.losses import MultipleNegativesRankingLoss
|
|
>>> from sentence_transformers import SentenceTransformer
|
|
>>> from sentence_transformers.util import fullname
|
|
>>> model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
>>> loss = MultipleNegativesRankingLoss(model)
|
|
>>> fullname(loss)
|
|
'sentence_transformers.losses.MultipleNegativesRankingLoss.MultipleNegativesRankingLoss'
|
|
"""
|
|
|
|
module = o.__class__.__module__
|
|
if module is None or module == str.__class__.__module__:
|
|
return o.__class__.__name__ # Avoid reporting __builtin__
|
|
else:
|
|
return module + "." + o.__class__.__name__
|
|
|
|
|
|
def import_from_string(dotted_path: str) -> type:
|
|
"""
|
|
Import a dotted module path and return the attribute/class designated by the
|
|
last name in the path. Raise ImportError if the import failed.
|
|
|
|
Args:
|
|
dotted_path (str): The dotted module path.
|
|
|
|
Returns:
|
|
Any: The attribute/class designated by the last name in the path.
|
|
|
|
Raises:
|
|
ImportError: If the import failed.
|
|
|
|
Example:
|
|
>>> import_from_string('sentence_transformers.losses.MultipleNegativesRankingLoss')
|
|
<class 'sentence_transformers.losses.MultipleNegativesRankingLoss.MultipleNegativesRankingLoss'>
|
|
"""
|
|
try:
|
|
module_path, class_name = dotted_path.rsplit(".", 1)
|
|
except ValueError:
|
|
msg = f"{dotted_path} doesn't look like a module path"
|
|
raise ImportError(msg)
|
|
|
|
try:
|
|
module = importlib.import_module(dotted_path)
|
|
except Exception:
|
|
module = importlib.import_module(module_path)
|
|
|
|
try:
|
|
return getattr(module, class_name)
|
|
except AttributeError:
|
|
msg = f'Module "{module_path}" does not define a "{class_name}" attribute/class'
|
|
raise ImportError(msg)
|
|
|
|
|
|
@contextmanager
|
|
def disable_datasets_caching():
|
|
"""
|
|
A context manager that will disable caching in the datasets library.
|
|
"""
|
|
from datasets import disable_caching, enable_caching, is_caching_enabled
|
|
|
|
is_originally_enabled = is_caching_enabled()
|
|
|
|
try:
|
|
if is_originally_enabled:
|
|
disable_caching()
|
|
yield
|
|
finally:
|
|
if is_originally_enabled:
|
|
enable_caching()
|
|
|
|
|
|
@contextmanager
|
|
def disable_logging(highest_level=logging.CRITICAL):
|
|
"""
|
|
A context manager that will prevent any logging messages
|
|
triggered during the body from being processed.
|
|
|
|
Args:
|
|
highest_level: the maximum logging level allowed.
|
|
"""
|
|
previous_level = logging.root.manager.disable
|
|
logging.disable(highest_level)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
logging.disable(previous_level)
|
|
|
|
|
|
def append_to_last_row(csv_path, additional_data):
|
|
# Read the entire CSV file
|
|
with open(csv_path, newline="", encoding="utf-8") as f:
|
|
reader = csv.reader(f)
|
|
rows = list(reader)
|
|
|
|
if len(rows) > 1: # Make sure there's at least one data row (after the header)
|
|
# Append the additional data to the last row
|
|
rows[-1].extend(additional_data)
|
|
|
|
# Write the entire file back
|
|
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerows(rows)
|
|
return True
|
|
return False
|