File: plot_release_highlights_1_3_0.py

package info (click to toggle)
scikit-learn 1.7.2%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 25,752 kB
  • sloc: python: 219,120; cpp: 5,790; ansic: 846; makefile: 191; javascript: 110
file content (163 lines) | stat: -rw-r--r-- 6,398 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# ruff: noqa: CPY001
"""
=======================================
Release Highlights for scikit-learn 1.3
=======================================

.. currentmodule:: sklearn

We are pleased to announce the release of scikit-learn 1.3! Many bug fixes
and improvements were added, as well as some new key features. We detail
below a few of the major features of this release. **For an exhaustive list of
all the changes**, please refer to the :ref:`release notes <release_notes_1_3>`.

To install the latest version (with pip)::

    pip install --upgrade scikit-learn

or with conda::

    conda install -c conda-forge scikit-learn

"""

# %%
# Metadata Routing
# ----------------
# We are in the process of introducing a new way to route metadata such as
# ``sample_weight`` throughout the codebase, which would affect how
# meta-estimators such as :class:`pipeline.Pipeline` and
# :class:`model_selection.GridSearchCV` route metadata. While the
# infrastructure for this feature is already included in this release, the work
# is ongoing and not all meta-estimators support this new feature. You can read
# more about this feature in the :ref:`Metadata Routing User Guide
# <metadata_routing>`. Note that this feature is still under development and
# not implemented for most meta-estimators.
#
# Third party developers can already start incorporating this into their
# meta-estimators. For more details, see
# :ref:`metadata routing developer guide
# <sphx_glr_auto_examples_miscellaneous_plot_metadata_routing.py>`.

# %%
# HDBSCAN: hierarchical density-based clustering
# ----------------------------------------------
# Originally hosted in the scikit-learn-contrib repository, :class:`cluster.HDBSCAN`
# has been adpoted into scikit-learn. It's missing a few features from the original
# implementation which will be added in future releases.
# By performing a modified version of :class:`cluster.DBSCAN` over multiple epsilon
# values simultaneously, :class:`cluster.HDBSCAN` finds clusters of varying densities
# making it more robust to parameter selection than :class:`cluster.DBSCAN`.
# More details in the :ref:`User Guide <hdbscan>`.
import numpy as np

from sklearn.cluster import HDBSCAN
from sklearn.datasets import load_digits
from sklearn.metrics import v_measure_score

X, true_labels = load_digits(return_X_y=True)
print(f"number of digits: {len(np.unique(true_labels))}")

hdbscan = HDBSCAN(min_cluster_size=15).fit(X)
non_noisy_labels = hdbscan.labels_[hdbscan.labels_ != -1]
print(f"number of clusters found: {len(np.unique(non_noisy_labels))}")

print(v_measure_score(true_labels[hdbscan.labels_ != -1], non_noisy_labels))

# %%
# TargetEncoder: a new category encoding strategy
# -----------------------------------------------
# Well suited for categorical features with high cardinality,
# :class:`preprocessing.TargetEncoder` encodes the categories based on a shrunk
# estimate of the average target values for observations belonging to that category.
# More details in the :ref:`User Guide <target_encoder>`.
import numpy as np

from sklearn.preprocessing import TargetEncoder

X = np.array([["cat"] * 30 + ["dog"] * 20 + ["snake"] * 38], dtype=object).T
y = [90.3] * 30 + [20.4] * 20 + [21.2] * 38

enc = TargetEncoder(random_state=0)
X_trans = enc.fit_transform(X, y)

enc.encodings_

# %%
# Missing values support in decision trees
# ----------------------------------------
# The classes :class:`tree.DecisionTreeClassifier` and
# :class:`tree.DecisionTreeRegressor` now support missing values. For each potential
# threshold on the non-missing data, the splitter will evaluate the split with all the
# missing values going to the left node or the right node.
# See more details in the :ref:`User Guide <tree_missing_value_support>` or see
# :ref:`sphx_glr_auto_examples_ensemble_plot_hgbt_regression.py` for a usecase
# example of this feature in :class:`~ensemble.HistGradientBoostingRegressor`.
import numpy as np

from sklearn.tree import DecisionTreeClassifier

X = np.array([0, 1, 6, np.nan]).reshape(-1, 1)
y = [0, 0, 1, 1]

tree = DecisionTreeClassifier(random_state=0).fit(X, y)
tree.predict(X)

# %%
# New display :class:`~model_selection.ValidationCurveDisplay`
# ------------------------------------------------------------
# :class:`model_selection.ValidationCurveDisplay` is now available to plot results
# from :func:`model_selection.validation_curve`.
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import ValidationCurveDisplay

X, y = make_classification(1000, 10, random_state=0)

_ = ValidationCurveDisplay.from_estimator(
    LogisticRegression(),
    X,
    y,
    param_name="C",
    param_range=np.geomspace(1e-5, 1e3, num=9),
    score_type="both",
    score_name="Accuracy",
)

# %%
# Gamma loss for gradient boosting
# --------------------------------
# The class :class:`ensemble.HistGradientBoostingRegressor` supports the
# Gamma deviance loss function via `loss="gamma"`. This loss function is useful for
# modeling strictly positive targets with a right-skewed distribution.
import numpy as np

from sklearn.datasets import make_low_rank_matrix
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.model_selection import cross_val_score

n_samples, n_features = 500, 10
rng = np.random.RandomState(0)
X = make_low_rank_matrix(n_samples, n_features, random_state=rng)
coef = rng.uniform(low=-10, high=20, size=n_features)
y = rng.gamma(shape=2, scale=np.exp(X @ coef) / 2)
gbdt = HistGradientBoostingRegressor(loss="gamma")
cross_val_score(gbdt, X, y).mean()

# %%
# Grouping infrequent categories in :class:`~preprocessing.OrdinalEncoder`
# ------------------------------------------------------------------------
# Similarly to :class:`preprocessing.OneHotEncoder`, the class
# :class:`preprocessing.OrdinalEncoder` now supports aggregating infrequent categories
# into a single output for each feature. The parameters to enable the gathering of
# infrequent categories are `min_frequency` and `max_categories`.
# See the :ref:`User Guide <encoder_infrequent_categories>` for more details.
import numpy as np

from sklearn.preprocessing import OrdinalEncoder

X = np.array(
    [["dog"] * 5 + ["cat"] * 20 + ["rabbit"] * 10 + ["snake"] * 3], dtype=object
).T
enc = OrdinalEncoder(min_frequency=6).fit(X)
enc.infrequent_categories_