1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
|
from .. import auc
from .. import roc_curve
from .base import _check_classifer_response_method
from ...utils import check_matplotlib_support
from ...base import is_classifier
from ...utils.validation import _deprecate_positional_args
class RocCurveDisplay:
"""ROC Curve visualization.
It is recommend to use :func:`~sklearn.metrics.plot_roc_curve` to create a
visualizer. All parameters are stored as attributes.
Read more in the :ref:`User Guide <visualizations>`.
Parameters
----------
fpr : ndarray
False positive rate.
tpr : ndarray
True positive rate.
roc_auc : float, default=None
Area under ROC curve. If None, the roc_auc score is not shown.
estimator_name : str, default=None
Name of estimator. If None, the estimator name is not shown.
Attributes
----------
line_ : matplotlib Artist
ROC Curve.
ax_ : matplotlib Axes
Axes with ROC Curve.
figure_ : matplotlib Figure
Figure containing the curve.
Examples
--------
>>> import matplotlib.pyplot as plt # doctest: +SKIP
>>> import numpy as np
>>> from sklearn import metrics
>>> y = np.array([0, 0, 1, 1])
>>> pred = np.array([0.1, 0.4, 0.35, 0.8])
>>> fpr, tpr, thresholds = metrics.roc_curve(y, pred)
>>> roc_auc = metrics.auc(fpr, tpr)
>>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,\
estimator_name='example estimator')
>>> display.plot() # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None):
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.estimator_name = estimator_name
@_deprecate_positional_args
def plot(self, ax=None, *, name=None, **kwargs):
"""Plot visualization
Extra keyword arguments will be passed to matplotlib's ``plot``.
Parameters
----------
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
name : str, default=None
Name of ROC Curve for labeling. If `None`, use the name of the
estimator.
Returns
-------
display : :class:`~sklearn.metrics.plot.RocCurveDisplay`
Object that stores computed values.
"""
check_matplotlib_support('RocCurveDisplay.plot')
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots()
name = self.estimator_name if name is None else name
line_kwargs = {}
if self.roc_auc is not None and name is not None:
line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})"
elif self.roc_auc is not None:
line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}"
elif name is not None:
line_kwargs["label"] = name
line_kwargs.update(**kwargs)
self.line_ = ax.plot(self.fpr, self.tpr, **line_kwargs)[0]
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
if "label" in line_kwargs:
ax.legend(loc='lower right')
self.ax_ = ax
self.figure_ = ax.figure
return self
@_deprecate_positional_args
def plot_roc_curve(estimator, X, y, *, sample_weight=None,
drop_intermediate=True, response_method="auto",
name=None, ax=None, **kwargs):
"""Plot Receiver operating characteristic (ROC) curve.
Extra keyword arguments will be passed to matplotlib's `plot`.
Read more in the :ref:`User Guide <visualizations>`.
Parameters
----------
estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
drop_intermediate : boolean, default=True
Whether to drop some suboptimal thresholds which would not appear
on a plotted ROC curve. This is useful in order to create lighter
ROC curves.
response_method : {'predict_proba', 'decision_function', 'auto'} \
default='auto'
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.
name : str, default=None
Name of ROC Curve for labeling. If `None`, use the name of the
estimator.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.
Returns
-------
display : :class:`~sklearn.metrics.RocCurveDisplay`
Object that stores computed values.
Examples
--------
>>> import matplotlib.pyplot as plt # doctest: +SKIP
>>> from sklearn import datasets, metrics, model_selection, svm
>>> X, y = datasets.make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = model_selection.train_test_split(\
X, y, random_state=0)
>>> clf = svm.SVC(random_state=0)
>>> clf.fit(X_train, y_train)
SVC(random_state=0)
>>> metrics.plot_roc_curve(clf, X_test, y_test) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
check_matplotlib_support('plot_roc_curve')
classification_error = (
"{} should be a binary classifier".format(estimator.__class__.__name__)
)
if not is_classifier(estimator):
raise ValueError(classification_error)
prediction_method = _check_classifer_response_method(estimator,
response_method)
y_pred = prediction_method(X)
if y_pred.ndim != 1:
if y_pred.shape[1] != 2:
raise ValueError(classification_error)
else:
y_pred = y_pred[:, 1]
pos_label = estimator.classes_[1]
fpr, tpr, _ = roc_curve(y, y_pred, pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate)
roc_auc = auc(fpr, tpr)
name = estimator.__class__.__name__ if name is None else name
viz = RocCurveDisplay(
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name
)
return viz.plot(ax=ax, name=name, **kwargs)
|