1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_TESTING_MOCK_SEGMENTATION_PLATFORM_SERVICE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_TESTING_MOCK_SEGMENTATION_PLATFORM_SERVICE_H_
#include <utility>
#include "base/strings/strcat.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/result.h"
#include "components/segmentation_platform/public/segment_selection_result.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "components/segmentation_platform/public/trigger.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace segmentation_platform {
class MockSegmentationPlatformService : public SegmentationPlatformService {
public:
MockSegmentationPlatformService() = default;
~MockSegmentationPlatformService() override = default;
MOCK_METHOD(void,
GetSelectedSegment,
(const std::string&, SegmentSelectionCallback));
MOCK_METHOD(SegmentSelectionResult,
GetCachedSegmentResult,
(const std::string&));
MOCK_METHOD(void,
GetClassificationResult,
(const std::string&,
const PredictionOptions&,
scoped_refptr<InputContext>,
ClassificationResultCallback));
MOCK_METHOD(void,
GetAnnotatedNumericResult,
(const std::string&,
const PredictionOptions&,
scoped_refptr<InputContext>,
AnnotatedNumericResultCallback));
MOCK_METHOD(void,
GetInputKeysForModel,
(const std::string& segmentation_key,
InputContextKeysCallback callback));
MOCK_METHOD(void,
CollectTrainingData,
(proto::SegmentId,
TrainingRequestId,
const TrainingLabels&,
SuccessCallback));
MOCK_METHOD(void,
CollectTrainingData,
(proto::SegmentId,
TrainingRequestId,
ukm::SourceId,
const TrainingLabels&,
SuccessCallback));
MOCK_METHOD(void, EnableMetrics, (bool));
MOCK_METHOD(void, GetServiceStatus, ());
MOCK_METHOD(bool, IsPlatformInitialized, ());
MOCK_METHOD(DatabaseClient*, GetDatabaseClient, ());
};
MATCHER_P(IsInputContextWithArgs,
input_context_args,
::testing::PrintToString(input_context_args)) {
return testing::ExplainMatchResult(
testing::Field(
&InputContext::metadata_args,
testing::Eq(base::flat_map<std::string, processing::ProcessedValue>(
input_context_args))),
arg, result_listener);
}
MATCHER(TrainingLabelEmpty, "no training labels present") {
return testing::ExplainMatchResult(
testing::Field(&TrainingLabels::output_metric, testing::Eq(std::nullopt)),
arg, result_listener);
}
MATCHER_P2(HasTrainingLabel,
histogram_name,
histogram_value,
base::StrCat({"histogram with name", histogram_name, " and value ",
testing::PrintToString(histogram_value)})) {
return testing::ExplainMatchResult(
testing::Field(
&TrainingLabels::output_metric,
testing::Eq(std::pair<std::string, base::HistogramBase::Sample32>(
histogram_name, histogram_value))),
arg, result_listener);
}
} // namespace segmentation_platform
#endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_TESTING_MOCK_SEGMENTATION_PLATFORM_SERVICE_H_
|