1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556
|
"""
Least Angle Regression algorithm. See the documentation on the
Generalized Linear Model for a complete discussion.
"""
from __future__ import print_function
# Author: Fabian Pedregosa <fabian.pedregosa@inria.fr>
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
# Gael Varoquaux
#
# License: BSD 3 clause
from math import log
import sys
import warnings
import numpy as np
from scipy import linalg, interpolate
from scipy.linalg.lapack import get_lapack_funcs
from .base import LinearModel
from ..base import RegressorMixin
from ..utils import arrayfuncs, as_float_array, check_X_y, deprecated
from ..model_selection import check_cv
from ..exceptions import ConvergenceWarning
from ..utils._joblib import Parallel, delayed
from ..externals.six.moves import xrange
from ..externals.six import string_types
solve_triangular_args = {'check_finite': False}
def lars_path(X, y, Xy=None, Gram=None, max_iter=500,
alpha_min=0, method='lar', copy_X=True,
eps=np.finfo(np.float).eps,
copy_Gram=True, verbose=0, return_path=True,
return_n_iter=False, positive=False):
"""Compute Least Angle Regression or Lasso path using LARS algorithm [1]
The optimization objective for the case method='lasso' is::
(1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1
in the case of method='lars', the objective function is only known in
the form of an implicit equation (see discussion in [1])
Read more in the :ref:`User Guide <least_angle_regression>`.
Parameters
-----------
X : array, shape: (n_samples, n_features)
Input data.
y : array, shape: (n_samples)
Input targets.
Xy : array-like, shape (n_samples,) or (n_samples, n_targets), \
optional
Xy = np.dot(X.T, y) that can be precomputed. It is useful
only when the Gram matrix is precomputed.
Gram : None, 'auto', array, shape: (n_features, n_features), optional
Precomputed Gram matrix (X' * X), if ``'auto'``, the Gram
matrix is precomputed from the given X, if there are more samples
than features.
max_iter : integer, optional (default=500)
Maximum number of iterations to perform, set to infinity for no limit.
alpha_min : float, optional (default=0)
Minimum correlation along the path. It corresponds to the
regularization parameter alpha parameter in the Lasso.
method : {'lar', 'lasso'}, optional (default='lar')
Specifies the returned model. Select ``'lar'`` for Least Angle
Regression, ``'lasso'`` for the Lasso.
copy_X : bool, optional (default=True)
If ``False``, ``X`` is overwritten.
eps : float, optional (default=``np.finfo(np.float).eps``)
The machine-precision regularization in the computation of the
Cholesky diagonal factors. Increase this for very ill-conditioned
systems.
copy_Gram : bool, optional (default=True)
If ``False``, ``Gram`` is overwritten.
verbose : int (default=0)
Controls output verbosity.
return_path : bool, optional (default=True)
If ``return_path==True`` returns the entire path, else returns only the
last point of the path.
return_n_iter : bool, optional (default=False)
Whether to return the number of iterations.
positive : boolean (default=False)
Restrict coefficients to be >= 0.
This option is only allowed with method 'lasso'. Note that the model
coefficients will not converge to the ordinary-least-squares solution
for small values of alpha. Only coefficients up to the smallest alpha
value (``alphas_[alphas_ > 0.].min()`` when fit_path=True) reached by
the stepwise Lars-Lasso algorithm are typically in congruence with the
solution of the coordinate descent lasso_path function.
Returns
--------
alphas : array, shape: [n_alphas + 1]
Maximum of covariances (in absolute value) at each iteration.
``n_alphas`` is either ``max_iter``, ``n_features`` or the
number of nodes in the path with ``alpha >= alpha_min``, whichever
is smaller.
active : array, shape [n_alphas]
Indices of active variables at the end of the path.
coefs : array, shape (n_features, n_alphas + 1)
Coefficients along the path
n_iter : int
Number of iterations run. Returned only if return_n_iter is set
to True.
See also
--------
lasso_path
LassoLars
Lars
LassoLarsCV
LarsCV
sklearn.decomposition.sparse_encode
References
----------
.. [1] "Least Angle Regression", Effron et al.
http://statweb.stanford.edu/~tibs/ftp/lars.pdf
.. [2] `Wikipedia entry on the Least-angle regression
<https://en.wikipedia.org/wiki/Least-angle_regression>`_
.. [3] `Wikipedia entry on the Lasso
<https://en.wikipedia.org/wiki/Lasso_(statistics)>`_
"""
if method == 'lar' and positive:
warnings.warn('positive option is broken for Least'
' Angle Regression (LAR). Use method="lasso".'
' This option will be removed in version 0.22.',
DeprecationWarning)
n_features = X.shape[1]
n_samples = y.size
max_features = min(max_iter, n_features)
if return_path:
coefs = np.zeros((max_features + 1, n_features))
alphas = np.zeros(max_features + 1)
else:
coef, prev_coef = np.zeros(n_features), np.zeros(n_features)
alpha, prev_alpha = np.array([0.]), np.array([0.]) # better ideas?
n_iter, n_active = 0, 0
active, indices = list(), np.arange(n_features)
# holds the sign of covariance
sign_active = np.empty(max_features, dtype=np.int8)
drop = False
# will hold the cholesky factorization. Only lower part is
# referenced.
L = np.empty((max_features, max_features), dtype=X.dtype)
swap, nrm2 = linalg.get_blas_funcs(('swap', 'nrm2'), (X,))
solve_cholesky, = get_lapack_funcs(('potrs',), (X,))
if Gram is None or Gram is False:
Gram = None
if copy_X:
# force copy. setting the array to be fortran-ordered
# speeds up the calculation of the (partial) Gram matrix
# and allows to easily swap columns
X = X.copy('F')
elif isinstance(Gram, string_types) and Gram == 'auto' or Gram is True:
if Gram is True or X.shape[0] > X.shape[1]:
Gram = np.dot(X.T, X)
else:
Gram = None
elif copy_Gram:
Gram = Gram.copy()
if Xy is None:
Cov = np.dot(X.T, y)
else:
Cov = Xy.copy()
if verbose:
if verbose > 1:
print("Step\t\tAdded\t\tDropped\t\tActive set size\t\tC")
else:
sys.stdout.write('.')
sys.stdout.flush()
tiny32 = np.finfo(np.float32).tiny # to avoid division by 0 warning
equality_tolerance = np.finfo(np.float32).eps
while True:
if Cov.size:
if positive:
C_idx = np.argmax(Cov)
else:
C_idx = np.argmax(np.abs(Cov))
C_ = Cov[C_idx]
if positive:
C = C_
else:
C = np.fabs(C_)
else:
C = 0.
if return_path:
alpha = alphas[n_iter, np.newaxis]
coef = coefs[n_iter]
prev_alpha = alphas[n_iter - 1, np.newaxis]
prev_coef = coefs[n_iter - 1]
alpha[0] = C / n_samples
if alpha[0] <= alpha_min + equality_tolerance: # early stopping
if abs(alpha[0] - alpha_min) > equality_tolerance:
# interpolation factor 0 <= ss < 1
if n_iter > 0:
# In the first iteration, all alphas are zero, the formula
# below would make ss a NaN
ss = ((prev_alpha[0] - alpha_min) /
(prev_alpha[0] - alpha[0]))
coef[:] = prev_coef + ss * (coef - prev_coef)
alpha[0] = alpha_min
if return_path:
coefs[n_iter] = coef
break
if n_iter >= max_iter or n_active >= n_features:
break
if not drop:
##########################################################
# Append x_j to the Cholesky factorization of (Xa * Xa') #
# #
# ( L 0 ) #
# L -> ( ) , where L * w = Xa' x_j #
# ( w z ) and z = ||x_j|| #
# #
##########################################################
if positive:
sign_active[n_active] = np.ones_like(C_)
else:
sign_active[n_active] = np.sign(C_)
m, n = n_active, C_idx + n_active
Cov[C_idx], Cov[0] = swap(Cov[C_idx], Cov[0])
indices[n], indices[m] = indices[m], indices[n]
Cov_not_shortened = Cov
Cov = Cov[1:] # remove Cov[0]
if Gram is None:
X.T[n], X.T[m] = swap(X.T[n], X.T[m])
c = nrm2(X.T[n_active]) ** 2
L[n_active, :n_active] = \
np.dot(X.T[n_active], X.T[:n_active].T)
else:
# swap does only work inplace if matrix is fortran
# contiguous ...
Gram[m], Gram[n] = swap(Gram[m], Gram[n])
Gram[:, m], Gram[:, n] = swap(Gram[:, m], Gram[:, n])
c = Gram[n_active, n_active]
L[n_active, :n_active] = Gram[n_active, :n_active]
# Update the cholesky decomposition for the Gram matrix
if n_active:
linalg.solve_triangular(L[:n_active, :n_active],
L[n_active, :n_active],
trans=0, lower=1,
overwrite_b=True,
**solve_triangular_args)
v = np.dot(L[n_active, :n_active], L[n_active, :n_active])
diag = max(np.sqrt(np.abs(c - v)), eps)
L[n_active, n_active] = diag
if diag < 1e-7:
# The system is becoming too ill-conditioned.
# We have degenerate vectors in our active set.
# We'll 'drop for good' the last regressor added.
# Note: this case is very rare. It is no longer triggered by
# the test suite. The `equality_tolerance` margin added in 0.16
# to get early stopping to work consistently on all versions of
# Python including 32 bit Python under Windows seems to make it
# very difficult to trigger the 'drop for good' strategy.
warnings.warn('Regressors in active set degenerate. '
'Dropping a regressor, after %i iterations, '
'i.e. alpha=%.3e, '
'with an active set of %i regressors, and '
'the smallest cholesky pivot element being %.3e.'
' Reduce max_iter or increase eps parameters.'
% (n_iter, alpha, n_active, diag),
ConvergenceWarning)
# XXX: need to figure a 'drop for good' way
Cov = Cov_not_shortened
Cov[0] = 0
Cov[C_idx], Cov[0] = swap(Cov[C_idx], Cov[0])
continue
active.append(indices[n_active])
n_active += 1
if verbose > 1:
print("%s\t\t%s\t\t%s\t\t%s\t\t%s" % (n_iter, active[-1], '',
n_active, C))
if method == 'lasso' and n_iter > 0 and prev_alpha[0] < alpha[0]:
# alpha is increasing. This is because the updates of Cov are
# bringing in too much numerical error that is greater than
# than the remaining correlation with the
# regressors. Time to bail out
warnings.warn('Early stopping the lars path, as the residues '
'are small and the current value of alpha is no '
'longer well controlled. %i iterations, alpha=%.3e, '
'previous alpha=%.3e, with an active set of %i '
'regressors.'
% (n_iter, alpha, prev_alpha, n_active),
ConvergenceWarning)
break
# least squares solution
least_squares, info = solve_cholesky(L[:n_active, :n_active],
sign_active[:n_active],
lower=True)
if least_squares.size == 1 and least_squares == 0:
# This happens because sign_active[:n_active] = 0
least_squares[...] = 1
AA = 1.
else:
# is this really needed ?
AA = 1. / np.sqrt(np.sum(least_squares * sign_active[:n_active]))
if not np.isfinite(AA):
# L is too ill-conditioned
i = 0
L_ = L[:n_active, :n_active].copy()
while not np.isfinite(AA):
L_.flat[::n_active + 1] += (2 ** i) * eps
least_squares, info = solve_cholesky(
L_, sign_active[:n_active], lower=True)
tmp = max(np.sum(least_squares * sign_active[:n_active]),
eps)
AA = 1. / np.sqrt(tmp)
i += 1
least_squares *= AA
if Gram is None:
# equiangular direction of variables in the active set
eq_dir = np.dot(X.T[:n_active].T, least_squares)
# correlation between each unactive variables and
# eqiangular vector
corr_eq_dir = np.dot(X.T[n_active:], eq_dir)
else:
# if huge number of features, this takes 50% of time, I
# think could be avoided if we just update it using an
# orthogonal (QR) decomposition of X
corr_eq_dir = np.dot(Gram[:n_active, n_active:].T,
least_squares)
g1 = arrayfuncs.min_pos((C - Cov) / (AA - corr_eq_dir + tiny32))
if positive:
gamma_ = min(g1, C / AA)
else:
g2 = arrayfuncs.min_pos((C + Cov) / (AA + corr_eq_dir + tiny32))
gamma_ = min(g1, g2, C / AA)
# TODO: better names for these variables: z
drop = False
z = -coef[active] / (least_squares + tiny32)
z_pos = arrayfuncs.min_pos(z)
if z_pos < gamma_:
# some coefficients have changed sign
idx = np.where(z == z_pos)[0][::-1]
# update the sign, important for LAR
sign_active[idx] = -sign_active[idx]
if method == 'lasso':
gamma_ = z_pos
drop = True
n_iter += 1
if return_path:
if n_iter >= coefs.shape[0]:
del coef, alpha, prev_alpha, prev_coef
# resize the coefs and alphas array
add_features = 2 * max(1, (max_features - n_active))
coefs = np.resize(coefs, (n_iter + add_features, n_features))
coefs[-add_features:] = 0
alphas = np.resize(alphas, n_iter + add_features)
alphas[-add_features:] = 0
coef = coefs[n_iter]
prev_coef = coefs[n_iter - 1]
else:
# mimic the effect of incrementing n_iter on the array references
prev_coef = coef
prev_alpha[0] = alpha[0]
coef = np.zeros_like(coef)
coef[active] = prev_coef[active] + gamma_ * least_squares
# update correlations
Cov -= gamma_ * corr_eq_dir
# See if any coefficient has changed sign
if drop and method == 'lasso':
# handle the case when idx is not length of 1
[arrayfuncs.cholesky_delete(L[:n_active, :n_active], ii) for ii in
idx]
n_active -= 1
m, n = idx, n_active
# handle the case when idx is not length of 1
drop_idx = [active.pop(ii) for ii in idx]
if Gram is None:
# propagate dropped variable
for ii in idx:
for i in range(ii, n_active):
X.T[i], X.T[i + 1] = swap(X.T[i], X.T[i + 1])
# yeah this is stupid
indices[i], indices[i + 1] = indices[i + 1], indices[i]
# TODO: this could be updated
residual = y - np.dot(X[:, :n_active], coef[active])
temp = np.dot(X.T[n_active], residual)
Cov = np.r_[temp, Cov]
else:
for ii in idx:
for i in range(ii, n_active):
indices[i], indices[i + 1] = indices[i + 1], indices[i]
Gram[i], Gram[i + 1] = swap(Gram[i], Gram[i + 1])
Gram[:, i], Gram[:, i + 1] = swap(Gram[:, i],
Gram[:, i + 1])
# Cov_n = Cov_j + x_j * X + increment(betas) TODO:
# will this still work with multiple drops ?
# recompute covariance. Probably could be done better
# wrong as Xy is not swapped with the rest of variables
# TODO: this could be updated
residual = y - np.dot(X, coef)
temp = np.dot(X.T[drop_idx], residual)
Cov = np.r_[temp, Cov]
sign_active = np.delete(sign_active, idx)
sign_active = np.append(sign_active, 0.) # just to maintain size
if verbose > 1:
print("%s\t\t%s\t\t%s\t\t%s\t\t%s" % (n_iter, '', drop_idx,
n_active, abs(temp)))
if return_path:
# resize coefs in case of early stop
alphas = alphas[:n_iter + 1]
coefs = coefs[:n_iter + 1]
if return_n_iter:
return alphas, active, coefs.T, n_iter
else:
return alphas, active, coefs.T
else:
if return_n_iter:
return alpha, active, coef, n_iter
else:
return alpha, active, coef
###############################################################################
# Estimator classes
class Lars(LinearModel, RegressorMixin):
"""Least Angle Regression model a.k.a. LAR
Read more in the :ref:`User Guide <least_angle_regression>`.
Parameters
----------
fit_intercept : boolean
Whether to calculate the intercept for this model. If set
to false, no intercept will be used in calculations
(e.g. data is expected to be already centered).
verbose : boolean or integer, optional
Sets the verbosity amount
normalize : boolean, optional, default True
This parameter is ignored when ``fit_intercept`` is set to False.
If True, the regressors X will be normalized before regression by
subtracting the mean and dividing by the l2-norm.
If you wish to standardize, please use
:class:`sklearn.preprocessing.StandardScaler` before calling ``fit``
on an estimator with ``normalize=False``.
precompute : True | False | 'auto' | array-like
Whether to use a precomputed Gram matrix to speed up
calculations. If set to ``'auto'`` let us decide. The Gram
matrix can also be passed as argument.
n_nonzero_coefs : int, optional
Target number of non-zero coefficients. Use ``np.inf`` for no limit.
eps : float, optional
The machine-precision regularization in the computation of the
Cholesky diagonal factors. Increase this for very ill-conditioned
systems. Unlike the ``tol`` parameter in some iterative
optimization-based algorithms, this parameter does not control
the tolerance of the optimization.
copy_X : boolean, optional, default True
If ``True``, X will be copied; else, it may be overwritten.
fit_path : boolean
If True the full path is stored in the ``coef_path_`` attribute.
If you compute the solution for a large problem or many targets,
setting ``fit_path`` to ``False`` will lead to a speedup, especially
with a small alpha.
positive : boolean (default=False)
Restrict coefficients to be >= 0. Be aware that you might want to
remove fit_intercept which is set True by default.
.. deprecated:: 0.20
The option is broken and deprecated. It will be removed in v0.22.
Attributes
----------
alphas_ : array, shape (n_alphas + 1,) | list of n_targets such arrays
Maximum of covariances (in absolute value) at each iteration. \
``n_alphas`` is either ``n_nonzero_coefs`` or ``n_features``, \
whichever is smaller.
active_ : list, length = n_alphas | list of n_targets such lists
Indices of active variables at the end of the path.
coef_path_ : array, shape (n_features, n_alphas + 1) \
| list of n_targets such arrays
The varying values of the coefficients along the path. It is not
present if the ``fit_path`` parameter is ``False``.
coef_ : array, shape (n_features,) or (n_targets, n_features)
Parameter vector (w in the formulation formula).
intercept_ : float | array, shape (n_targets,)
Independent term in decision function.
n_iter_ : array-like or int
The number of iterations taken by lars_path to find the
grid of alphas for each target.
Examples
--------
>>> from sklearn import linear_model
>>> reg = linear_model.Lars(n_nonzero_coefs=1)
>>> reg.fit([[-1, 1], [0, 0], [1, 1]], [-1.1111, 0, -1.1111])
... # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
Lars(copy_X=True, eps=..., fit_intercept=True, fit_path=True,
n_nonzero_coefs=1, normalize=True, positive=False, precompute='auto',
verbose=False)
>>> print(reg.coef_) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
[ 0. -1.11...]
See also
--------
lars_path, LarsCV
sklearn.decomposition.sparse_encode
"""
method = 'lar'
def __init__(self, fit_intercept=True, verbose=False, normalize=True,
precompute='auto', n_nonzero_coefs=500,
eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
positive=False):
self.fit_intercept = fit_intercept
self.verbose = verbose
self.normalize = normalize
self.precompute = precompute
self.n_nonzero_coefs = n_nonzero_coefs
self.positive = positive
self.eps = eps
self.copy_X = copy_X
self.fit_path = fit_path
@staticmethod
def _get_gram(precompute, X, y):
if (not hasattr(precompute, '__array__')) and (
(precompute is True) or
(precompute == 'auto' and X.shape[0] > X.shape[1]) or
(precompute == 'auto' and y.shape[1] > 1)):
precompute = np.dot(X.T, X)
return precompute
def _fit(self, X, y, max_iter, alpha, fit_path, Xy=None):
"""Auxiliary method to fit the model using X, y as training data"""
n_features = X.shape[1]
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
X, y, self.fit_intercept, self.normalize, self.copy_X)
if y.ndim == 1:
y = y[:, np.newaxis]
n_targets = y.shape[1]
Gram = self._get_gram(self.precompute, X, y)
self.alphas_ = []
self.n_iter_ = []
self.coef_ = np.empty((n_targets, n_features))
if fit_path:
self.active_ = []
self.coef_path_ = []
for k in xrange(n_targets):
this_Xy = None if Xy is None else Xy[:, k]
alphas, active, coef_path, n_iter_ = lars_path(
X, y[:, k], Gram=Gram, Xy=this_Xy, copy_X=self.copy_X,
copy_Gram=True, alpha_min=alpha, method=self.method,
verbose=max(0, self.verbose - 1), max_iter=max_iter,
eps=self.eps, return_path=True,
return_n_iter=True, positive=self.positive)
self.alphas_.append(alphas)
self.active_.append(active)
self.n_iter_.append(n_iter_)
self.coef_path_.append(coef_path)
self.coef_[k] = coef_path[:, -1]
if n_targets == 1:
self.alphas_, self.active_, self.coef_path_, self.coef_ = [
a[0] for a in (self.alphas_, self.active_, self.coef_path_,
self.coef_)]
self.n_iter_ = self.n_iter_[0]
else:
for k in xrange(n_targets):
this_Xy = None if Xy is None else Xy[:, k]
alphas, _, self.coef_[k], n_iter_ = lars_path(
X, y[:, k], Gram=Gram, Xy=this_Xy, copy_X=self.copy_X,
copy_Gram=True, alpha_min=alpha, method=self.method,
verbose=max(0, self.verbose - 1), max_iter=max_iter,
eps=self.eps, return_path=False, return_n_iter=True,
positive=self.positive)
self.alphas_.append(alphas)
self.n_iter_.append(n_iter_)
if n_targets == 1:
self.alphas_ = self.alphas_[0]
self.n_iter_ = self.n_iter_[0]
self._set_intercept(X_offset, y_offset, X_scale)
return self
def fit(self, X, y, Xy=None):
"""Fit the model using X, y as training data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data.
y : array-like, shape (n_samples,) or (n_samples, n_targets)
Target values.
Xy : array-like, shape (n_samples,) or (n_samples, n_targets), \
optional
Xy = np.dot(X.T, y) that can be precomputed. It is useful
only when the Gram matrix is precomputed.
Returns
-------
self : object
returns an instance of self.
"""
X, y = check_X_y(X, y, y_numeric=True, multi_output=True)
alpha = getattr(self, 'alpha', 0.)
if hasattr(self, 'n_nonzero_coefs'):
alpha = 0. # n_nonzero_coefs parametrization takes priority
max_iter = self.n_nonzero_coefs
else:
max_iter = self.max_iter
self._fit(X, y, max_iter=max_iter, alpha=alpha, fit_path=self.fit_path,
Xy=Xy)
return self
class LassoLars(Lars):
"""Lasso model fit with Least Angle Regression a.k.a. Lars
It is a Linear Model trained with an L1 prior as regularizer.
The optimization objective for Lasso is::
(1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1
Read more in the :ref:`User Guide <least_angle_regression>`.
Parameters
----------
alpha : float
Constant that multiplies the penalty term. Defaults to 1.0.
``alpha = 0`` is equivalent to an ordinary least square, solved
by :class:`LinearRegression`. For numerical reasons, using
``alpha = 0`` with the LassoLars object is not advised and you
should prefer the LinearRegression object.
fit_intercept : boolean
whether to calculate the intercept for this model. If set
to false, no intercept will be used in calculations
(e.g. data is expected to be already centered).
verbose : boolean or integer, optional
Sets the verbosity amount
normalize : boolean, optional, default True
This parameter is ignored when ``fit_intercept`` is set to False.
If True, the regressors X will be normalized before regression by
subtracting the mean and dividing by the l2-norm.
If you wish to standardize, please use
:class:`sklearn.preprocessing.StandardScaler` before calling ``fit``
on an estimator with ``normalize=False``.
precompute : True | False | 'auto' | array-like
Whether to use a precomputed Gram matrix to speed up
calculations. If set to ``'auto'`` let us decide. The Gram
matrix can also be passed as argument.
max_iter : integer, optional
Maximum number of iterations to perform.
eps : float, optional
The machine-precision regularization in the computation of the
Cholesky diagonal factors. Increase this for very ill-conditioned
systems. Unlike the ``tol`` parameter in some iterative
optimization-based algorithms, this parameter does not control
the tolerance of the optimization.
copy_X : boolean, optional, default True
If True, X will be copied; else, it may be overwritten.
fit_path : boolean
If ``True`` the full path is stored in the ``coef_path_`` attribute.
If you compute the solution for a large problem or many targets,
setting ``fit_path`` to ``False`` will lead to a speedup, especially
with a small alpha.
positive : boolean (default=False)
Restrict coefficients to be >= 0. Be aware that you might want to
remove fit_intercept which is set True by default.
Under the positive restriction the model coefficients will not converge
to the ordinary-least-squares solution for small values of alpha.
Only coefficients up to the smallest alpha value (``alphas_[alphas_ >
0.].min()`` when fit_path=True) reached by the stepwise Lars-Lasso
algorithm are typically in congruence with the solution of the
coordinate descent Lasso estimator.
Attributes
----------
alphas_ : array, shape (n_alphas + 1,) | list of n_targets such arrays
Maximum of covariances (in absolute value) at each iteration. \
``n_alphas`` is either ``max_iter``, ``n_features``, or the number of \
nodes in the path with correlation greater than ``alpha``, whichever \
is smaller.
active_ : list, length = n_alphas | list of n_targets such lists
Indices of active variables at the end of the path.
coef_path_ : array, shape (n_features, n_alphas + 1) or list
If a list is passed it's expected to be one of n_targets such arrays.
The varying values of the coefficients along the path. It is not
present if the ``fit_path`` parameter is ``False``.
coef_ : array, shape (n_features,) or (n_targets, n_features)
Parameter vector (w in the formulation formula).
intercept_ : float | array, shape (n_targets,)
Independent term in decision function.
n_iter_ : array-like or int.
The number of iterations taken by lars_path to find the
grid of alphas for each target.
Examples
--------
>>> from sklearn import linear_model
>>> reg = linear_model.LassoLars(alpha=0.01)
>>> reg.fit([[-1, 1], [0, 0], [1, 1]], [-1, 0, -1])
... # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
LassoLars(alpha=0.01, copy_X=True, eps=..., fit_intercept=True,
fit_path=True, max_iter=500, normalize=True, positive=False,
precompute='auto', verbose=False)
>>> print(reg.coef_) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
[ 0. -0.963257...]
See also
--------
lars_path
lasso_path
Lasso
LassoCV
LassoLarsCV
LassoLarsIC
sklearn.decomposition.sparse_encode
"""
method = 'lasso'
def __init__(self, alpha=1.0, fit_intercept=True, verbose=False,
normalize=True, precompute='auto', max_iter=500,
eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
positive=False):
self.alpha = alpha
self.fit_intercept = fit_intercept
self.max_iter = max_iter
self.verbose = verbose
self.normalize = normalize
self.positive = positive
self.precompute = precompute
self.copy_X = copy_X
self.eps = eps
self.fit_path = fit_path
###############################################################################
# Cross-validated estimator classes
def _check_copy_and_writeable(array, copy=False):
if copy or not array.flags.writeable:
return array.copy()
return array
def _lars_path_residues(X_train, y_train, X_test, y_test, Gram=None,
copy=True, method='lars', verbose=False,
fit_intercept=True, normalize=True, max_iter=500,
eps=np.finfo(np.float).eps, positive=False):
"""Compute the residues on left-out data for a full LARS path
Parameters
-----------
X_train : array, shape (n_samples, n_features)
The data to fit the LARS on
y_train : array, shape (n_samples)
The target variable to fit LARS on
X_test : array, shape (n_samples, n_features)
The data to compute the residues on
y_test : array, shape (n_samples)
The target variable to compute the residues on
Gram : None, 'auto', array, shape: (n_features, n_features), optional
Precomputed Gram matrix (X' * X), if ``'auto'``, the Gram
matrix is precomputed from the given X, if there are more samples
than features
copy : boolean, optional
Whether X_train, X_test, y_train and y_test should be copied;
if False, they may be overwritten.
method : 'lar' | 'lasso'
Specifies the returned model. Select ``'lar'`` for Least Angle
Regression, ``'lasso'`` for the Lasso.
verbose : integer, optional
Sets the amount of verbosity
fit_intercept : boolean
whether to calculate the intercept for this model. If set
to false, no intercept will be used in calculations
(e.g. data is expected to be already centered).
positive : boolean (default=False)
Restrict coefficients to be >= 0. Be aware that you might want to
remove fit_intercept which is set True by default.
See reservations for using this option in combination with method
'lasso' for expected small values of alpha in the doc of LassoLarsCV
and LassoLarsIC.
normalize : boolean, optional, default True
This parameter is ignored when ``fit_intercept`` is set to False.
If True, the regressors X will be normalized before regression by
subtracting the mean and dividing by the l2-norm.
If you wish to standardize, please use
:class:`sklearn.preprocessing.StandardScaler` before calling ``fit``
on an estimator with ``normalize=False``.
max_iter : integer, optional
Maximum number of iterations to perform.
eps : float, optional
The machine-precision regularization in the computation of the
Cholesky diagonal factors. Increase this for very ill-conditioned
systems. Unlike the ``tol`` parameter in some iterative
optimization-based algorithms, this parameter does not control
the tolerance of the optimization.
Returns
--------
alphas : array, shape (n_alphas,)
Maximum of covariances (in absolute value) at each iteration.
``n_alphas`` is either ``max_iter`` or ``n_features``, whichever
is smaller.
active : list
Indices of active variables at the end of the path.
coefs : array, shape (n_features, n_alphas)
Coefficients along the path
residues : array, shape (n_alphas, n_samples)
Residues of the prediction on the test data
"""
X_train = _check_copy_and_writeable(X_train, copy)
y_train = _check_copy_and_writeable(y_train, copy)
X_test = _check_copy_and_writeable(X_test, copy)
y_test = _check_copy_and_writeable(y_test, copy)
if fit_intercept:
X_mean = X_train.mean(axis=0)
X_train -= X_mean
X_test -= X_mean
y_mean = y_train.mean(axis=0)
y_train = as_float_array(y_train, copy=False)
y_train -= y_mean
y_test = as_float_array(y_test, copy=False)
y_test -= y_mean
if normalize:
norms = np.sqrt(np.sum(X_train ** 2, axis=0))
nonzeros = np.flatnonzero(norms)
X_train[:, nonzeros] /= norms[nonzeros]
alphas, active, coefs = lars_path(
X_train, y_train, Gram=Gram, copy_X=False, copy_Gram=False,
method=method, verbose=max(0, verbose - 1), max_iter=max_iter, eps=eps,
positive=positive)
if normalize:
coefs[nonzeros] /= norms[nonzeros][:, np.newaxis]
residues = np.dot(X_test, coefs) - y_test[:, np.newaxis]
return alphas, active, coefs, residues.T
class LarsCV(Lars):
"""Cross-validated Least Angle Regression model.
See glossary entry for :term:`cross-validation estimator`.
Read more in the :ref:`User Guide <least_angle_regression>`.
Parameters
----------
fit_intercept : boolean
whether to calculate the intercept for this model. If set
to false, no intercept will be used in calculations
(e.g. data is expected to be already centered).
verbose : boolean or integer, optional
Sets the verbosity amount
max_iter : integer, optional
Maximum number of iterations to perform.
normalize : boolean, optional, default True
This parameter is ignored when ``fit_intercept`` is set to False.
If True, the regressors X will be normalized before regression by
subtracting the mean and dividing by the l2-norm.
If you wish to standardize, please use
:class:`sklearn.preprocessing.StandardScaler` before calling ``fit``
on an estimator with ``normalize=False``.
precompute : True | False | 'auto' | array-like
Whether to use a precomputed Gram matrix to speed up
calculations. If set to ``'auto'`` let us decide. The Gram matrix
cannot be passed as argument since we will use only subsets of X.
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, :class:`KFold` is used.
Refer :ref:`User Guide <cross_validation>` for the various
cross-validation strategies that can be used here.
.. versionchanged:: 0.20
``cv`` default value if None will change from 3-fold to 5-fold
in v0.22.
max_n_alphas : integer, optional
The maximum number of points on the path used to compute the
residuals in the cross-validation
n_jobs : int or None, optional (default=None)
Number of CPUs to use during the cross validation.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
eps : float, optional
The machine-precision regularization in the computation of the
Cholesky diagonal factors. Increase this for very ill-conditioned
systems.
copy_X : boolean, optional, default True
If ``True``, X will be copied; else, it may be overwritten.
positive : boolean (default=False)
Restrict coefficients to be >= 0. Be aware that you might want to
remove fit_intercept which is set True by default.
.. deprecated:: 0.20
The option is broken and deprecated. It will be removed in v0.22.
Attributes
----------
coef_ : array, shape (n_features,)
parameter vector (w in the formulation formula)
intercept_ : float
independent term in decision function
coef_path_ : array, shape (n_features, n_alphas)
the varying values of the coefficients along the path
alpha_ : float
the estimated regularization parameter alpha
alphas_ : array, shape (n_alphas,)
the different values of alpha along the path
cv_alphas_ : array, shape (n_cv_alphas,)
all the values of alpha along the path for the different folds
mse_path_ : array, shape (n_folds, n_cv_alphas)
the mean square error on left-out for each fold along the path
(alpha values given by ``cv_alphas``)
n_iter_ : array-like or int
the number of iterations run by Lars with the optimal alpha.
Examples
--------
>>> from sklearn.linear_model import LarsCV
>>> from sklearn.datasets import make_regression
>>> X, y = make_regression(n_samples=200, noise=4.0, random_state=0)
>>> reg = LarsCV(cv=5).fit(X, y)
>>> reg.score(X, y) # doctest: +ELLIPSIS
0.9996...
>>> reg.alpha_
0.0254...
>>> reg.predict(X[:1,])
array([154.0842...])
See also
--------
lars_path, LassoLars, LassoLarsCV
"""
method = 'lar'
def __init__(self, fit_intercept=True, verbose=False, max_iter=500,
normalize=True, precompute='auto', cv='warn',
max_n_alphas=1000, n_jobs=None, eps=np.finfo(np.float).eps,
copy_X=True, positive=False):
self.max_iter = max_iter
self.cv = cv
self.max_n_alphas = max_n_alphas
self.n_jobs = n_jobs
super(LarsCV, self).__init__(fit_intercept=fit_intercept,
verbose=verbose, normalize=normalize,
precompute=precompute,
n_nonzero_coefs=500,
eps=eps, copy_X=copy_X, fit_path=True,
positive=positive)
def fit(self, X, y):
"""Fit the model using X, y as training data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data.
y : array-like, shape (n_samples,)
Target values.
Returns
-------
self : object
returns an instance of self.
"""
X, y = check_X_y(X, y, y_numeric=True)
X = as_float_array(X, copy=self.copy_X)
y = as_float_array(y, copy=self.copy_X)
# init cross-validation generator
cv = check_cv(self.cv, classifier=False)
# As we use cross-validation, the Gram matrix is not precomputed here
Gram = self.precompute
if hasattr(Gram, '__array__'):
warnings.warn("Parameter 'precompute' cannot be an array in "
"%s. Automatically switch to 'auto' instead."
% self.__class__.__name__)
Gram = 'auto'
cv_paths = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
delayed(_lars_path_residues)(
X[train], y[train], X[test], y[test], Gram=Gram, copy=False,
method=self.method, verbose=max(0, self.verbose - 1),
normalize=self.normalize, fit_intercept=self.fit_intercept,
max_iter=self.max_iter, eps=self.eps, positive=self.positive)
for train, test in cv.split(X, y))
all_alphas = np.concatenate(list(zip(*cv_paths))[0])
# Unique also sorts
all_alphas = np.unique(all_alphas)
# Take at most max_n_alphas values
stride = int(max(1, int(len(all_alphas) / float(self.max_n_alphas))))
all_alphas = all_alphas[::stride]
mse_path = np.empty((len(all_alphas), len(cv_paths)))
for index, (alphas, active, coefs, residues) in enumerate(cv_paths):
alphas = alphas[::-1]
residues = residues[::-1]
if alphas[0] != 0:
alphas = np.r_[0, alphas]
residues = np.r_[residues[0, np.newaxis], residues]
if alphas[-1] != all_alphas[-1]:
alphas = np.r_[alphas, all_alphas[-1]]
residues = np.r_[residues, residues[-1, np.newaxis]]
this_residues = interpolate.interp1d(alphas,
residues,
axis=0)(all_alphas)
this_residues **= 2
mse_path[:, index] = np.mean(this_residues, axis=-1)
mask = np.all(np.isfinite(mse_path), axis=-1)
all_alphas = all_alphas[mask]
mse_path = mse_path[mask]
# Select the alpha that minimizes left-out error
i_best_alpha = np.argmin(mse_path.mean(axis=-1))
best_alpha = all_alphas[i_best_alpha]
# Store our parameters
self.alpha_ = best_alpha
self.cv_alphas_ = all_alphas
self.mse_path_ = mse_path
# Now compute the full model
# it will call a lasso internally when self if LassoLarsCV
# as self.method == 'lasso'
self._fit(X, y, max_iter=self.max_iter, alpha=best_alpha,
Xy=None, fit_path=True)
return self
@property
@deprecated("Attribute alpha is deprecated in 0.19 and "
"will be removed in 0.21. See ``alpha_`` instead")
def alpha(self):
# impedance matching for the above Lars.fit (should not be documented)
return self.alpha_
class LassoLarsCV(LarsCV):
"""Cross-validated Lasso, using the LARS algorithm.
See glossary entry for :term:`cross-validation estimator`.
The optimization objective for Lasso is::
(1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1
Read more in the :ref:`User Guide <least_angle_regression>`.
Parameters
----------
fit_intercept : boolean
whether to calculate the intercept for this model. If set
to false, no intercept will be used in calculations
(e.g. data is expected to be already centered).
verbose : boolean or integer, optional
Sets the verbosity amount
max_iter : integer, optional
Maximum number of iterations to perform.
normalize : boolean, optional, default True
This parameter is ignored when ``fit_intercept`` is set to False.
If True, the regressors X will be normalized before regression by
subtracting the mean and dividing by the l2-norm.
If you wish to standardize, please use
:class:`sklearn.preprocessing.StandardScaler` before calling ``fit``
on an estimator with ``normalize=False``.
precompute : True | False | 'auto'
Whether to use a precomputed Gram matrix to speed up
calculations. If set to ``'auto'`` let us decide. The Gram matrix
cannot be passed as argument since we will use only subsets of X.
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 3-fold cross-validation,
- integer, to specify the number of folds.
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, :class:`KFold` is used.
Refer :ref:`User Guide <cross_validation>` for the various
cross-validation strategies that can be used here.
.. versionchanged:: 0.20
``cv`` default value if None will change from 3-fold to 5-fold
in v0.22.
max_n_alphas : integer, optional
The maximum number of points on the path used to compute the
residuals in the cross-validation
n_jobs : int or None, optional (default=None)
Number of CPUs to use during the cross validation.
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
for more details.
eps : float, optional
The machine-precision regularization in the computation of the
Cholesky diagonal factors. Increase this for very ill-conditioned
systems.
copy_X : boolean, optional, default True
If True, X will be copied; else, it may be overwritten.
positive : boolean (default=False)
Restrict coefficients to be >= 0. Be aware that you might want to
remove fit_intercept which is set True by default.
Under the positive restriction the model coefficients do not converge
to the ordinary-least-squares solution for small values of alpha.
Only coefficients up to the smallest alpha value (``alphas_[alphas_ >
0.].min()`` when fit_path=True) reached by the stepwise Lars-Lasso
algorithm are typically in congruence with the solution of the
coordinate descent Lasso estimator.
As a consequence using LassoLarsCV only makes sense for problems where
a sparse solution is expected and/or reached.
Attributes
----------
coef_ : array, shape (n_features,)
parameter vector (w in the formulation formula)
intercept_ : float
independent term in decision function.
coef_path_ : array, shape (n_features, n_alphas)
the varying values of the coefficients along the path
alpha_ : float
the estimated regularization parameter alpha
alphas_ : array, shape (n_alphas,)
the different values of alpha along the path
cv_alphas_ : array, shape (n_cv_alphas,)
all the values of alpha along the path for the different folds
mse_path_ : array, shape (n_folds, n_cv_alphas)
the mean square error on left-out for each fold along the path
(alpha values given by ``cv_alphas``)
n_iter_ : array-like or int
the number of iterations run by Lars with the optimal alpha.
Examples
--------
>>> from sklearn.linear_model import LassoLarsCV
>>> from sklearn.datasets import make_regression
>>> X, y = make_regression(noise=4.0, random_state=0)
>>> reg = LassoLarsCV(cv=5).fit(X, y)
>>> reg.score(X, y) # doctest: +ELLIPSIS
0.9992...
>>> reg.alpha_
0.0484...
>>> reg.predict(X[:1,])
array([-77.8723...])
Notes
-----
The object solves the same problem as the LassoCV object. However,
unlike the LassoCV, it find the relevant alphas values by itself.
In general, because of this property, it will be more stable.
However, it is more fragile to heavily multicollinear datasets.
It is more efficient than the LassoCV if only a small number of
features are selected compared to the total number, for instance if
there are very few samples compared to the number of features.
See also
--------
lars_path, LassoLars, LarsCV, LassoCV
"""
method = 'lasso'
def __init__(self, fit_intercept=True, verbose=False, max_iter=500,
normalize=True, precompute='auto', cv='warn',
max_n_alphas=1000, n_jobs=None, eps=np.finfo(np.float).eps,
copy_X=True, positive=False):
self.fit_intercept = fit_intercept
self.verbose = verbose
self.max_iter = max_iter
self.normalize = normalize
self.precompute = precompute
self.cv = cv
self.max_n_alphas = max_n_alphas
self.n_jobs = n_jobs
self.eps = eps
self.copy_X = copy_X
self.positive = positive
# XXX : we don't use super(LarsCV, self).__init__
# to avoid setting n_nonzero_coefs
class LassoLarsIC(LassoLars):
"""Lasso model fit with Lars using BIC or AIC for model selection
The optimization objective for Lasso is::
(1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1
AIC is the Akaike information criterion and BIC is the Bayes
Information criterion. Such criteria are useful to select the value
of the regularization parameter by making a trade-off between the
goodness of fit and the complexity of the model. A good model should
explain well the data while being simple.
Read more in the :ref:`User Guide <least_angle_regression>`.
Parameters
----------
criterion : 'bic' | 'aic'
The type of criterion to use.
fit_intercept : boolean
whether to calculate the intercept for this model. If set
to false, no intercept will be used in calculations
(e.g. data is expected to be already centered).
verbose : boolean or integer, optional
Sets the verbosity amount
normalize : boolean, optional, default True
This parameter is ignored when ``fit_intercept`` is set to False.
If True, the regressors X will be normalized before regression by
subtracting the mean and dividing by the l2-norm.
If you wish to standardize, please use
:class:`sklearn.preprocessing.StandardScaler` before calling ``fit``
on an estimator with ``normalize=False``.
precompute : True | False | 'auto' | array-like
Whether to use a precomputed Gram matrix to speed up
calculations. If set to ``'auto'`` let us decide. The Gram
matrix can also be passed as argument.
max_iter : integer, optional
Maximum number of iterations to perform. Can be used for
early stopping.
eps : float, optional
The machine-precision regularization in the computation of the
Cholesky diagonal factors. Increase this for very ill-conditioned
systems. Unlike the ``tol`` parameter in some iterative
optimization-based algorithms, this parameter does not control
the tolerance of the optimization.
copy_X : boolean, optional, default True
If True, X will be copied; else, it may be overwritten.
positive : boolean (default=False)
Restrict coefficients to be >= 0. Be aware that you might want to
remove fit_intercept which is set True by default.
Under the positive restriction the model coefficients do not converge
to the ordinary-least-squares solution for small values of alpha.
Only coefficients up to the smallest alpha value (``alphas_[alphas_ >
0.].min()`` when fit_path=True) reached by the stepwise Lars-Lasso
algorithm are typically in congruence with the solution of the
coordinate descent Lasso estimator.
As a consequence using LassoLarsIC only makes sense for problems where
a sparse solution is expected and/or reached.
Attributes
----------
coef_ : array, shape (n_features,)
parameter vector (w in the formulation formula)
intercept_ : float
independent term in decision function.
alpha_ : float
the alpha parameter chosen by the information criterion
n_iter_ : int
number of iterations run by lars_path to find the grid of
alphas.
criterion_ : array, shape (n_alphas,)
The value of the information criteria ('aic', 'bic') across all
alphas. The alpha which has the smallest information criterion is
chosen. This value is larger by a factor of ``n_samples`` compared to
Eqns. 2.15 and 2.16 in (Zou et al, 2007).
Examples
--------
>>> from sklearn import linear_model
>>> reg = linear_model.LassoLarsIC(criterion='bic')
>>> reg.fit([[-1, 1], [0, 0], [1, 1]], [-1.1111, 0, -1.1111])
... # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
LassoLarsIC(copy_X=True, criterion='bic', eps=..., fit_intercept=True,
max_iter=500, normalize=True, positive=False, precompute='auto',
verbose=False)
>>> print(reg.coef_) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
[ 0. -1.11...]
Notes
-----
The estimation of the number of degrees of freedom is given by:
"On the degrees of freedom of the lasso"
Hui Zou, Trevor Hastie, and Robert Tibshirani
Ann. Statist. Volume 35, Number 5 (2007), 2173-2192.
https://en.wikipedia.org/wiki/Akaike_information_criterion
https://en.wikipedia.org/wiki/Bayesian_information_criterion
See also
--------
lars_path, LassoLars, LassoLarsCV
"""
def __init__(self, criterion='aic', fit_intercept=True, verbose=False,
normalize=True, precompute='auto', max_iter=500,
eps=np.finfo(np.float).eps, copy_X=True, positive=False):
self.criterion = criterion
self.fit_intercept = fit_intercept
self.positive = positive
self.max_iter = max_iter
self.verbose = verbose
self.normalize = normalize
self.copy_X = copy_X
self.precompute = precompute
self.eps = eps
self.fit_path = True
def fit(self, X, y, copy_X=True):
"""Fit the model using X, y as training data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
training data.
y : array-like, shape (n_samples,)
target values. Will be cast to X's dtype if necessary
copy_X : boolean, optional, default True
If ``True``, X will be copied; else, it may be overwritten.
Returns
-------
self : object
returns an instance of self.
"""
X, y = check_X_y(X, y, y_numeric=True)
X, y, Xmean, ymean, Xstd = LinearModel._preprocess_data(
X, y, self.fit_intercept, self.normalize, self.copy_X)
max_iter = self.max_iter
Gram = self.precompute
alphas_, active_, coef_path_, self.n_iter_ = lars_path(
X, y, Gram=Gram, copy_X=copy_X, copy_Gram=True, alpha_min=0.0,
method='lasso', verbose=self.verbose, max_iter=max_iter,
eps=self.eps, return_n_iter=True, positive=self.positive)
n_samples = X.shape[0]
if self.criterion == 'aic':
K = 2 # AIC
elif self.criterion == 'bic':
K = log(n_samples) # BIC
else:
raise ValueError('criterion should be either bic or aic')
R = y[:, np.newaxis] - np.dot(X, coef_path_) # residuals
mean_squared_error = np.mean(R ** 2, axis=0)
sigma2 = np.var(y)
df = np.zeros(coef_path_.shape[1], dtype=np.int) # Degrees of freedom
for k, coef in enumerate(coef_path_.T):
mask = np.abs(coef) > np.finfo(coef.dtype).eps
if not np.any(mask):
continue
# get the number of degrees of freedom equal to:
# Xc = X[:, mask]
# Trace(Xc * inv(Xc.T, Xc) * Xc.T) ie the number of non-zero coefs
df[k] = np.sum(mask)
self.alphas_ = alphas_
eps64 = np.finfo('float64').eps
self.criterion_ = (n_samples * mean_squared_error / (sigma2 + eps64) +
K * df) # Eqns. 2.15--16 in (Zou et al, 2007)
n_best = np.argmin(self.criterion_)
self.alpha_ = alphas_[n_best]
self.coef_ = coef_path_[:, n_best]
self._set_intercept(Xmean, ymean, Xstd)
return self
|