File: ai_model_download_progress_manager.h

package info (click to toggle)
chromium 140.0.7339.127-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,192,880 kB
  • sloc: cpp: 35,093,808; ansic: 7,161,670; javascript: 4,199,694; python: 1,441,797; asm: 949,904; xml: 747,503; pascal: 187,748; perl: 88,691; sh: 88,248; objc: 79,953; sql: 52,714; cs: 44,599; fortran: 24,137; makefile: 22,114; tcl: 15,277; php: 13,980; yacc: 9,000; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (163 lines) | stat: -rw-r--r-- 5,561 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROME_BROWSER_AI_AI_MODEL_DOWNLOAD_PROGRESS_MANAGER_H_
#define CHROME_BROWSER_AI_AI_MODEL_DOWNLOAD_PROGRESS_MANAGER_H_

#include <memory>

#include "base/containers/unique_ptr_adapters.h"
#include "base/memory/raw_ref.h"
#include "base/memory/weak_ptr.h"
#include "base/scoped_observation.h"
#include "base/types/id_type.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom.h"

namespace on_device_ai {

// Manages a set of `ModelDownloadProgressObserver`s and sends them download
// progress updates for their respective components.
class AIModelDownloadProgressManager {
 public:
  // A component can be implemented to report progress for any resource or
  // operation. When added to `AIModelDownloadProgressManager` via
  // `AddObserver`, it will report its progress updates to the respective
  // `ModelDownloadProgressObserver`.
  class Component {
   public:
    Component();
    virtual ~Component();

    // Move only.
    Component(Component&&);
    Component& operator=(Component&&) = default;

   protected:
    // The implementer calls these when downloaded bytes is changed. Downloaded
    // bytes should only ever monotonically increase.
    void SetDownloadedBytes(int64_t downloaded_bytes);

    // The implementer calls this when total bytes has been determined. Total
    // bytes should never change after its been determined.
    void SetTotalBytes(int64_t total_bytes);

   private:
    friend AIModelDownloadProgressManager;

    using EventCallback = base::RepeatingCallback<void(Component&)>;

    // Only call if `determined_bytes()` is true.
    int64_t downloaded_bytes() const {
      CHECK(determined_bytes());
      return downloaded_bytes_.value();
    }

    int64_t total_bytes() const {
      CHECK(determined_bytes());
      return total_bytes_.value();
    }

    // True if both total and downloaded bytes are determined and they equal
    // each other.
    bool is_complete() const {
      return determined_bytes() &&
             (total_bytes_.value() == downloaded_bytes_.value());
    }

    // Returns true if both total and downloaded bytes are determined.
    bool determined_bytes() const {
      return downloaded_bytes_.has_value() && total_bytes_.has_value();
    }

    // `AIModelDownloadProgressManager` sets the event callback.
    void SetEventCallback(EventCallback event_callback);
    void MaybeRunEventCallback();

    std::optional<int64_t> downloaded_bytes_;
    std::optional<int64_t> total_bytes_;

    // Called anytime `SetDownloadedBytes()` or `SetTotalBytes()` is called.
    EventCallback event_callback_;
  };

  AIModelDownloadProgressManager();
  ~AIModelDownloadProgressManager();

  // Not copyable or movable.
  AIModelDownloadProgressManager(const AIModelDownloadProgressManager&) =
      delete;
  AIModelDownloadProgressManager& operator=(
      const AIModelDownloadProgressManager&) = delete;

  // Adds a `ModelDownloadProgressObserver` to send progress updates for
  // `components`.
  void AddObserver(
      mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
          observer_remote,
      base::flat_set<std::unique_ptr<Component>> components);

  int GetNumberOfReporters();

 private:
  // Observes progress updates from `components`, filters and processes them,
  // and reports the result to `observer_remote`.
  class Reporter {
   public:
    Reporter(AIModelDownloadProgressManager& manager,
             mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
                 observer_remote,
             base::flat_set<std::unique_ptr<Component>> components);
    ~Reporter();

    // Not copyable or movable.
    Reporter(const Reporter&) = delete;
    Reporter& operator=(const Reporter&) = delete;

    void OnEvent(Component& component);

   private:
    void OnRemoteDisconnect();
    void ProcessEvent(const Component& component);
    int64_t GetDownloadedBytes();

    // `manager_` owns `this`.
    base::raw_ref<AIModelDownloadProgressManager> manager_;

    mojo::Remote<blink::mojom::ModelDownloadProgressObserver> observer_remote_;

    // The components we're reporting the progress for.
    base::flat_set<std::unique_ptr<Component>> components_;

    // Map of the components to their observed downloaded bytes. Also serves as
    // a way to keep track of what components we've observed the total bytes of.
    //
    // `raw_ptr` safe since `this` owns the `Component` in `components_` and
    // `components_` and all its members outlive `observed_downloaded_bytes_`.
    std::map<raw_ptr<const Component>, int64_t> observed_downloaded_bytes_;

    // Sum of all observed components' total_bytes.
    int64_t components_total_bytes_ = 0;

    // The bytes already downloaded before we determined the `total_bytes_`.
    int64_t already_downloaded_bytes_ = 0;

    // True if we know the total bytes of the components we'll be watching.
    // Meaning we can start reporting.
    bool ready_to_report_ = false;

    int last_reported_progress_ = 0;
    base::TimeTicks last_progress_time_;

    base::WeakPtrFactory<Reporter> weak_ptr_factory_{this};
  };

  void RemoveReporter(Reporter* reporter);

  base::flat_set<std::unique_ptr<Reporter>, base::UniquePtrComparator>
      reporters_;
};
}  // namespace on_device_ai

#endif  // CHROME_BROWSER_AI_AI_MODEL_DOWNLOAD_PROGRESS_MANAGER_H_