1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_EXECUTOR_H_
#define COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_EXECUTOR_H_
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include "base/containers/lru_cache.h"
#include "base/files/file_path.h"
#include "base/files/memory_mapped_file.h"
#include "base/memory/raw_ptr.h"
#include "base/time/time.h"
#include "components/omnibox/browser/on_device_tail_tokenizer.h"
#include "components/optimization_guide/proto/on_device_tail_suggest_model_metadata.pb.h"
#include "third_party/tflite/src/tensorflow/lite/interpreter.h"
#include "third_party/tflite/src/tensorflow/lite/signature_runner.h"
// The on device tail model executor implements a beam search algorithm
// (https://en.wikipedia.org/wiki/Beam_search) to generate complete suggestions
// for the given prefix.
// At each search step, the executor feeds the token and cell states from the
// previous step into the model to generate the predictions for the next token.
// TODO(crbug.com/40241602): migrate to optimization_guide::TFLiteModelExecutor
// once it supports multi-subgraph model.
class OnDeviceTailModelExecutor {
public:
// The struct holds the prediction made by the model and its probability.
struct Prediction {
std::string suggestion;
float probability;
};
// The struct holds the input parameters needed to generate predictions from
// the model.
struct ModelInput {
ModelInput();
ModelInput(std::string prefix,
std::string previous_query,
size_t max_num_suggestions);
std::string prefix;
std::string previous_query;
size_t max_num_suggestions;
};
using ModelMetadata =
optimization_guide::proto::OnDeviceTailSuggestModelMetadata;
OnDeviceTailModelExecutor();
~OnDeviceTailModelExecutor();
// Initializes the model executor.
bool Init();
bool Init(const base::FilePath& model_filepath,
const base::flat_set<base::FilePath>& additional_files,
const ModelMetadata& metadata);
// Returns whether the executor is initialized.
bool IsReady() const { return interpreter_ != nullptr; }
// Resets the model executor.
void Reset();
// Returns at most `max_num_suggestions` suggestions and their probabilities,
// with minimum probability `probability_threshold` for the given `prefix` and
// `previous_query`. The given prefix will only be extended at most
// `max_rnn_steps` times.
std::vector<Prediction> GenerateSuggestionsForPrefix(const ModelInput& input);
// Returns the time when the executor is last called.
base::TimeTicks GetExecutorLastCalledTime() const {
return executor_last_called_time_;
}
private:
friend class OnDeviceTailModelExecutorPublic;
struct RnnCellStates {
RnnCellStates();
RnnCellStates(size_t num_layer, size_t state_size);
RnnCellStates(const RnnCellStates& other);
RnnCellStates(RnnCellStates&& other) noexcept;
RnnCellStates& operator=(const RnnCellStates& other);
RnnCellStates& operator=(RnnCellStates&& other) noexcept;
~RnnCellStates();
friend bool operator==(const RnnCellStates&,
const RnnCellStates&) = default;
// Cell states, see definitions at
// https://github.com/tensorflow/lingvo/blob/master/lingvo/core/rnn_cell.py#L221.
std::vector<std::vector<float>> c_i;
std::vector<std::vector<float>> m_i;
};
// The struct which holds the output from subgraph `rnn_step_`.
struct RnnStepOutput {
RnnStepOutput();
RnnStepOutput(size_t num_layer, size_t state_size, size_t vocab_size);
RnnStepOutput(const RnnStepOutput& other);
~RnnStepOutput();
bool operator==(const RnnStepOutput& other) const {
return states == other.states && probs == other.probs;
}
bool operator!=(const RnnStepOutput& other) const {
return !(*this == other);
}
// The output RNN cell states.
RnnCellStates states;
// The probability vector; `probs[i]` corresponds to the probability of the
// i-th token in the vocabulary.
std::vector<float> probs;
};
// The node struct which holds all information needed to run the beam search.
struct BeamNode {
BeamNode();
BeamNode(int num_layer, int state_size);
BeamNode(const BeamNode& other);
BeamNode(BeamNode&& other) noexcept;
BeamNode& operator=(const BeamNode& other);
BeamNode& operator=(BeamNode&& other) noexcept;
~BeamNode();
bool operator>(const BeamNode& other) const {
return this->log_prob > other.log_prob;
}
// The suggestion token IDs which the beam node is representing.
OnDeviceTailTokenizer::TokenIds token_ids;
// The cache key for `rnn_step_cache_` which is the vector of the previous
// query token IDs plus suggestion token IDs.
OnDeviceTailTokenizer::TokenIds rnn_step_cache_key;
// The prefix which has to be met in next expansion.
std::string constraint_prefix;
// The output RNN cell states from the last `rnn_step_` invocation.
RnnCellStates states;
// The accumulated log probability for the node.
float log_prob = 0.0;
};
// A min priority queue to store beam nodes such that we can conveniently
// discard nodes with low probability when there are too many candidates.
using CandidateQueue =
std::priority_queue<BeamNode, std::vector<BeamNode>, std::greater<>>;
using TokenIdAndProb = std::pair<OnDeviceTailTokenizer::TokenId, float>;
// Helper function to initialize TFlite model interpreter.
bool InitModelInterpreter(const base::FilePath& model_filepath);
// Gets the encoding for previous query token IDs.
bool EncodePreviousQuery(
const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids,
std::vector<float>* prev_query_encoding);
// Invokes subgraph `rnn_step_` to get the prediction for the next token.
bool RunRnnStep(const OnDeviceTailTokenizer::TokenIds& rnn_step_cache_key,
const OnDeviceTailTokenizer::TokenId& input_id,
const std::vector<float>& prev_query_encoding,
const RnnCellStates& previous_states,
RnnStepOutput* rnn_step_output);
// Creates new beams from the current beam and the RNN step output, and pushes
// them into related candidate queues.
void CreateNewBeams(const RnnStepOutput& rnn_step_output,
const BeamNode& current_beam,
size_t max_num_suggestions,
float log_prob_threshold,
CandidateQueue* partial_candidates,
CandidateQueue* completed_candidates);
// Builds and maybe insert new beam nodes from the given token ID &
// probability pair into the candidate queue and drop low probability node
// from the queue if needed.
void InsertBeamNodeToCandidateQueue(const TokenIdAndProb& token_id_and_prob,
const RnnCellStates& states,
const BeamNode& current_beam,
float log_prob_threshold,
size_t max_num_suggestions,
CandidateQueue* queue);
// Gets the root beam node by feeding all unambiguous token IDs (except the
// last token) into the model.
bool GetRootBeamNode(
const OnDeviceTailTokenizer::Tokenization& input_tokenization,
const OnDeviceTailTokenizer::TokenIds& prev_query_token_ids,
std::vector<float>* prev_query_encoding,
BeamNode* root_beam);
// Resets LRU caches.
void ResetCaches();
// Helper to calculate log probability.
static float GetLogProbability(float probability);
// Loads bad suggestion filter lists from filepaths.
void LoadBadSubstringSet();
void LoadBadwordHashSet();
// Determines if the given suggestion is bad and should be discarded, by
// checking if the suggestion contain words specified by `badword_hashes_`.
// Note currently this function might not support CJK language properly as it
// uses whitespace to split the suggestion.
// We use this on device filter since this model is an ML model and we do not
// have a good way to force the model to drop a given result in any
// circumstance during training.
bool IsSuggestionBad(const std::string& suggestion);
// The tokenizer and tensorflow lite model & interpreter instances.
std::unique_ptr<OnDeviceTailTokenizer> tokenizer_;
std::unique_ptr<base::MemoryMappedFile> model_fb_;
std::unique_ptr<tflite::Interpreter> interpreter_;
// The pointers to subgraphs in the model.
raw_ptr<tflite::SignatureRunner> prev_query_encoder_;
raw_ptr<tflite::SignatureRunner> rnn_step_;
// We use LRU caches to keep track of most recent outputs of subgraphs, such
// that we will not need to run the interpreter if a cache hit is found for a
// specific input.
base::LRUCache<OnDeviceTailTokenizer::TokenIds, std::vector<float>>
prev_query_cache_;
base::LRUCache<OnDeviceTailTokenizer::TokenIds, RnnStepOutput>
rnn_step_cache_;
// Parameters needed to run the executor.
size_t state_size_;
size_t num_layer_;
size_t embedding_dimension_;
size_t vocab_size_;
size_t max_num_steps_;
float log_probability_threshold_;
// The time when the executor is last called.
base::TimeTicks executor_last_called_time_;
// Files and metadata needed to initialize the model executor;
base::FilePath model_filepath_;
base::FilePath vocab_filepath_;
base::FilePath badword_hashes_filepath_;
base::FilePath bad_substrings_filepath_;
optimization_guide::proto::OnDeviceTailSuggestModelMetadata metadata_;
// The hashes (calculated by base::PersistentHash) of badword and the bad
// substrings which are encoded by BASE64 used to filter bad suggestions.
std::set<uint32_t> badword_hashes_;
std::set<std::string> bad_substrings_;
};
#endif // COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_EXECUTOR_H_
|