1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/embedder/model_provider_factory_impl.h"
#include "base/test/task_environment.h"
#include "base/test/test_simple_task_runner.h"
#include "components/optimization_guide/core/delivery/test_optimization_guide_model_provider.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace segmentation_platform {
class ModelProviderFactoryImplTest : public testing::Test {
public:
ModelProviderFactoryImplTest() = default;
~ModelProviderFactoryImplTest() override = default;
void SetUp() override {
task_runner_ = base::MakeRefCounted<base::TestSimpleTaskRunner>();
model_provider_ = std::make_unique<
optimization_guide::TestOptimizationGuideModelProvider>();
provider_factory_ = std::make_unique<ModelProviderFactoryImpl>(
model_provider_.get(), configs_, task_runner_);
}
void TearDown() override {
task_runner_->RunPendingTasks();
provider_factory_.reset();
model_provider_.reset();
}
protected:
base::test::TaskEnvironment task_environment_;
scoped_refptr<base::TestSimpleTaskRunner> task_runner_;
std::unique_ptr<optimization_guide::TestOptimizationGuideModelProvider>
model_provider_;
// TODO(ssid): Fxi test to take rael configs
std::vector<std::unique_ptr<Config>> configs_;
std::unique_ptr<ModelProviderFactoryImpl> provider_factory_;
};
TEST_F(ModelProviderFactoryImplTest, ProviderCreated) {
EXPECT_TRUE(provider_factory_->CreateProvider(
proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE));
EXPECT_TRUE(provider_factory_->CreateProvider(
proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
}
class DummyModelProviderFactoryImplTest : public ModelProviderFactoryImplTest {
public:
void SetUp() override {
task_runner_ = base::MakeRefCounted<base::TestSimpleTaskRunner>();
std::vector<std::unique_ptr<Config>> configs;
provider_factory_ = std::make_unique<ModelProviderFactoryImpl>(
nullptr, configs, task_runner_);
}
};
TEST_F(DummyModelProviderFactoryImplTest, ProviderCreated) {
EXPECT_TRUE(provider_factory_->CreateProvider(
proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE));
auto provider = provider_factory_->CreateProvider(
proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
ASSERT_TRUE(provider);
EXPECT_FALSE(provider->ModelAvailable());
// This callback should never be invoked. Send a null callback and chrome
// should not crash by invoking it.
provider->InitAndFetchModel(ModelProvider::ModelUpdatedCallback());
base::RunLoop wait;
provider->ExecuteModelWithInput(
{1, 2.5}, base::BindOnce(
[](base::OnceClosure quit,
const std::optional<ModelProvider::Response>& output) {
EXPECT_FALSE(output);
std::move(quit).Run();
},
wait.QuitClosure()));
wait.Run();
}
} // namespace segmentation_platform
|