File: index.py

package info (click to toggle)
python-lunr 0.8.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,644 kB
  • sloc: python: 3,811; javascript: 114; makefile: 60
file content (368 lines) | stat: -rw-r--r-- 15,778 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
from collections import defaultdict
import json
import logging

from lunr.exceptions import BaseLunrException
from lunr.field_ref import FieldRef
from lunr.match_data import MatchData
from lunr.token_set import TokenSet
from lunr.token_set_builder import TokenSetBuilder
from lunr.pipeline import Pipeline
from lunr.query import Query, QueryPresence
from lunr.query_parser import QueryParser
from lunr.utils import CompleteSet
from lunr.vector import Vector

logger = logging.getLogger(__name__)


class Index:
    """An index contains the built index of all documents and provides a query
    interface to the index.

    Usually instances of lunr.Index will not be created using this
    constructor, instead lunr.Builder should be used to construct new
    indexes, or lunr.Index.load should be used to load previously built and
    serialized indexes.
    """

    def __init__(self, inverted_index, field_vectors, token_set, fields, pipeline):
        self.inverted_index = inverted_index
        self.field_vectors = field_vectors
        self.token_set = token_set
        self.fields = fields
        self.pipeline = pipeline

    def __eq__(self, other):
        # TODO: extend equality to other attributes
        return (
            self.inverted_index == other.inverted_index and self.fields == other.fields
        )

    def search(self, query_string):
        """Performs a search against the index using lunr query syntax.

        Results will be returned sorted by their score, the most relevant
        results will be returned first.

        For more programmatic querying use `lunr.Index.query`.

        Args:
            query_string (str): A string to parse into a Query.

        Returns:
            dict: Results of executing the query.
        """
        query = self.create_query()
        # TODO: should QueryParser be a method of query? should it return one?
        parser = QueryParser(query_string, query)
        parser.parse()
        return self.query(query)

    def create_query(self, fields=None):
        """Convenience method to create a Query with the Index's fields.

        Args:
            fields (iterable, optional): The fields to include in the Query,
                defaults to the Index's `all_fields`.

        Returns:
            Query: With the specified fields or all the fields in the Index.
        """
        if fields is None:
            return Query(self.fields)

        non_contained_fields = set(fields) - set(self.fields)
        if non_contained_fields:
            raise BaseLunrException(
                "Fields {} are not part of the index", non_contained_fields
            )

        return Query(fields)

    def query(self, query=None, callback=None):
        """Performs a query against the index using the passed lunr.Query
        object.

        If performing programmatic queries against the index, this method is
        preferred over `lunr.Index.search` so as to avoid the additional query
        parsing overhead.

        Args:
            query (lunr.Query): A configured Query to perform the search
                against, use `create_query` to get a preconfigured object
                or use `callback` for convenience.
            callback (callable): An optional function taking a single Query
                object result of `create_query` for further configuration.
        """
        if query is None:
            query = self.create_query()

        if callback is not None:
            callback(query)

        if len(query.clauses) == 0:
            logger.warning(
                "Attempting a query with no clauses. Please add clauses by "
                "either using the `callback` argument or using `create_query` "
                "to create a preconfigured Query, manually adding clauses and "
                "passing it as the `query` argument."
            )
            return []

        # for each query clause
        # * process terms
        # * expand terms from token set
        # * find matching documents and metadata
        # * get document vectors
        # * score documents

        matching_fields = {}
        query_vectors = {field: Vector() for field in self.fields}
        term_field_cache = {}
        required_matches = {}
        prohibited_matches = defaultdict(set)

        for clause in query.clauses:
            # Unless the pipeline has been disabled for this term, which is
            # the case for terms with wildcards, we need to pass the clause
            # term through the search pipeline. A pipeline returns an array
            # of processed terms. Pipeline functions may expand the passed
            # term, which means we may end up performing multiple index lookups
            # for a single query term.
            if clause.use_pipeline:
                terms = self.pipeline.run_string(clause.term, {"fields": clause.fields})
            else:
                terms = [clause.term]

            clause_matches = set()

            for term in terms:
                # Each term returned from the pipeline needs to use the same
                # query clause object, e.g. the same boost and or edit distance
                # The simplest way to do this is to re-use the clause object
                # but mutate its term property.
                clause.term = term

                # From the term in the clause we create a token set which will
                # then be used to intersect the indexes token set to get a list
                # of terms to lookup in the inverted index
                term_token_set = TokenSet.from_clause(clause)
                expanded_terms = self.token_set.intersect(term_token_set).to_list()

                # If a term marked as required does not exist in the TokenSet
                # it is impossible for the search to return any matches.
                # We set all the field-scoped required matches set to empty
                # and stop examining further clauses
                if (
                    len(expanded_terms) == 0
                    and clause.presence == QueryPresence.REQUIRED
                ):
                    for field in clause.fields:
                        required_matches[field] = CompleteSet()
                    break

                for expanded_term in expanded_terms:
                    posting = self.inverted_index[expanded_term]
                    term_index = posting["_index"]

                    for field in clause.fields:
                        # For each field that this query term is scoped by
                        # (by default all fields are in scope) we need to get
                        # all the document refs that have this term in that
                        # field.
                        #
                        # The posting is the entry in the invertedIndex for the
                        # matching term from above.
                        field_posting = posting[field]
                        matching_document_refs = field_posting.keys()
                        term_field = expanded_term + "/" + field
                        matching_documents_set = set(matching_document_refs)

                        # If the presence of this term is required, ensure that
                        # the matching documents are added to the set of
                        # required matches for this clause.
                        if clause.presence == QueryPresence.REQUIRED:
                            clause_matches = clause_matches.union(
                                matching_documents_set
                            )

                            if field not in required_matches:
                                required_matches[field] = CompleteSet()

                        # If the presence of this term is prohibited,
                        # ensure that the matching documents are added to the
                        # set of prohibited matches for this field, creating
                        # that set if it does not exist yet.
                        elif clause.presence == QueryPresence.PROHIBITED:
                            prohibited_matches[field] = prohibited_matches[field].union(
                                matching_documents_set
                            )

                            # prohibited matches should not be part of the
                            # query vector used for similarity scoring and no
                            # metadata should be extracted so we continue
                            # to the next field
                            continue

                        # The query field vector is populated using the
                        # term_index found for the term an a unit value with
                        # the appropriate boost
                        # Using upsert because there could already be an entry
                        # in the vector for the term we are working with.
                        # In that case we just add the scores together.
                        query_vectors[field].upsert(
                            term_index, clause.boost, lambda a, b: a + b
                        )

                        # If we've already seen this term, field combo then
                        # we've already collected the matching documents and
                        # metadata, no need to go through all that again
                        if term_field in term_field_cache:
                            continue

                        for matching_document_ref in matching_document_refs:
                            # All metadata for this term/field/document triple
                            # are then extracted and collected into an instance
                            # of lunr.MatchData ready to be returned in the
                            # query results
                            matching_field_ref = FieldRef(matching_document_ref, field)
                            metadata = field_posting[str(matching_document_ref)]

                            if str(matching_field_ref) not in matching_fields:
                                matching_fields[str(matching_field_ref)] = MatchData(
                                    expanded_term, field, metadata
                                )
                            else:
                                matching_fields[str(matching_field_ref)].add(
                                    expanded_term, field, metadata
                                )

                        term_field_cache[term_field] = True

            # if the presence was required we need to update the required
            # matches field sets, we do this after all fields for the term
            # have collected their matches because the clause terms presence
            # is required in _any_ of the fields, not _all_ of the fields
            if clause.presence == QueryPresence.REQUIRED:
                for field in clause.fields:
                    required_matches[field] = required_matches[field].intersection(
                        clause_matches
                    )

        # We need to combine the field scoped required and prohibited
        # matching documents inot a global set of required and prohibited
        # matches
        all_required_matches = CompleteSet()
        all_prohibited_matches = set()
        for field in self.fields:
            if field in required_matches:
                all_required_matches = all_required_matches.intersection(
                    required_matches[field]
                )
            if field in prohibited_matches:
                all_prohibited_matches = all_prohibited_matches.union(
                    prohibited_matches[field]
                )

        matching_field_refs = matching_fields.keys()
        results = []
        matches = {}

        # If the query is negated (only contains prohibited terms)
        # we need to get _all_ field_refs currently existing in the index.
        # This to avoid any costs of getting all field regs unnecessarily
        # Additionally, blank match data must be created to correctly populate
        # the results
        if query.is_negated():
            matching_field_refs = list(self.field_vectors.keys())

            for matching_field_ref in matching_field_refs:
                field_ref = FieldRef.from_string(matching_field_ref)
                matching_fields[matching_field_ref] = MatchData()

        for matching_field_ref in matching_field_refs:
            # Currently we have document fields that match the query, but we
            # need to return documents. The matchData and scores are combined
            # from multiple fields belonging to the same document.
            #
            # Scores are calculated by field, using the query vectors created
            # above, and combined into a final document score using addition.
            field_ref = FieldRef.from_string(matching_field_ref)
            doc_ref = field_ref.doc_ref

            if doc_ref not in all_required_matches or doc_ref in all_prohibited_matches:
                continue

            field_vector = self.field_vectors[matching_field_ref]
            score = query_vectors[field_ref.field_name].similarity(field_vector)

            try:
                doc_match = matches[doc_ref]
                doc_match["score"] += score
                doc_match["match_data"].combine(matching_fields[matching_field_ref])
            except KeyError:
                match = {
                    "ref": doc_ref,
                    "score": score,
                    "match_data": matching_fields[matching_field_ref],
                }
                matches[doc_ref] = match
                results.append(match)

        return sorted(results, key=lambda a: a["score"], reverse=True)

    def serialize(self):
        """Returns a serialized index as a dict following lunr-schema."""
        from lunr import __TARGET_JS_VERSION__

        inverted_index = [
            [term, self.inverted_index[term]] for term in sorted(self.inverted_index)
        ]
        field_vectors = [
            [ref, vector.serialize()] for ref, vector in self.field_vectors.items()
        ]

        # CamelCased keys for compatibility with JS version
        return {
            "version": __TARGET_JS_VERSION__,
            "fields": self.fields,
            "fieldVectors": field_vectors,
            "invertedIndex": inverted_index,
            "pipeline": self.pipeline.serialize(),
        }

    @classmethod
    def load(cls, serialized_index):
        """Load a serialized index"""
        from lunr import __TARGET_JS_VERSION__

        if isinstance(serialized_index, str):
            serialized_index = json.loads(serialized_index)

        if serialized_index["version"] != __TARGET_JS_VERSION__:
            logger.warning(
                "Version mismatch when loading serialized index. "
                "Current version of lunr {} does not match that of serialized "
                "index {}".format(__TARGET_JS_VERSION__, serialized_index["version"])
            )

        field_vectors = {
            ref: Vector(elements) for ref, elements in serialized_index["fieldVectors"]
        }

        tokenset_builder = TokenSetBuilder()
        inverted_index = {}
        for term, posting in serialized_index["invertedIndex"]:
            tokenset_builder.insert(term)
            inverted_index[term] = posting

        tokenset_builder.finish()

        return Index(
            fields=serialized_index["fields"],
            field_vectors=field_vectors,
            inverted_index=inverted_index,
            token_set=tokenset_builder.root,
            pipeline=Pipeline.load(serialized_index["pipeline"]),
        )