File: link.py

package info (click to toggle)
scikit-learn 0.23.2-5
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 21,892 kB
  • sloc: python: 132,020; cpp: 5,765; javascript: 2,201; ansic: 831; makefile: 213; sh: 44
file content (110 lines) | stat: -rw-r--r-- 2,690 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Link functions used in GLM
"""

# Author: Christian Lorentzen <lorentzen.ch@googlemail.com>
# License: BSD 3 clause

from abc import ABCMeta, abstractmethod

import numpy as np
from scipy.special import expit, logit


class BaseLink(metaclass=ABCMeta):
    """Abstract base class for Link functions."""

    @abstractmethod
    def __call__(self, y_pred):
        """Compute the link function g(y_pred).

        The link function links the mean y_pred=E[Y] to the so called linear
        predictor (X*w), i.e. g(y_pred) = linear predictor.

        Parameters
        ----------
        y_pred : array of shape (n_samples,)
            Usually the (predicted) mean.
        """

    @abstractmethod
    def derivative(self, y_pred):
        """Compute the derivative of the link g'(y_pred).

        Parameters
        ----------
        y_pred : array of shape (n_samples,)
            Usually the (predicted) mean.
        """

    @abstractmethod
    def inverse(self, lin_pred):
        """Compute the inverse link function h(lin_pred).

        Gives the inverse relationship between linear predictor and the mean
        y_pred=E[Y], i.e. h(linear predictor) = y_pred.

        Parameters
        ----------
        lin_pred : array of shape (n_samples,)
            Usually the (fitted) linear predictor.
        """

    @abstractmethod
    def inverse_derivative(self, lin_pred):
        """Compute the derivative of the inverse link function h'(lin_pred).

        Parameters
        ----------
        lin_pred : array of shape (n_samples,)
            Usually the (fitted) linear predictor.
        """


class IdentityLink(BaseLink):
    """The identity link function g(x)=x."""

    def __call__(self, y_pred):
        return y_pred

    def derivative(self, y_pred):
        return np.ones_like(y_pred)

    def inverse(self, lin_pred):
        return lin_pred

    def inverse_derivative(self, lin_pred):
        return np.ones_like(lin_pred)


class LogLink(BaseLink):
    """The log link function g(x)=log(x)."""

    def __call__(self, y_pred):
        return np.log(y_pred)

    def derivative(self, y_pred):
        return 1 / y_pred

    def inverse(self, lin_pred):
        return np.exp(lin_pred)

    def inverse_derivative(self, lin_pred):
        return np.exp(lin_pred)


class LogitLink(BaseLink):
    """The logit link function g(x)=logit(x)."""

    def __call__(self, y_pred):
        return logit(y_pred)

    def derivative(self, y_pred):
        return 1 / (y_pred * (1 - y_pred))

    def inverse(self, lin_pred):
        return expit(lin_pred)

    def inverse_derivative(self, lin_pred):
        ep = expit(lin_pred)
        return ep * (1 - ep)