File: annotator_impl.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (156 lines) | stat: -rw-r--r-- 6,684 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_
#define COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_

#include <optional>
#include <string>
#include <unordered_map>
#include <vector>

#include "base/callback_list.h"
#include "base/files/file_path.h"
#include "base/functional/callback.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/task/sequenced_task_runner.h"
#include "components/browsing_topics/annotator.h"
#include "components/optimization_guide/core/inference/bert_model_handler.h"

namespace optimization_guide {
class OptimizationGuideModelProvider;
}

namespace browsing_topics {

// An implementation of the |Annotator| base class. This Annotator supports
// concurrent batch annotations and manages the lifetimes of all underlying
// components. This class must only be owned and called on the UI thread.
//
// |BatchAnnotate| is the main entry point for callers. The callback given to
// |BatchAnnotate| is forwarded through many subsequent PostTasks until all
// annotations are ready to be returned to the caller.
//
// Life of an Annotation:
// 1. |BatchAnnotate| checks if the override list needs to be loaded. If so, it
// is done on a background thread. After that check and possibly loading the
// list in |OnOverrideListLoadAttemptDone|, |StartBatchAnnotate| is called.
// 2. |StartBatchAnnotate| shares ownership of the |BatchAnnotationCallback|
// among a series of callbacks (using |base::BarrierClosure|), one for each
// input. Ownership of the inputs is moved to the heap where all individual
// model executions can reference their input and set their output.
// 3. |AnnotateSingleInput| runs a single annotation, first checking the
// override list if available. If the input is not covered in the override list,
// the ML model is run on a background thread.
// 4. |PostprocessCategoriesToBatchAnnotationResult| is called to post-process
// the output of the ML model.
// 5. |OnBatchComplete| is called by the barrier closure which passes the
// annotations back to the caller and unloads the model if no other batches are
// in progress.
class AnnotatorImpl : public Annotator,
                      public optimization_guide::BertModelHandler {
 public:
  AnnotatorImpl(
      optimization_guide::OptimizationGuideModelProvider* model_provider,
      scoped_refptr<base::SequencedTaskRunner> background_task_runner,
      const std::optional<optimization_guide::proto::Any>& model_metadata);
  ~AnnotatorImpl() override;

  // Annotator:
  void BatchAnnotate(BatchAnnotationCallback callback,
                     const std::vector<std::string>& inputs) override;
  void NotifyWhenModelAvailable(base::OnceClosure callback) override;
  std::optional<optimization_guide::ModelInfo> GetBrowsingTopicsModelInfo()
      const override;

  //////////////////////////////////////////////////////////////////////////////
  // Public methods below here are exposed only for testing.
  //////////////////////////////////////////////////////////////////////////////

  // optimization_guide::BertModelHandler:
  void OnModelUpdated(
      optimization_guide::proto::OptimizationTarget optimization_target,
      base::optional_ref<const optimization_guide::ModelInfo> model_info)
      override;

  // Extracts the scored categories from the output of the model.
  std::optional<std::vector<int32_t>> ExtractCategoriesFromModelOutput(
      const std::vector<tflite::task::core::Category>& model_output) const;

 protected:
  // optimization_guide::BertModelHandler:
  void UnloadModel() override;

 private:
  // Sets the |override_list_| after it was loaded on a background thread and
  // calls |StartBatchAnnotate|.
  void OnOverrideListLoadAttemptDone(
      BatchAnnotationCallback callback,
      const std::vector<std::string>& inputs,
      std::optional<std::unordered_map<std::string, std::vector<int32_t>>>
          override_list);

  // Starts a batch annotation once the override list is loaded, if provided.
  void StartBatchAnnotate(BatchAnnotationCallback callback,
                          const std::vector<std::string>& inputs);

  // Does the required preprocessing on a input domain.
  std::string PreprocessHost(const std::string& host) const;

  // Runs a single input through the ML model, setting the result in
  // |annotation|.
  void AnnotateSingleInput(base::OnceClosure single_input_done_signal,
                           Annotation* annotation);

  // Called when all single inputs have been annotated and the |callback| from
  // the caller can finally be run.
  void OnBatchComplete(
      BatchAnnotationCallback callback,
      std::unique_ptr<std::vector<Annotation>> annotations_ptr);

  // Sets |annotation.topics| from the output of the model, calling
  // |ExtractCategoriesFromModelOutput| in the process.
  void PostprocessCategoriesToBatchAnnotationResult(
      base::OnceClosure single_input_done_signal,
      Annotation* annotation,
      const std::optional<std::vector<tflite::task::core::Category>>& output);

  // Used to read the override list file on a background thread.
  scoped_refptr<base::SequencedTaskRunner> background_task_runner_;

  // Set whenever a valid override list file is passed along with the model file
  // update. Used on the UI thread.
  std::optional<base::FilePath> override_list_file_path_;

  // Set whenever an override list file is available and the model file is
  // loaded into memory. Reset whenever the model file is unloaded.
  // Used on the UI thread. Lookups in this mapping should have |PreprocessHost|
  // applied first.
  std::optional<std::unordered_map<std::string, std::vector<int32_t>>>
      override_list_;

  // The version of topics model provided by the server in the model metadata
  // which specifies the expected functionality of execution not contained
  // within the model itself (e.g., preprocessing/post processing).
  int version_ = 0;

  // Counts the number of batches that are in progress. This counter is
  // incremented in |StartBatchAnnotate| and decremented in |OnBatchComplete|.
  // When this counter is 0 in |OnBatchComplete|, the model in unloaded from
  // memory.
  size_t in_progess_batches_ = 0;

  // Callbacks that are run when the model is updated with the correct taxonomy
  // version.
  base::OnceClosureList model_available_callbacks_;

  SEQUENCE_CHECKER(sequence_checker_);

  base::WeakPtrFactory<AnnotatorImpl> weak_ptr_factory_{this};
};

}  // namespace browsing_topics

#endif  // COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_