from __future__ import annotations import functools import logging logger = logging.getLogger(__name__) def cross_encoder_init_args_decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): kwargs_renamed_mapping = { "model_name": "model_name_or_path", "automodel_args": "model_kwargs", "tokenizer_args": "tokenizer_kwargs", "config_args": "config_kwargs", "cache_dir": "cache_folder", "default_activation_function": "activation_fn", } for old_name, new_name in kwargs_renamed_mapping.items(): if old_name in kwargs: kwarg_value = kwargs.pop(old_name) logger.warning( f"The CrossEncoder `{old_name}` argument was renamed and is now deprecated, please use `{new_name}` instead." ) if new_name not in kwargs: kwargs[new_name] = kwarg_value if "classifier_dropout" in kwargs: classifier_dropout = kwargs.pop("classifier_dropout") logger.warning( f"The CrossEncoder `classifier_dropout` argument is deprecated. Please use `config_kwargs={{'classifier_dropout': {classifier_dropout}}}` instead." ) if "config_kwargs" not in kwargs: kwargs["config_kwargs"] = {"classifier_dropout": classifier_dropout} else: kwargs["config_kwargs"]["classifier_dropout"] = classifier_dropout return func(self, *args, **kwargs) return wrapper def cross_encoder_predict_rank_args_decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): kwargs_renamed_mapping = { "activation_fct": "activation_fn", } for old_name, new_name in kwargs_renamed_mapping.items(): if old_name in kwargs: kwarg_value = kwargs.pop(old_name) logger.warning( f"The CrossEncoder.predict `{old_name}` argument was renamed and is now deprecated, please use `{new_name}` instead." ) if new_name not in kwargs: kwargs[new_name] = kwarg_value deprecated_args = ["num_workers"] for deprecated_arg in deprecated_args: if deprecated_arg in kwargs: kwargs.pop(deprecated_arg) logger.warning( f"The CrossEncoder.predict `{deprecated_arg}` argument is deprecated and has no effect. It will be removed in a future version." ) return func(self, *args, **kwargs) return wrapper