File: segmentation_platform_service.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (162 lines) | stat: -rw-r--r-- 7,465 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_SEGMENTATION_PLATFORM_SERVICE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_SEGMENTATION_PLATFORM_SERVICE_H_

#include <string>

#include "base/functional/callback.h"
#include "base/metrics/histogram_base.h"
#include "base/supports_user_data.h"
#include "base/types/id_type.h"
#include "build/build_config.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/segmentation_platform/public/database_client.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/result.h"
#include "components/segmentation_platform/public/trigger.h"
#include "services/metrics/public/cpp/ukm_source_id.h"

#if BUILDFLAG(IS_ANDROID)
#include "base/android/jni_android.h"
#endif  // BUILDFLAG(IS_ANDROID)

class PrefRegistrySimple;

namespace segmentation_platform {
class ServiceProxy;
struct SegmentSelectionResult;

using CallbackId = base::IdType32<class OnDemandSegmentSelectionCallbackTag>;

// Structure used to store data for training output collection.
// The name should be UMA histogram name or user action name. Currently this
// output will be appended to the training data, but the histogram name is not
// recorded. So, each model can only take one type of metric. Using 2 different
// metrics for same model would make it unclear what the value means while
// training.
struct TrainingLabels {
  TrainingLabels();
  ~TrainingLabels();

  // Name and sample of the output metric to be collected as training data.
  std::optional<std::pair<std::string, base::HistogramBase::Sample32>>
      output_metric;

  TrainingLabels(const TrainingLabels& other);
};

// The core class of segmentation platform that integrates all the required
// pieces on the client side.
class SegmentationPlatformService : public KeyedService,
                                    public base::SupportsUserData {
 public:
#if BUILDFLAG(IS_ANDROID)
  // Returns a Java object of the type SegmentationPlatformService for the given
  // SegmentationPlatformService.
  static base::android::ScopedJavaLocalRef<jobject> GetJavaObject(
      SegmentationPlatformService* segmentation_platform_service);
#endif  // BUILDFLAG(IS_ANDROID)
  using SuccessCallback = base::OnceCallback<void(bool)>;

  SegmentationPlatformService() = default;
  ~SegmentationPlatformService() override = default;

  // Disallow copy/assign.
  SegmentationPlatformService(const SegmentationPlatformService&) = delete;
  SegmentationPlatformService& operator=(const SegmentationPlatformService&) =
      delete;

  // Registers preferences used by this class in the provided |registry|.  This
  // should be called for the Profile registry.
  static void RegisterProfilePrefs(PrefRegistrySimple* registry);

  // Registers preferences used by this class in the provided |registry|.  This
  // should be called for the local state registry.
  static void RegisterLocalStatePrefs(PrefRegistrySimple* registry);

  using SegmentSelectionCallback =
      base::OnceCallback<void(const SegmentSelectionResult&)>;

  // Called to get the selected segment asynchronously. If none, returns empty
  // result.
  virtual void GetSelectedSegment(const std::string& segmentation_key,
                                  SegmentSelectionCallback callback) = 0;

  // Called to get the classification results for a given client. The
  // classification config must be defined in the associated model metadata.
  // Depending on the options and client config, it either runs the associated
  // model or uses unexpired cached results.
  virtual void GetClassificationResult(
      const std::string& segmentation_key,
      const PredictionOptions& prediction_options,
      scoped_refptr<InputContext> input_context,
      ClassificationResultCallback callback) = 0;

  // Get the result from the model execution, annotated with output config to
  // interpret the results. Depending on the options and client config, it
  // either runs the associated model or uses unexpired cached results. This API
  // is experimental and does not cleanly support transitions from a heuristic
  // to ML models. This API is not usable for most ML models since ML models
  // require normalization of the output values to make them usable.
  virtual void GetAnnotatedNumericResult(
      const std::string& segmentation_key,
      const PredictionOptions& prediction_options,
      scoped_refptr<InputContext> input_context,
      AnnotatedNumericResultCallback callback) = 0;

  // Called to get the selected segment synchronously. If none, returns empty
  // result.
  virtual SegmentSelectionResult GetCachedSegmentResult(
      const std::string& segmentation_key) = 0;

  // Get the set of input keys required for the model execution for
  // `segmentation_key`.
  virtual void GetInputKeysForModel(const std::string& segmentation_key,
                                    InputContextKeysCallback callback) = 0;

  // Called to trigger training data collection for a given request ID. Request
  // IDs are given when |GetClassificationResult| is called. `param` is used to
  // pass one additional output feature to be uploaded as training data. It is
  // recommended that the additional feature is also recorded as UMA histogram.
  // Optionally set `ukm_source_id` to attach the training data to the right
  // URL. The source ID should be created by the caller. If the ID is invalid,
  // the data will be uploaded with a no-URL UKM source.
  virtual void CollectTrainingData(proto::SegmentId segment_id,
                                   TrainingRequestId request_id,
                                   const TrainingLabels& param,
                                   SuccessCallback callback) = 0;
  virtual void CollectTrainingData(proto::SegmentId segment_id,
                                   TrainingRequestId request_id,
                                   ukm::SourceId ukm_source_id,
                                   const TrainingLabels& param,
                                   SuccessCallback callback) = 0;

  // Called to enable or disable metrics collection. Must be explicitly called
  // on startup.
  virtual void EnableMetrics(bool signal_collection_allowed) = 0;

  // Called to get the proxy that is used for debugging purpose.
  virtual ServiceProxy* GetServiceProxy();

  // Get access to the segmentation databases using the client.
  // WARNING: This will return nullptr till `IsPlatformInitialized()` is false.
  // You can observe ServiceProxy to get notified when platform is initialized.
  // TODO(ssid): Remove the initialization requirement by handling waiting for
  // init internally.
  // TODO(ssid): Add a Java version of this API.
  virtual DatabaseClient* GetDatabaseClient();

  // Returns true when platform finished initializing, and can execute models.
  // The `GetSelectedSegment()` calls work without full platform initialization
  // since they load results from previous sessions.
  virtual bool IsPlatformInitialized() = 0;
};

}  // namespace segmentation_platform

#endif  // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_SEGMENTATION_PLATFORM_SERVICE_H_