1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
|
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/embedder/default_model/android_home_module_ranker.h"
#include <memory>
#include "base/metrics/field_trial_params.h"
#include "base/task/sequenced_task_runner.h"
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/proto/aggregation.pb.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
namespace segmentation_platform {
namespace {
using proto::SegmentId;
// Default parameters for AndroidHomeModuleRanker model.
constexpr SegmentId kSegmentId =
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_ANDROID_HOME_MODULE_RANKER;
constexpr int64_t kModelVersion = 9;
// Store 28 buckets of input data (28 days).
constexpr int64_t kSignalStorageLength = 28;
// Wait until we have 0 days of data.
constexpr int64_t kMinSignalCollectionLength = 0;
// Refresh the result every time.
constexpr int64_t kResultTTLMinutes = 5;
constexpr std::array<const char*, 3> kAndroidHomeModuleLabels = {
kPriceChange, kSingleTab, kSafetyHub};
constexpr std::array<const char*, 3> kAndroidHomeModuleInputContextKeys = {
kPriceChangeFreshness, kSingleTabFreshness, kSafetyHubFreshness};
// InputFeatures.
// Enum values for the MagicStack.Clank.NewTabPage|StartSurface.Module.Click and
// MagicStack.Clank.NewTabPage|StartSurface.Module.TopImpressionV2 histograms.
constexpr std::array<int32_t, 1> kEnumValueForSingleTab{/*SingleTab=*/0};
constexpr std::array<int32_t, 1> kEnumValueForPriceChange{/*PriceChange=*/1};
constexpr std::array<int32_t, 1> kEnumValueForSafetyHub{/*SafetyHub=*/3};
// Set UMA metrics to use as input.
constexpr std::array<MetadataWriter::UMAFeature, 6> kUMAFeatures = {
// Single Tab Module
// 0 : click
MetadataWriter::UMAFeature::FromEnumHistogram(
"MagicStack.Clank.NewTabPage.Module.Click",
28,
kEnumValueForSingleTab.data(),
kEnumValueForSingleTab.size()),
// 1 : impression
MetadataWriter::UMAFeature::FromEnumHistogram(
"MagicStack.Clank.NewTabPage.Module.TopImpressionV2",
28,
kEnumValueForSingleTab.data(),
kEnumValueForSingleTab.size()),
// Price Change Module
// 2 : click
MetadataWriter::UMAFeature::FromEnumHistogram(
"MagicStack.Clank.NewTabPage.Module.Click",
28,
kEnumValueForPriceChange.data(),
kEnumValueForPriceChange.size()),
// 3 : impression
MetadataWriter::UMAFeature::FromEnumHistogram(
"MagicStack.Clank.NewTabPage.Module.TopImpressionV2",
28,
kEnumValueForPriceChange.data(),
kEnumValueForPriceChange.size()),
// Safety Hub Module
// 4 : click
MetadataWriter::UMAFeature::FromEnumHistogram(
"MagicStack.Clank.NewTabPage.Module.Click",
28,
kEnumValueForSafetyHub.data(),
kEnumValueForSafetyHub.size()),
// 5 : impression
MetadataWriter::UMAFeature::FromEnumHistogram(
"MagicStack.Clank.NewTabPage.Module.TopImpressionV2",
28,
kEnumValueForSafetyHub.data(),
kEnumValueForSafetyHub.size()),
};
float TransformFreshness(float freshness_score, float freshness_threshold) {
float new_freshness_score = 0.0;
if (freshness_score >= 0.0 and freshness_score <= freshness_threshold) {
new_freshness_score = 1.0;
}
return new_freshness_score;
}
} // namespace
// static
std::unique_ptr<Config> AndroidHomeModuleRanker::GetConfig() {
if (!base::FeatureList::IsEnabled(
features::kSegmentationPlatformAndroidHomeModuleRanker)) {
return nullptr;
}
auto config = std::make_unique<Config>();
config->segmentation_key = kAndroidHomeModuleRankerKey;
config->segmentation_uma_name = kAndroidHomeModuleRankerUmaName;
config->AddSegmentId(kSegmentId, std::make_unique<AndroidHomeModuleRanker>());
config->auto_execute_and_cache = !base::FeatureList::IsEnabled(
features::kSegmentationPlatformAndroidHomeModuleRankerV2);
return config;
}
AndroidHomeModuleRanker::AndroidHomeModuleRanker()
: DefaultModelProvider(kSegmentId),
is_android_home_module_ranker_v2_enabled(base::FeatureList::IsEnabled(
features::kSegmentationPlatformAndroidHomeModuleRankerV2)) {}
std::unique_ptr<DefaultModelProvider::ModelConfig>
AndroidHomeModuleRanker::GetModelConfig() {
proto::SegmentationModelMetadata metadata;
MetadataWriter writer(&metadata);
writer.SetDefaultSegmentationMetadataConfig(kMinSignalCollectionLength,
kSignalStorageLength);
metadata.set_upload_tensors(true);
// Set output config.
writer.AddOutputConfigForMultiClassClassifier(kAndroidHomeModuleLabels,
kAndroidHomeModuleLabels.size(),
/*threshold=*/-99999.0);
writer.AddPredictedResultTTLInOutputConfig(
/*top_label_to_ttl_list=*/{},
/*default_ttl=*/kResultTTLMinutes, proto::TimeUnit::MINUTE);
writer.SetIgnorePreviousModelTTLInOutputConfig();
// Set features.
writer.AddUmaFeatures(kUMAFeatures.data(), kUMAFeatures.size());
if (is_android_home_module_ranker_v2_enabled) {
// Add freshness for all modules as custom input.
writer.AddFromInputContext("single_tab_input", kSingleTabFreshness);
writer.AddFromInputContext("price_change_input", kPriceChangeFreshness);
writer.AddFromInputContext("safety_hub_input", kSafetyHubFreshness);
}
return std::make_unique<ModelConfig>(std::move(metadata), kModelVersion);
}
void AndroidHomeModuleRanker::ExecuteModelWithInput(
const ModelProvider::Request& inputs,
ExecutionCallback callback) {
// Invalid inputs.
size_t expected_input_size =
is_android_home_module_ranker_v2_enabled
? kUMAFeatures.size() + kAndroidHomeModuleInputContextKeys.size()
: kUMAFeatures.size();
if (inputs.size() != expected_input_size) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), std::nullopt));
return;
}
// Add logic here.
// Single Tab score calculation.
float single_tab_weights[3] = {1.5, -0.5, 1.0};
float single_tab_engagement = inputs[0];
float single_tab_impression = inputs[1];
float single_tab_freshness = is_android_home_module_ranker_v2_enabled
? TransformFreshness(inputs[6], 1.0)
: 0.0;
float single_tab_score = single_tab_weights[0] * single_tab_engagement +
single_tab_weights[1] * single_tab_impression +
single_tab_weights[2] * single_tab_freshness;
// Price Change score calculation.
float price_change_weights[3] = {2.0, -1.0, 2.0};
float price_change_engagement = inputs[2];
float price_change_impression = inputs[3];
float price_change_freshness = is_android_home_module_ranker_v2_enabled
? TransformFreshness(inputs[7], 1.0)
: 0.0;
float price_change_score = price_change_weights[0] * price_change_engagement +
price_change_weights[1] * price_change_impression +
price_change_weights[2] * price_change_freshness;
// Safety Hub score calculation.
float safety_hub_weights[3] = {2.5, -2, 2.5};
float safety_hub_engagement = inputs[4];
float safety_hub_impression = inputs[5];
float safety_hub_freshness = is_android_home_module_ranker_v2_enabled
? TransformFreshness(inputs[8], 1.0)
: 0.0;
float safety_hub_score = safety_hub_weights[0] * safety_hub_engagement +
safety_hub_weights[1] * safety_hub_impression +
safety_hub_weights[2] * safety_hub_freshness;
ModelProvider::Response response(kAndroidHomeModuleLabels.size(), 0);
// Default ranking
response[0] = price_change_score; // Price Change
response[1] = single_tab_score; // Single tab
response[2] = safety_hub_score; // Safety Hub
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), response));
}
} // namespace segmentation_platform
|