File: ml_answerer.cc

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (359 lines) | stat: -rw-r--r-- 14,337 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/history_embeddings/ml_answerer.h"

#include <algorithm>

#include "base/barrier_callback.h"
#include "base/memory/scoped_refptr.h"
#include "base/strings/stringprintf.h"
#include "components/history_embeddings/history_embeddings_features.h"
#include "components/optimization_guide/core/model_quality/model_execution_logging_wrappers.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/features/history_answer.pb.h"

namespace history_embeddings {

using ModelExecutionError = optimization_guide::
    OptimizationGuideModelExecutionError::ModelExecutionError;
using optimization_guide::OptimizationGuideModelExecutionError;
using optimization_guide::OptimizationGuideModelStreamingExecutionResult;
using optimization_guide::SessionConfigParams;
using optimization_guide::proto::Answer;
using optimization_guide::proto::HistoryAnswerRequest;
using optimization_guide::proto::Passage;

namespace {

static constexpr std::string kPassageIdToken = "ID";
// Estimated token count of the preamble text in prompt.
static constexpr size_t kPreambleTokenBufferSize = 100u;
// Estimated token count of overhead text per passage.
static constexpr size_t kExtraTokensPerPassage = 10u;

std::string GetPassageIdStr(size_t id) {
  return base::StringPrintf("%04d", static_cast<int>(id));
}

float GetMlAnswerScoreThreshold() {
  return GetFeatureParameters().ml_answerer_min_score;
}

}  // namespace

// Helper struct to bundle raw model input (queries/passages) with its metadata.
struct MlAnswerer::ModelInput {
  // The string content of this input.
  std::string text;
  // Index 0 is reserved for queries, i.e. this index will be 0 iff. this input
  // is a query. If the input is a passage, index will contain the index of the
  // passage in the original passage vector (where lower index means higher
  // relevance), plus 1 to offset for query.
  size_t index;
  // The size of `text` in tokens.
  uint32_t token_count;
};

// Manages sessions for generating an answer for a given query and multiple
// URLs.
class MlAnswerer::SessionManager {
 public:
  using SessionScoreType = std::tuple<int, std::optional<float>>;

  SessionManager(std::string query,
                 Context context,
                 ComputeAnswerCallback callback,
                 base::WeakPtr<ModelQualityLogsUploaderService> logs_uploader)
      : query_(std::move(query)),
        context_(std::move(context)),
        callback_(std::move(callback)),
        origin_task_runner_(base::SequencedTaskRunner::GetCurrentDefault()),
        logs_uploader_(logs_uploader),
        weak_ptr_factory_(this) {}

  ~SessionManager() {
    // Explicitly invalidate weak pointers to prevent callbacks that may be
    // triggered by the destructor logic.
    weak_ptr_factory_.InvalidateWeakPtrs();
    // Run the existing callback if not called yet with canceled status.
    if (!callback_.is_null()) {
      FinishAndResetSessions(AnswererResult(
          ComputeAnswerStatus::kExecutionCancelled, query_, Answer()));
    }
  }

  // Adds a session that contains query and passage context.
  // It exists until this manager resets or gets destroyed.
  void AddSession(
      std::unique_ptr<OptimizationGuideModelExecutor::Session> session,
      std::string url) {
    sessions_.push_back(std::move(session));
    urls_.push_back(url);
  }

  // Runs speculative decoding by first getting scores for each URL candidate
  // and continuing decoding with only the highest scored session.
  void RunSpeculativeDecoding() {
    const size_t num_sessions = GetNumberOfSessions();
    base::OnceCallback<void(const std::vector<SessionScoreType>&)> cb =
        base::BindOnce(&SessionManager::SortAndDecode,
                       weak_ptr_factory_.GetWeakPtr());
    const auto barrier_cb =
        base::BarrierCallback<SessionScoreType>(num_sessions, std::move(cb));
    for (size_t s_index = 0; s_index < num_sessions; s_index++) {
      VLOG(3) << "Running Score for session " << s_index;
      sessions_[s_index]->Score(
          kPassageIdToken, base::BindOnce(
                               [](size_t index, std::optional<float> score) {
                                 VLOG(3) << "Score complete for " << index;
                                 return std::make_tuple(index, score);
                               },
                               s_index)
                               .Then(barrier_cb));
    }
  }

  size_t GetNumberOfSessions() { return sessions_.size(); }

  base::WeakPtr<MlAnswerer::SessionManager> GetWeakPtr() {
    return weak_ptr_factory_.GetWeakPtr();
  }

