1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
|
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <utility>
#include <vector>
#include "base/files/file.h"
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/sequence_checker.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "build/build_config.h"
#include "services/audio/ml_model_manager.h"
#include "third_party/tflite/src/tensorflow/lite/model_builder.h"
namespace audio {
// Holds the TFLite model and the buffer that backs it.
// The buffer must outlive the model.
struct ModelWithBuffer {
explicit ModelWithBuffer(size_t buffer_size) : buffer(buffer_size) {}
~ModelWithBuffer() = default;
std::vector<uint8_t> buffer;
std::unique_ptr<tflite::FlatBufferModel> model;
int num_active_clients = 0;
};
namespace {
// Reads the model contents from the given base::File.
// This function is intended to run on a background thread.
// Returns a struct containing the model and its backing buffer.
std::unique_ptr<ModelWithBuffer> ReadModelContents(base::File model_file) {
if (!model_file.IsValid()) {
LOG(ERROR) << "Invalid model file.";
return nullptr;
}
int64_t length = model_file.GetLength();
if (length <= 0) {
LOG(ERROR) << "Invalid model file length.";
return nullptr;
}
auto model_with_buffer = std::make_unique<ModelWithBuffer>(length);
if (!model_file.ReadAndCheck(0, model_with_buffer->buffer)) {
LOG(ERROR) << "Failed to read model file contents.";
return nullptr;
}
model_with_buffer->model = tflite::FlatBufferModel::BuildFromBuffer(
reinterpret_cast<const char*>(model_with_buffer->buffer.data()),
model_with_buffer->buffer.size());
if (!model_with_buffer->model) {
LOG(ERROR) << "Failed to build FlatBufferModel from buffer.";
return nullptr;
}
return model_with_buffer;
}
class MlModelHandleImpl : public MlModelHandle {
public:
MlModelHandleImpl(tflite::FlatBufferModel* model,
base::OnceClosure on_destruction_closure)
: model_(model),
on_destruction_closure_(std::move(on_destruction_closure)) {
CHECK(model);
}
~MlModelHandleImpl() override {
// Explicitly reset the model pointer as the pointed-to memory may be
// affected by the destruction closure.
model_ = nullptr;
std::move(on_destruction_closure_).Run();
}
const tflite::FlatBufferModel* Get() override { return model_; }
private:
raw_ptr<tflite::FlatBufferModel> model_;
base::OnceClosure on_destruction_closure_;
};
} // namespace
MlModelManagerImpl::MlModelManagerImpl() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
MlModelManagerImpl::~MlModelManagerImpl() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CHECK(!used_serving_model_ && retired_models_.size() == 0)
<< "MlModelManagerImpl has existing clients at destruction time";
}
void MlModelManagerImpl::BindReceiver(
mojo::PendingReceiver<mojom::MlModelManager> receiver) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CHECK(!receiver_.has_value());
receiver_.emplace(this, std::move(receiver));
}
void MlModelManagerImpl::OnResidualEchoEstimationModelRead(
std::unique_ptr<ModelWithBuffer> model_with_buffer) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!model_with_buffer) {
// Model load failed, ignore.
return;
}
if (used_serving_model_) {
retired_models_.emplace(used_serving_model_.get(),
std::move(used_serving_model_));
}
unused_serving_model_ = std::move(model_with_buffer);
}
void MlModelManagerImpl::StopServingResidualEchoEstimationModel() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
// Stop any ongoing model loading: This stop signal makes it obsolete.
CancelModelLoadingTasks();
if (used_serving_model_) {
retired_models_.emplace(used_serving_model_.get(),
std::move(used_serving_model_));
}
unused_serving_model_.reset();
used_serving_model_.reset();
}
void MlModelManagerImpl::SetResidualEchoEstimationModel(
base::File tflite_file) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
// Stop loading any older models:
// - They are soon replaced with this new file, and
// - we don't want races due to different file operation durations.
CancelModelLoadingTasks();
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE, {base::MayBlock(), base::TaskPriority::BEST_EFFORT},
base::BindOnce(&ReadModelContents, std::move(tflite_file)),
base::BindOnce(&MlModelManagerImpl::OnResidualEchoEstimationModelRead,
weak_factory_.GetWeakPtr()));
}
void MlModelManagerImpl::OnModelHandleDestruction(ModelWithBuffer* model) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CHECK(model);
// Find the model, and update the client count.
if (model == used_serving_model_.get()) {
// The model is actively serving.
CHECK_GT(used_serving_model_->num_active_clients, 0);
--used_serving_model_->num_active_clients;
if (used_serving_model_->num_active_clients == 0) {
unused_serving_model_ = std::move(used_serving_model_);
}
return;
}
// If we get here, the model is one of the retired models.
auto iter = retired_models_.find(model);
CHECK(iter != retired_models_.end());
ModelWithBuffer& retired_model = *(*iter).second;
CHECK_GT(retired_model.num_active_clients, 0);
--(retired_model.num_active_clients);
if (retired_model.num_active_clients == 0) {
// All clients are gone, the model can be deleted.
retired_models_.erase(iter);
}
}
std::unique_ptr<MlModelHandle>
MlModelManagerImpl::GetResidualEchoEstimationModel() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!unused_serving_model_ && !used_serving_model_) {
return nullptr;
}
if (unused_serving_model_) {
used_serving_model_ = std::move(unused_serving_model_);
}
++(used_serving_model_->num_active_clients);
return std::make_unique<MlModelHandleImpl>(
used_serving_model_->model.get(),
base::BindOnce(&MlModelManagerImpl::OnModelHandleDestruction,
// Safe because the MlModelManager API requires clients to
// destroy their model handles within the manager lifetime.
base::Unretained(this), used_serving_model_.get()));
}
void MlModelManagerImpl::CancelModelLoadingTasks() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
weak_factory_.InvalidateWeakPtrs();
}
bool MlModelManagerImpl::HasPendingTasksForTesting() const {
return weak_factory_.HasWeakPtrs();
}
} // namespace audio
|