File: prediction_model_download_manager.h

package info (click to toggle)
chromium 138.0.7204.183-1~deb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm-proposed-updates
  • size: 6,080,960 kB
  • sloc: cpp: 34,937,079; ansic: 7,176,967; javascript: 4,110,704; python: 1,419,954; asm: 946,768; xml: 739,971; pascal: 187,324; sh: 89,623; perl: 88,663; objc: 79,944; sql: 50,304; cs: 41,786; fortran: 24,137; makefile: 21,811; php: 13,980; tcl: 13,166; yacc: 8,925; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (202 lines) | stat: -rw-r--r-- 7,994 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MODEL_DOWNLOAD_MANAGER_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MODEL_DOWNLOAD_MANAGER_H_

#include <map>
#include <optional>
#include <set>
#include <string>

#include "base/files/file_path.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/observer_list.h"
#include "base/task/sequenced_task_runner.h"
#include "components/download/public/background_service/download_params.h"
#include "components/optimization_guide/core/prediction_model_store.h"
#include "components/optimization_guide/proto/models.pb.h"

namespace download {
class BackgroundDownloadService;
}  // namespace download

namespace optimization_guide {

class PredictionModelDownloadClient;
class PredictionModelDownloadObserver;

namespace proto {
class PredictionModel;
}  // namespace proto

extern const char kPredictionModelOptimizationTargetCustomDataKey[];

// Manages the downloads of prediction models.
// Keep in sync with OptimizationGuidePredictionModelDownloadState in enums.xml.
class PredictionModelDownloadManager {
 public:
  // Callback to get the directory to download models.
  using GetBaseModelDirForDownloadCallback =
      base::RepeatingCallback<base::FilePath(
          proto::OptimizationTarget optimization_target)>;

  // The different states a predition model download goes through.
  enum class PredictionModelDownloadState {
    kUnknown = 0,
    // Model was requested to be downloaded.
    kRequested = 1,
    // Download service started the model download.
    kStarted = 2,

    // Add new values above this line.
    kMaxValue = kStarted,
  };

  PredictionModelDownloadManager(
      download::BackgroundDownloadService* download_service,
      GetBaseModelDirForDownloadCallback
          get_base_model_dir_for_download_callback,
      scoped_refptr<base::SequencedTaskRunner> background_task_runner);
  virtual ~PredictionModelDownloadManager();
  PredictionModelDownloadManager(const PredictionModelDownloadManager&) =
      delete;
  PredictionModelDownloadManager& operator=(
      const PredictionModelDownloadManager&) = delete;

  // Starts a download for |download_url|.
  virtual void StartDownload(const GURL& download_url,
                             proto::OptimizationTarget optimization_target);

  // Verifies the |download_file_path| came from a trusted source and process
  // the downloaded contents. After verification, creates |base_model_dir|.
  // Returns true on success.
  //
  // Must be called on a background thread, as it performs file I/O.
  static bool VerifyDownload(const base::FilePath& download_file_path,
                             const base::FilePath& base_model_dir,
                             bool delete_file_on_error);

  // Cancels all pending downloads.
  virtual void CancelAllPendingDownloads();

  // Returns whether the downloader can download models.
  virtual bool IsAvailableForDownloads() const;

  // Adds and removes observers.
  //
  // All methods called on observers will be invoked on the UI thread.
  virtual void AddObserver(PredictionModelDownloadObserver* observer);
  virtual void RemoveObserver(PredictionModelDownloadObserver* observer);

 private:
  friend class PredictionModelDownloadClient;
  friend class PredictionModelDownloadManagerTest;

  // Invoked when the Download Service is ready.
  //
  // |pending_download_guids| is the set of GUIDs that were previously scheduled
  // to be downloaded and have still not been downloaded yet.
  // |successful_downloads| is the map from GUID to the file path that it was
  // successfully downloaded to.
  void OnDownloadServiceReady(
      const std::set<std::string>& pending_download_guids,
      const std::map<std::string, base::FilePath>& successful_downloads);

  // Invoked when the Download Service fails to initialize and should not be
  // used for the session.
  void OnDownloadServiceUnavailable();

  // Invoked when the download has been accepted and persisted by the
  // BackgroundDownloadService. The download was requested at
  // |download_requested_time| for |optimization_target|.
  void OnDownloadStarted(proto::OptimizationTarget optimization_target,
                         base::TimeTicks download_requested_time,
                         const std::string& guid,
                         download::DownloadParams::StartResult start_result);

  // Invoked when the download as specified by |downloaded_guid| succeeded for
  // |optimization_target|.
  void OnDownloadSucceeded(
      std::optional<proto::OptimizationTarget> optimization_target,
      const std::string& downloaded_guid,
      const base::FilePath& download_file_path);

  // Invoked when the download as specified by |failed_download_guid| failed
  // for |optimization_target|.
  void OnDownloadFailed(
      std::optional<proto::OptimizationTarget> optimization_target,
      const std::string& failed_download_guid);

  // Starts unzipping the contents of |download_file_path|, to |base_model_dir|,
  // when the previous step |is_verify_success| is true.
  void StartUnzipping(proto::OptimizationTarget optimization_target,
                      const base::FilePath& download_file_path,
                      const base::FilePath& base_model_dir,
                      bool is_verify_success);

  // Invoked when the contents of |original_file_path| have been unzipped to
  // |base_model_dir|.
  void OnDownloadUnzipped(proto::OptimizationTarget optimization_target,
                          const base::FilePath& original_file_path,
                          const base::FilePath& base_model_dir,
                          bool success);

  // Processes the contents in |base_model_dir|.
  //
  // Must be called on the background thread, as it performs file I/O. This is a
  // stateless func to avoid needing weird lifetime stuff.
  static std::optional<proto::PredictionModel> ProcessUnzippedContents(
      const base::FilePath& base_model_dir);

  // Notifies |observers_| that a model is ready for |optimization_target|.
  //
  // Must be invoked on the UI thread.
  void NotifyModelReady(proto::OptimizationTarget optimization_target,
                        const base::FilePath& base_model_dir,
                        const std::optional<proto::PredictionModel>& model);

  // Notifies |observers_| that a model download failed for
  // |optimization_target|.
  void NotifyModelDownloadFailed(proto::OptimizationTarget optimization_target);

  // The set of GUIDs that are still pending download.
  std::set<std::string> pending_download_guids_;

  // The Download Service to schedule model downloads with.
  //
  // Guaranteed to outlive |this|.
  raw_ptr<download::BackgroundDownloadService> download_service_;

  // Whether the download service is available.
  bool is_available_for_downloads_;

  // The API key to attach to download requests.
  std::string api_key_;

  // The set of observers to be notified of completed downloads.
  base::ObserverList<PredictionModelDownloadObserver> observers_;

  // Whether the download should be verified. Should only be false for testing.
  bool should_verify_download_ = true;

  // Callback to get the directory to download models.
  GetBaseModelDirForDownloadCallback get_base_model_dir_for_download_callback_;

  // Background thread where download file processing should be performed.
  scoped_refptr<base::SequencedTaskRunner> background_task_runner_;

  // Sequence checker used to verify all public API methods are called on the
  // UI thread.
  SEQUENCE_CHECKER(sequence_checker_);

  // Used to get weak ptr to self on the UI thread.
  base::WeakPtrFactory<PredictionModelDownloadManager> ui_weak_ptr_factory_{
      this};
};

}  // namespace optimization_guide

#endif  // COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MODEL_DOWNLOAD_MANAGER_H_