File: prediction_manager.h

package info (click to toggle)
chromium 139.0.7258.138-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,120,676 kB
  • sloc: cpp: 35,100,869; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (374 lines) | stat: -rw-r--r-- 15,571 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_DELIVERY_PREDICTION_MANAGER_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_DELIVERY_PREDICTION_MANAGER_H_

#include <memory>
#include <string>
#include <vector>

#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/containers/lru_cache.h"
#include "base/files/file_path.h"
#include "base/functional/callback.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/observer_list.h"
#include "base/sequence_checker.h"
#include "base/timer/timer.h"
#include "base/types/optional_ref.h"
#include "components/optimization_guide/core/delivery/model_enums.h"
#include "components/optimization_guide/core/delivery/prediction_model_download_observer.h"
#include "components/optimization_guide/core/delivery/prediction_model_fetch_timer.h"
#include "components/optimization_guide/core/delivery/prediction_model_store.h"
#include "components/optimization_guide/core/optimization_guide_enums.h"
#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "url/origin.h"

namespace download {
class BackgroundDownloadService;
}  // namespace download

namespace network {
class SharedURLLoaderFactory;
}  // namespace network

namespace unzip::mojom {
class Unzipper;
}  // namespace unzip::mojom

namespace unzip {
// TODO: crbug.com/421262905 - Avoid duplicating this alias.
using UnzipperFactory =
    base::RepeatingCallback<mojo::PendingRemote<mojom::Unzipper>()>;
}  // namespace unzip

class OptimizationGuideLogger;
class PrefService;

namespace optimization_guide {

class OptimizationTargetModelObserver;
class PredictionModelDownloadManager;
class PredictionModelFetcher;
class PredictionModelStore;
class ModelInfo;

// A PredictionManager supported by the optimization guide that makes an
// OptimizationTargetDecision by evaluating the corresponding prediction model
// for an OptimizationTarget.
class PredictionManager : public PredictionModelDownloadObserver {
 public:
  // Callback to whether component updates are enabled for the browser.
  using ComponentUpdatesEnabledProvider = base::RepeatingCallback<bool(void)>;

  PredictionManager(
      PredictionModelStore* prediction_model_store,
      scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
      PrefService* pref_service,
      bool off_the_record,
      const std::string& application_locale,
      OptimizationGuideLogger* optimization_guide_logger,
      ComponentUpdatesEnabledProvider component_updates_enabled_provider,
      unzip::UnzipperFactory unzipper_factory);

  PredictionManager(const PredictionManager&) = delete;
  PredictionManager& operator=(const PredictionManager&) = delete;

  ~PredictionManager() override;

  // Adds an observer for updates to the model for |optimization_target|.
  //
  // It is assumed that any model retrieved this way will be passed to the
  // Machine Learning Service for inference.
  void AddObserverForOptimizationTargetModel(
      proto::OptimizationTarget optimization_target,
      const std::optional<proto::Any>& model_metadata,
      OptimizationTargetModelObserver* observer);

  // Removes an observer for updates to the model for |optimization_target|.
  //
  // If |observer| is registered for multiple targets, |observer| must be
  // removed for all observed targets for in order for it to be fully
  // removed from receiving any calls.
  void RemoveObserverForOptimizationTargetModel(
      proto::OptimizationTarget optimization_target,
      OptimizationTargetModelObserver* observer);

  // Set the prediction model fetcher for testing.
  void SetPredictionModelFetcherForTesting(
      std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher);

  PredictionModelFetcher* prediction_model_fetcher() const {
    return prediction_model_fetcher_.get();
  }

  // Set the prediction model download manager for testing.
  void SetPredictionModelDownloadManagerForTesting(
      std::unique_ptr<PredictionModelDownloadManager>
          prediction_model_download_manager);

  PredictionModelDownloadManager* prediction_model_download_manager() const {
    return prediction_model_download_manager_.get();
  }

  // Return the optimization targets that are registered.
  base::flat_set<proto::OptimizationTarget> GetRegisteredOptimizationTargets()
      const;

  // Override the model file returned to observers for |optimization_target|.
  // Use |TestModelInfoBuilder| to construct the model files. For
  // testing purposes only.
  void OverrideTargetModelForTesting(
      proto::OptimizationTarget optimization_target,
      std::unique_ptr<ModelInfo> model_info);

  // PredictionModelDownloadObserver:
  void OnModelReady(const base::FilePath& base_model_dir,
                    const proto::PredictionModel& model) override;
  void OnModelDownloadStarted(
      proto::OptimizationTarget optimization_target) override;
  void OnModelDownloadFailed(
      proto::OptimizationTarget optimization_target) override;

