1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/execution/mock_model_provider.h"
#include <utility>
#include "base/containers/contains.h"
#include "base/functional/callback.h"
#include "base/logging.h"
#include "components/segmentation_platform/public/model_provider.h"
namespace segmentation_platform {
namespace {
using ::testing::_;
using ::testing::Invoke;
using ::testing::Return;
// Stores the client callbacks to |data|.
void StoreClientCallback(
proto::SegmentId segment_id,
TestModelProviderFactory::Data* data,
const ModelProvider::ModelUpdatedCallback& model_updated_callback) {
data->model_providers_callbacks.emplace(
std::make_pair(segment_id, model_updated_callback));
}
} // namespace
MockModelProvider::MockModelProvider(
proto::SegmentId segment_id,
base::RepeatingCallback<void(const ModelProvider::ModelUpdatedCallback&)>
get_client_callback)
: ModelProvider(segment_id), get_client_callback_(get_client_callback) {
ON_CALL(*this, InitAndFetchModel(_))
.WillByDefault(
Invoke([&](const ModelUpdatedCallback& model_updated_callback) {
get_client_callback_.Run(model_updated_callback);
}));
}
MockModelProvider::~MockModelProvider() = default;
MockDefaultModelProvider::MockDefaultModelProvider(
proto::SegmentId segment_id,
const proto::SegmentationModelMetadata& metadata)
: DefaultModelProvider(segment_id), metadata_(metadata) {
ON_CALL(*this, GetModelConfig()).WillByDefault([this]() {
return std::make_unique<ModelConfig>(this->metadata_, 1);
});
}
MockDefaultModelProvider::~MockDefaultModelProvider() = default;
TestModelProviderFactory::Data::Data() = default;
TestModelProviderFactory::Data::~Data() = default;
std::unique_ptr<ModelProvider> TestModelProviderFactory::CreateProvider(
proto::SegmentId segment_id) {
auto provider = std::make_unique<MockModelProvider>(
segment_id, base::BindRepeating(&StoreClientCallback, segment_id, data_));
data_->model_providers.emplace(std::make_pair(segment_id, provider.get()));
return provider;
}
std::unique_ptr<DefaultModelProvider>
TestModelProviderFactory::CreateDefaultProvider(proto::SegmentId segment_id) {
if (!base::Contains(data_->segments_supporting_default_model, segment_id))
return nullptr;
// The DefaultModelProvider is always expected to have valid segment info.
// Some tests set up default providers without segment info.
// TODO(ssid): Fix the tests to remove this check.
if (data_->default_provider_metadata.count(segment_id) == 0) {
LOG(WARNING)
<< "The test should set a valid segment info in "
"`TestModelProviderFactory::Data.default_provider_metadata` for "
<< proto::SegmentId_Name(segment_id);
proto::SegmentationModelMetadata metadata;
metadata.set_time_unit(proto::TimeUnit::DAY);
data_->default_provider_metadata[segment_id] = std::move(metadata);
}
auto provider = std::make_unique<MockDefaultModelProvider>(
segment_id, data_->default_provider_metadata[segment_id]);
data_->default_model_providers.emplace(
std::make_pair(segment_id, provider.get()));
return provider;
}
} // namespace segmentation_platform
|