File: test_docstring_parameters_consistency.py

package info (click to toggle)
scikit-learn 1.7.2%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 25,752 kB
  • sloc: python: 219,120; cpp: 5,790; ansic: 846; makefile: 191; javascript: 110
file content (113 lines) | stat: -rw-r--r-- 4,171 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import pytest

from sklearn import metrics
from sklearn.ensemble import (
    BaggingClassifier,
    BaggingRegressor,
    IsolationForest,
    StackingClassifier,
    StackingRegressor,
)
from sklearn.utils._testing import assert_docstring_consistency, skip_if_no_numpydoc

CLASS_DOCSTRING_CONSISTENCY_CASES = [
    {
        "objects": [BaggingClassifier, BaggingRegressor, IsolationForest],
        "include_params": ["max_samples"],
        "exclude_params": None,
        "include_attrs": False,
        "exclude_attrs": None,
        "include_returns": False,
        "exclude_returns": None,
        "descr_regex_pattern": r"The number of samples to draw from X to train each.*",
        "ignore_types": ("max_samples"),
    },
    {
        "objects": [StackingClassifier, StackingRegressor],
        "include_params": ["cv", "n_jobs", "passthrough", "verbose"],
        "exclude_params": None,
        "include_attrs": True,
        "exclude_attrs": ["final_estimator_"],
        "include_returns": False,
        "exclude_returns": None,
        "descr_regex_pattern": None,
    },
]

FUNCTION_DOCSTRING_CONSISTENCY_CASES = [
    {
        "objects": [
            metrics.precision_recall_fscore_support,
            metrics.f1_score,
            metrics.fbeta_score,
            metrics.precision_score,
            metrics.recall_score,
        ],
        "include_params": True,
        "exclude_params": ["average", "zero_division"],
        "include_attrs": False,
        "exclude_attrs": None,
        "include_returns": False,
        "exclude_returns": None,
        "descr_regex_pattern": None,
    },
    {
        "objects": [
            metrics.precision_recall_fscore_support,
            metrics.f1_score,
            metrics.fbeta_score,
            metrics.precision_score,
            metrics.recall_score,
        ],
        "include_params": ["average"],
        "exclude_params": None,
        "include_attrs": False,
        "exclude_attrs": None,
        "include_returns": False,
        "exclude_returns": None,
        "descr_regex_pattern": " ".join(
            (
                r"""This parameter is required for multiclass/multilabel targets\.
            If ``None``, the metrics for each class are returned\. Otherwise, this
            determines the type of averaging performed on the data:
            ``'binary'``:
                Only report results for the class specified by ``pos_label``\.
                This is applicable only if targets \(``y_\{true,pred\}``\) are binary\.
            ``'micro'``:
                Calculate metrics globally by counting the total true positives,
                false negatives and false positives\.
            ``'macro'``:
                Calculate metrics for each label, and find their unweighted
                mean\.  This does not take label imbalance into account\.
            ``'weighted'``:
                Calculate metrics for each label, and find their average weighted
                by support \(the number of true instances for each label\)\. This
                alters 'macro' to account for label imbalance; it can result in an
                F-score that is not between precision and recall\."""
                r"[\s\w]*\.*"  # optionally match additional sentence
                r"""
            ``'samples'``:
                Calculate metrics for each instance, and find their average \(only
                meaningful for multilabel classification where this differs from
                :func:`accuracy_score`\)\."""
            ).split()
        ),
    },
]


@pytest.mark.parametrize("case", CLASS_DOCSTRING_CONSISTENCY_CASES)
@skip_if_no_numpydoc
def test_class_docstring_consistency(case):
    """Check docstrings parameters consistency between related classes."""
    assert_docstring_consistency(**case)


@pytest.mark.parametrize("case", FUNCTION_DOCSTRING_CONSISTENCY_CASES)
@skip_if_no_numpydoc
def test_function_docstring_consistency(case):
    """Check docstrings parameters consistency between related functions."""
    assert_docstring_consistency(**case)