File: test_loss.py

package info (click to toggle)
scikit-learn 0.23.2-5
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 21,892 kB
  • sloc: python: 132,020; cpp: 5,765; javascript: 2,201; ansic: 831; makefile: 213; sh: 44
file content (318 lines) | stat: -rw-r--r-- 13,430 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import numpy as np
from numpy.testing import assert_almost_equal
from numpy.testing import assert_allclose
from scipy.optimize import newton
from sklearn.utils import assert_all_finite
from sklearn.utils.fixes import sp_version, parse_version
import pytest

from sklearn.ensemble._hist_gradient_boosting.loss import _LOSSES
from sklearn.ensemble._hist_gradient_boosting.common import Y_DTYPE
from sklearn.ensemble._hist_gradient_boosting.common import G_H_DTYPE
from sklearn.utils._testing import skip_if_32bit


def get_derivatives_helper(loss):
    """Return get_gradients() and get_hessians() functions for a given loss.
    """

    def get_gradients(y_true, raw_predictions):
        # create gradients and hessians array, update inplace, and return
        gradients = np.empty_like(raw_predictions, dtype=G_H_DTYPE)
        hessians = np.empty_like(raw_predictions, dtype=G_H_DTYPE)
        loss.update_gradients_and_hessians(gradients, hessians, y_true,
                                           raw_predictions, None)
        return gradients

    def get_hessians(y_true, raw_predictions):
        # create gradients and hessians array, update inplace, and return
        gradients = np.empty_like(raw_predictions, dtype=G_H_DTYPE)
        hessians = np.empty_like(raw_predictions, dtype=G_H_DTYPE)
        loss.update_gradients_and_hessians(gradients, hessians, y_true,
                                           raw_predictions, None)

        if loss.__class__.__name__ == 'LeastSquares':
            # hessians aren't updated because they're constant:
            # the value is 1 (and not 2) because the loss is actually an half
            # least squares loss.
            hessians = np.full_like(raw_predictions, fill_value=1)
        elif loss.__class__.__name__ == 'LeastAbsoluteDeviation':
            # hessians aren't updated because they're constant
            hessians = np.full_like(raw_predictions, fill_value=0)

        return hessians

    return get_gradients, get_hessians


@pytest.mark.parametrize('loss, x0, y_true', [
    ('least_squares', -2., 42),
    ('least_squares', 117., 1.05),
    ('least_squares', 0., 0.),
    # I don't understand why but y_true == 0 fails :/
    # ('binary_crossentropy', 0.3, 0),
    ('binary_crossentropy', -12, 1),
    ('binary_crossentropy', 30, 1),
    ('poisson', 12., 1.),
    ('poisson', 0., 2.),
    ('poisson', -22., 10.),
])
@pytest.mark.skipif(sp_version == parse_version('1.2.0'),
                    reason='bug in scipy 1.2.0, see scipy issue #9608')
@skip_if_32bit
def test_derivatives(loss, x0, y_true):
    # Check that gradients are zero when the loss is minimized on 1D array
    # using Halley's method with the first and second order derivatives
    # computed by the Loss instance.

    loss = _LOSSES[loss](sample_weight=None)
    y_true = np.array([y_true], dtype=Y_DTYPE)
    x0 = np.array([x0], dtype=Y_DTYPE).reshape(1, 1)
    get_gradients, get_hessians = get_derivatives_helper(loss)

    def func(x):
        return loss.pointwise_loss(y_true, x)

    def fprime(x):
        return get_gradients(y_true, x)

    def fprime2(x):
        return get_hessians(y_true, x)

    optimum = newton(func, x0=x0, fprime=fprime, fprime2=fprime2,
                     maxiter=70, tol=2e-8)
    assert np.allclose(loss.inverse_link_function(optimum), y_true)
    assert np.allclose(loss.pointwise_loss(y_true, optimum), 0)
    assert np.allclose(get_gradients(y_true, optimum), 0, atol=1e-7)


@pytest.mark.parametrize('loss, n_classes, prediction_dim', [
    ('least_squares', 0, 1),
    ('least_absolute_deviation', 0, 1),
    ('binary_crossentropy', 2, 1),
    ('categorical_crossentropy', 3, 3),
    ('poisson', 0, 1),
])
@pytest.mark.skipif(Y_DTYPE != np.float64,
                    reason='Need 64 bits float precision for numerical checks')
