You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
71 lines
2.6 KiB
71 lines
2.6 KiB
from __future__ import annotations
|
|
|
|
import functools
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def cross_encoder_init_args_decorator(func):
|
|
@functools.wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
kwargs_renamed_mapping = {
|
|
"model_name": "model_name_or_path",
|
|
"automodel_args": "model_kwargs",
|
|
"tokenizer_args": "tokenizer_kwargs",
|
|
"config_args": "config_kwargs",
|
|
"cache_dir": "cache_folder",
|
|
"default_activation_function": "activation_fn",
|
|
}
|
|
for old_name, new_name in kwargs_renamed_mapping.items():
|
|
if old_name in kwargs:
|
|
kwarg_value = kwargs.pop(old_name)
|
|
logger.warning(
|
|
f"The CrossEncoder `{old_name}` argument was renamed and is now deprecated, please use `{new_name}` instead."
|
|
)
|
|
if new_name not in kwargs:
|
|
kwargs[new_name] = kwarg_value
|
|
|
|
if "classifier_dropout" in kwargs:
|
|
classifier_dropout = kwargs.pop("classifier_dropout")
|
|
logger.warning(
|
|
f"The CrossEncoder `classifier_dropout` argument is deprecated. Please use `config_kwargs={{'classifier_dropout': {classifier_dropout}}}` instead."
|
|
)
|
|
if "config_kwargs" not in kwargs:
|
|
kwargs["config_kwargs"] = {"classifier_dropout": classifier_dropout}
|
|
else:
|
|
kwargs["config_kwargs"]["classifier_dropout"] = classifier_dropout
|
|
|
|
return func(self, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def cross_encoder_predict_rank_args_decorator(func):
|
|
@functools.wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
kwargs_renamed_mapping = {
|
|
"activation_fct": "activation_fn",
|
|
}
|
|
for old_name, new_name in kwargs_renamed_mapping.items():
|
|
if old_name in kwargs:
|
|
kwarg_value = kwargs.pop(old_name)
|
|
logger.warning(
|
|
f"The CrossEncoder.predict `{old_name}` argument was renamed and is now deprecated, please use `{new_name}` instead."
|
|
)
|
|
if new_name not in kwargs:
|
|
kwargs[new_name] = kwarg_value
|
|
|
|
deprecated_args = ["num_workers"]
|
|
|
|
for deprecated_arg in deprecated_args:
|
|
if deprecated_arg in kwargs:
|
|
kwargs.pop(deprecated_arg)
|
|
logger.warning(
|
|
f"The CrossEncoder.predict `{deprecated_arg}` argument is deprecated and has no effect. It will be removed in a future version."
|
|
)
|
|
|
|
return func(self, *args, **kwargs)
|
|
|
|
return wrapper
|