  // Runs callback with result.
  void FinishCallback(AnswererResult answer_result) {
    CHECK(!callback_.is_null());
    origin_task_runner_->PostTask(
        FROM_HERE,
        base::BindOnce(std::move(callback_), std::move(answer_result)));
  }

  // Finishes and cleans up sessions.
  void FinishAndResetSessions(AnswererResult answer_result) {
    FinishCallback(std::move(answer_result));

    // Destroy all existing sessions.
    VLOG(3) << "Sessions cleared.";
    sessions_.clear();
    urls_.clear();
  }

  // Called when all sessions are started and added.
  void OnSessionsStarted(std::vector<int> args) { RunSpeculativeDecoding(); }

  // Called when token counts of the query and all passages of a session are
  // computed.
  void OnTokenCountRetrieved(std::unique_ptr<Session> session,
                             const std::string url,
                             base::OnceCallback<void(int)> session_added_cb,
                             std::vector<ModelInput> inputs) {
    HistoryAnswerRequest request;
    int token_limit = session->GetTokenLimits().min_context_tokens;
    // Reserve space for preamble text.
    int token_count = kPreambleTokenBufferSize;

    // Sort the inputs according to their indices in the original vector, so
    // we prioritize passages that are more relevant.
    std::ranges::sort(
        inputs.begin(), inputs.end(),
        [](ModelInput& i1, ModelInput& i2) { return i1.index < i2.index; });

    // Add the query to the request. The query will always have index 0.
    token_count += inputs[0].token_count;
    request.set_query(inputs[0].text);

    // Add as many passages as the input window can fit.
    for (size_t i = 1; i < inputs.size(); ++i) {
      token_count += (inputs[i].token_count + kExtraTokensPerPassage);
      if (token_count > token_limit) {
        break;
      }

      auto* passage = request.add_passages();
      passage->set_text(inputs[i].text);
      passage->set_passage_id(GetPassageIdStr(i));
    }

    VLOG(3) << "Running AddContext for query: `" << request.query() << "`";
    session->AddContext(request);
    AddSession(std::move(session), url);
    std::move(session_added_cb).Run(1);
  }

 private:
  // Callback to be repeatedly called during streaming execution.
  void StreamingExecutionCallback(
      size_t session_index,
      optimization_guide::OptimizationGuideModelStreamingExecutionResult result,
      std::unique_ptr<optimization_guide::proto::HistoryAnswerLoggingData>
          logging_data) {
    auto log_entry = std::make_unique<optimization_guide::ModelQualityLogEntry>(
        logs_uploader_);
    log_entry->log_ai_data_request()->set_allocated_history_answer(
        logging_data.release());
    if (!result.response.has_value()) {
      ComputeAnswerStatus status = ComputeAnswerStatus::kExecutionFailure;
      auto error = result.response.error().error();
      if (error == ModelExecutionError::kFiltered) {
        status = ComputeAnswerStatus::kFiltered;
      }
      FinishCallback(AnswererResult(status, query_, Answer(),
                                    std::move(log_entry), "", {}));
    } else if (result.response->is_complete) {
      auto response = optimization_guide::ParsedAnyMetadata<
          optimization_guide::proto::HistoryAnswerResponse>(
          std::move(result.response).value().response);
      AnswererResult answerer_result(ComputeAnswerStatus::kSuccess, query_,
                                     response->answer(), std::move(log_entry),
                                     urls_[session_index], {});
      answerer_result.PopulateScrollToTextFragment(
          context_.url_passages_map[answerer_result.url]);
      FinishCallback(std::move(answerer_result));
    }
  }

  // Decodes with the highest scored session.
  void SortAndDecode(const std::vector<SessionScoreType>& session_scores) {
    size_t max_index = session_scores.size();
    float max_score = 0.0;
    for (size_t i = 0; i < session_scores.size(); i++) {
      const std::optional<float> score = std::get<1>(session_scores[i]);
      if (score.has_value()) {
        VLOG(3) << "Session: " << std::get<0>(session_scores[i])
                << " Score: " << *score;
        VLOG(3) << "URL: " << urls_[std::get<0>(session_scores[i])];
        if (*score > max_score) {
          max_score = *score;
          max_index = i;
        }
      }
    }

    if (max_index == session_scores.size()) {
      FinishAndResetSessions(AnswererResult{
          ComputeAnswerStatus::kExecutionFailure, query_, Answer()});
      return;
    }

    // Return unanswerable status due to highest score is below the threshold.
    if (max_score < GetMlAnswerScoreThreshold()) {
      FinishAndResetSessions(
          AnswererResult{ComputeAnswerStatus::kUnanswerable, query_, Answer()});
      return;
    }

    // Continue decoding using the session with the highest score.
    // Use a dummy request here since both passages and query are already added
    // to context.
    if (!sessions_.empty()) {
      optimization_guide::proto::HistoryAnswerRequest request;
      const size_t session_index = std::get<0>(session_scores[max_index]);
      VLOG(3) << "Running ExecuteModel for session " << session_index;
      optimization_guide::ExecuteModelSessionWithLogging(
          sessions_[session_index].get(), request,
          base::BindRepeating(&SessionManager::StreamingExecutionCallback,
                              weak_ptr_factory_.GetWeakPtr(), session_index));
    } else {
      // If sessions are already cleaned up, run callback with canceled status.
      FinishAndResetSessions(AnswererResult{
          ComputeAnswerStatus::kExecutionCancelled, query_, Answer()});
    }
  }

