File: history_embeddings_service.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (472 lines) | stat: -rw-r--r-- 21,021 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_HISTORY_EMBEDDINGS_HISTORY_EMBEDDINGS_SERVICE_H_
#define COMPONENTS_HISTORY_EMBEDDINGS_HISTORY_EMBEDDINGS_SERVICE_H_

#include <atomic>
#include <optional>
#include <string>
#include <unordered_set>
#include <vector>

#include "base/files/file_path.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
#include "base/gtest_prod_util.h"
#include "base/memory/weak_ptr.h"
#include "base/threading/sequence_bound.h"
#include "base/time/time.h"
#include "base/timer/elapsed_timer.h"
#include "components/history/core/browser/history_service.h"
#include "components/history/core/browser/history_service_observer.h"
#include "components/history/core/browser/history_types.h"
#include "components/history/core/browser/url_database.h"
#include "components/history/core/browser/url_row.h"
#include "components/history_embeddings/answerer.h"
#include "components/history_embeddings/intent_classifier.h"
#include "components/history_embeddings/sql_database.h"
#include "components/history_embeddings/vector_database.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/optimization_guide/core/model_quality/model_quality_log_entry.h"
#include "components/optimization_guide/proto/features/common_quality_data.pb.h"
#include "components/os_crypt/async/common/encryptor.h"
#include "components/passage_embeddings/passage_embeddings_types.h"

namespace optimization_guide {
class OptimizationGuideDecider;
}  // namespace optimization_guide

namespace page_content_annotations {
class BatchAnnotationResult;
class PageContentAnnotationsService;
}  // namespace page_content_annotations

namespace os_crypt_async {
class OSCryptAsync;
}

namespace history_embeddings {

// Counts the # of ' ' vanilla-space characters in `s`.
// TODO(crbug.com/343256907): Should work on international inputs which may:
//   a) Use special whitespace, OR
//   b) Not use whitespace for word breaks (e.g. Thai).
//   `String16VectorFromString16()` is the omnibox solution. We could probably
//   just replace-all `CountWords(s)` ->
//   `String16VectorFromString16(CleanUpTitleForMatching(s, nullptr)).size()`.
size_t CountWords(const std::string& s);

// A single item that forms part of a search result; combines metadata found in
// the history embeddings database with additional info from history database.
struct ScoredUrlRow {
  explicit ScoredUrlRow(ScoredUrl scored_url);
  ScoredUrlRow(const ScoredUrlRow&);
  ScoredUrlRow(ScoredUrlRow&&);
  ~ScoredUrlRow();
  ScoredUrlRow& operator=(const ScoredUrlRow&);
  ScoredUrlRow& operator=(ScoredUrlRow&&);

  // Returns the highest scored passage in `passages_embeddings`.
  std::string GetBestPassage() const;

  // Finds the indices of the top scores, ordered descending by score.
  // This is useful for selecting a subset of `passages_embeddings` for use as
  // answerer context. The size of the returned vector will be at least
  // `min_count` provided there is sufficient data available. The
  // `min_word_count` parameter will also be used to ensure the
  // passages for returned indices have word counts adding up to at
  // least this minimum.
  std::vector<size_t> GetBestScoreIndices(size_t min_count,
                                          size_t min_word_count) const;

  // Basic scoring and history data for this URL.
  ScoredUrl scored_url;
  history::URLRow row;
  bool is_url_known_to_sync = false;

  // All passages and embeddings for this URL (i.e. not a partial set).
  UrlData passages_embeddings;

  // All scores against the query for `passages_embeddings`.
  std::vector<float> scores;
};

struct SearchResult {
  SearchResult();
  SearchResult(SearchResult&&);
  ~SearchResult();
  SearchResult& operator=(SearchResult&&);

  // Explicit copy only, since the `answerer_result` contains a log entry.
  // This should only be called if `answerer_result` is not populated with
  // a log entry yet, for example after initial search and before answering.
  SearchResult Clone();

  // Returns true if this search result is related to the given `other`
  // result returned by HistoryEmbeddingsService::Search (same session/query).
  bool IsContinuationOf(const SearchResult& other);

  // Gets the answer text from within the `answerer_result`.
  const std::string& AnswerText() const;

  // Finds the index in `scored_url_rows` that has the URL selected by the
  // `answerer_result`, indicating where the answer came from.
  size_t AnswerIndex() const;

  // Session ID to associate query with answers.
  std::string session_id;

  // Keep context for search parameters requested, to make logging easier.
  std::string query;
  std::optional<base::Time> time_range_start;
  size_t count = 0;
  SearchParams search_params;

  // The actual search result data. Note that the size of this vector will
  // not necessarily match the above requested `count`.
  std::vector<ScoredUrlRow> scored_url_rows;

