File: vector_database.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (239 lines) | stat: -rw-r--r-- 8,718 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_
#define COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_

#include <optional>
#include <unordered_set>
#include <vector>

#include "base/time/time.h"
#include "components/history/core/browser/history_types.h"
#include "components/history_embeddings/proto/history_embeddings.pb.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/passage_embeddings/passage_embeddings_types.h"

namespace history_embeddings {

// Hash function used for query filtering.
uint32_t HashString(std::string_view str);

struct ScoredUrl {
  ScoredUrl(history::URLID url_id,
            history::VisitID visit_id,
            base::Time visit_time,
            float score,
            float word_match_score);
  ~ScoredUrl();
  ScoredUrl(ScoredUrl&&);
  ScoredUrl& operator=(ScoredUrl&&);
  ScoredUrl(const ScoredUrl&);
  ScoredUrl& operator=(const ScoredUrl&);

  // Basic data about the found URL/visit.
  history::URLID url_id;
  history::VisitID visit_id;
  base::Time visit_time;

  // A measure of how closely the query matched the found data. This includes
  // the single best embedding score plus a word match boost from text search
  // across all passages.
  float score;

  // This is the score computed by word match text search. It's included in
  // the total `score`, but is also kept separate for second-chance word
  // match result filling.
  float word_match_score;
};

struct SearchParams {
  SearchParams();
  SearchParams(const SearchParams&);
  SearchParams(SearchParams&&);
  ~SearchParams();
  SearchParams& operator=(const SearchParams&);

  // Portions of lower-cased query representing terms usable for text search.
  // Owned std::string instances are used instead of std::string_view into
  // an owned query instance because this struct can move, and view data
  // pointers are not guaranteed valid after source string moves.
  std::vector<std::string> query_terms;

  // Embedding similarity score below which no word matching takes place.
  float word_match_minimum_embedding_score = 0.0f;

  // Raw score boost, applied per word.
  float word_match_score_boost_factor = 0.2f;

  // Divides and caps a word match boost. Finding the word more than this many
  // times won't increase the boost for the word.
  size_t word_match_limit = 5;

  // Used as a term in final score boost divide to normalize for long queries.
  size_t word_match_smoothing_factor = 1;

  // Maximum number of terms a query may have. When term count exceeds this
  // limit, no text search for the terms occurs.
  size_t word_match_max_term_count = 3;

  // Makes the total word match boost zero when the ratio of terms matched to
  // total query terms is less than this required value. A term is considered
  // matched if there's at least one instance found in all passages.
  // Stop words are not considered query terms and are not counted.
  float word_match_required_term_ratio = 1.0f;

  // If true, any non-ASCII characters in queries or passages will be erased
  // instead of ignoring such queries or passages entirely.
  bool erase_non_ascii_characters = false;

  // If true, word match text search can still be applied for passages with
  // non-ASCII characters; similar to `erase_non_ascii_characters` but for word
  // match text search only.
  bool word_match_search_non_ascii_passages = false;

  // If true, answering step will be skipped even if the query is answerable.
  bool skip_answering = false;
};

struct SearchInfo {
  SearchInfo();
  SearchInfo(SearchInfo&&);
  ~SearchInfo();

  // Result of the search, the best scored URLs considering total score.
  std::vector<ScoredUrl> scored_urls;

  // Secondary results of the search, the best scored URLs considering
  // word match text search score.
  std::vector<ScoredUrl> word_match_scored_urls;

  // The number of URLs searched to find this result.
  size_t searched_url_count = 0u;

  // The number of embeddings searched to find this result.
  size_t searched_embedding_count = 0u;

  // The number of embeddings scored zero due to having a source passage
  // containing non-ASCII characters.
  size_t skipped_nonascii_passage_count = 0u;

  // The number of source passages modified by erasing non-ASCII characters.
  size_t modified_nonascii_passage_count = 0u;