def test_numerical_gradients(loss, n_classes, prediction_dim, seed=0):
    # Make sure gradients and hessians computed in the loss are correct, by
    # comparing with their approximations computed with finite central
    # differences.
    # See https://en.wikipedia.org/wiki/Finite_difference.

    rng = np.random.RandomState(seed)
    n_samples = 100
    if loss in ('least_squares', 'least_absolute_deviation'):
        y_true = rng.normal(size=n_samples).astype(Y_DTYPE)
    elif loss in ('poisson'):
        y_true = rng.poisson(size=n_samples).astype(Y_DTYPE)
    else:
        y_true = rng.randint(0, n_classes, size=n_samples).astype(Y_DTYPE)
    raw_predictions = rng.normal(
        size=(prediction_dim, n_samples)
    ).astype(Y_DTYPE)
    loss = _LOSSES[loss](sample_weight=None)
    get_gradients, get_hessians = get_derivatives_helper(loss)

    # only take gradients and hessians of first tree / class.
    gradients = get_gradients(y_true, raw_predictions)[0, :].ravel()
    hessians = get_hessians(y_true, raw_predictions)[0, :].ravel()

    # Approximate gradients
    # For multiclass loss, we should only change the predictions of one tree
    # (here the first), hence the use of offset[0, :] += eps
    # As a softmax is computed, offsetting the whole array by a constant would
    # have no effect on the probabilities, and thus on the loss
    eps = 1e-9
    offset = np.zeros_like(raw_predictions)
    offset[0, :] = eps
    f_plus_eps = loss.pointwise_loss(y_true, raw_predictions + offset / 2)
    f_minus_eps = loss.pointwise_loss(y_true, raw_predictions - offset / 2)
    numerical_gradients = (f_plus_eps - f_minus_eps) / eps

    # Approximate hessians
    eps = 1e-4  # need big enough eps as we divide by its square
    offset[0, :] = eps
    f_plus_eps = loss.pointwise_loss(y_true, raw_predictions + offset)
    f_minus_eps = loss.pointwise_loss(y_true, raw_predictions - offset)
    f = loss.pointwise_loss(y_true, raw_predictions)
    numerical_hessians = (f_plus_eps + f_minus_eps - 2 * f) / eps**2

    assert_allclose(numerical_gradients, gradients, rtol=1e-4, atol=1e-7)
    assert_allclose(numerical_hessians, hessians, rtol=1e-4, atol=1e-7)


def test_baseline_least_squares():
    rng = np.random.RandomState(0)

    loss = _LOSSES['least_squares'](sample_weight=None)
    y_train = rng.normal(size=100)
    baseline_prediction = loss.get_baseline_prediction(y_train, None, 1)
    assert baseline_prediction.shape == tuple()  # scalar
    assert baseline_prediction.dtype == y_train.dtype
    # Make sure baseline prediction is the mean of all targets
    assert_almost_equal(baseline_prediction, y_train.mean())
    assert np.allclose(loss.inverse_link_function(baseline_prediction),
                       baseline_prediction)


def test_baseline_least_absolute_deviation():
    rng = np.random.RandomState(0)

    loss = _LOSSES['least_absolute_deviation'](sample_weight=None)
    y_train = rng.normal(size=100)
    baseline_prediction = loss.get_baseline_prediction(y_train, None, 1)
    assert baseline_prediction.shape == tuple()  # scalar
    assert baseline_prediction.dtype == y_train.dtype
    # Make sure baseline prediction is the median of all targets
    assert np.allclose(loss.inverse_link_function(baseline_prediction),
                       baseline_prediction)
    assert baseline_prediction == pytest.approx(np.median(y_train))


def test_baseline_poisson():
    rng = np.random.RandomState(0)

    loss = _LOSSES['poisson'](sample_weight=None)
    y_train = rng.poisson(size=100).astype(np.float64)
    # Sanity check, make sure at least one sample is non-zero so we don't take
    # log(0)
    assert y_train.sum() > 0
    baseline_prediction = loss.get_baseline_prediction(y_train, None, 1)
    assert np.isscalar(baseline_prediction)
    assert baseline_prediction.dtype == y_train.dtype
    assert_all_finite(baseline_prediction)
    # Make sure baseline prediction produces the log of the mean of all targets
    assert_almost_equal(np.log(y_train.mean()), baseline_prediction)

    # Test baseline for y_true = 0
    y_train.fill(0.)
    baseline_prediction = loss.get_baseline_prediction(y_train, None, 1)
    assert_all_finite(baseline_prediction)


def test_baseline_binary_crossentropy():
    rng = np.random.RandomState(0)

    loss = _LOSSES['binary_crossentropy'](sample_weight=None)
    for y_train in (np.zeros(shape=100), np.ones(shape=100)):
        y_train = y_train.astype(np.float64)
        baseline_prediction = loss.get_baseline_prediction(y_train, None, 1)
        assert_all_finite(baseline_prediction)
        assert np.allclose(loss.inverse_link_function(baseline_prediction),
                           y_train[0])

    # Make sure baseline prediction is equal to link_function(p), where p
    # is the proba of the positive class. We want predict_proba() to return p,
    # and by definition
    # p = inverse_link_function(raw_prediction) = sigmoid(raw_prediction)
    # So we want raw_prediction = link_function(p) = log(p / (1 - p))
    y_train = rng.randint(0, 2, size=100).astype(np.float64)
    baseline_prediction = loss.get_baseline_prediction(y_train, None, 1)
    assert baseline_prediction.shape == tuple()  # scalar
    assert baseline_prediction.dtype == y_train.dtype
    p = y_train.mean()
    assert np.allclose(baseline_prediction, np.log(p / (1 - p)))