  std::vector<optimization_guide_internals::mojom::DownloadedModelInfoPtr>
  GetDownloadedModelsInfoForWebUI() const;

  base::flat_map<std::string, bool> GetOnDeviceSupplementaryModelsInfoForWebUI()
      const;

  // Initialize the model metadata fetching and downloads.
  void MaybeInitializeModelDownloads(
      download::BackgroundDownloadService* background_download_service);

  PredictionModelFetchTimer* GetPredictionModelFetchTimerForTesting() {
    return &prediction_model_fetch_timer_;
  }

 protected:
  // Process `prediction_models` to be stored in the in memory optimization
  // target prediction model map for immediate use and asynchronously write the
  // models to the model and features store to be persisted.
  // `models_request_info` is the list of models the fetch request was made
  // for, and `prediction_models` is the models received in response. Any models
  // missing in the response will be deleted from the store, since the remote
  // optimization guide service has no models for them.
  void UpdatePredictionModels(
      const std::vector<proto::ModelInfo>& models_request_info,
      const google::protobuf::RepeatedPtrField<proto::PredictionModel>&
          prediction_models);

 private:
  // Contains the model registration specific info to be kept for each
  // optimization target.
  struct ModelRegistrationInfo {
    explicit ModelRegistrationInfo(std::optional<proto::Any> metadata);
    ~ModelRegistrationInfo();

    // The feature-provided metadata that was registered with the prediction
    // manager.
    std::optional<proto::Any> metadata;

    // The set of model observers that were registered to receive model updates
    // from the prediction manager.
    base::ObserverList<OptimizationTargetModelObserver> model_observers;
  };

  friend class PredictionManagerTestBase;
  friend class PredictionModelStoreBrowserTestBase;

  // Called to make a request to fetch models from the remote Optimization Guide
  // Service. Used to fetch models for the registered optimization targets.
  void FetchModels();

  // Callback when the models have been fetched from the remote Optimization
  // Guide Service and are ready for parsing. Processes the prediction models in
  // the response and stores them for use. The metadata entry containing the
  // time that updates should be fetched from the remote Optimization Guide
  // Service is updated, even when the response is empty.
  void OnModelsFetched(const std::vector<proto::ModelInfo> models_request_info,
                       std::optional<std::unique_ptr<proto::GetModelsResponse>>
                           get_models_response_data);

  // Load models for every target in |optimization_targets| that have not yet
  // been loaded from the store.
  void LoadPredictionModels(
      const base::flat_set<proto::OptimizationTarget>& optimization_targets);

  // Callback run after prediction models are stored in
  // `prediction_model_store_`.
  void OnPredictionModelsStored();

  // Callback run after a prediction model is loaded from the store.
  // |prediction_model| is used to construct a PredictionModel capable of making
  // prediction for the appropriate |optimization_target|.
  void OnLoadPredictionModel(
      proto::OptimizationTarget optimization_target,
      bool record_availability_metrics,
      std::unique_ptr<proto::PredictionModel> prediction_model);

  // Callback run after a prediction model is loaded from a command-line
  // override.
  void OnPredictionModelOverrideLoaded(
      proto::OptimizationTarget optimization_target,
      std::unique_ptr<proto::PredictionModel> prediction_model);

  // Process loaded |model| into memory. Return true if a prediction
  // model object was created and successfully stored, otherwise false.
  bool ProcessAndStoreLoadedModel(const proto::PredictionModel& model);

  // Removes the model for `optimization_target` from store, for the
  // `model_removal_reason`.
  void RemoveModelFromStore(
      proto::OptimizationTarget optimization_target,
      PredictionModelStoreModelRemovalReason model_removal_reason);

  // Return whether the model stored in memory for |optimization_target| should
  // be updated based on what's currently stored and |new_version|.
  bool ShouldUpdateStoredModelForTarget(
      proto::OptimizationTarget optimization_target,
      int64_t new_version) const;

  // Updates the in-memory model file for |optimization_target| to
  // |prediction_model_file|.
  void StoreLoadedModelInfo(proto::OptimizationTarget optimization_target,
                            std::unique_ptr<ModelInfo> prediction_model_file);

  // Post-processing callback invoked after processing |model|.
  void OnProcessLoadedModel(const proto::PredictionModel& model, bool success);

  // Return the time when a prediction model fetch was last attempted.
  base::Time GetLastFetchAttemptTime() const;

  // Set the last time when a prediction model fetch was last attempted to
  // |last_attempt_time|.
  void SetLastModelFetchAttemptTime(base::Time last_attempt_time);

  // Return the time when a prediction model fetch was last successfully
  // completed.
  base::Time GetLastFetchSuccessTime() const;