  // This may be empty for initial embeddings search results, as the answer
  // isn't ready yet. When the answerer finishes work, a second search
  // result is provided with this answer filled.
  AnswererResult answerer_result;
};

using UrlDataCallback = base::OnceCallback<void(std::optional<UrlData>)>;

using PassagesStoredCallback = base::RepeatingCallback<void(UrlData)>;

using SearchResultCallback = base::RepeatingCallback<void(SearchResult)>;

using QualityLogEntry =
    std::unique_ptr<optimization_guide::ModelQualityLogEntry>;

class HistoryEmbeddingsService
    : public KeyedService,
      public history::HistoryServiceObserver,
      public passage_embeddings::EmbedderMetadataObserver {
 public:
  // Number of low-order bits to use in session_id for sequence number.
  static constexpr uint64_t kSessionIdSequenceBits = 16;
  static constexpr uint64_t kSessionIdSequenceBitMask =
      (1 << kSessionIdSequenceBits) - 1;

  // `history_service` is never nullptr and must outlive `this`.
  // Storage uses its `history_dir() location for the database.
  HistoryEmbeddingsService(
      os_crypt_async::OSCryptAsync* os_crypt_async,
      history::HistoryService* history_service,
      page_content_annotations::PageContentAnnotationsService*
          page_content_annotations_service,
      optimization_guide::OptimizationGuideDecider* optimization_guide_decider,
      passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
      passage_embeddings::Embedder* embedder,
      std::unique_ptr<Answerer> answerer,
      std::unique_ptr<IntentClassifier> intent_classifier);
  HistoryEmbeddingsService(const HistoryEmbeddingsService&) = delete;
  HistoryEmbeddingsService& operator=(const HistoryEmbeddingsService&) = delete;
  ~HistoryEmbeddingsService() override;

  // Identify if the given URL is eligible for history embeddings.
  bool IsEligible(const GURL& url);

  // Called by `HistoryEmbeddingsTabHelper` when passage extraction completes.
  // Retrieves existing passages and embeddings for `url_id` from the database
  // before calling
  // `ComputeAndStorePassageEmbeddingsWithExistingData()`.
  void ComputeAndStorePassageEmbeddings(history::URLID url_id,
                                        history::VisitID visit_id,
                                        base::Time visit_time,
                                        std::vector<std::string> passages);

  // Finds the top `count` URL visit info entries nearest to `query`. Passes the
  // results to `callback` when search completes, whether successfully or not.
  // Search will be narrowed to a time range if `time_range_start` is provided.
  // In that case, the start of the time range is inclusive and the end is
  // unbounded. Practically, this can be thought of as [start, now) but now
  // isn't fixed.
  // The `callback` may be called a second time with another search result
  // containing an answer, only if `skip_answering` is false and an answer is
  // successfully generated. This two-phase result callback scheme lets callers
  // receive initial search results without having to wait longer for answers.
  // The `previous_search_result` may be nullptr to signal the beginning of a
  // completely new search session; if it is non-null and the session_id is set,
  // the new session_id is set based on the previous to indicate a continuing
  // search session.
  // Returns a stub result that can be used to detect if a later published
  // SearchResult instance is related to this search.
  // Virtual for testing.
  virtual SearchResult Search(SearchResult* previous_search_result,
                              std::string query,
                              std::optional<base::Time> time_range_start,
                              size_t count,
                              bool skip_answering,
                              SearchResultCallback callback);

  // Weak `this` provider method.
  base::WeakPtr<HistoryEmbeddingsService> AsWeakPtr();

  // Submit quality logging data after user selects an item from search result.
  // Note, the `result` contains a log entry that will be consumed by this call.
  void SendQualityLog(SearchResult& result,
                      std::set<size_t> selections,
                      size_t num_entered_characters,
                      optimization_guide::proto::UserFeedback user_feedback,
                      optimization_guide::proto::UiSurface ui_surface);

  // KeyedService:
  void Shutdown() override;

  // history::HistoryServiceObserver:
  void OnHistoryDeletions(history::HistoryService* history_service,
                          const history::DeletionInfo& deletion_info) override;

  // This can be overridden to gate answer generation for some accounts.
  virtual bool IsAnswererUseAllowed() const;

  // Asynchronously gets passages and embeddings from storage for given
  // `url_id`. Calls `callback` with the data or nullopt if no data is found in
  // the HistoryEmbeddings database.
  void GetUrlData(history::URLID url_id, UrlDataCallback callback) const;

  // Asynchronously gets passages and embeddings from storage where visits
  // are within a given time range. Calls `callback` with the data.
  // The `limit` and `offset` can be used to control data range with
  // standard SQL style paging.
  void GetUrlDataInTimeRange(
      base::Time from_time,
      base::Time to_time,
      size_t limit,
      size_t offset,
      base::OnceCallback<void(std::vector<UrlData>)> callback) const;

  // Targeted deletion for testing scenarios like model version change.
  void DeleteDataForTesting(bool delete_passages,
                            bool delete_embeddings,
                            base::OnceClosure callback);

  // Set a callback to be called when `ProcessAndStorePassages` completes.
  void SetPassagesStoredCallbackForTesting(PassagesStoredCallback callback);

 private:
  friend class HistoryEmbeddingsServicePublic;

  // A utility container to wrap anything that should be accessed on
  // the separate storage worker sequence.
  struct Storage {
    Storage(const base::FilePath& storage_dir,
            bool erase_non_ascii_characters,
            bool delete_embeddings);

    // Associate the given metadata with this Storage instance. The storage is
    // not considered initialized until this metadata is supplied.
    void SetEmbedderMetadata(passage_embeddings::EmbedderMetadata metadata,
                             os_crypt_async::Encryptor encryptor);

    // Called on the worker sequence to persist passages and embeddings.
    void ProcessAndStorePassages(UrlData url_data);

    // Runs search on worker sequence.
    std::vector<ScoredUrlRow> Search(
        base::WeakPtr<std::atomic<size_t>> weak_latest_query_id,
        size_t query_id,
        SearchParams search_params,
        passage_embeddings::Embedding query_embedding,
        std::optional<base::Time> time_range_start,
        size_t count);

    // Handles the History deletions on the worker thread.
    void HandleHistoryDeletions(bool for_all_history,
                                history::URLRows deleted_rows,
                                std::set<history::VisitID> deleted_visit_ids);

    // Targeted deletion for testing scenarios like model version change.
    void DeleteDataForTesting(bool delete_passages, bool delete_embeddings);

    // Gathers URL and passage data from the database where corresponding
    // embeddings are absent. This is used to rebuild the embeddings table
    // when the model changes.
    std::vector<UrlData> CollectPassagesWithoutEmbeddings();

    // Retrieves passages and embeddings from the database for use as a cache
    // to avoid recomputing embeddings that exist for identical passages.
    std::optional<UrlData> GetUrlData(history::URLID url_id);

    // Retrieves passages and embeddings from the database that have visit times
    // within specified range.
    std::vector<UrlData> GetUrlDataInTimeRange(base::Time from_time,
                                               base::Time to_time,
                                               size_t limit,
                                               size_t offset);

    // A VectorDatabase implementation that holds data in memory.
    VectorDatabaseInMemory vector_database;

    // The underlying SQL database for persistent storage.
    SqlDatabase sql_database;
  };

  // passage_embeddings::EmbedderMetadataObserver:
  // Passes the metadata to the internal storage.
  void EmbedderMetadataUpdated(
      passage_embeddings::EmbedderMetadata metadata) override;

  void OnOsCryptAsyncReady(os_crypt_async::Encryptor encryptor);

  // This can be overridden to prepare a log entry that will then be filled
  // with data and sent on destruction. Default implementation returns null.
  virtual QualityLogEntry PrepareQualityLogEntry();

  // Called by `ComputeAndStorePassageEmbeddings()` after retrieving existing
  // passages and embeddings for `url_data.url_id` from the database.
  // `existing_url_data` may be nullopt if no existing data was found.
  void ComputeAndStorePassageEmbeddingsWithExistingData(
      UrlData url_data,
      std::vector<std::string> passages,
      std::optional<base::ElapsedTimer> database_access_timer,
      std::optional<UrlData> existing_url_data);

  // Invoked after the embeddings for `passages` has been computed. Stores the
  // passages along with their embeddings in the database.
  void OnPassagesEmbeddingsComputed(
      UrlData url_passages,
      std::vector<std::string> passages,
      std::vector<passage_embeddings::Embedding> embeddings,
      passage_embeddings::Embedder::TaskId task_id,
      passage_embeddings::ComputeEmbeddingsStatus status);

  // Invoked after the embedding for the original search query has been
  // computed.
  void OnQueryEmbeddingComputed(
      SearchResultCallback callback,
      SearchResult result,
      std::vector<std::string> query_passages,
      std::vector<passage_embeddings::Embedding> query_embedding,
      passage_embeddings::Embedder::TaskId task_id,
      passage_embeddings::ComputeEmbeddingsStatus status);

  // Finishes a search result by combining found data with additional data from
  // history database. Moves each ScoredUrl into a more complete structure with
  // a history URLRow. Omits any entries that don't have corresponding data in
  // the history database.
  void OnSearchCompleted(SearchResultCallback callback,
                         SearchResult result,
                         std::vector<ScoredUrlRow> scored_url_rows);

  // Calls `page_content_annotation_service_` to determine whether the passage
  // of each ScoredUrl should be shown to the user.
  void DeterminePassageVisibility(SearchResultCallback callback,
                                  SearchResult result,
                                  std::vector<ScoredUrlRow> scored_url_rows);

  // Called after `page_content_annotation_service_` has determined visibility
  // for the passage of each ScoredUrl. This will filter `scored_urls` to only
  // contain entries that can be shown to the user.
  void OnPassageVisibilityCalculated(
      SearchResultCallback callback,
      SearchResult result,
      std::vector<ScoredUrlRow> scored_url_rows,
      const std::vector<page_content_annotations::BatchAnnotationResult>&
          annotation_results);

  // Called on main sequence after the history worker thread finalizes
  // the initial search result with URL rows. Calls the `callback` and
  // then proceeds to intent check and v2 answer generation if needed.
  void OnPrimarySearchResultReady(SearchResultCallback callback,
                                  SearchResult result);

  // Invoked after the intent classifier computes query answerability.
  void OnQueryIntentComputed(SearchResultCallback callback,
                             SearchResult result,
                             ComputeIntentStatus status,
                             bool query_is_answerable);

  // Called after the answerer finishes computing an answer. Combines
  // the `answer_result` into `search_result` and invokes `callback`
  // with new search result complete with answer.
  void OnAnswerComputed(base::Time start_time,
                        SearchResultCallback callback,
                        SearchResult search_result,
                        AnswererResult answerer_result);

  // Rebuild absent embeddings from source passages.
  void RebuildAbsentEmbeddings(std::vector<UrlData> all_url_passages);

  // Returns true if query should be filtered. If false, then `search_params`
  // will have its query_terms set.
  bool QueryIsFiltered(const std::string& raw_query,
                       SearchParams& search_params) const;

  raw_ptr<os_crypt_async::OSCryptAsync> os_crypt_async_;

  // The history service is used to fill in details about URLs and visits
  // found via search. It strictly outlives this due to the dependency
  // specified in HistoryEmbeddingsServiceFactory.
  raw_ptr<history::HistoryService> history_service_;

  // The page content annotations service is used to determine whether the
  // content is safe. It strictly outlives this due to the dependency specified
  // in `HistoryEmbeddingsServiceFactory`. Can be nullptr if the underlying
  // capabilities are not supported.
  raw_ptr<page_content_annotations::PageContentAnnotationsService>
      page_content_annotations_service_;

  // Used to determine whether a page should be excluded from history
  // embeddings.
  raw_ptr<optimization_guide::OptimizationGuideDecider>
      optimization_guide_decider_;

  // Tracks the observed history service, for cleanup.
  base::ScopedObservation<history::HistoryService,
                          history::HistoryServiceObserver>
      history_service_observation_{this};

  // The embedder used to compute embeddings. Outlives this.
  raw_ptr<passage_embeddings::Embedder> embedder_;

  // The answerer used to answer queries with context. May be nullptr if
  // the kHistoryEmbeddingsAnswers feature is disabled.
  std::unique_ptr<Answerer> answerer_;

  // The intent classifier used to determine query intent and answerability.
  std::unique_ptr<IntentClassifier> intent_classifier_;

  // Metadata about the embedder; Set when valid metadata is received from
  // `embedder_metadata_provider`.
  passage_embeddings::EmbedderMetadata embedder_metadata_{0, 0};

  // Storage is bound to a separate sequence.
  // This will be null if the feature flag is disabled.
  base::SequenceBound<Storage> storage_;

  // Callback called when `ProcessAndStorePassages` completes. Needed for tests
  // as the blink dependency doesn't have a 'wait for pending requests to
  // complete' mechanism.
  PassagesStoredCallback passages_stored_callback_for_tests_ =
      base::DoNothing();

  // A thread-safe invalidation mechanism to halt searches for stale queries:
  // Each query is run with the current `query_id_` and a weak pointer to the
  // atomic value itself. When it changes, any queries other than the latest
  // can be halted. Note this is not task cancellation, it breaks the inner
  // search loop while running so the atomic is needed for thread safety.
  std::atomic<size_t> query_id_ = 0u;

  // Used to cancel the in-flight embedding task for the previous stale query.
  std::optional<passage_embeddings::Embedder::TaskId> query_embedding_task_id_;

  // Scoped observation for when the embedder metadata is available.
  base::ScopedObservation<passage_embeddings::EmbedderMetadataProvider,
                          passage_embeddings::EmbedderMetadataObserver>
      embedder_metadata_observation_{this};

  base::WeakPtrFactory<std::atomic<size_t>> query_id_weak_ptr_factory_;

  base::WeakPtrFactory<HistoryEmbeddingsService> weak_ptr_factory_;
};

}  // namespace history_embeddings

#endif  // COMPONENTS_HISTORY_EMBEDDINGS_HISTORY_EMBEDDINGS_SERVICE_H_