File: caching.py

package info (click to toggle)
brian 2.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,872 kB
  • sloc: python: 51,820; cpp: 2,033; makefile: 108; sh: 72
file content (146 lines) | stat: -rw-r--r-- 5,379 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Module to support caching of function results to memory (used to cache results
of parsing, generation of state update code, etc.). Provides the `cached`
decorator.
"""

import collections
import functools
from collections.abc import Mapping


class CacheKey:
    """
    Mixin class for objects that will be used as keys for caching (e.g.
    `Variable` objects) and have to define a certain "identity" with respect
    to caching. This "identity" is different from standard Python hashing and
    equality checking: a `Variable` for example would be considered "identical"
    for caching purposes regardless which object (e.g. `NeuronGroup`) it belongs
    to (because this does not matter for parsing, creating abstract code, etc.)
    but this of course matters for the values it refers to and therefore for
    comparison of equality to other variables.

    Classes that mix in the `CacheKey` class should re-define the
    ``_cache_irrelevant_attributes`` attribute to note all the attributes that
    should be ignored. The property ``_state_tuple`` will refer to a tuple of
    all attributes that were not excluded in such a way; this tuple will be used
    as the key for caching purposes.
    """

    #: Set of attributes that should not be considered for caching of state
    #: update code, etc.
    _cache_irrelevant_attributes = set()

    @property
    def _state_tuple(self):
        """A tuple with this object's attribute values, defining its identity
        for caching purposes. See `CacheKey` for details."""
        return tuple(
            value
            for key, value in sorted(self.__dict__.items())
            if key not in self._cache_irrelevant_attributes
        )


class _CacheStatistics:
    """
    Helper class to store cache statistics
    """

    def __init__(self):
        self.hits = 0
        self.misses = 0

    def __repr__(self):
        return f"<Cache statistics: {int(self.hits)} hits, {int(self.misses)} misses>"


def cached(func):
    """
    Decorator to cache a function so that it will not be re-evaluated when
    called with the same arguments. Uses the `_hashable` function to make
    arguments usable as a dictionary key even though they mutable (lists,
    dictionaries, etc.).

    Notes
    -----
    This is *not* a general-purpose caching decorator in any way comparable to
    ``functools.lru_cache`` or joblib's caching functions. It is very simplistic
    (no maximum cache size, no normalization of calls, e.g. ``foo(3)`` and
    ``foo(x=3)`` are not considered equivalent function calls) and makes very
    specific assumptions for our use case. Most importantly, `Variable` objects
    are considered to be identical when they refer to the same object, even
    though the actually stored values might have changed.

    Parameters
    ----------
    func : function
        The function to decorate.

    Returns
    -------
    decorated : function
        The decorated function.
    """
    # For simplicity, we store the cache in the function itself
    func._cache = {}
    func._cache_statistics = _CacheStatistics()

    @functools.wraps(func)
    def cached_func(*args, **kwds):
        try:
            cache_key = tuple(
                [_hashable(arg) for arg in args]
                + [(key, _hashable(value)) for key, value in sorted(kwds.items())]
            )
        except TypeError:
            # If we cannot handle a type here, that most likely means that the
            # user provided an argument of a type we don't handle. This will
            # lead to an error message later that is most likely more meaningful
            # to the user than an error message by the caching system
            # complaining about an unsupported type.
            return func(*args, **kwds)
        if cache_key in func._cache:
            func._cache_statistics.hits += 1
        else:
            func._cache_statistics.misses += 1
            func._cache[cache_key] = func(*args, **kwds)
        return func._cache[cache_key]

    return cached_func


_of_type_cache = collections.defaultdict(set)


def _of_type(obj_type, check_type):
    if (obj_type, check_type) not in _of_type_cache:
        _of_type_cache[(obj_type, check_type)] = issubclass(obj_type, check_type)
    return _of_type_cache[(obj_type, check_type)]


def _hashable(obj):
    """Helper function to make a few data structures hashable (e.g. a
    dictionary gets converted to a frozenset). The function is specifically
    tailored to our use case and not meant to be generally useful."""
    if hasattr(obj, "_state_tuple"):
        return _hashable(obj._state_tuple)
    obj_type = type(obj)
    if _of_type(obj_type, Mapping):
        return frozenset(
            (_hashable(key), _hashable(value)) for key, value in obj.items()
        )
    elif _of_type(obj_type, set):
        return frozenset(_hashable(el) for el in obj)
    elif _of_type(obj_type, tuple) or _of_type(obj_type, list):
        return tuple(_hashable(el) for el in obj)
    if hasattr(obj, "dim") and getattr(obj, "shape", None) == ():
        # Scalar Quantity object
        return float(obj), obj.dim
    else:
        try:
            # Make sure that the object is hashable
            hash(obj)
            return obj
        except TypeError:
            raise TypeError(f"Do not know how to handle object of type {type(obj)}")