File: test_link.py

package info (click to toggle)
scikit-learn 1.4.2%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 25,036 kB
  • sloc: python: 201,105; cpp: 5,790; ansic: 854; makefile: 304; sh: 56; javascript: 20
file content (111 lines) | stat: -rw-r--r-- 3,954 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal

from sklearn._loss.link import (
    _LINKS,
    HalfLogitLink,
    Interval,
    MultinomialLogit,
    _inclusive_low_high,
)

LINK_FUNCTIONS = list(_LINKS.values())


def test_interval_raises():
    """Test that interval with low > high raises ValueError."""
    with pytest.raises(
        ValueError, match="One must have low <= high; got low=1, high=0."
    ):
        Interval(1, 0, False, False)


@pytest.mark.parametrize(
    "interval",
    [
        Interval(0, 1, False, False),
        Interval(0, 1, False, True),
        Interval(0, 1, True, False),
        Interval(0, 1, True, True),
        Interval(-np.inf, np.inf, False, False),
        Interval(-np.inf, np.inf, False, True),
        Interval(-np.inf, np.inf, True, False),
        Interval(-np.inf, np.inf, True, True),
        Interval(-10, -1, False, False),
        Interval(-10, -1, False, True),
        Interval(-10, -1, True, False),
        Interval(-10, -1, True, True),
    ],
)
def test_is_in_range(interval):
    # make sure low and high are always within the interval, used for linspace
    low, high = _inclusive_low_high(interval)

    x = np.linspace(low, high, num=10)
    assert interval.includes(x)

    # x contains lower bound
    assert interval.includes(np.r_[x, interval.low]) == interval.low_inclusive

    # x contains upper bound
    assert interval.includes(np.r_[x, interval.high]) == interval.high_inclusive

    # x contains upper and lower bound
    assert interval.includes(np.r_[x, interval.low, interval.high]) == (
        interval.low_inclusive and interval.high_inclusive
    )


@pytest.mark.parametrize("link", LINK_FUNCTIONS)
def test_link_inverse_identity(link, global_random_seed):
    # Test that link of inverse gives identity.
    rng = np.random.RandomState(global_random_seed)
    link = link()
    n_samples, n_classes = 100, None
    # The values for `raw_prediction` are limited from -20 to 20 because in the
    # class `LogitLink` the term `expit(x)` comes very close to 1 for large
    # positive x and therefore loses precision.
    if link.is_multiclass:
        n_classes = 10
        raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples, n_classes))
        if isinstance(link, MultinomialLogit):
            raw_prediction = link.symmetrize_raw_prediction(raw_prediction)
    elif isinstance(link, HalfLogitLink):
        raw_prediction = rng.uniform(low=-10, high=10, size=(n_samples))
    else:
        raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples))

    assert_allclose(link.link(link.inverse(raw_prediction)), raw_prediction)
    y_pred = link.inverse(raw_prediction)
    assert_allclose(link.inverse(link.link(y_pred)), y_pred)


@pytest.mark.parametrize("link", LINK_FUNCTIONS)
def test_link_out_argument(link):
    # Test that out argument gets assigned the result.
    rng = np.random.RandomState(42)
    link = link()
    n_samples, n_classes = 100, None
    if link.is_multiclass:
        n_classes = 10
        raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples, n_classes))
        if isinstance(link, MultinomialLogit):
            raw_prediction = link.symmetrize_raw_prediction(raw_prediction)
    else:
        # So far, the valid interval of raw_prediction is (-inf, inf) and
        # we do not need to distinguish.
        raw_prediction = rng.uniform(low=-10, high=10, size=(n_samples))

    y_pred = link.inverse(raw_prediction, out=None)
    out = np.empty_like(raw_prediction)
    y_pred_2 = link.inverse(raw_prediction, out=out)
    assert_allclose(y_pred, out)
    assert_array_equal(out, y_pred_2)
    assert np.shares_memory(out, y_pred_2)

    out = np.empty_like(y_pred)
    raw_prediction_2 = link.link(y_pred, out=out)
    assert_allclose(raw_prediction, out)
    assert_array_equal(out, raw_prediction_2)
    assert np.shares_memory(out, raw_prediction_2)