def test_baseline_categorical_crossentropy():
    rng = np.random.RandomState(0)

    prediction_dim = 4
    loss = _LOSSES['categorical_crossentropy'](sample_weight=None)
    for y_train in (np.zeros(shape=100), np.ones(shape=100)):
        y_train = y_train.astype(np.float64)
        baseline_prediction = loss.get_baseline_prediction(y_train, None,
                                                           prediction_dim)
        assert baseline_prediction.dtype == y_train.dtype
        assert_all_finite(baseline_prediction)

    # Same logic as for above test. Here inverse_link_function = softmax and
    # link_function = log
    y_train = rng.randint(0, prediction_dim + 1, size=100).astype(np.float32)
    baseline_prediction = loss.get_baseline_prediction(y_train, None,
                                                       prediction_dim)
    assert baseline_prediction.shape == (prediction_dim, 1)
    for k in range(prediction_dim):
        p = (y_train == k).mean()
        assert np.allclose(baseline_prediction[k, :], np.log(p))


@pytest.mark.parametrize('loss, problem', [
    ('least_squares', 'regression'),
    ('least_absolute_deviation', 'regression'),
    ('binary_crossentropy', 'classification'),
    ('categorical_crossentropy', 'classification'),
    ('poisson', 'poisson_regression'),
    ])
@pytest.mark.parametrize('sample_weight', ['ones', 'random'])
def test_sample_weight_multiplies_gradients(loss, problem, sample_weight):
    # Make sure that passing sample weights to the gradient and hessians
    # computation methods is equivalent to multiplying by the weights.

    rng = np.random.RandomState(42)
    n_samples = 1000

    if loss == 'categorical_crossentropy':
        n_classes = prediction_dim = 3
    else:
        n_classes = prediction_dim = 1

    if problem == 'regression':
        y_true = rng.normal(size=n_samples).astype(Y_DTYPE)
    elif problem == 'poisson_regression':
        y_true = rng.poisson(size=n_samples).astype(Y_DTYPE)
    else:
        y_true = rng.randint(0, n_classes, size=n_samples).astype(Y_DTYPE)

    if sample_weight == 'ones':
        sample_weight = np.ones(shape=n_samples, dtype=Y_DTYPE)
    else:
        sample_weight = rng.normal(size=n_samples).astype(Y_DTYPE)

    loss_ = _LOSSES[loss](sample_weight=sample_weight)

    baseline_prediction = loss_.get_baseline_prediction(
        y_true, None, prediction_dim
    )
    raw_predictions = np.zeros(shape=(prediction_dim, n_samples),
                               dtype=baseline_prediction.dtype)
    raw_predictions += baseline_prediction

    gradients = np.empty(shape=(prediction_dim, n_samples), dtype=G_H_DTYPE)
    hessians = np.ones(shape=(prediction_dim, n_samples), dtype=G_H_DTYPE)
    loss_.update_gradients_and_hessians(gradients, hessians, y_true,
                                        raw_predictions, None)

    gradients_sw = np.empty(shape=(prediction_dim, n_samples), dtype=G_H_DTYPE)
    hessians_sw = np.ones(shape=(prediction_dim, n_samples), dtype=G_H_DTYPE)
    loss_.update_gradients_and_hessians(gradients_sw, hessians_sw, y_true,
                                        raw_predictions, sample_weight)

    assert np.allclose(gradients * sample_weight, gradients_sw)
    assert np.allclose(hessians * sample_weight, hessians_sw)


def test_init_gradient_and_hessians_sample_weight():
    # Make sure that passing sample_weight to a loss correctly influences the
    # hessians_are_constant attribute, and consequently the shape of the
    # hessians array.

    prediction_dim = 2
    n_samples = 5
    sample_weight = None
    loss = _LOSSES['least_squares'](sample_weight=sample_weight)
    _, hessians = loss.init_gradients_and_hessians(
        n_samples=n_samples, prediction_dim=prediction_dim,
        sample_weight=None)
    assert loss.hessians_are_constant
    assert hessians.shape == (1, 1)

    sample_weight = np.ones(n_samples)
    loss = _LOSSES['least_squares'](sample_weight=sample_weight)
    _, hessians = loss.init_gradients_and_hessians(
        n_samples=n_samples, prediction_dim=prediction_dim,
        sample_weight=sample_weight)
    assert not loss.hessians_are_constant
    assert hessians.shape == (prediction_dim, n_samples)