1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421
|
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import uuid
from typing import Any, Callable, Dict, List, Optional
from elasticsearch import AsyncElasticsearch
from elasticsearch._version import __versionstr__ as lib_version
from elasticsearch.helpers import BulkIndexError, async_bulk
from elasticsearch.helpers.vectorstore import (
AsyncEmbeddingService,
AsyncRetrievalStrategy,
)
from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
class AsyncVectorStore:
"""
VectorStore is a higher-level abstraction of indexing and search.
Users can pick from available retrieval strategies.
Documents have up to 3 fields:
- text_field: the text to be indexed and searched.
- metadata: additional information about the document, either schema-free
or defined by the supplied metadata_mappings.
- vector_field (usually not filled by the user): the embedding vector of the text.
Depending on the strategy, vector embeddings are
- created by the user beforehand
- created by this AsyncVectorStore class in Python
- created in-stack by inference pipelines.
"""
def __init__(
self,
client: AsyncElasticsearch,
*,
index: str,
retrieval_strategy: AsyncRetrievalStrategy,
embedding_service: Optional[AsyncEmbeddingService] = None,
num_dimensions: Optional[int] = None,
text_field: str = "text_field",
vector_field: str = "vector_field",
metadata_mappings: Optional[Dict[str, Any]] = None,
user_agent: str = f"elasticsearch-py-vs/{lib_version}",
custom_index_settings: Optional[Dict[str, Any]] = None,
) -> None:
"""
:param user_header: user agent header specific to the 3rd party integration.
Used for usage tracking in Elastic Cloud.
:param index: The name of the index to query.
:param retrieval_strategy: how to index and search the data. See the strategies
module for availble strategies.
:param text_field: Name of the field with the textual data.
:param vector_field: For strategies that perform embedding inference in Python,
the embedding vector goes in this field.
:param client: Elasticsearch client connection. Alternatively specify the
Elasticsearch connection with the other es_* parameters.
:param custom_index_settings: A dictionary of custom settings for the index.
This can include configurations like the number of shards, number of replicas,
analysis settings, and other index-specific settings. If not provided, default
settings will be used. Note that if the same setting is provided by both the user
and the strategy, will raise an error.
"""
# Add integration-specific usage header for tracking usage in Elastic Cloud.
# client.options preserves existing (non-user-agent) headers.
client = client.options(headers={"User-Agent": user_agent})
if hasattr(retrieval_strategy, "text_field"):
retrieval_strategy.text_field = text_field
if hasattr(retrieval_strategy, "vector_field"):
retrieval_strategy.vector_field = vector_field
self.client = client
self.index = index
self.retrieval_strategy = retrieval_strategy
self.embedding_service = embedding_service
self.num_dimensions = num_dimensions
self.text_field = text_field
self.vector_field = vector_field
self.metadata_mappings = metadata_mappings
self.custom_index_settings = custom_index_settings
async def close(self) -> None:
return await self.client.close()
async def add_texts(
self,
texts: List[str],
*,
metadatas: Optional[List[Dict[str, Any]]] = None,
vectors: Optional[List[List[float]]] = None,
ids: Optional[List[str]] = None,
refresh_indices: bool = True,
create_index_if_not_exists: bool = True,
bulk_kwargs: Optional[Dict[str, Any]] = None,
) -> List[str]:
"""Add documents to the Elasticsearch index.
:param texts: List of text documents.
:param metadata: Optional list of document metadata. Must be of same length as
texts.
:param vectors: Optional list of embedding vectors. Must be of same length as
texts.
:param ids: Optional list of ID strings. Must be of same length as texts.
:param refresh_indices: Whether to refresh the index after deleting documents.
Defaults to True.
:param create_index_if_not_exists: Whether to create the index if it does not
exist. Defaults to True.
:param bulk_kwargs: Arguments to pass to the bulk function when indexing
(for example chunk_size).
:return: List of IDs of the created documents, either echoing the provided one
or returning newly created ones.
"""
bulk_kwargs = bulk_kwargs or {}
ids = ids or [str(uuid.uuid4()) for _ in texts]
requests = []
if create_index_if_not_exists:
await self._create_index_if_not_exists()
if self.embedding_service and not vectors:
vectors = await self.embedding_service.embed_documents(texts)
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
request: Dict[str, Any] = {
"_op_type": "index",
"_index": self.index,
self.text_field: text,
"metadata": metadata,
"_id": ids[i],
}
if vectors:
request[self.vector_field] = vectors[i]
requests.append(request)
if len(requests) > 0:
try:
success, failed = await async_bulk(
self.client,
requests,
stats_only=True,
refresh=refresh_indices,
**bulk_kwargs,
)
logger.debug(f"added texts {ids} to index")
return ids
except BulkIndexError as e:
logger.error(f"Error adding texts: {e}")
firstError = e.errors[0].get("index", {}).get("error", {})
logger.error(f"First error reason: {firstError.get('reason')}")
raise e
else:
logger.debug("No texts to add to index")
return []
async def delete( # type: ignore[no-untyped-def]
self,
*,
ids: Optional[List[str]] = None,
query: Optional[Dict[str, Any]] = None,
refresh_indices: bool = True,
**delete_kwargs,
) -> bool:
"""Delete documents from the Elasticsearch index.
:param ids: List of IDs of documents to delete.
:param refresh_indices: Whether to refresh the index after deleting documents.
Defaults to True.
:return: True if deletion was successful.
"""
if ids is not None and query is not None:
raise ValueError("one of ids or query must be specified")
elif ids is None and query is None:
raise ValueError("either specify ids or query")
try:
if ids:
body = [
{"_op_type": "delete", "_index": self.index, "_id": _id}
for _id in ids
]
await async_bulk(
self.client,
body,
refresh=refresh_indices,
ignore_status=404,
**delete_kwargs,
)
logger.debug(f"Deleted {len(body)} texts from index")
else:
await self.client.delete_by_query(
index=self.index,
query=query,
refresh=refresh_indices,
**delete_kwargs,
)
except BulkIndexError as e:
logger.error(f"Error deleting texts: {e}")
firstError = e.errors[0].get("index", {}).get("error", {})
logger.error(f"First error reason: {firstError.get('reason')}")
raise e
return True
async def search(
self,
*,
query: Optional[str] = None,
query_vector: Optional[List[float]] = None,
k: int = 4,
num_candidates: int = 50,
fields: Optional[List[str]] = None,
filter: Optional[List[Dict[str, Any]]] = None,
custom_query: Optional[
Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]]
] = None,
) -> List[Dict[str, Any]]:
"""
:param query: Input query string.
:param query_vector: Input embedding vector. If given, input query string is
ignored.
:param k: Number of returned results.
:param num_candidates: Number of candidates to fetch from data nodes in knn.
:param fields: List of field names to return.
:param filter: Elasticsearch filters to apply.
:param custom_query: Function to modify the Elasticsearch query body before it is
sent to Elasticsearch.
:return: List of document hits. Includes _index, _id, _score and _source.
"""
if fields is None:
fields = []
if "metadata" not in fields:
fields.append("metadata")
if self.text_field not in fields:
fields.append(self.text_field)
if self.embedding_service and not query_vector:
if not query:
raise ValueError("specify a query or a query_vector to search")
query_vector = await self.embedding_service.embed_query(query)
query_body = self.retrieval_strategy.es_query(
query=query,
query_vector=query_vector,
text_field=self.text_field,
vector_field=self.vector_field,
k=k,
num_candidates=num_candidates,
filter=filter or [],
)
if custom_query is not None:
query_body = custom_query(query_body, query)
logger.debug(f"Calling custom_query, Query body now: {query_body}")
response = await self.client.search(
index=self.index,
**query_body,
size=k,
source=True,
source_includes=fields,
)
hits: List[Dict[str, Any]] = response["hits"]["hits"]
return hits
async def _create_index_if_not_exists(self) -> None:
exists = await self.client.indices.exists(index=self.index)
if exists.meta.status == 200:
logger.debug(f"Index {self.index} already exists. Skipping creation.")
return
if self.retrieval_strategy.needs_inference():
if not self.num_dimensions and not self.embedding_service:
raise ValueError(
"retrieval strategy requires embeddings; either embedding_service "
"or num_dimensions need to be specified"
)
if not self.num_dimensions and self.embedding_service:
vector = await self.embedding_service.embed_query("get num dimensions")
self.num_dimensions = len(vector)
mappings, settings = self.retrieval_strategy.es_mappings_settings(
text_field=self.text_field,
vector_field=self.vector_field,
num_dimensions=self.num_dimensions,
)
if self.custom_index_settings:
conflicting_keys = set(self.custom_index_settings.keys()) & set(
settings.keys()
)
if conflicting_keys:
raise ValueError(f"Conflicting settings: {conflicting_keys}")
else:
settings.update(self.custom_index_settings)
if self.metadata_mappings:
metadata = mappings["properties"].get("metadata", {"properties": {}})
for key in self.metadata_mappings.keys():
if key in metadata:
raise ValueError(f"metadata key {key} already exists in mappings")
metadata = dict(**metadata["properties"], **self.metadata_mappings)
mappings["properties"]["metadata"] = {"properties": metadata}
await self.retrieval_strategy.before_index_creation(
client=self.client,
text_field=self.text_field,
vector_field=self.vector_field,
)
await self.client.indices.create(
index=self.index, mappings=mappings, settings=settings
)
async def max_marginal_relevance_search(
self,
*,
query: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
embedding_service: Optional[AsyncEmbeddingService] = None,
vector_field: str,
k: int = 4,
num_candidates: int = 20,
lambda_mult: float = 0.5,
fields: Optional[List[str]] = None,
custom_query: Optional[
Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]]
] = None,
) -> List[Dict[str, Any]]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
:param query (str): Text to look up documents similar to.
:param query_embedding: Input embedding vector. If given, input query string is
ignored.
:param k (int): Number of Documents to return. Defaults to 4.
:param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
:param lambda_mult (float): Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
:param fields: Other fields to get from elasticsearch source. These fields
will be added to the document metadata.
:return: A list of Documents selected by maximal marginal relevance.
"""
remove_vector_query_field_from_metadata = True
if fields is None:
fields = [vector_field]
elif vector_field not in fields:
fields.append(vector_field)
else:
remove_vector_query_field_from_metadata = False
# Embed the query
if query_embedding:
query_vector = query_embedding
else:
if not query:
raise ValueError("specify either query or query_embedding to search")
elif embedding_service:
query_vector = await embedding_service.embed_query(query)
elif self.embedding_service:
query_vector = await self.embedding_service.embed_query(query)
else:
raise ValueError("specify embedding_service to search with query")
# Fetch the initial documents
got_hits = await self.search(
query=None,
query_vector=query_vector,
k=num_candidates,
fields=fields,
custom_query=custom_query,
)
# Get the embeddings for the fetched documents
got_embeddings = [hit["_source"][vector_field] for hit in got_hits]
# Select documents using maximal marginal relevance
selected_indices = maximal_marginal_relevance(
query_vector, got_embeddings, lambda_mult=lambda_mult, k=k
)
selected_hits = [got_hits[i] for i in selected_indices]
if remove_vector_query_field_from_metadata:
for hit in selected_hits:
del hit["_source"][vector_field]
return selected_hits
|