File: transformers.py

package info (click to toggle)
sklearn-pandas 2.2.0-5
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 440 kB
  • sloc: python: 1,177; sh: 12; makefile: 8
file content (55 lines) | stat: -rw-r--r-- 1,595 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin
import warnings


def _get_mask(X, value):
    """
    Compute the boolean mask X == missing_values.
    """
    if value == "NaN" or \
       value is None or \
       (isinstance(value, float) and np.isnan(value)):
        return pd.isnull(X)
    else:
        return X == value


class NumericalTransformer(TransformerMixin):
    """
    Provides commonly used numerical transformers.
    """
    SUPPORTED_FUNCTIONS = ['log', 'log1p']

    def __init__(self, func):
        """
        Params

        func    function to apply to input columns. The function will be
                applied to each value. Supported functions are defined
                in SUPPORTED_FUNCTIONS variable. Throws assertion error if the
                not supported.
        """

        warnings.warn("""
            NumericalTransformer will be deprecated in 3.0 version.
            Please use Sklearn.base.TransformerMixin to write
            customer transformers
            """, DeprecationWarning)

        assert func in self.SUPPORTED_FUNCTIONS, \
            f"Only following func are supported: {self.SUPPORTED_FUNCTIONS}"
        super(NumericalTransformer, self).__init__()
        self.__func = func

    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None):
        if self.__func == 'log1p':
            return np.vectorize(np.log1p)(X)
        elif self.__func == 'log':
            return np.vectorize(np.log)(X)

        raise ValueError(f"Invalid function name: {self.__func}")