File: vectorstore.py

package info (click to toggle)
python-elasticsearch 9.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 22,728 kB
  • sloc: python: 104,053; makefile: 151; javascript: 75
file content (421 lines) | stat: -rw-r--r-- 16,707 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
#  Licensed to Elasticsearch B.V. under one or more contributor
#  license agreements. See the NOTICE file distributed with
#  this work for additional information regarding copyright
#  ownership. Elasticsearch B.V. licenses this file to you under
#  the Apache License, Version 2.0 (the "License"); you may
#  not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
# 	http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.

import logging
import uuid
from typing import Any, Callable, Dict, List, Optional

from elasticsearch import AsyncElasticsearch
from elasticsearch._version import __versionstr__ as lib_version
from elasticsearch.helpers import BulkIndexError, async_bulk
from elasticsearch.helpers.vectorstore import (
    AsyncEmbeddingService,
    AsyncRetrievalStrategy,
)
from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)


class AsyncVectorStore:
    """
    VectorStore is a higher-level abstraction of indexing and search.
    Users can pick from available retrieval strategies.

    Documents have up to 3 fields:
      - text_field: the text to be indexed and searched.
      - metadata: additional information about the document, either schema-free
        or defined by the supplied metadata_mappings.
      - vector_field (usually not filled by the user): the embedding vector of the text.

    Depending on the strategy, vector embeddings are
      - created by the user beforehand
      - created by this AsyncVectorStore class in Python
      - created in-stack by inference pipelines.
    """

    def __init__(
        self,
        client: AsyncElasticsearch,
        *,
        index: str,
        retrieval_strategy: AsyncRetrievalStrategy,
        embedding_service: Optional[AsyncEmbeddingService] = None,
        num_dimensions: Optional[int] = None,
        text_field: str = "text_field",
        vector_field: str = "vector_field",
        metadata_mappings: Optional[Dict[str, Any]] = None,
        user_agent: str = f"elasticsearch-py-vs/{lib_version}",
        custom_index_settings: Optional[Dict[str, Any]] = None,
    ) -> None:
        """
        :param user_header: user agent header specific to the 3rd party integration.
            Used for usage tracking in Elastic Cloud.
        :param index: The name of the index to query.
        :param retrieval_strategy: how to index and search the data. See the strategies
            module for availble strategies.
        :param text_field: Name of the field with the textual data.
        :param vector_field: For strategies that perform embedding inference in Python,
            the embedding vector goes in this field.
        :param client: Elasticsearch client connection. Alternatively specify the
            Elasticsearch connection with the other es_* parameters.
        :param custom_index_settings: A dictionary of custom settings for the index.
            This can include configurations like the number of shards, number of replicas,
            analysis settings, and other index-specific settings. If not provided, default
            settings will be used. Note that if the same setting is provided by both the user
            and the strategy, will raise an error.
        """
        # Add integration-specific usage header for tracking usage in Elastic Cloud.
        # client.options preserves existing (non-user-agent) headers.
        client = client.options(headers={"User-Agent": user_agent})

        if hasattr(retrieval_strategy, "text_field"):
            retrieval_strategy.text_field = text_field
        if hasattr(retrieval_strategy, "vector_field"):
            retrieval_strategy.vector_field = vector_field

        self.client = client
        self.index = index
        self.retrieval_strategy = retrieval_strategy
        self.embedding_service = embedding_service
        self.num_dimensions = num_dimensions
        self.text_field = text_field
        self.vector_field = vector_field
        self.metadata_mappings = metadata_mappings
        self.custom_index_settings = custom_index_settings

    async def close(self) -> None:
        return await self.client.close()

    async def add_texts(
        self,
        texts: List[str],
        *,
        metadatas: Optional[List[Dict[str, Any]]] = None,
        vectors: Optional[List[List[float]]] = None,
        ids: Optional[List[str]] = None,
        refresh_indices: bool = True,
        create_index_if_not_exists: bool = True,
        bulk_kwargs: Optional[Dict[str, Any]] = None,
    ) -> List[str]:
        """Add documents to the Elasticsearch index.

        :param texts: List of text documents.
        :param metadata: Optional list of document metadata. Must be of same length as
            texts.
        :param vectors: Optional list of embedding vectors. Must be of same length as
            texts.
        :param ids: Optional list of ID strings. Must be of same length as texts.
        :param refresh_indices: Whether to refresh the index after deleting documents.
            Defaults to True.
        :param create_index_if_not_exists: Whether to create the index if it does not
            exist. Defaults to True.
        :param bulk_kwargs: Arguments to pass to the bulk function when indexing
            (for example chunk_size).

        :return: List of IDs of the created documents, either echoing the provided one
            or returning newly created ones.
        """
        bulk_kwargs = bulk_kwargs or {}
        ids = ids or [str(uuid.uuid4()) for _ in texts]
        requests = []

        if create_index_if_not_exists:
            await self._create_index_if_not_exists()

        if self.embedding_service and not vectors:
            vectors = await self.embedding_service.embed_documents(texts)

        for i, text in enumerate(texts):
            metadata = metadatas[i] if metadatas else {}

            request: Dict[str, Any] = {
                "_op_type": "index",
                "_index": self.index,
                self.text_field: text,
                "metadata": metadata,
                "_id": ids[i],
            }

            if vectors:
                request[self.vector_field] = vectors[i]

            requests.append(request)

        if len(requests) > 0:
            try:
                success, failed = await async_bulk(
                    self.client,
                    requests,
                    stats_only=True,
                    refresh=refresh_indices,
                    **bulk_kwargs,
                )
                logger.debug(f"added texts {ids} to index")
                return ids
            except BulkIndexError as e:
                logger.error(f"Error adding texts: {e}")
                firstError = e.errors[0].get("index", {}).get("error", {})
                logger.error(f"First error reason: {firstError.get('reason')}")
                raise e

        else:
            logger.debug("No texts to add to index")
            return []

    async def delete(  # type: ignore[no-untyped-def]
        self,
        *,
        ids: Optional[List[str]] = None,
        query: Optional[Dict[str, Any]] = None,
        refresh_indices: bool = True,
        **delete_kwargs,
    ) -> bool:
        """Delete documents from the Elasticsearch index.

        :param ids: List of IDs of documents to delete.
        :param refresh_indices: Whether to refresh the index after deleting documents.
            Defaults to True.

        :return: True if deletion was successful.
        """
        if ids is not None and query is not None:
            raise ValueError("one of ids or query must be specified")
        elif ids is None and query is None:
            raise ValueError("either specify ids or query")

        try:
            if ids:
                body = [
                    {"_op_type": "delete", "_index": self.index, "_id": _id}
                    for _id in ids
                ]
                await async_bulk(
                    self.client,
                    body,
                    refresh=refresh_indices,
                    ignore_status=404,
                    **delete_kwargs,
                )
                logger.debug(f"Deleted {len(body)} texts from index")

            else:
                await self.client.delete_by_query(
                    index=self.index,
                    query=query,
                    refresh=refresh_indices,
                    **delete_kwargs,
                )

        except BulkIndexError as e:
            logger.error(f"Error deleting texts: {e}")
            firstError = e.errors[0].get("index", {}).get("error", {})
            logger.error(f"First error reason: {firstError.get('reason')}")
            raise e

        return True

    async def search(
        self,
        *,
        query: Optional[str] = None,
        query_vector: Optional[List[float]] = None,
        k: int = 4,
        num_candidates: int = 50,
        fields: Optional[List[str]] = None,
        filter: Optional[List[Dict[str, Any]]] = None,
        custom_query: Optional[
            Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]]
        ] = None,
    ) -> List[Dict[str, Any]]:
        """
        :param query: Input query string.
        :param query_vector: Input embedding vector. If given, input query string is
            ignored.
        :param k: Number of returned results.
        :param num_candidates: Number of candidates to fetch from data nodes in knn.
        :param fields: List of field names to return.
        :param filter: Elasticsearch filters to apply.
        :param custom_query: Function to modify the Elasticsearch query body before it is
            sent to Elasticsearch.

        :return: List of document hits. Includes _index, _id, _score and _source.
        """
        if fields is None:
            fields = []
        if "metadata" not in fields:
            fields.append("metadata")
        if self.text_field not in fields:
            fields.append(self.text_field)

        if self.embedding_service and not query_vector:
            if not query:
                raise ValueError("specify a query or a query_vector to search")
            query_vector = await self.embedding_service.embed_query(query)

        query_body = self.retrieval_strategy.es_query(
            query=query,
            query_vector=query_vector,
            text_field=self.text_field,
            vector_field=self.vector_field,
            k=k,
            num_candidates=num_candidates,
            filter=filter or [],
        )

        if custom_query is not None:
            query_body = custom_query(query_body, query)
            logger.debug(f"Calling custom_query, Query body now: {query_body}")

        response = await self.client.search(
            index=self.index,
            **query_body,
            size=k,
            source=True,
            source_includes=fields,
        )
        hits: List[Dict[str, Any]] = response["hits"]["hits"]

        return hits

    async def _create_index_if_not_exists(self) -> None:
        exists = await self.client.indices.exists(index=self.index)
        if exists.meta.status == 200:
            logger.debug(f"Index {self.index} already exists. Skipping creation.")
            return

        if self.retrieval_strategy.needs_inference():
            if not self.num_dimensions and not self.embedding_service:
                raise ValueError(
                    "retrieval strategy requires embeddings; either embedding_service "
                    "or num_dimensions need to be specified"
                )
            if not self.num_dimensions and self.embedding_service:
                vector = await self.embedding_service.embed_query("get num dimensions")
                self.num_dimensions = len(vector)

        mappings, settings = self.retrieval_strategy.es_mappings_settings(
            text_field=self.text_field,
            vector_field=self.vector_field,
            num_dimensions=self.num_dimensions,
        )

        if self.custom_index_settings:
            conflicting_keys = set(self.custom_index_settings.keys()) & set(
                settings.keys()
            )
            if conflicting_keys:
                raise ValueError(f"Conflicting settings: {conflicting_keys}")
            else:
                settings.update(self.custom_index_settings)

        if self.metadata_mappings:
            metadata = mappings["properties"].get("metadata", {"properties": {}})
            for key in self.metadata_mappings.keys():
                if key in metadata:
                    raise ValueError(f"metadata key {key} already exists in mappings")

            metadata = dict(**metadata["properties"], **self.metadata_mappings)
            mappings["properties"]["metadata"] = {"properties": metadata}

        await self.retrieval_strategy.before_index_creation(
            client=self.client,
            text_field=self.text_field,
            vector_field=self.vector_field,
        )
        await self.client.indices.create(
            index=self.index, mappings=mappings, settings=settings
        )

    async def max_marginal_relevance_search(
        self,
        *,
        query: Optional[str] = None,
        query_embedding: Optional[List[float]] = None,
        embedding_service: Optional[AsyncEmbeddingService] = None,
        vector_field: str,
        k: int = 4,
        num_candidates: int = 20,
        lambda_mult: float = 0.5,
        fields: Optional[List[str]] = None,
        custom_query: Optional[
            Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]]
        ] = None,
    ) -> List[Dict[str, Any]]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
            among selected documents.

        :param query (str): Text to look up documents similar to.
        :param query_embedding: Input embedding vector. If given, input query string is
            ignored.
        :param k (int): Number of Documents to return. Defaults to 4.
        :param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
        :param lambda_mult (float): Number between 0 and 1 that determines the degree
            of diversity among the results with 0 corresponding
            to maximum diversity and 1 to minimum diversity.
            Defaults to 0.5.
        :param fields: Other fields to get from elasticsearch source. These fields
            will be added to the document metadata.

        :return: A list of Documents selected by maximal marginal relevance.
        """
        remove_vector_query_field_from_metadata = True
        if fields is None:
            fields = [vector_field]
        elif vector_field not in fields:
            fields.append(vector_field)
        else:
            remove_vector_query_field_from_metadata = False

        # Embed the query
        if query_embedding:
            query_vector = query_embedding
        else:
            if not query:
                raise ValueError("specify either query or query_embedding to search")
            elif embedding_service:
                query_vector = await embedding_service.embed_query(query)
            elif self.embedding_service:
                query_vector = await self.embedding_service.embed_query(query)
            else:
                raise ValueError("specify embedding_service to search with query")

        # Fetch the initial documents
        got_hits = await self.search(
            query=None,
            query_vector=query_vector,
            k=num_candidates,
            fields=fields,
            custom_query=custom_query,
        )

        # Get the embeddings for the fetched documents
        got_embeddings = [hit["_source"][vector_field] for hit in got_hits]

        # Select documents using maximal marginal relevance
        selected_indices = maximal_marginal_relevance(
            query_vector, got_embeddings, lambda_mult=lambda_mult, k=k
        )
        selected_hits = [got_hits[i] for i in selected_indices]

        if remove_vector_query_field_from_metadata:
            for hit in selected_hits:
                del hit["_source"][vector_field]

        return selected_hits