File: vector.py

package info (click to toggle)
python-lunr 0.8.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,644 kB
  • sloc: python: 3,811; javascript: 114; makefile: 60
file content (156 lines) | stat: -rw-r--r-- 5,295 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from math import sqrt

from lunr.exceptions import BaseLunrException


class Vector:
    """A vector is used to construct the vector space of documents and queries.
    These vectors support operations to determine the similarity between two
    documents or a document and a query.

    Normally no parameters are required for initializing a vector, but in the
    case of loading a previously dumped vector the raw elements can be provided
    to the constructor.

    For performance reasons vectors are implemented with a flat array, where an
    elements index is immediately followed by its value.
    E.g. [index, value, index, value].

    TODO: consider implemetation as 2-tuples.

    This allows the underlying array to be as sparse as possible and still
    offer decent performance when being used for vector calculations.
    """

    def __init__(self, elements=None):
        self._magnitude = 0
        self.elements = elements or []

    def __repr__(self):
        return "<Vector magnitude={}>".format(self.magnitude)

    def __iter__(self):
        return iter(self.elements)

    def position_for_index(self, index):
        """Calculates the position within the vector to insert a given index.

        This is used internally by insert and upsert. If there are duplicate
        indexes then the position is returned as if the value for that index
        were to be updated, but it is the callers responsibility to check
        whether there is a duplicate at that index
        """
        if not self.elements:
            return 0

        start = 0
        end = int(len(self.elements) / 2)
        slice_length = end - start
        pivot_point = int(slice_length / 2)
        pivot_index = self.elements[pivot_point * 2]

        while slice_length > 1:
            if pivot_index < index:
                start = pivot_point
            elif pivot_index > index:
                end = pivot_point
            else:
                break

            slice_length = end - start
            pivot_point = start + int(slice_length / 2)
            pivot_index = self.elements[pivot_point * 2]

        if pivot_index == index:
            return pivot_point * 2
        elif pivot_index > index:
            return pivot_point * 2
        else:
            return (pivot_point + 1) * 2

    def insert(self, insert_index, val):
        """Inserts an element at an index within the vector.

        Does not allow duplicates, will throw an error if there is already an
        entry for this index.
        """

        def prevent_duplicates(index, val):
            raise BaseLunrException("Duplicate index")

        self.upsert(insert_index, val, prevent_duplicates)

    def upsert(self, insert_index, val, fn=None):
        """Inserts or updates an existing index within the vector.

        Args:
            - insert_index (int): The index at which the element should be
                inserted.
            - val (int|float): The value to be inserted into the vector.
            - fn (callable, optional): An optional callable taking two
                arguments, the current value and the passed value to generate
                the final inserted value at the position in case of collision.
        """
        fn = fn or (lambda current, passed: passed)
        self._magnitude = 0
        position = self.position_for_index(insert_index)
        if position < len(self.elements) and self.elements[position] == insert_index:
            self.elements[position + 1] = fn(self.elements[position + 1], val)
        else:
            self.elements.insert(position, val)
            self.elements.insert(position, insert_index)

    def to_list(self):
        """Converts the vector to an array of the elements within the vector"""
        output = []
        for i in range(1, len(self.elements), 2):
            output.append(self.elements[i])
        return output

    def serialize(self):
        # TODO: the JS version forces rounding on the elements upon insertion
        # to ensure symmetry upon serialization
        return [round(element, 3) for element in self.elements]

    @property
    def magnitude(self):
        if not self._magnitude:
            sum_of_squares = 0
            for i in range(1, len(self.elements), 2):
                value = self.elements[i]
                sum_of_squares += value * value

            self._magnitude = sqrt(sum_of_squares)

        return self._magnitude

    def dot(self, other):
        """Calculates the dot product of this vector and another vector."""
        dot_product = 0
        a = self.elements
        b = other.elements
        a_len = len(a)
        b_len = len(b)
        i = j = 0

        while i < a_len and j < b_len:
            a_val = a[i]
            b_val = b[j]
            if a_val < b_val:
                i += 2
            elif a_val > b_val:
                j += 2
            else:
                dot_product += a[i + 1] * b[j + 1]
                i += 2
                j += 2

        return dot_product

    def similarity(self, other):
        """Calculates the cosine similarity between this vector and another
        vector."""
        if self.magnitude == 0 or other.magnitude == 0:
            return 0

        return self.dot(other) / self.magnitude