File: updates.py

package info (click to toggle)
theano 1.0.3+dfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster, sid
  • size: 30,752 kB
  • sloc: python: 141,182; ansic: 9,505; makefile: 259; sh: 214; pascal: 81
file content (92 lines) | stat: -rw-r--r-- 3,468 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Defines Updates object for storing a (SharedVariable, new_value) mapping.

"""
from __future__ import absolute_import, print_function, division
from collections import OrderedDict

from six import iteritems
from theano.compile.sharedvalue import SharedVariable
import logging
import warnings


__authors__ = "theano-dev"
__copyright__ = "(c) 2010, Universite de Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev <theano-dev@googlegroups.com>"

__docformat__ = "restructuredtext en"

logger = logging.getLogger('theano.updates')


# Must be an OrderedDict or updates will be applied in a non-deterministic
# order.
class OrderedUpdates(OrderedDict):
    """
    Dict-like mapping from SharedVariable keys to their new values.

    This mapping supports the use of the "+" operator for the union of updates.
    """
    def __init__(self, *key, **kwargs):
        if (len(key) >= 1 and
                isinstance(key[0], dict) and
                len(key[0]) > 1 and
                not isinstance(key[0], OrderedDict)):
            # Warn when using as input a non-ordered dictionary.
            warnings.warn('Initializing an `OrderedUpdates` from a '
                          'non-ordered dictionary with 2+ elements could '
                          'make your code non-deterministic. You can use '
                          'an OrderedDict that is available at '
                          'theano.compat.OrderedDict for python 2.6+.')
        super(OrderedUpdates, self).__init__(*key, **kwargs)
        for key in self:
            if not isinstance(key, SharedVariable):
                raise TypeError(
                    'OrderedUpdates keys must inherit from SharedVariable',
                    key)

    def __setitem__(self, key, value):
        if isinstance(key, SharedVariable):

            # TODO: consider doing error-checking on value.
            # insist that it is a Theano variable? Have the right type?
            # This could have weird consequences - for example a
            # GPU SharedVariable is customarily associated with a TensorType
            # value. Should it be cast to a GPU value right away?  Should
            # literals be transformed into constants immediately?

            return super(OrderedUpdates, self).__setitem__(key, value)
        else:
            raise TypeError('OrderedUpdates keys must inherit from '
                            'SharedVariable', key)

    def update(self, other=None):
        if other is None:
            return
        if (isinstance(other, dict) and
                len(other) > 1 and
                not isinstance(other, OrderedDict)):
            # Warn about non-determinism.
            warnings.warn('Updating an `OrderedUpdates` with a '
                          'non-ordered dictionary with 2+ elements could '
                          'make your code non-deterministic',
                          stacklevel=2)
        for key, val in iteritems(OrderedDict(other)):
            if key in self:
                if self[key] == val:
                    continue
                raise KeyError('Collision', key)
            self[key] = val  # __setitem__ does type-checking

    def __add__(self, other):
        rval = OrderedUpdates()
        rval.update(self)
        rval.update(other)
        return rval

    def __radd__(other, self):
        rval = OrderedUpdates()
        rval.update(other)
        rval.update(self)
        return rval