1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/passage_embedder_model_observer.h"
#include <memory>
#include "base/memory/raw_ptr.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace passage_embeddings {
class FakePassageEmbeddingsServiceController
: public passage_embeddings::PassageEmbeddingsServiceController {
public:
explicit FakePassageEmbeddingsServiceController(
base::test::TestFuture<bool>* model_info_future)
: model_info_received_future_(model_info_future) {}
~FakePassageEmbeddingsServiceController() override = default;
// passage_embeddings::PassageEmbeddingsServiceController:
bool MaybeUpdateModelInfo(
base::optional_ref<const optimization_guide::ModelInfo> model_info)
override {
const bool received_model_info = model_info.has_value();
model_info_received_future_->SetValue(received_model_info);
return received_model_info;
}
void MaybeLaunchService() override {}
void ResetServiceRemote() override {}
protected:
raw_ptr<base::test::TestFuture<bool>> model_info_received_future_;
};
class TestOptimizationGuideModelProvider
: public optimization_guide::TestOptimizationGuideModelProvider {
public:
explicit TestOptimizationGuideModelProvider(
base::test::TestFuture<bool>* target_observed_future)
: target_observed_future_(target_observed_future) {}
// optimization_guide::OptimizationGuideModelProvider:
void AddObserverForOptimizationTargetModel(
optimization_guide::proto::OptimizationTarget optimization_target,
const std::optional<optimization_guide::proto::Any>& model_metadata,
optimization_guide::OptimizationTargetModelObserver* observer) override {
target_observed_future_->SetValue(
optimization_target ==
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER);
observer_list_.AddObserver(observer);
NotifyObservers();
}
void RemoveObserverForOptimizationTargetModel(
optimization_guide::proto::OptimizationTarget optimization_target,
optimization_guide::OptimizationTargetModelObserver* observer) override {
observer_list_.RemoveObserver(observer);
}
// Set the model info to be sent to the observer.
void SetModelInfo(std::unique_ptr<optimization_guide::ModelInfo> model_info) {
model_info_ = std::move(model_info);
NotifyObservers();
}
private:
void NotifyObservers() {
if (model_info_) {
observer_list_.Notify(
&optimization_guide::OptimizationTargetModelObserver::OnModelUpdated,
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
*model_info_);
} else {
observer_list_.Notify(
&optimization_guide::OptimizationTargetModelObserver::OnModelUpdated,
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
std::nullopt);
}
}
raw_ptr<base::test::TestFuture<bool>> target_observed_future_;
base::ObserverList<optimization_guide::OptimizationTargetModelObserver>
observer_list_;
std::unique_ptr<optimization_guide::ModelInfo> model_info_;
};
class PassageEmbedderModelObserverTest : public testing::Test {
protected:
base::test::TestFuture<bool> target_observed_future_;
base::test::TestFuture<bool> model_info_received_future_;
};
TEST_F(PassageEmbedderModelObserverTest, ObservesTargetAndNotifiesObserver) {
auto model_provider = std::make_unique<TestOptimizationGuideModelProvider>(
&target_observed_future_);
EXPECT_FALSE(target_observed_future_.IsReady());
auto service_controller =
std::make_unique<FakePassageEmbeddingsServiceController>(
&model_info_received_future_);
EXPECT_FALSE(model_info_received_future_.IsReady());
auto passage_embedder_model_observer =
std::make_unique<PassageEmbedderModelObserver>(
model_provider.get(), service_controller.get(), false);
EXPECT_TRUE(target_observed_future_.IsReady());
EXPECT_TRUE(target_observed_future_.Take());
EXPECT_TRUE(model_info_received_future_.IsReady());
EXPECT_FALSE(model_info_received_future_.Take());
model_provider->SetModelInfo(GetBuilderWithValidModelInfo().Build());
EXPECT_TRUE(model_info_received_future_.IsReady());
EXPECT_TRUE(model_info_received_future_.Take());
}
} // namespace passage_embeddings
|