File: pipeline.py

package info (click to toggle)
python-lunr 0.8.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,644 kB
  • sloc: python: 3,811; javascript: 114; makefile: 60
file content (161 lines) | stat: -rw-r--r-- 5,451 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from collections import defaultdict
import logging
from typing import Callable, Dict, List, Set

from lunr.exceptions import BaseLunrException
from lunr.token import Token

log = logging.getLogger(__name__)


class Pipeline:
    """lunr.Pipelines maintain a list of functions to be applied to all tokens
    in documents entering the search index and queries ran agains the index.

    """

    registered_functions: Dict[str, Callable] = {}

    def __init__(self):
        self._stack: List[Callable] = []
        self._skip: Dict[Callable, Set[str]] = defaultdict(set)

    def __len__(self):
        return len(self._stack)

    def __repr__(self):
        return '<Pipeline stack="{}">'.format(",".join(fn.label for fn in self._stack))

    # TODO: add iterator methods?

    @classmethod
    def register_function(cls, fn, label=None):
        """Register a function with the pipeline."""
        label = label or fn.__name__
        if label in cls.registered_functions:
            log.warning("Overwriting existing registered function %s", label)

        fn.label = label
        cls.registered_functions[fn.label] = fn

    @classmethod
    def load(cls, serialised):
        """Loads a previously serialised pipeline."""
        pipeline = cls()
        for fn_name in serialised:
            try:
                fn = cls.registered_functions[fn_name]
            except KeyError:
                raise BaseLunrException(
                    "Cannot load unregistered function {}".format(fn_name)
                )
            else:
                pipeline.add(fn)

        return pipeline

    def add(self, *args):
        """Adds new functions to the end of the pipeline.

        Functions must accept three arguments:
        - Token: A lunr.Token object which will be updated
        - i: The index of the token in the set
        - tokens: A list of tokens representing the set
        """
        for fn in args:
            self.warn_if_function_not_registered(fn)
            self._stack.append(fn)

    def warn_if_function_not_registered(self, fn):
        try:
            return fn.label in self.registered_functions
        except AttributeError:
            log.warning(
                'Function "{}" is not registered with pipeline. '
                "This may cause problems when serialising the index.".format(
                    getattr(fn, "label", fn)
                )
            )

    def after(self, existing_fn, new_fn):
        """Adds a single function after a function that already exists in the
        pipeline."""
        self.warn_if_function_not_registered(new_fn)
        try:
            index = self._stack.index(existing_fn)
            self._stack.insert(index + 1, new_fn)
        except ValueError as e:
            raise BaseLunrException("Cannot find existing_fn") from e

    def before(self, existing_fn, new_fn):
        """Adds a single function before a function that already exists in the
        pipeline.

        """
        self.warn_if_function_not_registered(new_fn)
        try:
            index = self._stack.index(existing_fn)
            self._stack.insert(index, new_fn)
        except ValueError as e:
            raise BaseLunrException("Cannot find existing_fn") from e

    def remove(self, fn):
        """Removes a function from the pipeline."""
        try:
            self._stack.remove(fn)
        except ValueError:
            pass

    def skip(self, fn: Callable, field_names: List[str]):
        """
        Make the pipeline skip the function based on field name we're processing.

        This relies on passing the field name to Pipeline.run().
        """
        self._skip[fn].update(field_names)

    def run(self, tokens, field_name=None):
        """
        Runs the current list of functions that make up the pipeline against
        the passed tokens.

        :param tokens: The tokens to process.
        :param field_name: The name of the field these tokens belongs to, can be ommited.
            Used to skip some functions based on field names.
        """
        for fn in self._stack:
            # Skip the function based on field name.
            if field_name and field_name in self._skip[fn]:
                continue
            results = []
            for i, token in enumerate(tokens):
                # JS ignores additional arguments to the functions but we
                # force pipeline functions to declare (token, i, tokens)
                # or *args
                result = fn(token, i, tokens)
                if not result:
                    continue
                if isinstance(result, (list, tuple)):  # simulate Array.concat
                    results.extend(result)
                else:
                    results.append(result)
            tokens = results

        return tokens

    def run_string(self, string, metadata=None):
        """Convenience method for passing a string through a pipeline and
        getting strings out. This method takes care of wrapping the passed
        string in a token and mapping the resulting tokens back to strings.

        .. note:: This ignores the skipped functions since we can't
            access field names from this context.
        """
        token = Token(string, metadata)
        return [str(tkn) for tkn in self.run([token])]

    def reset(self):
        self._stack = []

    def serialize(self):
        return [fn.label for fn in self._stack]