File: on_device_tail_model_executor.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (265 lines) | stat: -rw-r--r-- 10,174 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_EXECUTOR_H_
#define COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_EXECUTOR_H_

#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>

#include "base/containers/lru_cache.h"
#include "base/files/file_path.h"
#include "base/files/memory_mapped_file.h"
#include "base/memory/raw_ptr.h"
#include "base/time/time.h"
#include "components/omnibox/browser/on_device_tail_tokenizer.h"
#include "components/optimization_guide/proto/on_device_tail_suggest_model_metadata.pb.h"
#include "third_party/tflite/src/tensorflow/lite/interpreter.h"
#include "third_party/tflite/src/tensorflow/lite/signature_runner.h"

// The on device tail model executor implements a beam search algorithm
// (https://en.wikipedia.org/wiki/Beam_search) to generate complete suggestions
// for the given prefix.
// At each search step, the executor feeds the token and cell states from the
// previous step into the model to generate the predictions for the next token.
// TODO(crbug.com/40241602): migrate to optimization_guide::TFLiteModelExecutor
// once it supports multi-subgraph model.
class OnDeviceTailModelExecutor {
 public:
  // The struct holds the prediction made by the model and its probability.
  struct Prediction {
    std::string suggestion;
    float probability;
  };

  // The struct holds the input parameters needed to generate predictions from
  // the model.
  struct ModelInput {
    ModelInput();
    ModelInput(std::string prefix,
               std::string previous_query,
               size_t max_num_suggestions);

    std::string prefix;
    std::string previous_query;
    size_t max_num_suggestions;
  };

  using ModelMetadata =
      optimization_guide::proto::OnDeviceTailSuggestModelMetadata;

  OnDeviceTailModelExecutor();
  ~OnDeviceTailModelExecutor();

  // Initializes the model executor.
  bool Init();
  bool Init(const base::FilePath& model_filepath,
            const base::flat_set<base::FilePath>& additional_files,
            const ModelMetadata& metadata);

  // Returns whether the executor is initialized.
  bool IsReady() const { return interpreter_ != nullptr; }

  // Resets the model executor.
  void Reset();

  // Returns at most `max_num_suggestions` suggestions and their probabilities,
  // with minimum probability `probability_threshold` for the given `prefix` and
  // `previous_query`. The given prefix will only be extended at most
  // `max_rnn_steps` times.
  std::vector<Prediction> GenerateSuggestionsForPrefix(const ModelInput& input);

  // Returns the time when the executor is last called.
  base::TimeTicks GetExecutorLastCalledTime() const {
    return executor_last_called_time_;
  }

 private:
  friend class OnDeviceTailModelExecutorPublic;

  struct RnnCellStates {
    RnnCellStates();
    RnnCellStates(size_t num_layer, size_t state_size);
    RnnCellStates(const RnnCellStates& other);
    RnnCellStates(RnnCellStates&& other) noexcept;
    RnnCellStates& operator=(const RnnCellStates& other);
    RnnCellStates& operator=(RnnCellStates&& other) noexcept;
    ~RnnCellStates();

    friend bool operator==(const RnnCellStates&,
                           const RnnCellStates&) = default;

    // Cell states, see definitions at
    // https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_cell.py#L221.
    std::vector<std::vector<float>> c_i;
    std::vector<std::vector<float>> m_i;
  };

  // The struct which holds the output from subgraph `rnn_step_`.
  struct RnnStepOutput {
    RnnStepOutput();
    RnnStepOutput(size_t num_layer, size_t state_size, size_t vocab_size);
    RnnStepOutput(const RnnStepOutput& other);
    ~RnnStepOutput();

    bool operator==(const RnnStepOutput& other) const {
      return states == other.states && probs == other.probs;
    }

    bool operator!=(const RnnStepOutput& other) const {
      return !(*this == other);
    }

    // The output RNN cell states.
    RnnCellStates states;

    // The probability vector; `probs[i]` corresponds to the probability of the
    // i-th token in the vocabulary.
    std::vector<float> probs;
  };

  // The node struct which holds all information needed to run the beam search.
  struct BeamNode {
    BeamNode();
    BeamNode(int num_layer, int state_size);
    BeamNode(const BeamNode& other);
    BeamNode(BeamNode&& other) noexcept;
    BeamNode& operator=(const BeamNode& other);
    BeamNode& operator=(BeamNode&& other) noexcept;
    ~BeamNode();

    bool operator>(const BeamNode& other) const {
      return this->log_prob > other.log_prob;
    }

    // The suggestion token IDs which the beam node is representing.
    OnDeviceTailTokenizer::TokenIds token_ids;

    // The cache key for `rnn_step_cache_` which is the vector of the previous
    // query token IDs plus suggestion token IDs.
    OnDeviceTailTokenizer::TokenIds rnn_step_cache_key;

