1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
|
// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_DELIVERY_PREDICTION_MANAGER_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_DELIVERY_PREDICTION_MANAGER_H_
#include <memory>
#include <string>
#include <vector>
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/containers/lru_cache.h"
#include "base/files/file_path.h"
#include "base/functional/callback.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/observer_list.h"
#include "base/sequence_checker.h"
#include "base/timer/timer.h"
#include "base/types/optional_ref.h"
#include "components/optimization_guide/core/delivery/model_enums.h"
#include "components/optimization_guide/core/delivery/prediction_model_download_observer.h"
#include "components/optimization_guide/core/delivery/prediction_model_fetch_timer.h"
#include "components/optimization_guide/core/delivery/prediction_model_store.h"
#include "components/optimization_guide/core/optimization_guide_enums.h"
#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "url/origin.h"
namespace download {
class BackgroundDownloadService;
} // namespace download
namespace network {
class SharedURLLoaderFactory;
} // namespace network
namespace unzip::mojom {
class Unzipper;
} // namespace unzip::mojom
namespace unzip {
// TODO: crbug.com/421262905 - Avoid duplicating this alias.
using UnzipperFactory =
base::RepeatingCallback<mojo::PendingRemote<mojom::Unzipper>()>;
} // namespace unzip
class OptimizationGuideLogger;
class PrefService;
namespace optimization_guide {
class OptimizationTargetModelObserver;
class PredictionModelDownloadManager;
class PredictionModelFetcher;
class PredictionModelStore;
class ModelInfo;
// A PredictionManager supported by the optimization guide that makes an
// OptimizationTargetDecision by evaluating the corresponding prediction model
// for an OptimizationTarget.
class PredictionManager : public PredictionModelDownloadObserver {
public:
// Callback to whether component updates are enabled for the browser.
using ComponentUpdatesEnabledProvider = base::RepeatingCallback<bool(void)>;
PredictionManager(
PredictionModelStore* prediction_model_store,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
PrefService* pref_service,
bool off_the_record,
const std::string& application_locale,
OptimizationGuideLogger* optimization_guide_logger,
ComponentUpdatesEnabledProvider component_updates_enabled_provider,
unzip::UnzipperFactory unzipper_factory);
PredictionManager(const PredictionManager&) = delete;
PredictionManager& operator=(const PredictionManager&) = delete;
~PredictionManager() override;
// Adds an observer for updates to the model for |optimization_target|.
//
// It is assumed that any model retrieved this way will be passed to the
// Machine Learning Service for inference.
void AddObserverForOptimizationTargetModel(
proto::OptimizationTarget optimization_target,
const std::optional<proto::Any>& model_metadata,
OptimizationTargetModelObserver* observer);
// Removes an observer for updates to the model for |optimization_target|.
//
// If |observer| is registered for multiple targets, |observer| must be
// removed for all observed targets for in order for it to be fully
// removed from receiving any calls.
void RemoveObserverForOptimizationTargetModel(
proto::OptimizationTarget optimization_target,
OptimizationTargetModelObserver* observer);
// Set the prediction model fetcher for testing.
void SetPredictionModelFetcherForTesting(
std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher);
PredictionModelFetcher* prediction_model_fetcher() const {
return prediction_model_fetcher_.get();
}
// Set the prediction model download manager for testing.
void SetPredictionModelDownloadManagerForTesting(
std::unique_ptr<PredictionModelDownloadManager>
prediction_model_download_manager);
PredictionModelDownloadManager* prediction_model_download_manager() const {
return prediction_model_download_manager_.get();
}
// Return the optimization targets that are registered.
base::flat_set<proto::OptimizationTarget> GetRegisteredOptimizationTargets()
const;
// Override the model file returned to observers for |optimization_target|.
// Use |TestModelInfoBuilder| to construct the model files. For
// testing purposes only.
void OverrideTargetModelForTesting(
proto::OptimizationTarget optimization_target,
std::unique_ptr<ModelInfo> model_info);
// PredictionModelDownloadObserver:
void OnModelReady(const base::FilePath& base_model_dir,
const proto::PredictionModel& model) override;
void OnModelDownloadStarted(
proto::OptimizationTarget optimization_target) override;
void OnModelDownloadFailed(
proto::OptimizationTarget optimization_target) override;
std::vector<optimization_guide_internals::mojom::DownloadedModelInfoPtr>
GetDownloadedModelsInfoForWebUI() const;
base::flat_map<std::string, bool> GetOnDeviceSupplementaryModelsInfoForWebUI()
const;
// Initialize the model metadata fetching and downloads.
void MaybeInitializeModelDownloads(
download::BackgroundDownloadService* background_download_service);
PredictionModelFetchTimer* GetPredictionModelFetchTimerForTesting() {
return &prediction_model_fetch_timer_;
}
protected:
// Process `prediction_models` to be stored in the in memory optimization
// target prediction model map for immediate use and asynchronously write the
// models to the model and features store to be persisted.
// `models_request_info` is the list of models the fetch request was made
// for, and `prediction_models` is the models received in response. Any models
// missing in the response will be deleted from the store, since the remote
// optimization guide service has no models for them.
void UpdatePredictionModels(
const std::vector<proto::ModelInfo>& models_request_info,
const google::protobuf::RepeatedPtrField<proto::PredictionModel>&
prediction_models);
private:
// Contains the model registration specific info to be kept for each
// optimization target.
struct ModelRegistrationInfo {
explicit ModelRegistrationInfo(std::optional<proto::Any> metadata);
~ModelRegistrationInfo();
// The feature-provided metadata that was registered with the prediction
// manager.
std::optional<proto::Any> metadata;
// The set of model observers that were registered to receive model updates
// from the prediction manager.
base::ObserverList<OptimizationTargetModelObserver> model_observers;
};
friend class PredictionManagerTestBase;
friend class PredictionModelStoreBrowserTestBase;
// Called to make a request to fetch models from the remote Optimization Guide
// Service. Used to fetch models for the registered optimization targets.
void FetchModels();
// Callback when the models have been fetched from the remote Optimization
// Guide Service and are ready for parsing. Processes the prediction models in
// the response and stores them for use. The metadata entry containing the
// time that updates should be fetched from the remote Optimization Guide
// Service is updated, even when the response is empty.
void OnModelsFetched(const std::vector<proto::ModelInfo> models_request_info,
std::optional<std::unique_ptr<proto::GetModelsResponse>>
get_models_response_data);
// Load models for every target in |optimization_targets| that have not yet
// been loaded from the store.
void LoadPredictionModels(
const base::flat_set<proto::OptimizationTarget>& optimization_targets);
// Callback run after prediction models are stored in
// `prediction_model_store_`.
void OnPredictionModelsStored();
// Callback run after a prediction model is loaded from the store.
// |prediction_model| is used to construct a PredictionModel capable of making
// prediction for the appropriate |optimization_target|.
void OnLoadPredictionModel(
proto::OptimizationTarget optimization_target,
bool record_availability_metrics,
std::unique_ptr<proto::PredictionModel> prediction_model);
// Callback run after a prediction model is loaded from a command-line
// override.
void OnPredictionModelOverrideLoaded(
proto::OptimizationTarget optimization_target,
std::unique_ptr<proto::PredictionModel> prediction_model);
// Process loaded |model| into memory. Return true if a prediction
// model object was created and successfully stored, otherwise false.
bool ProcessAndStoreLoadedModel(const proto::PredictionModel& model);
// Removes the model for `optimization_target` from store, for the
// `model_removal_reason`.
void RemoveModelFromStore(
proto::OptimizationTarget optimization_target,
PredictionModelStoreModelRemovalReason model_removal_reason);
// Return whether the model stored in memory for |optimization_target| should
// be updated based on what's currently stored and |new_version|.
bool ShouldUpdateStoredModelForTarget(
proto::OptimizationTarget optimization_target,
int64_t new_version) const;
// Updates the in-memory model file for |optimization_target| to
// |prediction_model_file|.
void StoreLoadedModelInfo(proto::OptimizationTarget optimization_target,
std::unique_ptr<ModelInfo> prediction_model_file);
// Post-processing callback invoked after processing |model|.
void OnProcessLoadedModel(const proto::PredictionModel& model, bool success);
// Return the time when a prediction model fetch was last attempted.
base::Time GetLastFetchAttemptTime() const;
// Set the last time when a prediction model fetch was last attempted to
// |last_attempt_time|.
void SetLastModelFetchAttemptTime(base::Time last_attempt_time);
// Return the time when a prediction model fetch was last successfully
// completed.
base::Time GetLastFetchSuccessTime() const;
// Set the last time when a fetch for prediction models last succeeded to
// |last_success_time|.
void SetLastModelFetchSuccessTime(base::Time last_success_time);
// Schedule first fetch for models if enabled for this profile.
void MaybeScheduleFirstModelFetch();
// Schedule |fetch_timer_| to fire based on:
// 1. The update time for models in the store and
// 2. The last time a fetch attempt was made.
void ScheduleModelsFetch();
// Notifies observers of `optimization_target` that the model has been
// updated. `model_info` will be nullopt when the model was stopped to be
// served from the server, and removed from the store,
void NotifyObserversOfNewModel(
proto::OptimizationTarget optimization_target,
base::optional_ref<const ModelInfo> model_info);
// Updates the metadata for |model|.
void UpdateModelMetadata(const proto::PredictionModel& model);
// Returns whether the model should be downloaded, or the correct model
// version already exists in the model store.
bool ShouldDownloadNewModel(const proto::PredictionModel& model) const;
// Starts the model download for |optimization_target| from |download_url|.
void StartModelDownload(proto::OptimizationTarget optimization_target,
const GURL& download_url);
// Start downloading the model if the load failed, or update the model if it
// is loaded fine.
void MaybeDownloadOrUpdatePredictionModel(
proto::OptimizationTarget optimization_target,
const proto::PredictionModel& get_models_response_model,
std::unique_ptr<proto::PredictionModel> loaded_model);
// Returns a new file path for the directory to download the model files for
// |optimization_target|. The directory will not be created.
base::FilePath GetBaseModelDirForDownload(
proto::OptimizationTarget optimization_target);
void SetModelCacheKeyForTesting(const proto::ModelCacheKey& model_cache_key) {
model_cache_key_ = model_cache_key;
}
// A map of optimization target to the model file containing the model for the
// target.
base::flat_map<proto::OptimizationTarget, std::unique_ptr<ModelInfo>>
optimization_target_model_info_map_ GUARDED_BY_CONTEXT(sequence_checker_);
// The map from optimization target to the model registration specific data.
std::map<proto::OptimizationTarget, ModelRegistrationInfo>
model_registration_info_map_ GUARDED_BY_CONTEXT(sequence_checker_);
// The fetcher that handles making requests to update the models and host
// model features from the remote Optimization Guide Service.
std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher_;
// The downloader that handles making requests to download the prediction
// models. Can be null if model downloading is disabled.
std::unique_ptr<PredictionModelDownloadManager>
prediction_model_download_manager_;
// The new optimization guide model store. Will be null when the feature is
// not enabled. Not owned and outlives |this| since its an install-wide store.
raw_ptr<PredictionModelStore> prediction_model_store_;
// A stored response from a model and host model features fetch used to hold
// models to be stored once host model features are processed and stored.
std::unique_ptr<proto::GetModelsResponse> get_models_response_data_to_store_;
// The URL loader factory used for fetching model and host feature updates
// from the remote Optimization Guide Service.
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
// The logger that plumbs the debug logs to the optimization guide
// internals page. Not owned. Guaranteed to outlive |this|, since the logger
// and |this| are owned by the optimization guide keyed service.
raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;
// The repeating callback that will be used to determine if component updates
// are enabled.
ComponentUpdatesEnabledProvider component_updates_enabled_provider_;
// Callback to build Unzipper remotes.
unzip::UnzipperFactory unzipper_factory_;
// Time the prediction manager got initialized.
// TODO(crbug.com/40861855): Remove this old model store once the new model
// store is launched.
base::TimeTicks init_time_;
PredictionModelFetchTimer prediction_model_fetch_timer_
GUARDED_BY_CONTEXT(sequence_checker_);
// Whether the profile for this PredictionManager is off the record.
bool off_the_record_ = false;
// The locale of the application.
std::string application_locale_;
// Model cache key for the profile.
proto::ModelCacheKey model_cache_key_;
// The path to the directory containing the models.
base::FilePath models_dir_path_;
// Whether to check for Google API key configuration.
bool should_check_google_api_key_configuration_ = true;
SEQUENCE_CHECKER(sequence_checker_);
// Used to get |weak_ptr_| to self on the UI thread.
base::WeakPtrFactory<PredictionManager> ui_weak_ptr_factory_{this};
};
} // namespace optimization_guide
#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_DELIVERY_PREDICTION_MANAGER_H_
|