1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/execution/processing/feature_processor_state.h"
#include "base/strings/string_number_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/database/ukm_types.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
namespace segmentation_platform::processing {
FeatureProcessorState::FeatureProcessorState()
: prediction_time_(base::Time::Now()),
bucket_duration_(base::TimeDelta()),
segment_id_(SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {}
FeatureProcessorState::FeatureProcessorState(
FeatureProcessorStateId id,
base::Time prediction_time,
base::Time observation_time,
base::TimeDelta bucket_duration,
SegmentId segment_id,
scoped_refptr<InputContext> input_context,
FeatureListQueryProcessor::FeatureProcessorCallback callback)
: id_(id),
prediction_time_(prediction_time),
observation_time_(observation_time),
bucket_duration_(bucket_duration),
segment_id_(segment_id),
input_context_(std::move(input_context)),
callback_(std::move(callback)) {}
FeatureProcessorState::~FeatureProcessorState() = default;
void FeatureProcessorState::SetError(stats::FeatureProcessingError error,
const std::string& message) {
stats::RecordFeatureProcessingError(segment_id_, error);
LOG(ERROR) << "Processing error occured: model "
<< SegmentIdToHistogramVariant(segment_id_) << " failed with "
<< stats::FeatureProcessingErrorToString(error)
<< ", message: " << message;
error_ = true;
input_tensor_.clear();
}
base::WeakPtr<FeatureProcessorState> FeatureProcessorState::GetWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
std::optional<std::pair<std::unique_ptr<QueryProcessor>, bool>>
FeatureProcessorState::PopNextProcessor() {
std::optional<std::pair<std::unique_ptr<QueryProcessor>, bool>>
next_processor;
if (!out_processors_.empty()) {
std::unique_ptr<QueryProcessor> processor =
std::move(out_processors_.front());
out_processors_.pop_front();
next_processor = std::make_pair(std::move(processor), false);
} else if (!in_processors_.empty()) {
std::unique_ptr<QueryProcessor> processor =
std::move(in_processors_.front());
in_processors_.pop_front();
next_processor = std::make_pair(std::move(processor), true);
}
return next_processor;
}
void FeatureProcessorState::AppendProcessor(
std::unique_ptr<QueryProcessor> processor,
bool is_input) {
if (is_input) {
in_processors_.emplace_back(std::move(processor));
} else {
out_processors_.emplace_back(std::move(processor));
}
}
void FeatureProcessorState::AppendIndexedTensors(
const QueryProcessor::IndexedTensors& result,
bool is_input) {
if (is_input) {
for (const auto& item : result) {
input_tensor_[item.first] = item.second;
}
} else {
for (const auto& item : result) {
output_tensor_[item.first] = item.second;
}
}
}
void FeatureProcessorState::OnFinishProcessing() {
std::vector<float> input;
std::vector<float> output;
if (!error_) {
input = MergeTensors(std::move(input_tensor_));
output = MergeTensors(std::move(output_tensor_));
stats::RecordFeatureProcessingError(
segment_id_, stats::FeatureProcessingError::kSuccess);
}
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback_), error_, std::move(input),
std::move(output)));
}
std::vector<float> FeatureProcessorState::MergeTensors(
const QueryProcessor::IndexedTensors& tensor) {
std::vector<float> result;
if (metadata_utils::ValidateIndexedTensors(tensor, tensor.size()) !=
metadata_utils::ValidationResult::kValidationSuccess) {
// Note that since the state does not know the expected size, if a tensor is
// missing from the end of the indexed tensor, this validation will not
// fail.
SetError(stats::FeatureProcessingError::kResultTensorError);
} else {
for (size_t i = 0; i < tensor.size(); ++i) {
for (const ProcessedValue& value : tensor.at(i)) {
if (value.type == ProcessedValue::Type::FLOAT) {
result.push_back(value.float_val);
} else {
SetError(stats::FeatureProcessingError::kResultTensorError,
"Expected ProcessedValue::Type::FLOAT but found " +
base::NumberToString(static_cast<int>(value.type)));
return result;
}
}
}
}
return result;
}
} // namespace segmentation_platform::processing
|