    // The prefix which has to be met in next expansion.
    std::string constraint_prefix;

    // The output RNN cell states from the last `rnn_step_` invocation.
    RnnCellStates states;

    // The accumulated log probability for the node.
    float log_prob = 0.0;
  };

  // A min priority queue to store beam nodes such that we can conveniently
  // discard nodes with low probability when there are too many candidates.
  using CandidateQueue =
      std::priority_queue<BeamNode, std::vector<BeamNode>, std::greater<>>;

  using TokenIdAndProb = std::pair<OnDeviceTailTokenizer::TokenId, float>;

  // Helper function to initialize TFlite model interpreter.
  bool InitModelInterpreter(const base::FilePath& model_filepath);

  // Gets the encoding for previous query token IDs.
  bool EncodePreviousQuery(
      const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids,
      std::vector<float>* prev_query_encoding);

  // Invokes subgraph `rnn_step_` to get the prediction for the next token.
  bool RunRnnStep(const OnDeviceTailTokenizer::TokenIds& rnn_step_cache_key,
                  const OnDeviceTailTokenizer::TokenId& input_id,
                  const std::vector<float>& prev_query_encoding,
                  const RnnCellStates& previous_states,
                  RnnStepOutput* rnn_step_output);

  // Creates new beams from the current beam and the RNN step output, and pushes
  // them into related candidate queues.
  void CreateNewBeams(const RnnStepOutput& rnn_step_output,
                      const BeamNode& current_beam,
                      size_t max_num_suggestions,
                      float log_prob_threshold,
                      CandidateQueue* partial_candidates,
                      CandidateQueue* completed_candidates);

  // Builds and maybe insert new beam nodes from the given token ID &
  // probability pair into the candidate queue and drop low probability node
  // from the queue if needed.
  void InsertBeamNodeToCandidateQueue(const TokenIdAndProb& token_id_and_prob,
                                      const RnnCellStates& states,
                                      const BeamNode& current_beam,
                                      float log_prob_threshold,
                                      size_t max_num_suggestions,
                                      CandidateQueue* queue);

  // Gets the root beam node by feeding all unambiguous token IDs (except the
  // last token) into the model.
  bool GetRootBeamNode(
      const OnDeviceTailTokenizer::Tokenization& input_tokenization,
      const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids,
      std::vector<float>* prev_query_encoding,
      BeamNode* root_beam);

  // Resets LRU caches.
  void ResetCaches();

  // Helper to calculate log probability.
  static float GetLogProbability(float probability);

  // Loads bad suggestion filter lists from filepaths.
  void LoadBadSubstringSet();
  void LoadBadwordHashSet();

  // Determines if the given suggestion is bad and should be discarded, by
  // checking if the suggestion contain words specified by `badword_hashes_`.
  // Note currently this function might not support CJK language properly as it
  // uses whitespace to split the suggestion.
  // We use this on device filter since this model is an ML model and we do not
  // have a good way to force the model to drop a given result in any
  // circumstance during training.
  bool IsSuggestionBad(const std::string& suggestion);

  // The tokenizer and tensorflow lite model & interpreter instances.
  std::unique_ptr<OnDeviceTailTokenizer> tokenizer_;
  std::unique_ptr<base::MemoryMappedFile> model_fb_;
  std::unique_ptr<tflite::Interpreter> interpreter_;

  // The pointers to subgraphs in the model.
  raw_ptr<tflite::SignatureRunner> prev_query_encoder_;
  raw_ptr<tflite::SignatureRunner> rnn_step_;

  // We use LRU caches to keep track of most recent outputs of subgraphs, such
  // that we will not need to run the interpreter if a cache hit is found for a
  // specific input.
  base::LRUCache<OnDeviceTailTokenizer::TokenIds, std::vector<float>>
      prev_query_cache_;
  base::LRUCache<OnDeviceTailTokenizer::TokenIds, RnnStepOutput>
      rnn_step_cache_;

  // Parameters needed to run the executor.
  size_t state_size_;
  size_t num_layer_;
  size_t embedding_dimension_;
  size_t vocab_size_;
  size_t max_num_steps_;
  float log_probability_threshold_;

  // The time when the executor is last called.
  base::TimeTicks executor_last_called_time_;

  // Files and metadata needed to initialize the model executor;
  base::FilePath model_filepath_;
  base::FilePath vocab_filepath_;
  base::FilePath badword_hashes_filepath_;
  base::FilePath bad_substrings_filepath_;
  optimization_guide::proto::OnDeviceTailSuggestModelMetadata metadata_;

  // The hashes (calculated by base::PersistentHash) of badword and the bad
  // substrings which are encoded by BASE64 used to filter bad suggestions.
  std::set<uint32_t> badword_hashes_;
  std::set<std::string> bad_substrings_;
};

#endif  // COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_EXECUTOR_H_