  // Set the last time when a fetch for prediction models last succeeded to
  // |last_success_time|.
  void SetLastModelFetchSuccessTime(base::Time last_success_time);

  // Schedule first fetch for models if enabled for this profile.
  void MaybeScheduleFirstModelFetch();

  // Schedule |fetch_timer_| to fire based on:
  // 1. The update time for models in the store and
  // 2. The last time a fetch attempt was made.
  void ScheduleModelsFetch();

  // Notifies observers of `optimization_target` that the model has been
  // updated. `model_info` will be nullopt when the model was stopped to be
  // served from the server, and removed from the store,
  void NotifyObserversOfNewModel(
      proto::OptimizationTarget optimization_target,
      base::optional_ref<const ModelInfo> model_info);

  // Updates the metadata for |model|.
  void UpdateModelMetadata(const proto::PredictionModel& model);

  // Returns whether the model should be downloaded, or the correct model
  // version already exists in the model store.
  bool ShouldDownloadNewModel(const proto::PredictionModel& model) const;

  // Starts the model download for |optimization_target| from |download_url|.
  void StartModelDownload(proto::OptimizationTarget optimization_target,
                          const GURL& download_url);

  // Start downloading the model if the load failed, or update the model if it
  // is loaded fine.
  void MaybeDownloadOrUpdatePredictionModel(
      proto::OptimizationTarget optimization_target,
      const proto::PredictionModel& get_models_response_model,
      std::unique_ptr<proto::PredictionModel> loaded_model);

  // Returns a new file path for the directory to download the model files for
  // |optimization_target|. The directory will not be created.
  base::FilePath GetBaseModelDirForDownload(
      proto::OptimizationTarget optimization_target);

  void SetModelCacheKeyForTesting(const proto::ModelCacheKey& model_cache_key) {
    model_cache_key_ = model_cache_key;
  }

  // A map of optimization target to the model file containing the model for the
  // target.
  base::flat_map<proto::OptimizationTarget, std::unique_ptr<ModelInfo>>
      optimization_target_model_info_map_ GUARDED_BY_CONTEXT(sequence_checker_);

  // The map from optimization target to the model registration specific data.
  std::map<proto::OptimizationTarget, ModelRegistrationInfo>
      model_registration_info_map_ GUARDED_BY_CONTEXT(sequence_checker_);

  // The fetcher that handles making requests to update the models and host
  // model features from the remote Optimization Guide Service.
  std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher_;

  // The downloader that handles making requests to download the prediction
  // models. Can be null if model downloading is disabled.
  std::unique_ptr<PredictionModelDownloadManager>
      prediction_model_download_manager_;

  // The new optimization guide model store. Will be null when the feature is
  // not enabled. Not owned and outlives |this| since its an install-wide store.
  raw_ptr<PredictionModelStore> prediction_model_store_;

  // A stored response from a model and host model features fetch used to hold
  // models to be stored once host model features are processed and stored.
  std::unique_ptr<proto::GetModelsResponse> get_models_response_data_to_store_;

  // The URL loader factory used for fetching model and host feature updates
  // from the remote Optimization Guide Service.
  scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;

  // The logger that plumbs the debug logs to the optimization guide
  // internals page. Not owned. Guaranteed to outlive |this|, since the logger
  // and |this| are owned by the optimization guide keyed service.
  raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;

  // The repeating callback that will be used to determine if component updates
  // are enabled.
  ComponentUpdatesEnabledProvider component_updates_enabled_provider_;

  // Callback to build Unzipper remotes.
  unzip::UnzipperFactory unzipper_factory_;

  // Time the prediction manager got initialized.
  // TODO(crbug.com/40861855): Remove this old model store once the new model
  // store is launched.
  base::TimeTicks init_time_;

  PredictionModelFetchTimer prediction_model_fetch_timer_
      GUARDED_BY_CONTEXT(sequence_checker_);

  // Whether the profile for this PredictionManager is off the record.
  bool off_the_record_ = false;

  // The locale of the application.
  std::string application_locale_;

  // Model cache key for the profile.
  proto::ModelCacheKey model_cache_key_;

  // The path to the directory containing the models.
  base::FilePath models_dir_path_;

  // Whether to check for Google API key configuration.
  bool should_check_google_api_key_configuration_ = true;

  SEQUENCE_CHECKER(sequence_checker_);

  // Used to get |weak_ptr_| to self on the UI thread.
  base::WeakPtrFactory<PredictionManager> ui_weak_ptr_factory_{this};
};

}  // namespace optimization_guide

#endif  // COMPONENTS_OPTIMIZATION_GUIDE_CORE_DELIVERY_PREDICTION_MANAGER_H_