File: _common.py

package info (click to toggle)
imbalanced-learn 0.12.4-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 2,160 kB
  • sloc: python: 17,221; sh: 481; makefile: 187; javascript: 50
file content (105 lines) | stat: -rw-r--r-- 3,498 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from numbers import Integral, Real

from sklearn.tree._criterion import Criterion

from ..utils._param_validation import (
    HasMethods,
    Hidden,
    Interval,
    RealNotInt,
    StrOptions,
)


def _estimator_has(attr):
    """Check if we can delegate a method to the underlying estimator.
    First, we check the first fitted estimator if available, otherwise we
    check the estimator attribute.
    """

    def check(self):
        if hasattr(self, "estimators_"):
            return hasattr(self.estimators_[0], attr)
        elif self.estimator is not None:
            return hasattr(self.estimator, attr)
        else:  # TODO(1.4): Remove when the base_estimator deprecation cycle ends
            return hasattr(self.base_estimator, attr)

    return check


_bagging_parameter_constraints = {
    "estimator": [HasMethods(["fit", "predict"]), None],
    "n_estimators": [Interval(Integral, 1, None, closed="left")],
    "max_samples": [
        Interval(Integral, 1, None, closed="left"),
        Interval(RealNotInt, 0, 1, closed="right"),
    ],
    "max_features": [
        Interval(Integral, 1, None, closed="left"),
        Interval(RealNotInt, 0, 1, closed="right"),
    ],
    "bootstrap": ["boolean"],
    "bootstrap_features": ["boolean"],
    "oob_score": ["boolean"],
    "warm_start": ["boolean"],
    "n_jobs": [None, Integral],
    "random_state": ["random_state"],
    "verbose": ["verbose"],
    "base_estimator": [
        HasMethods(["fit", "predict"]),
        StrOptions({"deprecated"}),
        None,
    ],
}

_adaboost_classifier_parameter_constraints = {
    "estimator": [HasMethods(["fit", "predict"]), None],
    "n_estimators": [Interval(Integral, 1, None, closed="left")],
    "learning_rate": [Interval(Real, 0, None, closed="neither")],
    "random_state": ["random_state"],
    "base_estimator": [HasMethods(["fit", "predict"]), StrOptions({"deprecated"})],
    "algorithm": [StrOptions({"SAMME", "SAMME.R"})],
}

_random_forest_classifier_parameter_constraints = {
    "n_estimators": [Interval(Integral, 1, None, closed="left")],
    "bootstrap": ["boolean"],
    "oob_score": ["boolean"],
    "n_jobs": [Integral, None],
    "random_state": ["random_state"],
    "verbose": ["verbose"],
    "warm_start": ["boolean"],
    "criterion": [StrOptions({"gini", "entropy", "log_loss"}), Hidden(Criterion)],
    "max_samples": [
        None,
        Interval(Real, 0.0, 1.0, closed="right"),
        Interval(Integral, 1, None, closed="left"),
    ],
    "max_depth": [Interval(Integral, 1, None, closed="left"), None],
    "min_samples_split": [
        Interval(Integral, 2, None, closed="left"),
        Interval(RealNotInt, 0.0, 1.0, closed="right"),
    ],
    "min_samples_leaf": [
        Interval(Integral, 1, None, closed="left"),
        Interval(RealNotInt, 0.0, 1.0, closed="neither"),
    ],
    "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
    "max_features": [
        Interval(Integral, 1, None, closed="left"),
        Interval(RealNotInt, 0.0, 1.0, closed="right"),
        StrOptions({"sqrt", "log2"}),
        None,
    ],
    "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None],
    "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")],
    "ccp_alpha": [Interval(Real, 0.0, None, closed="left")],
    "class_weight": [
        StrOptions({"balanced_subsample", "balanced"}),
        dict,
        list,
        None,
    ],
    "monotonic_cst": ["array-like", None],
}