File: segment_selector_impl.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (162 lines) | stat: -rw-r--r-- 6,504 bytes parent folder | download | duplicates (11)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_SELECTOR_IMPL_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_SELECTOR_IMPL_H_

#include <utility>
#include "base/containers/flat_map.h"
#include "base/functional/callback_helpers.h"
#include "base/memory/raw_ptr.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/platform_options.h"
#include "components/segmentation_platform/internal/scheduler/execution_service.h"
#include "components/segmentation_platform/internal/selection/segment_result_provider.h"
#include "components/segmentation_platform/internal/selection/segment_selector.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/segment_selection_result.h"

class PrefService;

namespace base {
class Clock;
}  // namespace base

namespace segmentation_platform {

struct Config;
class ExperimentalGroupRecorder;
class FieldTrialRegister;
class SegmentationResultPrefs;
class SignalStorageConfig;

class SegmentSelectorImpl : public SegmentSelector {
 public:
  SegmentSelectorImpl(SegmentInfoDatabase* segment_database,
                      SignalStorageConfig* signal_storage_config,
                      PrefService* pref_service,
                      const Config* config,
                      FieldTrialRegister* field_trial_register,
                      base::Clock* clock,
                      const PlatformOptions& platform_options);

  SegmentSelectorImpl(SegmentInfoDatabase* segment_database,
                      SignalStorageConfig* signal_storage_config,
                      std::unique_ptr<SegmentationResultPrefs> prefs,
                      const Config* config,
                      FieldTrialRegister* field_trial_register,
                      base::Clock* clock,
                      const PlatformOptions& platform_options);

  ~SegmentSelectorImpl() override;

  // SegmentSelector overrides.
  void OnPlatformInitialized(ExecutionService* execution_service) override;
  void GetSelectedSegment(SegmentSelectionCallback callback) override;
  SegmentSelectionResult GetCachedSegmentResult() override;

  // Helper function to update the selected segment in the prefs. Auto-extends
  // the selection if the new result is unknown.
  virtual void UpdateSelectedSegment(SegmentId new_selection, float rank);

  // Called whenever a model eval completes. Runs segment selection to find the
  // best segment, and writes it to the pref.
  void OnModelExecutionCompleted(SegmentId segment_id) override;

  void set_segment_result_provider_for_testing(
      std::unique_ptr<SegmentResultProvider> result_provider) {
    segment_result_provider_ = std::move(result_provider);
  }

  void set_training_data_collector_for_testing(
      TrainingDataCollector* training_data_collector) {
    training_data_collector_ = training_data_collector;
  }

 private:
  // For testing.
  friend class SegmentSelectorTest;

  using SegmentRanks = base::flat_map<SegmentId, float>;

  // Determines whether segment selection can be run based on whether the
  // segment selection TTL has expired, or selection is unavailable.
  bool IsPreviousSelectionInvalid();

  // Gets scores for all segments and recomputes selection and stores the result
  // to prefs.
  void SelectSegmentAndStoreToPrefs();

  // Gets ranks for each segment from SegmentResultProvider, and then computes
  // segment selection.
  void GetRankForNextSegment(std::unique_ptr<SegmentRanks> ranks,
                             scoped_refptr<InputContext> input_context,
                             SegmentSelectionCallback callback);

  // Callback used to get result from SegmentResultProvider for each segment.
  void OnGetResultForSegmentSelection(
      std::unique_ptr<SegmentRanks> ranks,
      scoped_refptr<InputContext> input_context,
      SegmentSelectionCallback callback,
      SegmentId current_segment_id,
      std::unique_ptr<SegmentResultProvider::SegmentResult> result);

  void RecordFieldTrials() const;

  // Loops through all segments, performs discrete mapping, honors finch
  // supplied tie-breakers, TTL, inertia etc, and finds the highest rank.
  // Ignores the segments that have no results.
  std::pair<SegmentId, float> FindBestSegment(
      const SegmentRanks& segment_scores);

  // Wrapped result callback for recording metrics.
  void CallbackWrapper(base::Time start_time,
                       SegmentSelectionCallback callback,
                       const SegmentSelectionResult& result);

  std::unique_ptr<SegmentResultProvider> segment_result_provider_;

  // Helper class to read/write results to the prefs.
  std::unique_ptr<SegmentationResultPrefs> result_prefs_;

  // The database storing metadata and results.
  const raw_ptr<SegmentInfoDatabase> segment_database_;

  // The database to determine whether the signal storage requirements are met.
  const raw_ptr<SignalStorageConfig> signal_storage_config_;

  // The config for providing configuration params.
  const raw_ptr<const Config, DanglingUntriaged> config_;

  // Delegate that records selected segments to metrics.
  const raw_ptr<FieldTrialRegister> field_trial_register_;

  // Helper that records segmentation subgroups to `field_trial_register_`. Once
  // for each segment in the `config_`.
  std::vector<std::unique_ptr<ExperimentalGroupRecorder>>
      experimental_group_recorder_;

  // The time provider.
  const raw_ptr<base::Clock> clock_;

  const PlatformOptions platform_options_;

  // Segment selection result is read from prefs on init and used for serving
  // the clients in the current session. The selection could be updated if it
  // was unused by the client and a result refresh was triggered. If used by
  // client, then the result is not updated and effective onnly in the next
  // session.
  SegmentSelectionResult selected_segment_;
  bool used_result_in_current_session_ = false;

  // Pointer to the training data collector.
  raw_ptr<TrainingDataCollector, DanglingUntriaged> training_data_collector_ =
      nullptr;

  base::WeakPtrFactory<SegmentSelectorImpl> weak_ptr_factory_{this};
};

}  // namespace segmentation_platform

#endif  // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_SELECTOR_IMPL_H_