File: on_device_head_provider.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (116 lines) | stat: -rw-r--r-- 5,201 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_PROVIDER_H_
#define COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_PROVIDER_H_

#include <memory>

#include "base/callback_list.h"
#include "base/memory/raw_ptr.h"
#include "base/sequence_checker.h"
#include "base/task/sequenced_task_runner.h"
#include "components/omnibox/browser/autocomplete_enums.h"
#include "components/omnibox/browser/autocomplete_provider.h"
#include "components/omnibox/browser/autocomplete_provider_client.h"
#include "components/omnibox/browser/on_device_head_model.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"

// TODO(crbug.com/40241602): clean up this build flag guard later if possible.
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "components/omnibox/browser/on_device_tail_model_executor.h"
#endif

class AutocompleteProviderListener;

// An asynchronous autocomplete provider which receives input string and tries
// to find the matches in an on device head model. This provider is designed to
// help users get suggestions when they are in poor network.
// By default, all matches provided by this provider will have a relevance no
// greater than 99, such that its matches will not show before any other
// providers; However the relevance can be changed to any arbitrary value by
// Finch when the input is not classified as a URL.
// TODO(crbug.com/40241602): rename this provider to "OnDeviceProvider" since it
// will serve both head and tail suggestions.
class OnDeviceHeadProvider : public AutocompleteProvider {
 public:
  static OnDeviceHeadProvider* Create(AutocompleteProviderClient* client,
                                      AutocompleteProviderListener* listener);

  void Start(const AutocompleteInput& input, bool minimal_changes) override;
  void Stop(AutocompleteStopReason stop_reason) override;
  void AddProviderInfo(ProvidersInfo* provider_info) const override;

  AutocompleteProviderClient* client() { return client_; }

 private:
  friend class OnDeviceHeadProviderTest;

  // A useful data structure to store Autocomplete input and suggestions fetched
  // from the on device head model for a search request to the model.
  struct OnDeviceHeadProviderParams;

  OnDeviceHeadProvider(AutocompleteProviderClient* client,
                       AutocompleteProviderListener* listener);
  ~OnDeviceHeadProvider() override;
  OnDeviceHeadProvider(const OnDeviceHeadProvider&) = delete;
  OnDeviceHeadProvider& operator=(const OnDeviceHeadProvider&) = delete;

  bool IsOnDeviceHeadProviderAllowed(const AutocompleteInput& input);

  // Helper functions used for asynchronous search to the on device models.
  // The Autocomplete input and output from the model will be passed from
  // DoSearch to AllSearchDone via the OnDeviceHeadProviderParams object.
  // DoSearch: searches the on device models and returns the tops suggestions
  // matches the given AutocompleteInput.
  void DoSearch(std::unique_ptr<OnDeviceHeadProviderParams> params);
  // AllSearchDone: called after all searches are completed, fills |matches_|
  // with the suggestions fetched by DoSearch and then calls NotifyListeners.
  void AllSearchDone(std::unique_ptr<OnDeviceHeadProviderParams> params);
  // Helper function to be called when searches to the head model is done.
  void HeadModelSearchDone(std::unique_ptr<OnDeviceHeadProviderParams> params);

#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
  // Helper function to be called when searches to the tail model is done.
  void TailModelSearchDone(
      std::unique_ptr<OnDeviceHeadProviderParams> params,
      std::vector<OnDeviceTailModelExecutor::Prediction> predictions);
#endif

  // Helper functions to read head model filename from the static
  // OnDeviceModelUpdateListener instance.
  std::string GetOnDeviceHeadModelFilename() const;

  // Fetches suggestions matching the params from the given on device head
  // model.
  static std::unique_ptr<OnDeviceHeadProviderParams>
  GetSuggestionsFromHeadModel(
      const std::string& model_filename,
      const size_t provider_max_matches,
      std::unique_ptr<OnDeviceHeadProviderParams> params);

  // Determines whether should fetch tail suggestions.
  static bool ShouldFetchTailSuggestions(
      const OnDeviceHeadProviderParams& params,
      const std::string& locale);

  raw_ptr<AutocompleteProviderClient> client_;

  // The task runner dedicated for on device model operations which is added to
  // offload expensive operations out of the UI sequence.
  scoped_refptr<base::SequencedTaskRunner> worker_task_runner_;

  // Sequence checker that ensure autocomplete request handling will only happen
  // on main thread.
  SEQUENCE_CHECKER(main_sequence_checker_);

  // The request id used to trace current request to the on device models.
  // The id will be increased whenever a new request is received from the
  // AutocompleteController.
  size_t on_device_search_request_id_;

  base::WeakPtrFactory<OnDeviceHeadProvider> weak_ptr_factory_{this};
};

#endif  // COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_HEAD_PROVIDER_H_