File: cached_result_provider_unittest.cc

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (121 lines) | stat: -rw-r--r-- 4,671 bytes parent folder | download | duplicates (9)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/segmentation_platform/internal/database/cached_result_provider.h"

#include "components/prefs/pref_registry_simple.h"
#include "components/prefs/testing_pref_service.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/database/client_result_prefs.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/proto/prediction_result.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

using testing::_;
using testing::Invoke;
using testing::Return;
using testing::SaveArg;

namespace segmentation_platform {

namespace proto {
class SegmentInfo;
}  // namespace proto

namespace {

const char kClientKey[] = "test_key";

// Labels for MultiClassClassifier.
constexpr std::array<const char*, 5> kMultiClassLabels{
    "Vanilla", "Chocolate", "Strawberry", "Mango", "Peach"};

std::unique_ptr<Config> CreateTestConfig() {
  auto config = std::make_unique<Config>();
  config->segmentation_key = kClientKey;
  config->segmentation_uma_name = "test_key";
  config->segment_selection_ttl = base::Days(28);
  config->unknown_selection_ttl = base::Days(14);
  config->AddSegmentId(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
  return config;
}

proto::OutputConfig GetTestOutputConfigForMultiClassClassifier() {
  proto::SegmentationModelMetadata model_metadata;
  MetadataWriter writer(&model_metadata);

  writer.AddOutputConfigForMultiClassClassifier(
      kMultiClassLabels,
      /*top_k_outputs=*/kMultiClassLabels.size(), /*threshold=*/0.1);

  return model_metadata.output_config();
}

}  // namespace

class CachedResultProviderTest : public testing::Test {
 public:
  CachedResultProviderTest() = default;
  ~CachedResultProviderTest() override = default;

  void SetUp() override {
    result_prefs_ = std::make_unique<ClientResultPrefs>(&pref_service_);
    pref_service_.registry()->RegisterStringPref(kSegmentationClientResultPrefs,
                                                 std::string());

    configs_.push_back(CreateTestConfig());
  }

 protected:
  TestingPrefServiceSimple pref_service_;
  std::unique_ptr<ClientResultPrefs> result_prefs_;
  std::unique_ptr<CachedResultProvider> cached_result_provider_;
  std::vector<std::unique_ptr<Config>> configs_;
};

TEST_F(CachedResultProviderTest, CachedResultProviderWithEmptyPrefs) {
  cached_result_provider_ =
      std::make_unique<CachedResultProvider>(result_prefs_.get(), configs_);
  std::optional<proto::PredictionResult> retrieved_prediction_result =
      cached_result_provider_->GetPredictionResultForClient(kClientKey);
  EXPECT_FALSE(retrieved_prediction_result);
}

TEST_F(CachedResultProviderTest,
       GetPredictionResultForClient_WithNonEmptyResult) {
  std::vector<float> model_scores = {0, 0, 1, 0, 0};
  proto::PredictionResult saved_prediction_result =
      metadata_utils::CreatePredictionResult(
          model_scores, GetTestOutputConfigForMultiClassClassifier(),
          /*timestamp=*/base::Time::Now(), /*model_version=*/1);
  result_prefs_->SaveClientResultToPrefs(
      kClientKey, metadata_utils::CreateClientResultFromPredResult(
                      saved_prediction_result,
                      /*timestamp=*/base::Time::Now()));

  cached_result_provider_ =
      std::make_unique<CachedResultProvider>(result_prefs_.get(), configs_);
  std::optional<proto::PredictionResult> retrieved_prediction_result =
      cached_result_provider_->GetPredictionResultForClient(kClientKey);
  EXPECT_TRUE(retrieved_prediction_result.has_value());
  EXPECT_EQ(saved_prediction_result.result_size(),
            retrieved_prediction_result.value().result_size());
  for (int i = 0; i < saved_prediction_result.result_size(); i++) {
    EXPECT_EQ(saved_prediction_result.result(i),
              retrieved_prediction_result.value().result(i));
  }
}

TEST_F(CachedResultProviderTest, GetPredictionResultForClient_WithNoResult) {
  cached_result_provider_ =
      std::make_unique<CachedResultProvider>(result_prefs_.get(), configs_);
  std::optional<proto::PredictionResult> retrieved_prediction_result =
      cached_result_provider_->GetPredictionResultForClient(kClientKey);
  EXPECT_FALSE(retrieved_prediction_result.has_value());
}

}  // namespace segmentation_platform