1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_AI_AI_MODEL_DOWNLOAD_PROGRESS_MANAGER_H_
#define CHROME_BROWSER_AI_AI_MODEL_DOWNLOAD_PROGRESS_MANAGER_H_
#include <memory>
#include "base/containers/unique_ptr_adapters.h"
#include "base/memory/raw_ref.h"
#include "base/memory/weak_ptr.h"
#include "base/scoped_observation.h"
#include "base/types/id_type.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom.h"
namespace on_device_ai {
// Manages a set of `ModelDownloadProgressObserver`s and sends them download
// progress updates for their respective components.
class AIModelDownloadProgressManager {
public:
// A component can be implemented to report progress for any resource or
// operation. When added to `AIModelDownloadProgressManager` via
// `AddObserver`, it will report its progress updates to the respective
// `ModelDownloadProgressObserver`.
class Component {
public:
Component();
virtual ~Component();
// Move only.
Component(Component&&);
Component& operator=(Component&&) = default;
protected:
// The implementer calls these when downloaded bytes is changed. Downloaded
// bytes should only ever monotonically increase.
void SetDownloadedBytes(int64_t downloaded_bytes);
// The implementer calls this when total bytes has been determined. Total
// bytes should never change after its been determined.
void SetTotalBytes(int64_t total_bytes);
private:
friend AIModelDownloadProgressManager;
using EventCallback = base::RepeatingCallback<void(Component&)>;
// Only call if `determined_bytes()` is true.
int64_t downloaded_bytes() const {
CHECK(determined_bytes());
return downloaded_bytes_.value();
}
int64_t total_bytes() const {
CHECK(determined_bytes());
return total_bytes_.value();
}
// True if both total and downloaded bytes are determined and they equal
// each other.
bool is_complete() const {
return determined_bytes() &&
(total_bytes_.value() == downloaded_bytes_.value());
}
// Returns true if both total and downloaded bytes are determined.
bool determined_bytes() const {
return downloaded_bytes_.has_value() && total_bytes_.has_value();
}
// `AIModelDownloadProgressManager` sets the event callback.
void SetEventCallback(EventCallback event_callback);
void MaybeRunEventCallback();
std::optional<int64_t> downloaded_bytes_;
std::optional<int64_t> total_bytes_;
// Called anytime `SetDownloadedBytes()` or `SetTotalBytes()` is called.
EventCallback event_callback_;
};
AIModelDownloadProgressManager();
~AIModelDownloadProgressManager();
// Not copyable or movable.
AIModelDownloadProgressManager(const AIModelDownloadProgressManager&) =
delete;
AIModelDownloadProgressManager& operator=(
const AIModelDownloadProgressManager&) = delete;
// Adds a `ModelDownloadProgressObserver` to send progress updates for
// `components`.
void AddObserver(
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
observer_remote,
base::flat_set<std::unique_ptr<Component>> components);
int GetNumberOfReporters();
private:
// Observes progress updates from `components`, filters and processes them,
// and reports the result to `observer_remote`.
class Reporter {
public:
Reporter(AIModelDownloadProgressManager& manager,
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
observer_remote,
base::flat_set<std::unique_ptr<Component>> components);
~Reporter();
// Not copyable or movable.
Reporter(const Reporter&) = delete;
Reporter& operator=(const Reporter&) = delete;
void OnEvent(Component& component);
private:
void OnRemoteDisconnect();
void ProcessEvent(const Component& component);
int64_t GetDownloadedBytes();
// `manager_` owns `this`.
base::raw_ref<AIModelDownloadProgressManager> manager_;
mojo::Remote<blink::mojom::ModelDownloadProgressObserver> observer_remote_;
// The components we're reporting the progress for.
base::flat_set<std::unique_ptr<Component>> components_;
// Map of the components to their observed downloaded bytes. Also serves as
// a way to keep track of what components we've observed the total bytes of.
//
// `raw_ptr` safe since `this` owns the `Component` in `components_` and
// `components_` and all its members outlive `observed_downloaded_bytes_`.
std::map<raw_ptr<const Component>, int64_t> observed_downloaded_bytes_;
// Sum of all observed components' total_bytes.
int64_t components_total_bytes_ = 0;
// The bytes already downloaded before we determined the `total_bytes_`.
int64_t already_downloaded_bytes_ = 0;
// True if we know the total bytes of the components we'll be watching.
// Meaning we can start reporting.
bool ready_to_report_ = false;
int last_reported_progress_ = 0;
base::TimeTicks last_progress_time_;
base::WeakPtrFactory<Reporter> weak_ptr_factory_{this};
};
void RemoveReporter(Reporter* reporter);
base::flat_set<std::unique_ptr<Reporter>, base::UniquePtrComparator>
reporters_;
};
} // namespace on_device_ai
#endif // CHROME_BROWSER_AI_AI_MODEL_DOWNLOAD_PROGRESS_MANAGER_H_
|