File: passage_embedder_model_observer_unittest.cc

package info (click to toggle)
chromium 138.0.7204.157-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,071,864 kB
  • sloc: cpp: 34,936,859; ansic: 7,176,967; javascript: 4,110,704; python: 1,419,953; asm: 946,768; xml: 739,967; pascal: 187,324; sh: 89,623; perl: 88,663; objc: 79,944; sql: 50,304; cs: 41,786; fortran: 24,137; makefile: 21,806; php: 13,980; tcl: 13,166; yacc: 8,925; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (125 lines) | stat: -rw-r--r-- 4,712 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/passage_embeddings/passage_embedder_model_observer.h"

#include <memory>

#include "base/memory/raw_ptr.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace passage_embeddings {

class FakePassageEmbeddingsServiceController
    : public passage_embeddings::PassageEmbeddingsServiceController {
 public:
  explicit FakePassageEmbeddingsServiceController(
      base::test::TestFuture<bool>* model_info_future)
      : model_info_received_future_(model_info_future) {}
  ~FakePassageEmbeddingsServiceController() override = default;

  // passage_embeddings::PassageEmbeddingsServiceController:
  bool MaybeUpdateModelInfo(
      base::optional_ref<const optimization_guide::ModelInfo> model_info)
      override {
    const bool received_model_info = model_info.has_value();
    model_info_received_future_->SetValue(received_model_info);
    return received_model_info;
  }
  void MaybeLaunchService() override {}
  void ResetServiceRemote() override {}

 protected:
  raw_ptr<base::test::TestFuture<bool>> model_info_received_future_;
};

class TestOptimizationGuideModelProvider
    : public optimization_guide::TestOptimizationGuideModelProvider {
 public:
  explicit TestOptimizationGuideModelProvider(
      base::test::TestFuture<bool>* target_observed_future)
      : target_observed_future_(target_observed_future) {}

  // optimization_guide::OptimizationGuideModelProvider:
  void AddObserverForOptimizationTargetModel(
      optimization_guide::proto::OptimizationTarget optimization_target,
      const std::optional<optimization_guide::proto::Any>& model_metadata,
      optimization_guide::OptimizationTargetModelObserver* observer) override {
    target_observed_future_->SetValue(
        optimization_target ==
        optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER);
    observer_list_.AddObserver(observer);
    NotifyObservers();
  }
  void RemoveObserverForOptimizationTargetModel(
      optimization_guide::proto::OptimizationTarget optimization_target,
      optimization_guide::OptimizationTargetModelObserver* observer) override {
    observer_list_.RemoveObserver(observer);
  }

  // Set the model info to be sent to the observer.
  void SetModelInfo(std::unique_ptr<optimization_guide::ModelInfo> model_info) {
    model_info_ = std::move(model_info);
    NotifyObservers();
  }

 private:
  void NotifyObservers() {
    if (model_info_) {
      observer_list_.Notify(
          &optimization_guide::OptimizationTargetModelObserver::OnModelUpdated,
          optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
          *model_info_);
    } else {
      observer_list_.Notify(
          &optimization_guide::OptimizationTargetModelObserver::OnModelUpdated,
          optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
          std::nullopt);
    }
  }

  raw_ptr<base::test::TestFuture<bool>> target_observed_future_;
  base::ObserverList<optimization_guide::OptimizationTargetModelObserver>
      observer_list_;
  std::unique_ptr<optimization_guide::ModelInfo> model_info_;
};

class PassageEmbedderModelObserverTest : public testing::Test {
 protected:
  base::test::TestFuture<bool> target_observed_future_;
  base::test::TestFuture<bool> model_info_received_future_;
};

TEST_F(PassageEmbedderModelObserverTest, ObservesTargetAndNotifiesObserver) {
  auto model_provider = std::make_unique<TestOptimizationGuideModelProvider>(
      &target_observed_future_);

  EXPECT_FALSE(target_observed_future_.IsReady());

  auto service_controller =
      std::make_unique<FakePassageEmbeddingsServiceController>(
          &model_info_received_future_);

  EXPECT_FALSE(model_info_received_future_.IsReady());

  auto passage_embedder_model_observer =
      std::make_unique<PassageEmbedderModelObserver>(
          model_provider.get(), service_controller.get(), false);

  EXPECT_TRUE(target_observed_future_.IsReady());
  EXPECT_TRUE(target_observed_future_.Take());

  EXPECT_TRUE(model_info_received_future_.IsReady());
  EXPECT_FALSE(model_info_received_future_.Take());

  model_provider->SetModelInfo(GetBuilderWithValidModelInfo().Build());
  EXPECT_TRUE(model_info_received_future_.IsReady());
  EXPECT_TRUE(model_info_received_future_.Take());
}

}  // namespace passage_embeddings