  std::vector<std::unique_ptr<OptimizationGuideModelExecutor::Session>>
      sessions_;
  // URLs associated with sessions by index.
  std::vector<std::string> urls_;
  std::string query_;
  Context context_;
  ComputeAnswerCallback callback_;
  const scoped_refptr<base::SequencedTaskRunner> origin_task_runner_;
  base::WeakPtr<ModelQualityLogsUploaderService> logs_uploader_;
  base::WeakPtrFactory<SessionManager> weak_ptr_factory_;
};

MlAnswerer::MlAnswerer(OptimizationGuideModelExecutor* model_executor,
                       ModelQualityLogsUploaderService* logs_uploader)
    : model_executor_(model_executor) {
  if (logs_uploader) {
    logs_uploader_ = logs_uploader->GetWeakPtr();
  }
}

MlAnswerer::~MlAnswerer() = default;

int64_t MlAnswerer::GetModelVersion() {
  // This can be replaced with the real implementation.
  return 0;
}

void MlAnswerer::ComputeAnswer(std::string query,
                               Context context,
                               ComputeAnswerCallback callback) {
  CHECK(model_executor_);

  // Assign a new session manager (and destroy the existing one).
  session_manager_ = std::make_unique<SessionManager>(
      query, context, std::move(callback), logs_uploader_);

  const auto sessions_started_callback = base::BarrierCallback<int>(
      context.url_passages_map.size(),
      base::BindOnce(&MlAnswerer::SessionManager::OnSessionsStarted,
                     session_manager_->GetWeakPtr()));

  const SessionConfigParams session_config{
      .execution_mode = SessionConfigParams::ExecutionMode::kOnDeviceOnly};
  // Start a session for each URL.
  for (const auto& url_and_passages : context.url_passages_map) {
    std::unique_ptr<Session> session = model_executor_->StartSession(
        optimization_guide::ModelBasedCapabilityKey::kHistorySearch,
        session_config);
    if (session == nullptr) {
      session_manager_->FinishAndResetSessions(AnswererResult(
          ComputeAnswerStatus::kModelUnavailable, query, Answer()));
      return;
    }

    StartAndAddSession(query, url_and_passages.first, url_and_passages.second,
                       std::move(session), sessions_started_callback);
  }
}

void MlAnswerer::StartAndAddSession(
    const std::string& query,
    const std::string& url,
    const std::vector<std::string>& passages,
    std::unique_ptr<Session> session,
    base::OnceCallback<void(int)> session_started) {
  Session* raw_session = session.get();
  const auto token_count_callback = base::BarrierCallback<ModelInput>(
      passages.size() + 1,  // We need token count for passages + query.
      base::BindOnce(&MlAnswerer::SessionManager::OnTokenCountRetrieved,
                     session_manager_->GetWeakPtr(), std::move(session), url,
                     std::move(session_started)));

  const auto make_model_input = [](std::string text, size_t index,
                                   std::optional<uint32_t> token_count) {
    VLOG(3) << "Created model input for " << index;
    return ModelInput{text, index, token_count.value_or(0)};
  };

  // Get token count for query, always assign index 0 to query to make a
  // ModelInput.
  raw_session->GetSizeInTokens(
      query,
      base::BindOnce(make_model_input, query, 0).Then(token_count_callback));

  // Get token count for passages, and assign their index + 1 to make
  // ModelInput, in order to reserve index 0 for query.
  VLOG(3) << "Running GetSizeInTokens for " << passages.size() << " passages..";
  for (size_t i = 0; i < passages.size(); ++i) {
    raw_session->GetSizeInTokens(
        passages[i], base::BindOnce(make_model_input, passages[i], i + 1)
                         .Then(token_count_callback));
  }
}

}  // namespace history_embeddings