You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
114 lines
4.1 KiB
114 lines
4.1 KiB
# Authors: The scikit-learn developers
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
import pytest
|
|
|
|
from sklearn import metrics
|
|
from sklearn.ensemble import (
|
|
BaggingClassifier,
|
|
BaggingRegressor,
|
|
IsolationForest,
|
|
StackingClassifier,
|
|
StackingRegressor,
|
|
)
|
|
from sklearn.utils._testing import assert_docstring_consistency, skip_if_no_numpydoc
|
|
|
|
CLASS_DOCSTRING_CONSISTENCY_CASES = [
|
|
{
|
|
"objects": [BaggingClassifier, BaggingRegressor, IsolationForest],
|
|
"include_params": ["max_samples"],
|
|
"exclude_params": None,
|
|
"include_attrs": False,
|
|
"exclude_attrs": None,
|
|
"include_returns": False,
|
|
"exclude_returns": None,
|
|
"descr_regex_pattern": r"The number of samples to draw from X to train each.*",
|
|
"ignore_types": ("max_samples"),
|
|
},
|
|
{
|
|
"objects": [StackingClassifier, StackingRegressor],
|
|
"include_params": ["cv", "n_jobs", "passthrough", "verbose"],
|
|
"exclude_params": None,
|
|
"include_attrs": True,
|
|
"exclude_attrs": ["final_estimator_"],
|
|
"include_returns": False,
|
|
"exclude_returns": None,
|
|
"descr_regex_pattern": None,
|
|
},
|
|
]
|
|
|
|
FUNCTION_DOCSTRING_CONSISTENCY_CASES = [
|
|
{
|
|
"objects": [
|
|
metrics.precision_recall_fscore_support,
|
|
metrics.f1_score,
|
|
metrics.fbeta_score,
|
|
metrics.precision_score,
|
|
metrics.recall_score,
|
|
],
|
|
"include_params": True,
|
|
"exclude_params": ["average", "zero_division"],
|
|
"include_attrs": False,
|
|
"exclude_attrs": None,
|
|
"include_returns": False,
|
|
"exclude_returns": None,
|
|
"descr_regex_pattern": None,
|
|
},
|
|
{
|
|
"objects": [
|
|
metrics.precision_recall_fscore_support,
|
|
metrics.f1_score,
|
|
metrics.fbeta_score,
|
|
metrics.precision_score,
|
|
metrics.recall_score,
|
|
],
|
|
"include_params": ["average"],
|
|
"exclude_params": None,
|
|
"include_attrs": False,
|
|
"exclude_attrs": None,
|
|
"include_returns": False,
|
|
"exclude_returns": None,
|
|
"descr_regex_pattern": " ".join(
|
|
(
|
|
r"""This parameter is required for multiclass/multilabel targets\.
|
|
If ``None``, the metrics for each class are returned\. Otherwise, this
|
|
determines the type of averaging performed on the data:
|
|
``'binary'``:
|
|
Only report results for the class specified by ``pos_label``\.
|
|
This is applicable only if targets \(``y_\{true,pred\}``\) are binary\.
|
|
``'micro'``:
|
|
Calculate metrics globally by counting the total true positives,
|
|
false negatives and false positives\.
|
|
``'macro'``:
|
|
Calculate metrics for each label, and find their unweighted
|
|
mean\. This does not take label imbalance into account\.
|
|
``'weighted'``:
|
|
Calculate metrics for each label, and find their average weighted
|
|
by support \(the number of true instances for each label\)\. This
|
|
alters 'macro' to account for label imbalance; it can result in an
|
|
F-score that is not between precision and recall\."""
|
|
r"[\s\w]*\.*" # optionally match additional sentence
|
|
r"""
|
|
``'samples'``:
|
|
Calculate metrics for each instance, and find their average \(only
|
|
meaningful for multilabel classification where this differs from
|
|
:func:`accuracy_score`\)\."""
|
|
).split()
|
|
),
|
|
},
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("case", CLASS_DOCSTRING_CONSISTENCY_CASES)
|
|
@skip_if_no_numpydoc
|
|
def test_class_docstring_consistency(case):
|
|
"""Check docstrings parameters consistency between related classes."""
|
|
assert_docstring_consistency(**case)
|
|
|
|
|
|
@pytest.mark.parametrize("case", FUNCTION_DOCSTRING_CONSISTENCY_CASES)
|
|
@skip_if_no_numpydoc
|
|
def test_function_docstring_consistency(case):
|
|
"""Check docstrings parameters consistency between related functions."""
|
|
assert_docstring_consistency(**case)
|