File: plot_display_object_visualization.py

package info (click to toggle)
scikit-learn 0.23.2-5
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bullseye, sid
  • size: 21,892 kB
  • sloc: python: 132,020; cpp: 5,765; javascript: 2,201; ansic: 831; makefile: 213; sh: 44
file content (92 lines) | stat: -rw-r--r-- 4,046 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
===================================
Visualizations with Display Objects
===================================

.. currentmodule:: sklearn.metrics

In this example, we will construct display objects,
:class:`ConfusionMatrixDisplay`, :class:`RocCurveDisplay`, and
:class:`PrecisionRecallDisplay` directly from their respective metrics. This
is an alternative to using their corresponding plot functions when
a model's predictions are already computed or expensive to compute. Note that
this is advanced usage, and in general we recommend using their respective
plot functions.
"""
print(__doc__)

##############################################################################
# Load Data and train model
# -------------------------
# For this example, we load a blood transfusion service center data set from
# `OpenML <https://www.openml.org/d/1464>`. This is a binary classification
# problem where the target is whether an individual donated blood. Then the
# data is split into a train and test dataset and a logistic regression is
# fitted wtih the train dataset.
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

X, y = fetch_openml(data_id=1464, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)

clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0))
clf.fit(X_train, y_train)

##############################################################################
# Create :class:`ConfusionMatrixDisplay`
##############################################################################
# With the fitted model, we compute the predictions of the model on the test
# dataset. These predictions are used to compute the confustion matrix which
# is plotted with the :class:`ConfusionMatrixDisplay`
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred)

cm_display = ConfusionMatrixDisplay(cm).plot()


##############################################################################
# Create :class:`RocCurveDisplay`
##############################################################################
# The roc curve requires either the probabilities or the non-thresholded
# decision values from the estimator. Since the logistic regression provides
# a decision function, we will use it to plot the roc curve:
from sklearn.metrics import roc_curve
from sklearn.metrics import RocCurveDisplay
y_score = clf.decision_function(X_test)

fpr, tpr, _ = roc_curve(y_test, y_score, pos_label=clf.classes_[1])
roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()

##############################################################################
# Create :class:`PrecisionRecallDisplay`
##############################################################################
# Similarly, the precision recall curve can be plotted using `y_score` from
# the prevision sections.
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import PrecisionRecallDisplay

prec, recall, _ = precision_recall_curve(y_test, y_score,
                                         pos_label=clf.classes_[1])
pr_display = PrecisionRecallDisplay(precision=prec, recall=recall).plot()

##############################################################################
# Combining the display objects into a single plot
##############################################################################
# The display objects store the computed values that were passed as arguments.
# This allows for the visualizations to be easliy combined using matplotlib's
# API. In the following example, we place the displays next to each other in a
# row.

# sphinx_gallery_thumbnail_number = 4
import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))

roc_display.plot(ax=ax1)
pr_display.plot(ax=ax2)
plt.show()