  // Whether the search completed without interruption. Starting a new search
  // may cause a search to halt, and in that case this member will be false.
  bool completed = false;

  // Time breakdown for metrics: total > scoring > passage_scanning as each
  // succeeding time value is a portion of the last.
  base::TimeDelta total_search_time;
  base::TimeDelta scoring_time;
  base::TimeDelta passage_scanning_time;
};

struct UrlScore {
  float score;
  float word_match_score;
};

struct UrlData {
  UrlData(history::URLID url_id,
          history::VisitID visit_id,
          base::Time visit_time);
  UrlData(const UrlData&);
  UrlData(UrlData&&);
  UrlData& operator=(const UrlData&);
  UrlData& operator=(UrlData&&);
  ~UrlData();

  bool operator==(const UrlData&) const;

  // Finds score of embedding nearest to query, also taking passages
  // into consideration since some should be skipped. The passages
  // correspond to the embeddings 1:1 by index.
  UrlScore BestScoreWith(SearchInfo& search_info,
                         const SearchParams& search_params,
                         const passage_embeddings::Embedding& query_embedding,
                         size_t search_minimum_word_count) const;

  history::URLID url_id;
  history::VisitID visit_id;
  base::Time visit_time;
  proto::PassagesValue passages;
  std::vector<passage_embeddings::Embedding> embeddings;
};

// This base class decouples storage classes and inverts the dependency so that
// a vector database can work with a SQLite database, simple in-memory storage,
// flat files, or whatever kinds of storage will work efficiently.
class VectorDatabase {
 public:
  struct UrlDataIterator {
    virtual ~UrlDataIterator() = default;

    // Returns nullptr if none remain; otherwise advances the iterator
    // and returns a pointer to the next instance (which may be owned
    // by the iterator itself).
    virtual const UrlData* Next() = 0;
  };

  virtual ~VectorDatabase() = default;

  // Returns the expected number of dimensions for an embedding.
  virtual size_t GetEmbeddingDimensions() const = 0;

  // Insert or update all embeddings for a URL's full set of passages.
  // Returns true on success.
  virtual bool AddUrlData(UrlData url_data) = 0;

  // Create an iterator that steps through database items.
  // Null may be returned if there are none.
  virtual std::unique_ptr<UrlDataIterator> MakeUrlDataIterator(
      std::optional<base::Time> time_range_start) = 0;

  // Searches the database for embeddings near given `query` and returns
  // information about where they were found and how nearly the query matched.
  SearchInfo FindNearest(std::optional<base::Time> time_range_start,
                         size_t count,
                         const SearchParams& search_params,
                         const passage_embeddings::Embedding& query_embedding,
                         base::RepeatingCallback<bool()> is_search_halted);
};

// This is an in-memory vector store that supports searching and saving to
// another persistent backing store.
class VectorDatabaseInMemory : public VectorDatabase {
 public:
  VectorDatabaseInMemory();
  ~VectorDatabaseInMemory() override;

  // Save this store's data to another given store. Most implementations don't
  // need this, but it's useful for an in-memory store to work with a separate
  // backing database on a worker sequence.
  void SaveTo(VectorDatabase* database);

  // VectorDatabase:
  size_t GetEmbeddingDimensions() const override;
  bool AddUrlData(UrlData url_data) override;
  std::unique_ptr<UrlDataIterator> MakeUrlDataIterator(
      std::optional<base::Time> time_range_start) override;

 private:
  std::vector<UrlData> data_;
};

// Utility method to split a query into separate query terms for search.
std::vector<std::string> SplitQueryToTerms(
    const std::unordered_set<uint32_t>& stop_words_hashes,
    std::string_view raw_query,
    size_t min_term_length);

// Destructively removes non-ASCII characters from single or many passages.
void EraseNonAsciiCharacters(std::string& passage);
void EraseNonAsciiCharacters(std::vector<std::string>& passages);

}  // namespace history_embeddings

#endif  // COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_