File: custom_management.py

package info (click to toggle)
sqlalchemy 0.6.3-3%2Bsqueeze1
  • links: PTS, VCS
  • area: main
  • in suites: squeeze
  • size: 10,744 kB
  • ctags: 15,132
  • sloc: python: 93,431; ansic: 787; makefile: 137; xml: 17
file content (193 lines) | stat: -rw-r--r-- 6,249 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""this example illustrates how to replace SQLAlchemy's class descriptors with a user-defined system.

This sort of thing is appropriate for integration with frameworks that redefine class behaviors
in their own way, such that SQLA's default instrumentation is not compatible.   

The example illustrates redefinition of instrumentation at the class level as well as the collection
level, and redefines the storage of the class to store state within "instance._goofy_dict" instead
of "instance.__dict__".  Note that the default collection implementations can be used 
with a custom attribute system as well.

"""
from sqlalchemy import (create_engine, MetaData, Table, Column, Integer, Text,
    ForeignKey)
from sqlalchemy.orm import (mapper, relationship, create_session,
    InstrumentationManager)

from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented
from sqlalchemy.orm.collections import collection_adapter


class MyClassState(InstrumentationManager):
    def __init__(self, cls):
        self.states = {}

    def instrument_attribute(self, class_, key, attr):
        pass

    def install_descriptor(self, class_, key, attr):
        pass

    def uninstall_descriptor(self, class_, key, attr):
        pass

    def instrument_collection_class(self, class_, key, collection_class):
        return MyCollection

    def get_instance_dict(self, class_, instance):
        return instance._goofy_dict

    def initialize_instance_dict(self, class_, instance):
        instance.__dict__['_goofy_dict'] = {}

    def initialize_collection(self, key, state, factory):
        data = factory()
        return MyCollectionAdapter(key, state, data), data

    def install_state(self, class_, instance, state):
        self.states[id(instance)] = state

    def state_getter(self, class_):
        def find(instance):
            return self.states[id(instance)]
        return find

class MyClass(object):
    __sa_instrumentation_manager__ = MyClassState

    def __init__(self, **kwargs):
        for k in kwargs:
            setattr(self, k, kwargs[k])

    def __getattr__(self, key):
        if is_instrumented(self, key):
            return get_attribute(self, key)
        else:
            try:
                return self._goofy_dict[key]
            except KeyError:
                raise AttributeError(key)

    def __setattr__(self, key, value):
        if is_instrumented(self, key):
            set_attribute(self, key, value)
        else:
            self._goofy_dict[key] = value

    def __delattr__(self, key):
        if is_instrumented(self, key):
            del_attribute(self, key)
        else:
            del self._goofy_dict[key]

class MyCollectionAdapter(object):
    """An wholly alternative instrumentation implementation."""
    
    def __init__(self, key, state, collection):
        self.key = key
        self.state = state
        self.collection = collection
        setattr(collection, '_sa_adapter', self)

    def unlink(self, data):
        setattr(data, '_sa_adapter', None)

    def adapt_like_to_iterable(self, obj):
        return iter(obj)

    def append_with_event(self, item, initiator=None):
        self.collection.add(item, emit=initiator)

    def append_without_event(self, item):
        self.collection.add(item, emit=False)

    def remove_with_event(self, item, initiator=None):
        self.collection.remove(item, emit=initiator)

    def remove_without_event(self, item):
        self.collection.remove(item, emit=False)

    def clear_with_event(self, initiator=None):
        for item in list(self):
            self.remove_with_event(item, initiator)
    def clear_without_event(self):
        for item in list(self):
            self.remove_without_event(item)
    def __iter__(self):
        return iter(self.collection)

    def fire_append_event(self, item, initiator=None):
        if initiator is not False and item is not None:
            self.state.get_impl(self.key).fire_append_event(self.state, self.state.dict, item,
                                                            initiator)

    def fire_remove_event(self, item, initiator=None):
        if initiator is not False and item is not None:
            self.state.get_impl(self.key).fire_remove_event(self.state, self.state.dict, item,
                                                            initiator)

    def fire_pre_remove_event(self, initiator=None):
        self.state.get_impl(self.key).fire_pre_remove_event(self.state, self.state.dict, 
                                                            initiator)

class MyCollection(object):
    def __init__(self):
        self.members = list()
    def add(self, object, emit=None):
        self.members.append(object)
        collection_adapter(self).fire_append_event(object, emit)
    def remove(self, object, emit=None):
        collection_adapter(self).fire_pre_remove_event(object)
        self.members.remove(object)
        collection_adapter(self).fire_remove_event(object, emit)
    def __getitem__(self, index):
        return self.members[index]
    def __iter__(self):
        return iter(self.members)
    def __len__(self):
        return len(self.members)

if __name__ == '__main__':
    meta = MetaData(create_engine('sqlite://'))

    table1 = Table('table1', meta, Column('id', Integer, primary_key=True), Column('name', Text))
    table2 = Table('table2', meta, Column('id', Integer, primary_key=True), Column('name', Text), Column('t1id', Integer, ForeignKey('table1.id')))
    meta.create_all()

    class A(MyClass):
        pass

    class B(MyClass):
        pass

    mapper(A, table1, properties={
        'bs':relationship(B)
    })

    mapper(B, table2)

    a1 = A(name='a1', bs=[B(name='b1'), B(name='b2')])

    assert a1.name == 'a1'
    assert a1.bs[0].name == 'b1'
    assert isinstance(a1.bs, MyCollection)

    sess = create_session()
    sess.add(a1)

    sess.flush()
    sess.expunge_all()

    a1 = sess.query(A).get(a1.id)

    assert a1.name == 'a1'
    assert a1.bs[0].name == 'b1'
    assert isinstance(a1.bs, MyCollection)

    a1.bs.remove(a1.bs[0])

    sess.flush()
    sess.expunge_all()

    a1 = sess.query(A).get(a1.id)
    assert len(a1.bs) == 1