1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515
|
"""
The :mod:`sklearn.model_selection._search` includes utilities to fine-tune the
parameters of an estimator.
"""
from __future__ import print_function
from __future__ import division
# Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>,
# Gael Varoquaux <gael.varoquaux@normalesup.org>
# Andreas Mueller <amueller@ais.uni-bonn.de>
# Olivier Grisel <olivier.grisel@ensta.org>
# Raghav RV <rvraghav93@gmail.com>
# License: BSD 3 clause
from abc import ABCMeta, abstractmethod
from collections import namedtuple, defaultdict
from functools import partial, reduce
from itertools import product
import operator
import time
import warnings
import numpy as np
from scipy.stats import rankdata
from ..base import BaseEstimator, is_classifier, clone
from ..base import MetaEstimatorMixin
from ._split import check_cv
from ._validation import _fit_and_score
from ._validation import _aggregate_score_dicts
from ..exceptions import NotFittedError
from ..utils._joblib import Parallel, delayed
from ..externals import six
from ..utils import check_random_state
from ..utils.fixes import sp_version
from ..utils.fixes import MaskedArray
from ..utils.fixes import _Mapping as Mapping, _Sequence as Sequence
from ..utils.fixes import _Iterable as Iterable
from ..utils.random import sample_without_replacement
from ..utils.validation import indexable, check_is_fitted
from ..utils.metaestimators import if_delegate_has_method
from ..utils.deprecation import DeprecationDict
from ..metrics.scorer import _check_multimetric_scoring
from ..metrics.scorer import check_scoring
__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_grid_point',
'ParameterSampler', 'RandomizedSearchCV']
class ParameterGrid(object):
"""Grid of parameters with a discrete number of values for each.
Can be used to iterate over parameter value combinations with the
Python built-in function iter.
Read more in the :ref:`User Guide <search>`.
Parameters
----------
param_grid : dict of string to sequence, or sequence of such
The parameter grid to explore, as a dictionary mapping estimator
parameters to sequences of allowed values.
An empty dict signifies default parameters.
A sequence of dicts signifies a sequence of grids to search, and is
useful to avoid exploring parameter combinations that make no sense
or have no effect. See the examples below.
Examples
--------
>>> from sklearn.model_selection import ParameterGrid
>>> param_grid = {'a': [1, 2], 'b': [True, False]}
>>> list(ParameterGrid(param_grid)) == (
... [{'a': 1, 'b': True}, {'a': 1, 'b': False},
... {'a': 2, 'b': True}, {'a': 2, 'b': False}])
True
>>> grid = [{'kernel': ['linear']}, {'kernel': ['rbf'], 'gamma': [1, 10]}]
>>> list(ParameterGrid(grid)) == [{'kernel': 'linear'},
... {'kernel': 'rbf', 'gamma': 1},
... {'kernel': 'rbf', 'gamma': 10}]
True
>>> ParameterGrid(grid)[1] == {'kernel': 'rbf', 'gamma': 1}
True
See also
--------
:class:`GridSearchCV`:
Uses :class:`ParameterGrid` to perform a full parallelized parameter
search.
"""
def __init__(self, param_grid):
if not isinstance(param_grid, (Mapping, Iterable)):
raise TypeError('Parameter grid is not a dict or '
'a list ({!r})'.format(param_grid))
if isinstance(param_grid, Mapping):
# wrap dictionary in a singleton list to support either dict
# or list of dicts
param_grid = [param_grid]
# check if all entries are dictionaries of lists
for grid in param_grid:
if not isinstance(grid, dict):
raise TypeError('Parameter grid is not a '
'dict ({!r})'.format(grid))
for key in grid:
if not isinstance(grid[key], Iterable):
raise TypeError('Parameter grid value is not iterable '
'(key={!r}, value={!r})'
.format(key, grid[key]))
self.param_grid = param_grid
def __iter__(self):
"""Iterate over the points in the grid.
Returns
-------
params : iterator over dict of string to any
Yields dictionaries mapping each estimator parameter to one of its
allowed values.
"""
for p in self.param_grid:
# Always sort the keys of a dictionary, for reproducibility
items = sorted(p.items())
if not items:
yield {}
else:
keys, values = zip(*items)
for v in product(*values):
params = dict(zip(keys, v))
yield params
def __len__(self):
"""Number of points on the grid."""
# Product function that can handle iterables (np.product can't).
product = partial(reduce, operator.mul)
return sum(product(len(v) for v in p.values()) if p else 1
for p in self.param_grid)
def __getitem__(self, ind):
"""Get the parameters that would be ``ind``th in iteration
Parameters
----------
ind : int
The iteration index
Returns
-------
params : dict of string to any
Equal to list(self)[ind]
"""
# This is used to make discrete sampling without replacement memory
# efficient.
for sub_grid in self.param_grid:
# XXX: could memoize information used here
if not sub_grid:
if ind == 0:
return {}
else:
ind -= 1
continue
# Reverse so most frequent cycling parameter comes first
keys, values_lists = zip(*sorted(sub_grid.items())[::-1])
sizes = [len(v_list) for v_list in values_lists]
total = np.product(sizes)
if ind >= total:
# Try the next grid
ind -= total
else:
out = {}
for key, v_list, n in zip(keys, values_lists, sizes):
ind, offset = divmod(ind, n)
out[key] = v_list[offset]
return out
raise IndexError('ParameterGrid index out of range')
class ParameterSampler(object):
"""Generator on parameters sampled from given distributions.
Non-deterministic iterable over random candidate combinations for hyper-
parameter search. If all parameters are presented as a list,
sampling without replacement is performed. If at least one parameter
is given as a distribution, sampling with replacement is used.
It is highly recommended to use continuous distributions for continuous
parameters.
Note that before SciPy 0.16, the ``scipy.stats.distributions`` do not
accept a custom RNG instance and always use the singleton RNG from
``numpy.random``. Hence setting ``random_state`` will not guarantee a
deterministic iteration whenever ``scipy.stats`` distributions are used to
define the parameter search space. Deterministic behavior is however
guaranteed from SciPy 0.16 onwards.
Read more in the :ref:`User Guide <search>`.
Parameters
----------
param_distributions : dict
Dictionary where the keys are parameters and values
are distributions from which a parameter is to be sampled.
Distributions either have to provide a ``rvs`` function
to sample from them, or can be given as a list of values,
where a uniform distribution is assumed.
n_iter : integer
Number of parameter settings that are produced.
random_state : int, RandomState instance or None, optional (default=None)
Pseudo random number generator state used for random uniform sampling
from lists of possible values instead of scipy.stats distributions.
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
Returns
-------
params : dict of string to any
**Yields** dictionaries mapping each estimator parameter to
as sampled value.
Examples
--------
>>> from sklearn.model_selection import ParameterSampler
>>> from scipy.stats.distributions import expon
>>> import numpy as np
>>> np.random.seed(0)
>>> param_grid = {'a':[1, 2], 'b': expon()}
>>> param_list = list(ParameterSampler(param_grid, n_iter=4))
>>> rounded_list = [dict((k, round(v, 6)) for (k, v) in d.items())
... for d in param_list]
>>> rounded_list == [{'b': 0.89856, 'a': 1},
... {'b': 0.923223, 'a': 1},
... {'b': 1.878964, 'a': 2},
... {'b': 1.038159, 'a': 2}]
True
"""
def __init__(self, param_distributions, n_iter, random_state=None):
self.param_distributions = param_distributions
self.n_iter = n_iter
self.random_state = random_state
def __iter__(self):
# check if all distributions are given as lists
# in this case we want to sample without replacement
all_lists = np.all([not hasattr(v, "rvs")
for v in self.param_distributions.values()])
rnd = check_random_state(self.random_state)
if all_lists:
# look up sampled parameter settings in parameter grid
param_grid = ParameterGrid(self.param_distributions)
grid_size = len(param_grid)
n_iter = self.n_iter
if grid_size < n_iter:
warnings.warn(
'The total space of parameters %d is smaller '
'than n_iter=%d. Running %d iterations. For exhaustive '
'searches, use GridSearchCV.'
% (grid_size, self.n_iter, grid_size), UserWarning)
n_iter = grid_size
for i in sample_without_replacement(grid_size, n_iter,
random_state=rnd):
yield param_grid[i]
else:
# Always sort the keys of a dictionary, for reproducibility
items = sorted(self.param_distributions.items())
for _ in six.moves.range(self.n_iter):
params = dict()
for k, v in items:
if hasattr(v, "rvs"):
if sp_version < (0, 16):
params[k] = v.rvs()
else:
params[k] = v.rvs(random_state=rnd)
else:
params[k] = v[rnd.randint(len(v))]
yield params
def __len__(self):
"""Number of points that will be sampled."""
return self.n_iter
def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
verbose, error_score='raise-deprecating', **fit_params):
"""Run fit on one set of parameters.
Parameters
----------
X : array-like, sparse matrix or list
Input data.
y : array-like or None
Targets for input data.
estimator : estimator object
A object of that type is instantiated for each grid point.
This is assumed to implement the scikit-learn estimator interface.
Either estimator needs to provide a ``score`` function,
or ``scoring`` must be passed.
parameters : dict
Parameters to be set on estimator for this grid point.
train : ndarray, dtype int or bool
Boolean mask or indices for training set.
test : ndarray, dtype int or bool
Boolean mask or indices for test set.
scorer : callable or None
The scorer callable object / function must have its signature as
``scorer(estimator, X, y)``.
If ``None`` the estimator's default scorer is used.
verbose : int
Verbosity level.
**fit_params : kwargs
Additional parameter passed to the fit function of the estimator.
error_score : 'raise' or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised. If a numeric value is given,
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error. Default is 'raise' but from
version 0.22 it will change to np.nan.
Returns
-------
score : float
Score of this parameter setting on given training / test split.
parameters : dict
The parameters that have been evaluated.
n_samples_test : int
Number of test samples in this split.
"""
# NOTE we are not using the return value as the scorer by itself should be
# validated before. We use check_scoring only to reject multimetric scorer
check_scoring(estimator, scorer)
scores, n_samples_test = _fit_and_score(estimator, X, y,
scorer, train,
test, verbose, parameters,
fit_params=fit_params,
return_n_test_samples=True,
error_score=error_score)
return scores, parameters, n_samples_test
def _check_param_grid(param_grid):
if hasattr(param_grid, 'items'):
param_grid = [param_grid]
for p in param_grid:
for name, v in p.items():
if isinstance(v, np.ndarray) and v.ndim > 1:
raise ValueError("Parameter array should be one-dimensional.")
if (isinstance(v, six.string_types) or
not isinstance(v, (np.ndarray, Sequence))):
raise ValueError("Parameter values for parameter ({0}) need "
"to be a sequence(but not a string) or"
" np.ndarray.".format(name))
if len(v) == 0:
raise ValueError("Parameter values for parameter ({0}) need "
"to be a non-empty sequence.".format(name))
# XXX Remove in 0.20
class _CVScoreTuple (namedtuple('_CVScoreTuple',
('parameters',
'mean_validation_score',
'cv_validation_scores'))):
# A raw namedtuple is very memory efficient as it packs the attributes
# in a struct to get rid of the __dict__ of attributes in particular it
# does not copy the string for the keys on each instance.
# By deriving a namedtuple class just to introduce the __repr__ method we
# would also reintroduce the __dict__ on the instance. By telling the
# Python interpreter that this subclass uses static __slots__ instead of
# dynamic attributes. Furthermore we don't need any additional slot in the
# subclass so we set __slots__ to the empty tuple.
__slots__ = ()
def __repr__(self):
"""Simple custom repr to summarize the main info"""
return "mean: {0:.5f}, std: {1:.5f}, params: {2}".format(
self.mean_validation_score,
np.std(self.cv_validation_scores),
self.parameters)
class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
MetaEstimatorMixin)):
"""Abstract base class for hyper parameter search with cross-validation.
"""
@abstractmethod
def __init__(self, estimator, scoring=None,
fit_params=None, n_jobs=None, iid='warn',
refit=True, cv='warn', verbose=0, pre_dispatch='2*n_jobs',
error_score='raise-deprecating', return_train_score=True):
self.scoring = scoring
self.estimator = estimator
self.n_jobs = n_jobs
self.fit_params = fit_params
self.iid = iid
self.refit = refit
self.cv = cv
self.verbose = verbose
self.pre_dispatch = pre_dispatch
self.error_score = error_score
self.return_train_score = return_train_score
@property
def _estimator_type(self):
return self.estimator._estimator_type
def score(self, X, y=None):
"""Returns the score on the given data, if the estimator has been refit.
This uses the score defined by ``scoring`` where provided, and the
``best_estimator_.score`` method otherwise.
Parameters
----------
X : array-like, shape = [n_samples, n_features]
Input data, where n_samples is the number of samples and
n_features is the number of features.
y : array-like, shape = [n_samples] or [n_samples, n_output], optional
Target relative to X for classification or regression;
None for unsupervised learning.
Returns
-------
score : float
"""
self._check_is_fitted('score')
if self.scorer_ is None:
raise ValueError("No score function explicitly defined, "
"and the estimator doesn't provide one %s"
% self.best_estimator_)
score = self.scorer_[self.refit] if self.multimetric_ else self.scorer_
return score(self.best_estimator_, X, y)
def _check_is_fitted(self, method_name):
if not self.refit:
raise NotFittedError('This %s instance was initialized '
'with refit=False. %s is '
'available only after refitting on the best '
'parameters. You can refit an estimator '
'manually using the ``best_params_`` '
'attribute'
% (type(self).__name__, method_name))
else:
check_is_fitted(self, 'best_estimator_')
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict(self, X):
"""Call predict on the estimator with the best found parameters.
Only available if ``refit=True`` and the underlying estimator supports
``predict``.
Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.
"""
self._check_is_fitted('predict')
return self.best_estimator_.predict(X)
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict_proba(self, X):
"""Call predict_proba on the estimator with the best found parameters.
Only available if ``refit=True`` and the underlying estimator supports
``predict_proba``.
Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.
"""
self._check_is_fitted('predict_proba')
return self.best_estimator_.predict_proba(X)
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def predict_log_proba(self, X):
"""Call predict_log_proba on the estimator with the best found parameters.
Only available if ``refit=True`` and the underlying estimator supports
``predict_log_proba``.
Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.
"""
self._check_is_fitted('predict_log_proba')
return self.best_estimator_.predict_log_proba(X)
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def decision_function(self, X):
"""Call decision_function on the estimator with the best found parameters.
Only available if ``refit=True`` and the underlying estimator supports
``decision_function``.
Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.
"""
self._check_is_fitted('decision_function')
return self.best_estimator_.decision_function(X)
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def transform(self, X):
"""Call transform on the estimator with the best found parameters.
Only available if the underlying estimator supports ``transform`` and
``refit=True``.
Parameters
-----------
X : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.
"""
self._check_is_fitted('transform')
return self.best_estimator_.transform(X)
@if_delegate_has_method(delegate=('best_estimator_', 'estimator'))
def inverse_transform(self, Xt):
"""Call inverse_transform on the estimator with the best found params.
Only available if the underlying estimator implements
``inverse_transform`` and ``refit=True``.
Parameters
-----------
Xt : indexable, length n_samples
Must fulfill the input assumptions of the
underlying estimator.
"""
self._check_is_fitted('inverse_transform')
return self.best_estimator_.inverse_transform(Xt)
@property
def classes_(self):
self._check_is_fitted("classes_")
return self.best_estimator_.classes_
def _run_search(self, evaluate_candidates):
"""Repeatedly calls `evaluate_candidates` to conduct a search.
This method, implemented in sub-classes, makes it possible to
customize the the scheduling of evaluations: GridSearchCV and
RandomizedSearchCV schedule evaluations for their whole parameter
search space at once but other more sequential approaches are also
possible: for instance is possible to iteratively schedule evaluations
for new regions of the parameter search space based on previously
collected evaluation results. This makes it possible to implement
Bayesian optimization or more generally sequential model-based
optimization by deriving from the BaseSearchCV abstract base class.
Parameters
----------
evaluate_candidates : callable
This callback accepts a list of candidates, where each candidate is
a dict of parameter settings. It returns a dict of all results so
far, formatted like ``cv_results_``.
Examples
--------
::
def _run_search(self, evaluate_candidates):
'Try C=0.1 only if C=1 is better than C=10'
all_results = evaluate_candidates([{'C': 1}, {'C': 10}])
score = all_results['mean_test_score']
if score[0] < score[1]:
evaluate_candidates([{'C': 0.1}])
"""
raise NotImplementedError("_run_search not implemented.")
def fit(self, X, y=None, groups=None, **fit_params):
"""Run fit with all sets of parameters.
Parameters
----------
X : array-like, shape = [n_samples, n_features]
Training vector, where n_samples is the number of samples and
n_features is the number of features.
y : array-like, shape = [n_samples] or [n_samples, n_output], optional
Target relative to X for classification or regression;
None for unsupervised learning.
groups : array-like, with shape (n_samples,), optional
Group labels for the samples used while splitting the dataset into
train/test set.
**fit_params : dict of string -> object
Parameters passed to the ``fit`` method of the estimator
"""
if self.fit_params is not None:
warnings.warn('"fit_params" as a constructor argument was '
'deprecated in version 0.19 and will be removed '
'in version 0.21. Pass fit parameters to the '
'"fit" method instead.', DeprecationWarning)
if fit_params:
warnings.warn('Ignoring fit_params passed as a constructor '
'argument in favor of keyword arguments to '
'the "fit" method.', RuntimeWarning)
else:
fit_params = self.fit_params
estimator = self.estimator
cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
scorers, self.multimetric_ = _check_multimetric_scoring(
self.estimator, scoring=self.scoring)
if self.multimetric_:
if self.refit is not False and (
not isinstance(self.refit, six.string_types) or
# This will work for both dict / list (tuple)
self.refit not in scorers):
raise ValueError("For multi-metric scoring, the parameter "
"refit must be set to a scorer key "
"to refit an estimator with the best "
"parameter setting on the whole data and "
"make the best_* attributes "
"available for that metric. If this is not "
"needed, refit should be set to False "
"explicitly. %r was passed." % self.refit)
else:
refit_metric = self.refit
else:
refit_metric = 'score'
X, y, groups = indexable(X, y, groups)
n_splits = cv.get_n_splits(X, y, groups)
base_estimator = clone(self.estimator)
parallel = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
pre_dispatch=self.pre_dispatch)
fit_and_score_kwargs = dict(scorer=scorers,
fit_params=fit_params,
return_train_score=self.return_train_score,
return_n_test_samples=True,
return_times=True,
return_parameters=False,
error_score=self.error_score,
verbose=self.verbose)
results_container = [{}]
with parallel:
all_candidate_params = []
all_out = []
def evaluate_candidates(candidate_params):
candidate_params = list(candidate_params)
n_candidates = len(candidate_params)
if self.verbose > 0:
print("Fitting {0} folds for each of {1} candidates,"
" totalling {2} fits".format(
n_splits, n_candidates, n_candidates * n_splits))
out = parallel(delayed(_fit_and_score)(clone(base_estimator),
X, y,
train=train, test=test,
parameters=parameters,
**fit_and_score_kwargs)
for parameters, (train, test)
in product(candidate_params,
cv.split(X, y, groups)))
all_candidate_params.extend(candidate_params)
all_out.extend(out)
# XXX: When we drop Python 2 support, we can use nonlocal
# instead of results_container
results_container[0] = self._format_results(
all_candidate_params, scorers, n_splits, all_out)
return results_container[0]
self._run_search(evaluate_candidates)
results = results_container[0]
# For multi-metric evaluation, store the best_index_, best_params_ and
# best_score_ iff refit is one of the scorer names
# In single metric evaluation, refit_metric is "score"
if self.refit or not self.multimetric_:
self.best_index_ = results["rank_test_%s" % refit_metric].argmin()
self.best_params_ = results["params"][self.best_index_]
self.best_score_ = results["mean_test_%s" % refit_metric][
self.best_index_]
if self.refit:
self.best_estimator_ = clone(base_estimator).set_params(
**self.best_params_)
refit_start_time = time.time()
if y is not None:
self.best_estimator_.fit(X, y, **fit_params)
else:
self.best_estimator_.fit(X, **fit_params)
refit_end_time = time.time()
self.refit_time_ = refit_end_time - refit_start_time
# Store the only scorer not as a dict for single metric evaluation
self.scorer_ = scorers if self.multimetric_ else scorers['score']
self.cv_results_ = results
self.n_splits_ = n_splits
return self
def _format_results(self, candidate_params, scorers, n_splits, out):
n_candidates = len(candidate_params)
# if one choose to see train score, "out" will contain train score info
if self.return_train_score:
(train_score_dicts, test_score_dicts, test_sample_counts, fit_time,
score_time) = zip(*out)
else:
(test_score_dicts, test_sample_counts, fit_time,
score_time) = zip(*out)
# test_score_dicts and train_score dicts are lists of dictionaries and
# we make them into dict of lists
test_scores = _aggregate_score_dicts(test_score_dicts)
if self.return_train_score:
train_scores = _aggregate_score_dicts(train_score_dicts)
# TODO: replace by a dict in 0.21
results = (DeprecationDict() if self.return_train_score == 'warn'
else {})
def _store(key_name, array, weights=None, splits=False, rank=False):
"""A small helper to store the scores/times to the cv_results_"""
# When iterated first by splits, then by parameters
# We want `array` to have `n_candidates` rows and `n_splits` cols.
array = np.array(array, dtype=np.float64).reshape(n_candidates,
n_splits)
if splits:
for split_i in range(n_splits):
# Uses closure to alter the results
results["split%d_%s"
% (split_i, key_name)] = array[:, split_i]
array_means = np.average(array, axis=1, weights=weights)
results['mean_%s' % key_name] = array_means
# Weighted std is not directly available in numpy
array_stds = np.sqrt(np.average((array -
array_means[:, np.newaxis]) ** 2,
axis=1, weights=weights))
results['std_%s' % key_name] = array_stds
if rank:
results["rank_%s" % key_name] = np.asarray(
rankdata(-array_means, method='min'), dtype=np.int32)
_store('fit_time', fit_time)
_store('score_time', score_time)
# Use one MaskedArray and mask all the places where the param is not
# applicable for that candidate. Use defaultdict as each candidate may
# not contain all the params
param_results = defaultdict(partial(MaskedArray,
np.empty(n_candidates,),
mask=True,
dtype=object))
for cand_i, params in enumerate(candidate_params):
for name, value in params.items():
# An all masked empty array gets created for the key
# `"param_%s" % name` at the first occurrence of `name`.
# Setting the value at an index also unmasks that index
param_results["param_%s" % name][cand_i] = value
results.update(param_results)
# Store a list of param dicts at the key 'params'
results['params'] = candidate_params
# NOTE test_sample counts (weights) remain the same for all candidates
test_sample_counts = np.array(test_sample_counts[:n_splits],
dtype=np.int)
iid = self.iid
if self.iid == 'warn':
warn = False
for scorer_name in scorers.keys():
scores = test_scores[scorer_name].reshape(n_candidates,
n_splits)
means_weighted = np.average(scores, axis=1,
weights=test_sample_counts)
means_unweighted = np.average(scores, axis=1)
if not np.allclose(means_weighted, means_unweighted,
rtol=1e-4, atol=1e-4):
warn = True
break
if warn:
warnings.warn("The default of the `iid` parameter will change "
"from True to False in version 0.22 and will be"
" removed in 0.24. This will change numeric"
" results when test-set sizes are unequal.",
DeprecationWarning)
iid = True
for scorer_name in scorers.keys():
# Computed the (weighted) mean and std for test scores alone
_store('test_%s' % scorer_name, test_scores[scorer_name],
splits=True, rank=True,
weights=test_sample_counts if iid else None)
if self.return_train_score:
prev_keys = set(results.keys())
_store('train_%s' % scorer_name, train_scores[scorer_name],
splits=True)
if self.return_train_score == 'warn':
for key in set(results.keys()) - prev_keys:
message = (
'You are accessing a training score ({!r}), '
'which will not be available by default '
'any more in 0.21. If you need training scores, '
'please set return_train_score=True').format(key)
# warn on key access
results.add_warning(key, message, FutureWarning)
return results
class GridSearchCV(BaseSearchCV):
"""Exhaustive search over specified parameter values for an estimator.
Important members are fit, predict.
GridSearchCV implements a "fit" and a "score" method.
It also implements "predict", "predict_proba", "decision_function",
"transform" and "inverse_transform" if they are implemented in the
estimator used.
The parameters of the estimator used to apply these methods are optimized
by cross-validated grid-search over a parameter grid.
Read more in the :ref:`User Guide <grid_search>`.
Parameters
----------
estimator : estimator object.
This is assumed to implement the scikit-learn estimator interface.
Either estimator needs to provide a ``score`` function,
or ``scoring`` must be passed.
param_grid : dict or list of dictionaries
Dictionary with parameters names (string) as keys and lists of
parameter settings to try as values, or a list of such
dictionaries, in which case the grids spanned by each dictionary
in the list are explored. This enables searching over any sequence
of parameter settings.
scoring : string, callable, list/tuple, dict or None, default: None
A single string (see :ref:`scoring_parameter`) or a callable
(see :ref:`scoring`) to evaluate the predictions on the test set.
For evaluating multiple metrics, either give a list of (unique) strings
or a dict with names as keys and callables as values.
NOTE that when using custom scorers, each scorer should return a single
value. Metric functions returning a list/array of values can be wrapped
into multiple scorers that return one value each.
See :ref:`multimetric_grid_search` for an example.
If None, the estimator's default scorer (if available) is used.
fit_params : dict, optional
Parameters to pass to the fit method.
.. deprecated:: 0.19
``fit_params`` as a constructor argument was deprecated in version
0.19 and will be removed in version 0.21. Pass fit parameters to
the ``fit`` method instead.
n_jobs : int or None, optional (default=None)
Number of jobs to run in parallel.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
pre_dispatch : int, or string, optional
Controls the number of jobs that get dispatched during parallel
execution. Reducing this number can be useful to avoid an
explosion of memory consumption when more jobs get dispatched
than CPUs can process. This parameter can be:
- None, in which case all the jobs are immediately
created and spawned. Use this for lightweight and
fast-running jobs, to avoid delays due to on-demand
spawning of the jobs
- An int, giving the exact number of total jobs that are
spawned
- A string, giving an expression as a function of n_jobs,
as in '2*n_jobs'
iid : boolean, default='warn'
If True, return the average score across folds, weighted by the number
of samples in each test set. In this case, the data is assumed to be
identically distributed across the folds, and the loss minimized is
the total loss per sample, and not the mean loss across the folds. If
False, return the average score across folds. Default is True, but
will change to False in version 0.21, to correspond to the standard
definition of cross-validation.
.. versionchanged:: 0.20
Parameter ``iid`` will change from True to False by default in
version 0.22, and will be removed in 0.24.
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 3-fold cross validation,
- integer, to specify the number of folds in a `(Stratified)KFold`,
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the estimator is a classifier and ``y`` is
either binary or multiclass, :class:`StratifiedKFold` is used. In all
other cases, :class:`KFold` is used.
Refer :ref:`User Guide <cross_validation>` for the various
cross-validation strategies that can be used here.
.. versionchanged:: 0.20
``cv`` default value if None will change from 3-fold to 5-fold
in v0.22.
refit : boolean, or string, default=True
Refit an estimator using the best found parameters on the whole
dataset.
For multiple metric evaluation, this needs to be a string denoting the
scorer is used to find the best parameters for refitting the estimator
at the end.
The refitted estimator is made available at the ``best_estimator_``
attribute and permits using ``predict`` directly on this
``GridSearchCV`` instance.
Also for multiple metric evaluation, the attributes ``best_index_``,
``best_score_`` and ``best_params_`` will only be available if
``refit`` is set and all of them will be determined w.r.t this specific
scorer.
See ``scoring`` parameter to know more about multiple metric
evaluation.
verbose : integer
Controls the verbosity: the higher, the more messages.
error_score : 'raise' or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised. If a numeric value is given,
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error. Default is 'raise' but from
version 0.22 it will change to np.nan.
return_train_score : boolean, optional
If ``False``, the ``cv_results_`` attribute will not include training
scores.
Current default is ``'warn'``, which behaves as ``True`` in addition
to raising a warning when a training score is looked up.
That default will be changed to ``False`` in 0.21.
Computing training scores is used to get insights on how different
parameter settings impact the overfitting/underfitting trade-off.
However computing the scores on the training set can be computationally
expensive and is not strictly required to select the parameters that
yield the best generalization performance.
Examples
--------
>>> from sklearn import svm, datasets
>>> from sklearn.model_selection import GridSearchCV
>>> iris = datasets.load_iris()
>>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
>>> svc = svm.SVC(gamma="scale")
>>> clf = GridSearchCV(svc, parameters, cv=5)
>>> clf.fit(iris.data, iris.target)
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
GridSearchCV(cv=5, error_score=...,
estimator=SVC(C=1.0, cache_size=..., class_weight=..., coef0=...,
decision_function_shape='ovr', degree=..., gamma=...,
kernel='rbf', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=...,
verbose=False),
fit_params=None, iid=..., n_jobs=None,
param_grid=..., pre_dispatch=..., refit=..., return_train_score=...,
scoring=..., verbose=...)
>>> sorted(clf.cv_results_.keys())
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
['mean_fit_time', 'mean_score_time', 'mean_test_score',...
'mean_train_score', 'param_C', 'param_kernel', 'params',...
'rank_test_score', 'split0_test_score',...
'split0_train_score', 'split1_test_score', 'split1_train_score',...
'split2_test_score', 'split2_train_score',...
'std_fit_time', 'std_score_time', 'std_test_score', 'std_train_score'...]
Attributes
----------
cv_results_ : dict of numpy (masked) ndarrays
A dict with keys as column headers and values as columns, that can be
imported into a pandas ``DataFrame``.
For instance the below given table
+------------+-----------+------------+-----------------+---+---------+
|param_kernel|param_gamma|param_degree|split0_test_score|...|rank_t...|
+============+===========+============+=================+===+=========+
| 'poly' | -- | 2 | 0.80 |...| 2 |
+------------+-----------+------------+-----------------+---+---------+
| 'poly' | -- | 3 | 0.70 |...| 4 |
+------------+-----------+------------+-----------------+---+---------+
| 'rbf' | 0.1 | -- | 0.80 |...| 3 |
+------------+-----------+------------+-----------------+---+---------+
| 'rbf' | 0.2 | -- | 0.93 |...| 1 |
+------------+-----------+------------+-----------------+---+---------+
will be represented by a ``cv_results_`` dict of::
{
'param_kernel': masked_array(data = ['poly', 'poly', 'rbf', 'rbf'],
mask = [False False False False]...)
'param_gamma': masked_array(data = [-- -- 0.1 0.2],
mask = [ True True False False]...),
'param_degree': masked_array(data = [2.0 3.0 -- --],
mask = [False False True True]...),
'split0_test_score' : [0.80, 0.70, 0.80, 0.93],
'split1_test_score' : [0.82, 0.50, 0.70, 0.78],
'mean_test_score' : [0.81, 0.60, 0.75, 0.85],
'std_test_score' : [0.01, 0.10, 0.05, 0.08],
'rank_test_score' : [2, 4, 3, 1],
'split0_train_score' : [0.80, 0.92, 0.70, 0.93],
'split1_train_score' : [0.82, 0.55, 0.70, 0.87],
'mean_train_score' : [0.81, 0.74, 0.70, 0.90],
'std_train_score' : [0.01, 0.19, 0.00, 0.03],
'mean_fit_time' : [0.73, 0.63, 0.43, 0.49],
'std_fit_time' : [0.01, 0.02, 0.01, 0.01],
'mean_score_time' : [0.01, 0.06, 0.04, 0.04],
'std_score_time' : [0.00, 0.00, 0.00, 0.01],
'params' : [{'kernel': 'poly', 'degree': 2}, ...],
}
NOTE
The key ``'params'`` is used to store a list of parameter
settings dicts for all the parameter candidates.
The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
``std_score_time`` are all in seconds.
For multi-metric evaluation, the scores for all the scorers are
available in the ``cv_results_`` dict at the keys ending with that
scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown
above. ('split0_test_precision', 'mean_train_precision' etc.)
best_estimator_ : estimator or dict
Estimator that was chosen by the search, i.e. estimator
which gave highest score (or smallest loss if specified)
on the left out data. Not available if ``refit=False``.
See ``refit`` parameter for more information on allowed values.
best_score_ : float
Mean cross-validated score of the best_estimator
For multi-metric evaluation, this is present only if ``refit`` is
specified.
best_params_ : dict
Parameter setting that gave the best results on the hold out data.
For multi-metric evaluation, this is present only if ``refit`` is
specified.
best_index_ : int
The index (of the ``cv_results_`` arrays) which corresponds to the best
candidate parameter setting.
The dict at ``search.cv_results_['params'][search.best_index_]`` gives
the parameter setting for the best model, that gives the highest
mean score (``search.best_score_``).
For multi-metric evaluation, this is present only if ``refit`` is
specified.
scorer_ : function or a dict
Scorer function used on the held out data to choose the best
parameters for the model.
For multi-metric evaluation, this attribute holds the validated
``scoring`` dict which maps the scorer key to the scorer callable.
n_splits_ : int
The number of cross-validation splits (folds/iterations).
refit_time_ : float
Seconds used for refitting the best model on the whole dataset.
This is present only if ``refit`` is not False.
Notes
------
The parameters selected are those that maximize the score of the left out
data, unless an explicit score is passed in which case it is used instead.
If `n_jobs` was set to a value higher than one, the data is copied for each
point in the grid (and not `n_jobs` times). This is done for efficiency
reasons if individual jobs take very little time, but may raise errors if
the dataset is large and not enough memory is available. A workaround in
this case is to set `pre_dispatch`. Then, the memory is copied only
`pre_dispatch` many times. A reasonable value for `pre_dispatch` is `2 *
n_jobs`.
See Also
---------
:class:`ParameterGrid`:
generates all the combinations of a hyperparameter grid.
:func:`sklearn.model_selection.train_test_split`:
utility function to split the data into a development set usable
for fitting a GridSearchCV instance and an evaluation set for
its final evaluation.
:func:`sklearn.metrics.make_scorer`:
Make a scorer from a performance metric or loss function.
"""
def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
n_jobs=None, iid='warn', refit=True, cv='warn', verbose=0,
pre_dispatch='2*n_jobs', error_score='raise-deprecating',
return_train_score="warn"):
super(GridSearchCV, self).__init__(
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score,
return_train_score=return_train_score)
self.param_grid = param_grid
_check_param_grid(param_grid)
def _run_search(self, evaluate_candidates):
"""Search all candidates in param_grid"""
evaluate_candidates(ParameterGrid(self.param_grid))
class RandomizedSearchCV(BaseSearchCV):
"""Randomized search on hyper parameters.
RandomizedSearchCV implements a "fit" and a "score" method.
It also implements "predict", "predict_proba", "decision_function",
"transform" and "inverse_transform" if they are implemented in the
estimator used.
The parameters of the estimator used to apply these methods are optimized
by cross-validated search over parameter settings.
In contrast to GridSearchCV, not all parameter values are tried out, but
rather a fixed number of parameter settings is sampled from the specified
distributions. The number of parameter settings that are tried is
given by n_iter.
If all parameters are presented as a list,
sampling without replacement is performed. If at least one parameter
is given as a distribution, sampling with replacement is used.
It is highly recommended to use continuous distributions for continuous
parameters.
Note that before SciPy 0.16, the ``scipy.stats.distributions`` do not
accept a custom RNG instance and always use the singleton RNG from
``numpy.random``. Hence setting ``random_state`` will not guarantee a
deterministic iteration whenever ``scipy.stats`` distributions are used to
define the parameter search space.
Read more in the :ref:`User Guide <randomized_parameter_search>`.
Parameters
----------
estimator : estimator object.
A object of that type is instantiated for each grid point.
This is assumed to implement the scikit-learn estimator interface.
Either estimator needs to provide a ``score`` function,
or ``scoring`` must be passed.
param_distributions : dict
Dictionary with parameters names (string) as keys and distributions
or lists of parameters to try. Distributions must provide a ``rvs``
method for sampling (such as those from scipy.stats.distributions).
If a list is given, it is sampled uniformly.
n_iter : int, default=10
Number of parameter settings that are sampled. n_iter trades
off runtime vs quality of the solution.
scoring : string, callable, list/tuple, dict or None, default: None
A single string (see :ref:`scoring_parameter`) or a callable
(see :ref:`scoring`) to evaluate the predictions on the test set.
For evaluating multiple metrics, either give a list of (unique) strings
or a dict with names as keys and callables as values.
NOTE that when using custom scorers, each scorer should return a single
value. Metric functions returning a list/array of values can be wrapped
into multiple scorers that return one value each.
See :ref:`multimetric_grid_search` for an example.
If None, the estimator's default scorer (if available) is used.
fit_params : dict, optional
Parameters to pass to the fit method.
.. deprecated:: 0.19
``fit_params`` as a constructor argument was deprecated in version
0.19 and will be removed in version 0.21. Pass fit parameters to
the ``fit`` method instead.
n_jobs : int or None, optional (default=None)
Number of jobs to run in parallel.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
pre_dispatch : int, or string, optional
Controls the number of jobs that get dispatched during parallel
execution. Reducing this number can be useful to avoid an
explosion of memory consumption when more jobs get dispatched
than CPUs can process. This parameter can be:
- None, in which case all the jobs are immediately
created and spawned. Use this for lightweight and
fast-running jobs, to avoid delays due to on-demand
spawning of the jobs
- An int, giving the exact number of total jobs that are
spawned
- A string, giving an expression as a function of n_jobs,
as in '2*n_jobs'
iid : boolean, default='warn'
If True, return the average score across folds, weighted by the number
of samples in each test set. In this case, the data is assumed to be
identically distributed across the folds, and the loss minimized is
the total loss per sample, and not the mean loss across the folds. If
False, return the average score across folds. Default is True, but
will change to False in version 0.21, to correspond to the standard
definition of cross-validation.
.. versionchanged:: 0.20
Parameter ``iid`` will change from True to False by default in
version 0.22, and will be removed in 0.24.
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 3-fold cross validation,
- integer, to specify the number of folds in a `(Stratified)KFold`,
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the estimator is a classifier and ``y`` is
either binary or multiclass, :class:`StratifiedKFold` is used. In all
other cases, :class:`KFold` is used.
Refer :ref:`User Guide <cross_validation>` for the various
cross-validation strategies that can be used here.
.. versionchanged:: 0.20
``cv`` default value if None will change from 3-fold to 5-fold
in v0.22.
refit : boolean, or string default=True
Refit an estimator using the best found parameters on the whole
dataset.
For multiple metric evaluation, this needs to be a string denoting the
scorer that would be used to find the best parameters for refitting
the estimator at the end.
The refitted estimator is made available at the ``best_estimator_``
attribute and permits using ``predict`` directly on this
``RandomizedSearchCV`` instance.
Also for multiple metric evaluation, the attributes ``best_index_``,
``best_score_`` and ``best_params_`` will only be available if
``refit`` is set and all of them will be determined w.r.t this specific
scorer.
See ``scoring`` parameter to know more about multiple metric
evaluation.
verbose : integer
Controls the verbosity: the higher, the more messages.
random_state : int, RandomState instance or None, optional, default=None
Pseudo random number generator state used for random uniform sampling
from lists of possible values instead of scipy.stats distributions.
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
error_score : 'raise' or numeric
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised. If a numeric value is given,
FitFailedWarning is raised. This parameter does not affect the refit
step, which will always raise the error. Default is 'raise' but from
version 0.22 it will change to np.nan.
return_train_score : boolean, optional
If ``False``, the ``cv_results_`` attribute will not include training
scores.
Current default is ``'warn'``, which behaves as ``True`` in addition
to raising a warning when a training score is looked up.
That default will be changed to ``False`` in 0.21.
Computing training scores is used to get insights on how different
parameter settings impact the overfitting/underfitting trade-off.
However computing the scores on the training set can be computationally
expensive and is not strictly required to select the parameters that
yield the best generalization performance.
Attributes
----------
cv_results_ : dict of numpy (masked) ndarrays
A dict with keys as column headers and values as columns, that can be
imported into a pandas ``DataFrame``.
For instance the below given table
+--------------+-------------+-------------------+---+---------------+
| param_kernel | param_gamma | split0_test_score |...|rank_test_score|
+==============+=============+===================+===+===============+
| 'rbf' | 0.1 | 0.80 |...| 2 |
+--------------+-------------+-------------------+---+---------------+
| 'rbf' | 0.2 | 0.90 |...| 1 |
+--------------+-------------+-------------------+---+---------------+
| 'rbf' | 0.3 | 0.70 |...| 1 |
+--------------+-------------+-------------------+---+---------------+
will be represented by a ``cv_results_`` dict of::
{
'param_kernel' : masked_array(data = ['rbf', 'rbf', 'rbf'],
mask = False),
'param_gamma' : masked_array(data = [0.1 0.2 0.3], mask = False),
'split0_test_score' : [0.80, 0.90, 0.70],
'split1_test_score' : [0.82, 0.50, 0.70],
'mean_test_score' : [0.81, 0.70, 0.70],
'std_test_score' : [0.01, 0.20, 0.00],
'rank_test_score' : [3, 1, 1],
'split0_train_score' : [0.80, 0.92, 0.70],
'split1_train_score' : [0.82, 0.55, 0.70],
'mean_train_score' : [0.81, 0.74, 0.70],
'std_train_score' : [0.01, 0.19, 0.00],
'mean_fit_time' : [0.73, 0.63, 0.43],
'std_fit_time' : [0.01, 0.02, 0.01],
'mean_score_time' : [0.01, 0.06, 0.04],
'std_score_time' : [0.00, 0.00, 0.00],
'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
}
NOTE
The key ``'params'`` is used to store a list of parameter
settings dicts for all the parameter candidates.
The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
``std_score_time`` are all in seconds.
For multi-metric evaluation, the scores for all the scorers are
available in the ``cv_results_`` dict at the keys ending with that
scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown
above. ('split0_test_precision', 'mean_train_precision' etc.)
best_estimator_ : estimator or dict
Estimator that was chosen by the search, i.e. estimator
which gave highest score (or smallest loss if specified)
on the left out data. Not available if ``refit=False``.
For multi-metric evaluation, this attribute is present only if
``refit`` is specified.
See ``refit`` parameter for more information on allowed values.
best_score_ : float
Mean cross-validated score of the best_estimator.
For multi-metric evaluation, this is not available if ``refit`` is
``False``. See ``refit`` parameter for more information.
best_params_ : dict
Parameter setting that gave the best results on the hold out data.
For multi-metric evaluation, this is not available if ``refit`` is
``False``. See ``refit`` parameter for more information.
best_index_ : int
The index (of the ``cv_results_`` arrays) which corresponds to the best
candidate parameter setting.
The dict at ``search.cv_results_['params'][search.best_index_]`` gives
the parameter setting for the best model, that gives the highest
mean score (``search.best_score_``).
For multi-metric evaluation, this is not available if ``refit`` is
``False``. See ``refit`` parameter for more information.
scorer_ : function or a dict
Scorer function used on the held out data to choose the best
parameters for the model.
For multi-metric evaluation, this attribute holds the validated
``scoring`` dict which maps the scorer key to the scorer callable.
n_splits_ : int
The number of cross-validation splits (folds/iterations).
refit_time_ : float
Seconds used for refitting the best model on the whole dataset.
This is present only if ``refit`` is not False.
Notes
-----
The parameters selected are those that maximize the score of the held-out
data, according to the scoring parameter.
If `n_jobs` was set to a value higher than one, the data is copied for each
parameter setting(and not `n_jobs` times). This is done for efficiency
reasons if individual jobs take very little time, but may raise errors if
the dataset is large and not enough memory is available. A workaround in
this case is to set `pre_dispatch`. Then, the memory is copied only
`pre_dispatch` many times. A reasonable value for `pre_dispatch` is `2 *
n_jobs`.
See Also
--------
:class:`GridSearchCV`:
Does exhaustive search over a grid of parameters.
:class:`ParameterSampler`:
A generator over parameter settings, constructed from
param_distributions.
"""
def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
fit_params=None, n_jobs=None, iid='warn', refit=True,
cv='warn', verbose=0, pre_dispatch='2*n_jobs',
random_state=None, error_score='raise-deprecating',
return_train_score="warn"):
self.param_distributions = param_distributions
self.n_iter = n_iter
self.random_state = random_state
super(RandomizedSearchCV, self).__init__(
estimator=estimator, scoring=scoring, fit_params=fit_params,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch, error_score=error_score,
return_train_score=return_train_score)
def _run_search(self, evaluate_candidates):
"""Search n_iter candidates from param_distributions"""
evaluate_candidates(ParameterSampler(
self.param_distributions, self.n_iter,
random_state=self.random_state))
|