File: test_segment_info_database.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (113 lines) | stat: -rw-r--r-- 4,836 bytes parent folder | download | duplicates (9)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_

#include <utility>
#include <vector>

#include "base/containers/flat_set.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/database/ukm_types.h"
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/public/proto/aggregation.pb.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"

namespace segmentation_platform::test {

// A fake database with sample entries that can be used for tests.
// TODO(b/285912101) : Remove this class and migrate its callers to used mock
// version of SegmentInfoDatabase.
class TestSegmentInfoDatabase : public SegmentInfoDatabase {
 public:
  TestSegmentInfoDatabase();
  ~TestSegmentInfoDatabase() override;

  // SegmentInfoDatabase overrides.
  void Initialize(SuccessCallback callback) override;
  void GetSegmentInfoForSegments(const base::flat_set<SegmentId>& segment_ids,
                                 MultipleSegmentInfoCallback callback) override;
  std::unique_ptr<SegmentInfoDatabase::SegmentInfoList>
  GetSegmentInfoForBothModels(
      const base::flat_set<SegmentId>& segment_ids) override;
  const SegmentInfo* GetCachedSegmentInfo(SegmentId segment_id,
                                          ModelSource model_source) override;
  void UpdateSegment(SegmentId segment_id,
                     ModelSource model_score,
                     std::optional<proto::SegmentInfo> segment_info,
                     SuccessCallback callback) override;
  void SaveSegmentResult(SegmentId segment_id,
                         ModelSource model_source,
                         std::optional<proto::PredictionResult> result,
                         SuccessCallback callback) override;
  void SaveTrainingData(SegmentId segment_id,
                        ModelSource model_source,
                        const proto::TrainingData& data,
                        SuccessCallback callback) override;
  void GetTrainingData(SegmentId segment_id,
                       ModelSource model_source,
                       TrainingRequestId request_id,
                       bool delete_from_db,
                       TrainingDataCallback callback) override;

  // Test helper methods.
  void AddUserActionFeature(
      SegmentId segment_id,
      const std::string& user_action,
      uint64_t bucket_count,
      uint64_t tensor_length,
      proto::Aggregation aggregation,
      ModelSource model_source = ModelSource::SERVER_MODEL_SOURCE);
  void AddHistogramValueFeature(
      SegmentId segment_id,
      const std::string& histogram,
      uint64_t bucket_count,
      uint64_t tensor_length,
      proto::Aggregation aggregation,
      ModelSource model_source = ModelSource::SERVER_MODEL_SOURCE);
  void AddHistogramEnumFeature(
      SegmentId segment_id,
      const std::string& histogram_name,
      uint64_t bucket_count,
      uint64_t tensor_length,
      proto::Aggregation aggregation,
      const std::vector<int32_t>& accepted_enum_ids,
      ModelSource model_source = ModelSource::SERVER_MODEL_SOURCE);
  void AddSqlFeature(
      SegmentId segment_id,
      const MetadataWriter::SqlFeature& feature,
      ModelSource model_source = ModelSource::SERVER_MODEL_SOURCE);
  void AddPredictionResult(
      SegmentId segment_id,
      float score,
      base::Time timestamp,
      ModelSource model_source = ModelSource::SERVER_MODEL_SOURCE);
  void AddDiscreteMapping(
      SegmentId segment_id,
      const float mappings[][2],
      int num_pairs,
      const std::string& discrete_mapping_key,
      ModelSource model_source = ModelSource::SERVER_MODEL_SOURCE);
  void SetBucketDuration(
      SegmentId segment_id,
      uint64_t bucket_duration,
      proto::TimeUnit time_unit,
      ModelSource model_source = ModelSource::SERVER_MODEL_SOURCE);

  // Finds a segment with given |segment_id| and |model_source|. Creates one if
  // it doesn't exists. By default the |model_source| corresponds to server
  // model.
  proto::SegmentInfo* FindOrCreateSegment(
      SegmentId segment_id,
      ModelSource model_source = ModelSource::SERVER_MODEL_SOURCE);

 private:
  std::vector<std::pair<SegmentId, proto::SegmentInfo>> segment_infos_;
};

}  // namespace segmentation_platform::test

#endif  // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_