File: feature_processor_state.cc

package info (click to toggle)
chromium 144.0.7559.109-2
  • links: PTS, VCS
  • area: main
  • in suites: forky
  • size: 5,915,868 kB
  • sloc: cpp: 35,866,215; ansic: 7,599,035; javascript: 3,623,761; python: 1,639,407; xml: 833,084; asm: 716,173; pascal: 185,323; sh: 88,763; perl: 88,699; objc: 79,984; sql: 58,217; cs: 42,430; fortran: 24,101; makefile: 20,747; tcl: 15,277; php: 14,022; yacc: 9,059; ruby: 7,553; awk: 3,720; lisp: 3,233; lex: 1,330; ada: 727; jsp: 228; sed: 36
file content (151 lines) | stat: -rw-r--r-- 5,540 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/segmentation_platform/internal/execution/processing/feature_processor_state.h"

#include "base/strings/string_number_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/database/ukm_types.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"

namespace segmentation_platform::processing {

FeatureProcessorState::FeatureProcessorState()
    : prediction_time_(base::Time::Now()),
      bucket_duration_(base::TimeDelta()),
      segment_id_(SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {}

FeatureProcessorState::FeatureProcessorState(
    FeatureProcessorStateId id,
    base::Time prediction_time,
    base::Time observation_time,
    base::TimeDelta bucket_duration,
    SegmentId segment_id,
    scoped_refptr<InputContext> input_context,
    FeatureListQueryProcessor::FeatureProcessorCallback callback)
    : id_(id),
      prediction_time_(prediction_time),
      observation_time_(observation_time),
      bucket_duration_(bucket_duration),
      segment_id_(segment_id),
      input_context_(std::move(input_context)),
      callback_(std::move(callback)) {}

FeatureProcessorState::~FeatureProcessorState() = default;

void FeatureProcessorState::SetError(stats::FeatureProcessingError error,
                                     const std::string& message) {
  stats::RecordFeatureProcessingError(segment_id_, error);
  LOG(ERROR) << "Processing error occured: model "
             << SegmentIdToHistogramVariant(segment_id_) << " failed with "
             << stats::FeatureProcessingErrorToString(error)
             << ", message: " << message;
  error_ = true;
  input_tensor_.clear();
}

base::WeakPtr<FeatureProcessorState> FeatureProcessorState::GetWeakPtr() {
  return weak_ptr_factory_.GetWeakPtr();
}

std::optional<std::pair<std::unique_ptr<QueryProcessor>, bool>>
FeatureProcessorState::PopNextProcessor() {
  std::optional<std::pair<std::unique_ptr<QueryProcessor>, bool>>
      next_processor;
  if (!out_processors_.empty()) {
    std::unique_ptr<QueryProcessor> processor =
        std::move(out_processors_.front());
    out_processors_.pop_front();
    next_processor = std::make_pair(std::move(processor), false);
  } else if (!in_processors_.empty()) {
    std::unique_ptr<QueryProcessor> processor =
        std::move(in_processors_.front());
    in_processors_.pop_front();
    next_processor = std::make_pair(std::move(processor), true);
  }
  return next_processor;
}

void FeatureProcessorState::AppendProcessor(
    std::unique_ptr<QueryProcessor> processor,
    bool is_input) {
  if (is_input) {
    in_processors_.emplace_back(std::move(processor));
  } else {
    out_processors_.emplace_back(std::move(processor));
  }
}

void FeatureProcessorState::AppendIndexedTensors(
    const QueryProcessor::IndexedTensors& result,
    bool is_input) {
  if (is_input) {
    for (const auto& item : result) {
      input_tensor_[item.first] = item.second;
    }
  } else {
    for (const auto& item : result) {
      output_tensor_[item.first] = item.second;
    }
  }
}

void FeatureProcessorState::OnFinishProcessing() {
  std::vector<float> input;
  std::vector<float> output;
  if (!error_) {
    input = MergeTensors(std::move(input_tensor_));
    output = MergeTensors(std::move(output_tensor_));
    stats::RecordFeatureProcessingError(
        segment_id_, stats::FeatureProcessingError::kSuccess);
  }
  base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
      FROM_HERE, base::BindOnce(std::move(callback_), error_, std::move(input),
                                std::move(output)));
}

std::vector<float> FeatureProcessorState::MergeTensors(
    const QueryProcessor::IndexedTensors& tensor) {
  std::vector<float> result;
  if (metadata_utils::ValidateIndexedTensors(tensor, tensor.size()) !=
      metadata_utils::ValidationResult::kValidationSuccess) {
    // Note that since the state does not know the expected size, if a tensor is
    // missing from the end of the indexed tensor, this validation will not
    // fail.
    SetError(stats::FeatureProcessingError::kResultTensorError);
  } else {
    for (size_t i = 0; i < tensor.size(); ++i) {
      for (const ProcessedValue& value : tensor.at(i)) {
        switch (value.type) {
          case ProcessedValue::BOOL:
            result.push_back(static_cast<float>(value.bool_val));
            break;
          case ProcessedValue::INT:
            result.push_back(static_cast<float>(value.int_val));
            break;
          case ProcessedValue::DOUBLE:
            result.push_back(static_cast<float>(value.double_val));
            break;
          case ProcessedValue::INT64:
            result.push_back(static_cast<float>(value.int64_val));
            break;
          case ProcessedValue::FLOAT:
            result.push_back(value.float_val);
            break;
          default:
            SetError(stats::FeatureProcessingError::kResultTensorError,
                     "Expected float compatible but found " +
                         base::NumberToString(static_cast<int>(value.type)));
            return result;
        }
      }
    }
  }
  return result;
}

}  // namespace segmentation_platform::processing