1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_REQUEST_DISPATCHER_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_REQUEST_DISPATCHER_H_
#include <map>
#include <string>
#include <utility>
#include "base/containers/circular_deque.h"
#include "base/memory/scoped_refptr.h"
#include "components/segmentation_platform/internal/database/storage_service.h"
#include "components/segmentation_platform/internal/selection/request_handler.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/result.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace segmentation_platform {
struct PredictionOptions;
class SegmentResultProvider;
// RequestDispatcher is the topmost layer in serving API requests for all
// clients. It's responsible for
// 1. Queuing API requests until the platform isn't fully initialized.
// 2. Dispatching requests to client specific request handlers.
class RequestDispatcher {
public:
explicit RequestDispatcher(StorageService* storage_service);
~RequestDispatcher();
// Disallow copy/assign.
RequestDispatcher(RequestDispatcher&) = delete;
RequestDispatcher& operator=(RequestDispatcher&) = delete;
// Called when platform and database initializations are completed.
void OnPlatformInitialized(
bool success,
ExecutionService* execution_service,
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers);
// Called when the model for |segment_id| has been initialized. Used to
// execute any queued requests that depend on that model.
void OnModelUpdated(proto::SegmentId segment_id);
// Client API. See `SegmentationPlatformService::GetClassificationResult`.
void GetClassificationResult(const std::string& segmentation_key,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
ClassificationResultCallback callback);
// Client API. See `SegmentationPlatformService::GetAnnotatedNumericResult`.
void GetAnnotatedNumericResult(const std::string& segmentation_key,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
AnnotatedNumericResultCallback callback);
// For testing only.
int GetPendingActionCountForTesting();
void set_request_handler_for_testing(
const std::string& segmentation_key,
std::unique_ptr<RequestHandler> request_handler) {
request_handlers_[segmentation_key] = std::move(request_handler);
}
private:
void OnModelInitializationTimeout();
void ExecuteAllPendingActions();
void ExecutePendingActionsForKey(const std::string& segmentation_key);
using WrappedCallback = base::OnceCallback<void(bool, const RawResult&)>;
void GetModelResult(const std::string& segmentation_key,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
WrappedCallback callback);
// Wrap the result callback for recording metrics and converting raw result to
// necessary result type.
template <typename ResultType>
void CallbackWrapper(const std::string& segmentation_key,
base::Time start_time,
base::OnceCallback<void(const ResultType&)> callback,
bool is_cached_result,
const RawResult& raw_result);
// Request handlers associated with the clients.
std::map<std::string, std::unique_ptr<RequestHandler>> request_handlers_;
// List of segmentation keys whose models haven't been initialized. Used to
// enqueue requests that involve an uninitialized model. It gets populated
// when the platform initializes and each element gets removed when
// |OnModelUpdated| gets called with its corresponding segment ID. All
// elements get cleared after a timeout to avoid waiting for too long.
std::set<std::string> uninitialized_segmentation_keys_;
const raw_ptr<StorageService> storage_service_;
// Storage initialization status.
absl::optional<bool> storage_init_status_;
// For caching any method calls that were received before initialization.
// Key is a segmentation key, value is a queue of actions that use that model.
std::map<std::string, base::circular_deque<base::OnceClosure>>
pending_actions_;
base::WeakPtrFactory<RequestDispatcher> weak_ptr_factory_{this};
};
} // namespace segmentation_platform
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_REQUEST_DISPATCHER_H_
|