File: default_model_test_base.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (93 lines) | stat: -rw-r--r-- 3,863 bytes parent folder | download | duplicates (9)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_EMBEDDER_DEFAULT_MODEL_DEFAULT_MODEL_TEST_BASE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_EMBEDDER_DEFAULT_MODEL_DEFAULT_MODEL_TEST_BASE_H_

#include "base/run_loop.h"
#include "base/test/task_environment.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace segmentation_platform {

// This is a base class for default models unittest where all the setup for
// writing unit tests is present.
// How to use this class??
// As a base class in default model unit test class instead of whole setup
// work and write the tests only in the unit test class.
class DefaultModelTestBase : public testing::Test {
 public:
  explicit DefaultModelTestBase(
      std::unique_ptr<DefaultModelProvider> model_provider);
  ~DefaultModelTestBase() override;

  void SetUp() override;

  void TearDown() override;

 protected:
  void ExpectInitAndFetchModel();

  // This function is used to execute the model with set of inputs.
  // It should do one of the following cases:
  // 1. If `expected_error` is true, it expects error to be the thrown and
  // hence the model won't have any result.
  // 2. Else `expected_result` is checked against the actual result given by the
  // model after executing.
  void ExpectExecutionWithInput(const ModelProvider::Request& inputs,
                                bool expected_error,
                                ModelProvider::Response expected_result);

  // Executes the model with inputs and return the output.
  std::optional<ModelProvider::Response> ExecuteWithInput(
      const ModelProvider::Request& inputs);

  // Executes the model with inputs, applies classifier and checks against
  // the expected ordered labels.
  void ExpectClassifierResults(
      const ModelProvider::Request& input,
      const std::vector<std::string>& expected_ordered_labels);

  // `sub_segment_key` is combination of `segmentation_key` +
  // `kSubSegmentDiscreteMappingSuffix`. Use `GetSubsegmentKey()`  from
  // constants.h. `sub_segment_name` is the name of the segment expected to be
  // returned as result from model execution. `T` indicates the segment class
  // for which we need to evaluate subsegment based on inputs.
  template <typename T>
  void ExecuteWithInputAndCheckSubsegmentName(
      const ModelProvider::Request& inputs,
      std::string sub_segment_key,
      std::string sub_segment_name) {
    std::optional<ModelProvider::Response> result =
        DefaultModelTestBase::ExecuteWithInput(inputs);
    ASSERT_TRUE(result);
    EXPECT_EQ(sub_segment_name,
              T::GetSubsegmentName(metadata_utils::ConvertToDiscreteScore(
                  sub_segment_key, result.value()[0], *fetched_metadata_)));
  }

  base::test::TaskEnvironment task_environment_;
  std::unique_ptr<DefaultModelProvider> model_;
  std::optional<proto::SegmentationModelMetadata> fetched_metadata_;

 private:
  void OnFinishedExpectExecutionWithInput(
      base::RepeatingClosure closure,
      bool expected_error,
      ModelProvider::Response expected_result,
      const std::optional<ModelProvider::Response>& result);

  void OnFinishedExecuteWithInput(
      base::RepeatingClosure closure,
      std::optional<ModelProvider::Response>* output,
      const std::optional<ModelProvider::Response>& result);
};

}  // namespace segmentation_platform

#endif  // COMPONENTS_SEGMENTATION_PLATFORM_EMBEDDER_DEFAULT_MODEL_DEFAULT_MODEL_TEST_BASE_H_