1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
|
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_
#define COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_
#include <optional>
#include <unordered_set>
#include <vector>
#include "base/time/time.h"
#include "components/history/core/browser/history_types.h"
#include "components/history_embeddings/proto/history_embeddings.pb.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
namespace history_embeddings {
// Hash function used for query filtering.
uint32_t HashString(std::string_view str);
struct ScoredUrl {
ScoredUrl(history::URLID url_id,
history::VisitID visit_id,
base::Time visit_time,
float score,
float word_match_score);
~ScoredUrl();
ScoredUrl(ScoredUrl&&);
ScoredUrl& operator=(ScoredUrl&&);
ScoredUrl(const ScoredUrl&);
ScoredUrl& operator=(const ScoredUrl&);
// Basic data about the found URL/visit.
history::URLID url_id;
history::VisitID visit_id;
base::Time visit_time;
// A measure of how closely the query matched the found data. This includes
// the single best embedding score plus a word match boost from text search
// across all passages.
float score;
// This is the score computed by word match text search. It's included in
// the total `score`, but is also kept separate for second-chance word
// match result filling.
float word_match_score;
};
struct SearchParams {
SearchParams();
SearchParams(const SearchParams&);
SearchParams(SearchParams&&);
~SearchParams();
SearchParams& operator=(const SearchParams&);
// Portions of lower-cased query representing terms usable for text search.
// Owned std::string instances are used instead of std::string_view into
// an owned query instance because this struct can move, and view data
// pointers are not guaranteed valid after source string moves.
std::vector<std::string> query_terms;
// Embedding similarity score below which no word matching takes place.
float word_match_minimum_embedding_score = 0.0f;
// Raw score boost, applied per word.
float word_match_score_boost_factor = 0.2f;
// Divides and caps a word match boost. Finding the word more than this many
// times won't increase the boost for the word.
size_t word_match_limit = 5;
// Used as a term in final score boost divide to normalize for long queries.
size_t word_match_smoothing_factor = 1;
// Maximum number of terms a query may have. When term count exceeds this
// limit, no text search for the terms occurs.
size_t word_match_max_term_count = 3;
// Makes the total word match boost zero when the ratio of terms matched to
// total query terms is less than this required value. A term is considered
// matched if there's at least one instance found in all passages.
// Stop words are not considered query terms and are not counted.
float word_match_required_term_ratio = 1.0f;
// If true, any non-ASCII characters in queries or passages will be erased
// instead of ignoring such queries or passages entirely.
bool erase_non_ascii_characters = false;
// If true, word match text search can still be applied for passages with
// non-ASCII characters; similar to `erase_non_ascii_characters` but for word
// match text search only.
bool word_match_search_non_ascii_passages = false;
// If true, answering step will be skipped even if the query is answerable.
bool skip_answering = false;
};
struct SearchInfo {
SearchInfo();
SearchInfo(SearchInfo&&);
~SearchInfo();
// Result of the search, the best scored URLs considering total score.
std::vector<ScoredUrl> scored_urls;
// Secondary results of the search, the best scored URLs considering
// word match text search score.
std::vector<ScoredUrl> word_match_scored_urls;
// The number of URLs searched to find this result.
size_t searched_url_count = 0u;
// The number of embeddings searched to find this result.
size_t searched_embedding_count = 0u;
// The number of embeddings scored zero due to having a source passage
// containing non-ASCII characters.
size_t skipped_nonascii_passage_count = 0u;
// The number of source passages modified by erasing non-ASCII characters.
size_t modified_nonascii_passage_count = 0u;
// Whether the search completed without interruption. Starting a new search
// may cause a search to halt, and in that case this member will be false.
bool completed = false;
// Time breakdown for metrics: total > scoring > passage_scanning as each
// succeeding time value is a portion of the last.
base::TimeDelta total_search_time;
base::TimeDelta scoring_time;
base::TimeDelta passage_scanning_time;
};
struct UrlScore {
float score;
float word_match_score;
};
struct UrlData {
UrlData(history::URLID url_id,
history::VisitID visit_id,
base::Time visit_time);
UrlData(const UrlData&);
UrlData(UrlData&&);
UrlData& operator=(const UrlData&);
UrlData& operator=(UrlData&&);
~UrlData();
bool operator==(const UrlData&) const;
// Finds score of embedding nearest to query, also taking passages
// into consideration since some should be skipped. The passages
// correspond to the embeddings 1:1 by index.
UrlScore BestScoreWith(SearchInfo& search_info,
const SearchParams& search_params,
const passage_embeddings::Embedding& query_embedding,
size_t search_minimum_word_count) const;
history::URLID url_id;
history::VisitID visit_id;
base::Time visit_time;
proto::PassagesValue passages;
std::vector<passage_embeddings::Embedding> embeddings;
};
// This base class decouples storage classes and inverts the dependency so that
// a vector database can work with a SQLite database, simple in-memory storage,
// flat files, or whatever kinds of storage will work efficiently.
class VectorDatabase {
public:
struct UrlDataIterator {
virtual ~UrlDataIterator() = default;
// Returns nullptr if none remain; otherwise advances the iterator
// and returns a pointer to the next instance (which may be owned
// by the iterator itself).
virtual const UrlData* Next() = 0;
};
virtual ~VectorDatabase() = default;
// Returns the expected number of dimensions for an embedding.
virtual size_t GetEmbeddingDimensions() const = 0;
// Insert or update all embeddings for a URL's full set of passages.
// Returns true on success.
virtual bool AddUrlData(UrlData url_data) = 0;
// Create an iterator that steps through database items.
// Null may be returned if there are none.
virtual std::unique_ptr<UrlDataIterator> MakeUrlDataIterator(
std::optional<base::Time> time_range_start) = 0;
// Searches the database for embeddings near given `query` and returns
// information about where they were found and how nearly the query matched.
SearchInfo FindNearest(std::optional<base::Time> time_range_start,
size_t count,
const SearchParams& search_params,
const passage_embeddings::Embedding& query_embedding,
base::RepeatingCallback<bool()> is_search_halted);
};
// This is an in-memory vector store that supports searching and saving to
// another persistent backing store.
class VectorDatabaseInMemory : public VectorDatabase {
public:
VectorDatabaseInMemory();
~VectorDatabaseInMemory() override;
// Save this store's data to another given store. Most implementations don't
// need this, but it's useful for an in-memory store to work with a separate
// backing database on a worker sequence.
void SaveTo(VectorDatabase* database);
// VectorDatabase:
size_t GetEmbeddingDimensions() const override;
bool AddUrlData(UrlData url_data) override;
std::unique_ptr<UrlDataIterator> MakeUrlDataIterator(
std::optional<base::Time> time_range_start) override;
private:
std::vector<UrlData> data_;
};
// Utility method to split a query into separate query terms for search.
std::vector<std::string> SplitQueryToTerms(
const std::unordered_set<uint32_t>& stop_words_hashes,
std::string_view raw_query,
size_t min_term_length);
// Destructively removes non-ASCII characters from single or many passages.
void EraseNonAsciiCharacters(std::string& passage);
void EraseNonAsciiCharacters(std::vector<std::string>& passages);
} // namespace history_embeddings
#endif // COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_
|