You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
2.6 KiB

from __future__ import annotations
import functools
import logging
logger = logging.getLogger(__name__)
def cross_encoder_init_args_decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
kwargs_renamed_mapping = {
"model_name": "model_name_or_path",
"automodel_args": "model_kwargs",
"tokenizer_args": "tokenizer_kwargs",
"config_args": "config_kwargs",
"cache_dir": "cache_folder",
"default_activation_function": "activation_fn",
}
for old_name, new_name in kwargs_renamed_mapping.items():
if old_name in kwargs:
kwarg_value = kwargs.pop(old_name)
logger.warning(
f"The CrossEncoder `{old_name}` argument was renamed and is now deprecated, please use `{new_name}` instead."
)
if new_name not in kwargs:
kwargs[new_name] = kwarg_value
if "classifier_dropout" in kwargs:
classifier_dropout = kwargs.pop("classifier_dropout")
logger.warning(
f"The CrossEncoder `classifier_dropout` argument is deprecated. Please use `config_kwargs={{'classifier_dropout': {classifier_dropout}}}` instead."
)
if "config_kwargs" not in kwargs:
kwargs["config_kwargs"] = {"classifier_dropout": classifier_dropout}
else:
kwargs["config_kwargs"]["classifier_dropout"] = classifier_dropout
return func(self, *args, **kwargs)
return wrapper
def cross_encoder_predict_rank_args_decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
kwargs_renamed_mapping = {
"activation_fct": "activation_fn",
}
for old_name, new_name in kwargs_renamed_mapping.items():
if old_name in kwargs:
kwarg_value = kwargs.pop(old_name)
logger.warning(
f"The CrossEncoder.predict `{old_name}` argument was renamed and is now deprecated, please use `{new_name}` instead."
)
if new_name not in kwargs:
kwargs[new_name] = kwarg_value
deprecated_args = ["num_workers"]
for deprecated_arg in deprecated_args:
if deprecated_arg in kwargs:
kwargs.pop(deprecated_arg)
logger.warning(
f"The CrossEncoder.predict `{deprecated_arg}` argument is deprecated and has no effect. It will be removed in a future version."
)
return func(self, *args, **kwargs)
return wrapper