File: request_dispatcher.h

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (139 lines) | stat: -rw-r--r-- 6,132 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_REQUEST_DISPATCHER_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_REQUEST_DISPATCHER_H_

#include <map>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <utility>

#include "base/containers/circular_deque.h"
#include "base/memory/scoped_refptr.h"
#include "components/segmentation_platform/internal/database/storage_service.h"
#include "components/segmentation_platform/internal/selection/request_handler.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/result.h"

namespace segmentation_platform {
struct PredictionOptions;
class SegmentResultProvider;

// RequestDispatcher is the topmost layer in serving API requests for all
// clients. It's responsible for
// 1. Queuing API requests until the platform isn't fully initialized.
// 2. Dispatching requests to client specific request handlers.
class RequestDispatcher {
 public:
  explicit RequestDispatcher(StorageService* storage_service);
  ~RequestDispatcher();

  // Disallow copy/assign.
  RequestDispatcher(RequestDispatcher&) = delete;
  RequestDispatcher& operator=(RequestDispatcher&) = delete;

  // Called when platform and database initializations are completed.
  void OnPlatformInitialized(
      bool success,
      ExecutionService* execution_service,
      std::map<std::string, std::unique_ptr<SegmentResultProvider>>
          result_providers);

  // Called when the model for |segment_id| has been initialized. Used to
  // execute any queued requests that depend on that model.
  void OnModelUpdated(proto::SegmentId segment_id);

  // Client API. See `SegmentationPlatformService::GetClassificationResult`.
  void GetClassificationResult(const std::string& segmentation_key,
                               const PredictionOptions& options,
                               scoped_refptr<InputContext> input_context,
                               ClassificationResultCallback callback);

  // Client API. See `SegmentationPlatformService::GetAnnotatedNumericResult`.
  void GetAnnotatedNumericResult(const std::string& segmentation_key,
                                 const PredictionOptions& options,
                                 scoped_refptr<InputContext> input_context,
                                 AnnotatedNumericResultCallback callback);

  // Client API. See `SegmentationPlatformService::GetInputKeysForModel`.
  void GetInputKeysForModel(const std::string& segmentation_key,
                            InputContextKeysCallback callback);

  // For testing only.
  int GetPendingActionCountForTesting();
  void set_request_handler_for_testing(
      const std::string& segmentation_key,
      std::unique_ptr<RequestHandler> request_handler) {
    request_handlers_[segmentation_key] = std::move(request_handler);
  }

 private:
  void OnModelInitializationTimeout();
  void ExecuteAllPendingActions();
  void ExecutePendingActionsForKey(const std::string& segmentation_key);

  using WrappedCallback = base::OnceCallback<void(bool, const RawResult&)>;
  void GetModelResult(const std::string& segmentation_key,
                      const PredictionOptions& options,
                      scoped_refptr<InputContext> input_context,
                      WrappedCallback callback);

  void ExecuteOnDemand(const std::string& segmentation_key,
                       const Config* config,
                       const PredictionOptions& options,
                       scoped_refptr<InputContext> input_context,
                       WrappedCallback callback);

  void OnFinishedOnDemandExecution(const std::string& segmentation_key,
                                   const Config* config,
                                   const PredictionOptions& options,
                                   scoped_refptr<InputContext> input_context,
                                   WrappedCallback callback,
                                   const RawResult& raw_result);

  void HandleCachedExecution(const std::string& segmentation_key,
                             const Config* config,
                             const PredictionOptions& options,
                             scoped_refptr<InputContext> input_context,
                             WrappedCallback callback);

  // Wrap the result callback for recording metrics and converting raw result to
  // necessary result type.
  template <typename ResultType>
  void CallbackWrapper(const std::string& segmentation_key,
                       base::Time start_time,
                       base::OnceCallback<void(const ResultType&)> callback,
                       bool is_cached_result,
                       const RawResult& raw_result);

  // Request handlers associated with the clients.
  std::map<std::string, std::unique_ptr<RequestHandler>> request_handlers_;

  // List of segmentation keys whose models haven't been initialized. Used to
  // enqueue requests that involve an uninitialized model. It gets populated
  // when the platform initializes and each element gets removed when
  // |OnModelUpdated| gets called with its corresponding segment ID. All
  // elements get cleared after a timeout to avoid waiting for too long.
  std::set<std::string> uninitialized_segmentation_keys_;

  const raw_ptr<StorageService> storage_service_;

  // Storage initialization status.
  std::optional<bool> storage_init_status_;

  // For caching any method calls that were received before initialization.
  // Key is a segmentation key, value is a queue of actions that use that model.
  std::map<std::string, base::circular_deque<base::OnceClosure>>
      pending_actions_;

  base::WeakPtrFactory<RequestDispatcher> weak_ptr_factory_{this};
};

}  // namespace segmentation_platform

#endif  // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_REQUEST_DISPATCHER_H_