File: decorators.py

package info (click to toggle)
astroml 1.0.2-6
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 932 kB
  • sloc: python: 5,731; makefile: 3
file content (154 lines) | stat: -rw-r--r-- 5,784 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import warnings
import functools
from packaging.version import Version

import numpy as np
import astropy
import pickle

from astroML.utils.exceptions import AstroMLDeprecationWarning

# We use functionality of the deprecated decorator from astropy that was
# added in v2.0.10 LTS and v3.1
av = astropy.__version__
ASTROPY_LT_31 = (Version(av) < Version("2.0.10") or
                 (Version("3.0") <= Version(av) and
                  Version(av) < Version("3.1")))


__all__ = ['pickle_results', 'deprecated']


def pickle_results(filename=None, verbose=True):
    """Generator for decorator which allows pickling the results of a funcion

    Pickle is python's built-in object serialization.  This decorator, when
    used on a function, saves the results of the computation in the function
    to a pickle file.  If the function is called a second time with the
    same inputs, then the computation will not be repeated and the previous
    results will be used.

    This functionality is useful for computations which take a long time,
    but will need to be repeated (such as the first step of a data analysis).

    Parameters
    ----------
    filename : string (optional)
        pickle file to which results will be saved.
        If not specified, then the file is '<funcname>_output.pkl'
        where '<funcname>' is replaced by the name of the decorated function.
    verbose : boolean (optional)
        if True, then print a message to standard out specifying when the
        pickle file is written or read.

    Examples
    --------
    >>> @pickle_results('tmp.pkl', verbose=True)
    ... def f(x):
    ...     return x * x
    >>> f(4)
    @pickle_results: computing results and saving to 'tmp.pkl'
    16
    >>> f(4)
    @pickle_results: using precomputed results from 'tmp.pkl'
    16
    """
    def pickle_func(f, filename=filename, verbose=verbose):
        if filename is None:
            filename = '%s_output.pkl' % f.__name__

        def new_f(*args, **kwargs):
            # While loading, pickle, can raise any number of errors. Cover cases
            # when FileNotFoundError or when when pickle raises an error as equivalent.
            # In either case the data in the cache will have to be regenerated.
            try:
                D = pickle.load(open(filename, 'rb'))
                cache_exists = True
            except Exception:
                D = {}
                cache_exists = False

            # simple comparison doesn't work in the case of numpy arrays
            Dargs = D.get('args')
            Dkwargs = D.get('kwargs')

            try:
                args_match = (args == Dargs)
            except ValueError:
                args_match = np.all([np.all(a1 == a2)
                                     for (a1, a2) in zip(Dargs, args)])

            try:
                kwargs_match = (kwargs == Dkwargs)
            except ValueError:
                kwargs_match = ((sorted(Dkwargs.keys())
                                 == sorted(kwargs.keys()))
                                and (np.all([np.all(Dkwargs[key]
                                                    == kwargs[key])
                                             for key in kwargs])))

            if (type(D) == dict and D.get('funcname') == f.__name__
                    and args_match and kwargs_match):
                if verbose:
                    print("@pickle_results: using precomputed "
                          "results from '%s'" % filename)
                retval = D['retval']

            else:
                if verbose:
                    print("@pickle_results: computing results "
                          "and saving to '%s'" % filename)
                    if cache_exists:
                        print("  warning: cache file '%s' exists" % filename)
                        print("    - args match:   %s" % args_match)
                        print("    - kwargs match: %s" % kwargs_match)
                retval = f(*args, **kwargs)

                funcdict = dict(funcname=f.__name__, retval=retval,
                                args=args, kwargs=kwargs)
                with open(filename, 'wb') as outfile:
                    pickle.dump(funcdict, outfile)

            return retval
        return new_f
    return pickle_func


if not ASTROPY_LT_31:
    from astropy.utils.decorators import deprecated
else:
    def deprecated(since, message='', alternative=None, **kwargs):

        def deprecate_function(func, message=message, since=since,
                               alternative=alternative):
            if message == '':
                message = ('Function {} has been deprecated since {}.'
                           .format(func.__name__, since))
                if alternative is not None:
                    message += '\n Use {} instead.'.format(alternative)

            @functools.wraps(func)
            def deprecated_func(*args, **kwargs):
                warnings.warn(message, AstroMLDeprecationWarning)
                return func(*args, **kwargs)
            return deprecated_func

        def deprecate_class(cls, message=message, since=since,
                            alternative=alternative):
            if message == '':
                message = ('Class {} has been deprecated since {}.'
                           .format(cls.__name__, since))
                if alternative is not None:
                    message += '\n Use {} instead.'.format(alternative)

            cls.__init__ = deprecate_function(cls.__init__, message=message)

            return cls

        def deprecate(obj):
            if isinstance(obj, type):
                return deprecate_class(obj)
            else:
                return deprecate_function(obj)

        return deprecate