File: __init__.py

package info (click to toggle)
python-cassandra-driver 3.24.0-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 4,840 kB
  • sloc: python: 50,759; ansic: 771; makefile: 132
file content (303 lines) | stat: -rw-r--r-- 12,811 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import copy

from concurrent.futures import Future

HAVE_GREMLIN = False
try:
    import gremlin_python
    HAVE_GREMLIN = True
except ImportError:
    # gremlinpython is not installed.
    pass

if HAVE_GREMLIN:
    from gremlin_python.structure.graph import Graph
    from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal
    from gremlin_python.process.traversal import Traverser, TraversalSideEffects
    from gremlin_python.process.graph_traversal import GraphTraversal

    from cassandra.cluster import Session, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT
    from cassandra.datastax.graph import GraphOptions, GraphProtocol
    from cassandra.datastax.graph.query import _GraphSONContextRowFactory

    from cassandra.datastax.graph.fluent.serializers import (
        GremlinGraphSONReaderV2,
        GremlinGraphSONReaderV3,
        dse_graphson2_deserializers,
        gremlin_graphson2_deserializers,
        dse_graphson3_deserializers,
        gremlin_graphson3_deserializers
    )
    from cassandra.datastax.graph.fluent.query import _DefaultTraversalBatch, _query_from_traversal

    log = logging.getLogger(__name__)

    __all__ = ['BaseGraphRowFactory', 'graph_traversal_row_factory',
               'graph_traversal_dse_object_row_factory', 'DSESessionRemoteGraphConnection', 'DseGraph']

    # Traversal result keys
    _bulk_key = 'bulk'
    _result_key = 'result'


    class BaseGraphRowFactory(_GraphSONContextRowFactory):
        """
        Base row factory for graph traversal. This class basically wraps a
        graphson reader function to handle additional features of Gremlin/DSE
        and is callable as a normal row factory.

        Currently supported:
          - bulk results
        """

        def __call__(self, column_names, rows):
            for row in rows:
                parsed_row = self.graphson_reader.readObject(row[0])
                yield parsed_row[_result_key]
                bulk = parsed_row.get(_bulk_key, 1)
                for _ in range(bulk - 1):
                    yield copy.deepcopy(parsed_row[_result_key])


    class _GremlinGraphSON2RowFactory(BaseGraphRowFactory):
        """Row Factory that returns the decoded graphson2."""
        graphson_reader_class = GremlinGraphSONReaderV2
        graphson_reader_kwargs = {'deserializer_map': gremlin_graphson2_deserializers}


    class _DseGraphSON2RowFactory(BaseGraphRowFactory):
        """Row Factory that returns the decoded graphson2 as DSE types."""
        graphson_reader_class = GremlinGraphSONReaderV2
        graphson_reader_kwargs = {'deserializer_map': dse_graphson2_deserializers}

    gremlin_graphson2_traversal_row_factory = _GremlinGraphSON2RowFactory
    # TODO remove in next major
    graph_traversal_row_factory = gremlin_graphson2_traversal_row_factory

    dse_graphson2_traversal_row_factory = _DseGraphSON2RowFactory
    # TODO remove in next major
    graph_traversal_dse_object_row_factory = dse_graphson2_traversal_row_factory


    class _GremlinGraphSON3RowFactory(BaseGraphRowFactory):
        """Row Factory that returns the decoded graphson2."""
        graphson_reader_class = GremlinGraphSONReaderV3
        graphson_reader_kwargs = {'deserializer_map': gremlin_graphson3_deserializers}


    class _DseGraphSON3RowFactory(BaseGraphRowFactory):
        """Row Factory that returns the decoded graphson3 as DSE types."""
        graphson_reader_class = GremlinGraphSONReaderV3
        graphson_reader_kwargs = {'deserializer_map': dse_graphson3_deserializers}


    gremlin_graphson3_traversal_row_factory = _GremlinGraphSON3RowFactory
    dse_graphson3_traversal_row_factory = _DseGraphSON3RowFactory


    class DSESessionRemoteGraphConnection(RemoteConnection):
        """
        A Tinkerpop RemoteConnection to execute traversal queries on DSE.

        :param session: A DSE session
        :param graph_name: (Optional) DSE Graph name.
        :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`.
        """

        session = None
        graph_name = None
        execution_profile = None

        def __init__(self, session, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT):
            super(DSESessionRemoteGraphConnection, self).__init__(None, None)

            if not isinstance(session, Session):
                raise ValueError('A DSE Session must be provided to execute graph traversal queries.')

            self.session = session
            self.graph_name = graph_name
            self.execution_profile = execution_profile

        @staticmethod
        def _traversers_generator(traversers):
            for t in traversers:
                yield Traverser(t)

        def _prepare_query(self, bytecode):
            ep = self.session.execution_profile_clone_update(self.execution_profile)
            graph_options = ep.graph_options
            graph_options.graph_name = self.graph_name or graph_options.graph_name
            graph_options.graph_language = DseGraph.DSE_GRAPH_QUERY_LANGUAGE
            # We resolve the execution profile options here , to know how what gremlin factory to set
            self.session._resolve_execution_profile_options(ep)

            context = None
            if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0:
                row_factory = gremlin_graphson2_traversal_row_factory
            elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0:
                row_factory = gremlin_graphson3_traversal_row_factory
                context = {
                    'cluster': self.session.cluster,
                    'graph_name': graph_options.graph_name.decode('utf-8')
                }
            else:
                raise ValueError('Unknown graph protocol: {}'.format(graph_options.graph_protocol))

            ep.row_factory = row_factory
            query = DseGraph.query_from_traversal(bytecode, graph_options.graph_protocol, context)

            return query, ep

        @staticmethod
        def _handle_query_results(result_set, gremlin_future):
            try:
                gremlin_future.set_result(
                    RemoteTraversal(DSESessionRemoteGraphConnection._traversers_generator(result_set), TraversalSideEffects())
                )
            except Exception as e:
                gremlin_future.set_exception(e)

        @staticmethod
        def _handle_query_error(response, gremlin_future):
            gremlin_future.set_exception(response)

        def submit(self, bytecode):
            # the only reason I don't use submitAsync here
            # is to avoid an unuseful future wrap
            query, ep = self._prepare_query(bytecode)

            traversers = self.session.execute_graph(query, execution_profile=ep)
            return RemoteTraversal(self._traversers_generator(traversers), TraversalSideEffects())

        def submitAsync(self, bytecode):
            query, ep = self._prepare_query(bytecode)

            # to be compatible with gremlinpython, we need to return a concurrent.futures.Future
            gremlin_future = Future()
            response_future = self.session.execute_graph_async(query, execution_profile=ep)
            response_future.add_callback(self._handle_query_results, gremlin_future)
            response_future.add_errback(self._handle_query_error, gremlin_future)

            return gremlin_future

        def __str__(self):
            return "<DSESessionRemoteGraphConnection: graph_name='{0}'>".format(self.graph_name)

        __repr__ = __str__


    class DseGraph(object):
        """
        Dse Graph utility class for GraphTraversal construction and execution.
        """

        DSE_GRAPH_QUERY_LANGUAGE = 'bytecode-json'
        """
        Graph query language, Default is 'bytecode-json' (GraphSON).
        """

        DSE_GRAPH_QUERY_PROTOCOL = GraphProtocol.GRAPHSON_2_0
        """
        Graph query language, Default is GraphProtocol.GRAPHSON_2_0.
        """

        @staticmethod
        def query_from_traversal(traversal, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, context=None):
            """
            From a GraphTraversal, return a query string based on the language specified in `DseGraph.DSE_GRAPH_QUERY_LANGUAGE`.

            :param traversal: The GraphTraversal object
            :param graph_protocol: The graph protocol. Default is `DseGraph.DSE_GRAPH_QUERY_PROTOCOL`.
            :param context: The dict of the serialization context, needed for GraphSON3 (tuple, udt).
                            e.g: {'cluster': cluster, 'graph_name': name}
            """

            if isinstance(traversal, GraphTraversal):
                for strategy in traversal.traversal_strategies.traversal_strategies:
                    rc = strategy.remote_connection
                    if (isinstance(rc, DSESessionRemoteGraphConnection) and
                            rc.session or rc.graph_name or rc.execution_profile):
                        log.warning("GraphTraversal session, graph_name and execution_profile are "
                                    "only taken into account when executed with TinkerPop.")

            return _query_from_traversal(traversal, graph_protocol, context)

        @staticmethod
        def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT,
                             traversal_class=None):
            """
            Returns a TinkerPop GraphTraversalSource binded to the session and graph_name if provided.

            :param session: (Optional) A DSE session
            :param graph_name: (Optional) DSE Graph name
            :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`.
            :param traversal_class: (Optional) The GraphTraversalSource class to use (DSL).

            .. code-block:: python

                from cassandra.cluster import Cluster
                from cassandra.datastax.graph.fluent import DseGraph

                c = Cluster()
                session = c.connect()

                g = DseGraph.traversal_source(session, 'my_graph')
                print g.V().valueMap().toList()

            """

            graph = Graph()
            traversal_source = graph.traversal(traversal_class)

            if session:
                traversal_source = traversal_source.withRemote(
                    DSESessionRemoteGraphConnection(session, graph_name, execution_profile))

            return traversal_source

        @staticmethod
        def create_execution_profile(graph_name, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, **kwargs):
            """
            Creates an ExecutionProfile for GraphTraversal execution. You need to register that execution profile to the
            cluster by using `cluster.add_execution_profile`.

            :param graph_name: The graph name
            :param graph_protocol: (Optional) The graph protocol, default is `DSE_GRAPH_QUERY_PROTOCOL`.
            """

            if graph_protocol == GraphProtocol.GRAPHSON_2_0:
                row_factory = dse_graphson2_traversal_row_factory
            elif graph_protocol == GraphProtocol.GRAPHSON_3_0:
                row_factory = dse_graphson3_traversal_row_factory
            else:
                raise ValueError('Unknown graph protocol: {}'.format(graph_protocol))

            ep = GraphExecutionProfile(row_factory=row_factory,
                                       graph_options=GraphOptions(graph_name=graph_name,
                                                                  graph_language=DseGraph.DSE_GRAPH_QUERY_LANGUAGE,
                                                                  graph_protocol=graph_protocol),
                                       **kwargs)
            return ep

        @staticmethod
        def batch(*args, **kwargs):
            """
            Returns the :class:`cassandra.datastax.graph.fluent.query.TraversalBatch` object allowing to
            execute multiple traversals in the same transaction.
            """
            return _DefaultTraversalBatch(*args, **kwargs)