1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_
#define COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
#include "base/callback_list.h"
#include "base/files/file_path.h"
#include "base/functional/callback.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/task/sequenced_task_runner.h"
#include "components/browsing_topics/annotator.h"
#include "components/optimization_guide/core/inference/bert_model_handler.h"
namespace optimization_guide {
class OptimizationGuideModelProvider;
}
namespace browsing_topics {
// An implementation of the |Annotator| base class. This Annotator supports
// concurrent batch annotations and manages the lifetimes of all underlying
// components. This class must only be owned and called on the UI thread.
//
// |BatchAnnotate| is the main entry point for callers. The callback given to
// |BatchAnnotate| is forwarded through many subsequent PostTasks until all
// annotations are ready to be returned to the caller.
//
// Life of an Annotation:
// 1. |BatchAnnotate| checks if the override list needs to be loaded. If so, it
// is done on a background thread. After that check and possibly loading the
// list in |OnOverrideListLoadAttemptDone|, |StartBatchAnnotate| is called.
// 2. |StartBatchAnnotate| shares ownership of the |BatchAnnotationCallback|
// among a series of callbacks (using |base::BarrierClosure|), one for each
// input. Ownership of the inputs is moved to the heap where all individual
// model executions can reference their input and set their output.
// 3. |AnnotateSingleInput| runs a single annotation, first checking the
// override list if available. If the input is not covered in the override list,
// the ML model is run on a background thread.
// 4. |PostprocessCategoriesToBatchAnnotationResult| is called to post-process
// the output of the ML model.
// 5. |OnBatchComplete| is called by the barrier closure which passes the
// annotations back to the caller and unloads the model if no other batches are
// in progress.
class AnnotatorImpl : public Annotator,
public optimization_guide::BertModelHandler {
public:
AnnotatorImpl(
optimization_guide::OptimizationGuideModelProvider* model_provider,
scoped_refptr<base::SequencedTaskRunner> background_task_runner,
const std::optional<optimization_guide::proto::Any>& model_metadata);
~AnnotatorImpl() override;
// Annotator:
void BatchAnnotate(BatchAnnotationCallback callback,
const std::vector<std::string>& inputs) override;
void NotifyWhenModelAvailable(base::OnceClosure callback) override;
std::optional<optimization_guide::ModelInfo> GetBrowsingTopicsModelInfo()
const override;
//////////////////////////////////////////////////////////////////////////////
// Public methods below here are exposed only for testing.
//////////////////////////////////////////////////////////////////////////////
// optimization_guide::BertModelHandler:
void OnModelUpdated(
optimization_guide::proto::OptimizationTarget optimization_target,
base::optional_ref<const optimization_guide::ModelInfo> model_info)
override;
// Extracts the scored categories from the output of the model.
std::optional<std::vector<int32_t>> ExtractCategoriesFromModelOutput(
const std::vector<tflite::task::core::Category>& model_output) const;
protected:
// optimization_guide::BertModelHandler:
void UnloadModel() override;
private:
// Sets the |override_list_| after it was loaded on a background thread and
// calls |StartBatchAnnotate|.
void OnOverrideListLoadAttemptDone(
BatchAnnotationCallback callback,
const std::vector<std::string>& inputs,
std::optional<std::unordered_map<std::string, std::vector<int32_t>>>
override_list);
// Starts a batch annotation once the override list is loaded, if provided.
void StartBatchAnnotate(BatchAnnotationCallback callback,
const std::vector<std::string>& inputs);
// Does the required preprocessing on a input domain.
std::string PreprocessHost(const std::string& host) const;
// Runs a single input through the ML model, setting the result in
// |annotation|.
void AnnotateSingleInput(base::OnceClosure single_input_done_signal,
Annotation* annotation);
// Called when all single inputs have been annotated and the |callback| from
// the caller can finally be run.
void OnBatchComplete(
BatchAnnotationCallback callback,
std::unique_ptr<std::vector<Annotation>> annotations_ptr);
// Sets |annotation.topics| from the output of the model, calling
// |ExtractCategoriesFromModelOutput| in the process.
void PostprocessCategoriesToBatchAnnotationResult(
base::OnceClosure single_input_done_signal,
Annotation* annotation,
const std::optional<std::vector<tflite::task::core::Category>>& output);
// Used to read the override list file on a background thread.
scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
// Set whenever a valid override list file is passed along with the model file
// update. Used on the UI thread.
std::optional<base::FilePath> override_list_file_path_;
// Set whenever an override list file is available and the model file is
// loaded into memory. Reset whenever the model file is unloaded.
// Used on the UI thread. Lookups in this mapping should have |PreprocessHost|
// applied first.
std::optional<std::unordered_map<std::string, std::vector<int32_t>>>
override_list_;
// The version of topics model provided by the server in the model metadata
// which specifies the expected functionality of execution not contained
// within the model itself (e.g., preprocessing/post processing).
int version_ = 0;
// Counts the number of batches that are in progress. This counter is
// incremented in |StartBatchAnnotate| and decremented in |OnBatchComplete|.
// When this counter is 0 in |OnBatchComplete|, the model in unloaded from
// memory.
size_t in_progess_batches_ = 0;
// Callbacks that are run when the model is updated with the correct taxonomy
// version.
base::OnceClosureList model_available_callbacks_;
SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<AnnotatorImpl> weak_ptr_factory_{this};
};
} // namespace browsing_topics
#endif // COMPONENTS_BROWSING_TOPICS_ANNOTATOR_IMPL_H_
|