1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/scheduler/execution_service.h"
#include "base/memory/raw_ptr.h"
#include "base/task/sequenced_task_runner.h"
#include "components/prefs/pref_service.h"
#include "components/segmentation_platform/internal/database/cached_result_provider.h"
#include "components/segmentation_platform/internal/database/storage_service.h"
#include "components/segmentation_platform/internal/execution/execution_request.h"
#include "components/segmentation_platform/internal/execution/model_executor_impl.h"
#include "components/segmentation_platform/internal/execution/processing/feature_aggregator_impl.h"
#include "components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h"
#include "components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h"
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
#include "components/segmentation_platform/internal/signals/signal_handler.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/input_delegate.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
namespace segmentation_platform {
ExecutionService::ExecutionService() = default;
ExecutionService::~ExecutionService() = default;
void ExecutionService::InitForTesting(
std::unique_ptr<processing::FeatureListQueryProcessor> feature_processor,
std::unique_ptr<ModelExecutor> executor,
std::unique_ptr<ModelExecutionScheduler> scheduler,
ModelManager* model_manager) {
feature_list_query_processor_ = std::move(feature_processor);
model_executor_ = std::move(executor);
model_execution_scheduler_ = std::move(scheduler);
model_manager_ = model_manager;
}
void ExecutionService::Initialize(
StorageService* storage_service,
SignalHandler* signal_handler,
base::Clock* clock,
scoped_refptr<base::SequencedTaskRunner> task_runner,
const base::flat_set<SegmentId>& legacy_output_segment_ids,
ModelProviderFactory* model_provider_factory,
std::vector<raw_ptr<ModelExecutionScheduler::Observer,
VectorExperimental>>&& observers,
const PlatformOptions& platform_options,
std::unique_ptr<processing::InputDelegateHolder> input_delegate_holder,
PrefService* profile_prefs,
CachedResultProvider* cached_result_provider) {
storage_service_ = storage_service;
feature_list_query_processor_ =
std::make_unique<processing::FeatureListQueryProcessor>(
storage_service, std::move(input_delegate_holder),
std::make_unique<processing::FeatureAggregatorImpl>());
training_data_collector_ = TrainingDataCollector::Create(
platform_options, feature_list_query_processor_.get(),
signal_handler->deprecated_histogram_signal_handler(),
signal_handler->user_action_signal_handler(), storage_service,
profile_prefs, clock, cached_result_provider);
model_executor_ = std::make_unique<ModelExecutorImpl>(
clock, storage_service->segment_info_database(),
feature_list_query_processor_.get());
model_manager_ = storage_service->model_manager();
model_execution_scheduler_ = std::make_unique<ModelExecutionSchedulerImpl>(
std::move(observers), storage_service->segment_info_database(),
storage_service->signal_storage_config(), model_manager_,
model_executor_.get(), legacy_output_segment_ids, clock,
platform_options);
}
void ExecutionService::OnNewModelInfoReadyLegacy(
const proto::SegmentInfo& segment_info) {
// TODO(crbug.com/40258591): Change path flow as
// SPSI->RRM->EE::RequestModelExecution and migrate
// MES::CancelOutstandingExecutionRequests() to EE.
model_execution_scheduler_->OnNewModelInfoReady(segment_info);
}
ModelProvider* ExecutionService::GetModelProvider(SegmentId segment_id,
ModelSource model_source) {
return model_manager_->GetModelProvider(segment_id, model_source);
}
void ExecutionService::RequestModelExecution(
std::unique_ptr<ExecutionRequest> request) {
DCHECK_NE(request->segment_id, SegmentId::OPTIMIZATION_TARGET_UNKNOWN);
DCHECK_NE(request->model_source, proto::ModelSource::UNKNOWN_MODEL_SOURCE);
DCHECK(!request->callback.is_null());
model_executor_->ExecuteModel(std::move(request));
}
void ExecutionService::OverwriteModelExecutionResult(
proto::SegmentId segment_id,
const std::pair<float, ModelExecutionStatus>& result) {
// TODO(ritikagup): Change the use of this according to MultiOutputModel.
auto execution_result = std::make_unique<ModelExecutionResult>(
ModelProvider::Request(), ModelProvider::Response(1, result.first));
proto::SegmentInfo segment_info;
segment_info.set_segment_id(segment_id);
model_execution_scheduler_->OnModelExecutionCompleted(
segment_info, std::move(execution_result));
}
void ExecutionService::RefreshModelResults() {
model_execution_scheduler_->RequestModelExecutionForEligibleSegments(
/*expired_only=*/true);
}
void ExecutionService::RunDailyTasks(bool is_startup) {
RefreshModelResults();
if (is_startup) {
// This will trigger data collection after initialization finishes.
training_data_collector_->OnServiceInitialized();
} else {
training_data_collector_->ReportCollectedContinuousTrainingData();
}
}
} // namespace segmentation_platform
|