1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_RESULT_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_RESULT_H_
#include <optional>
#include <string>
#include <string_view>
#include <vector>
#include "base/containers/flat_map.h"
#include "base/functional/callback_helpers.h"
#include "components/segmentation_platform/public/proto/prediction_result.pb.h"
#include "components/segmentation_platform/public/trigger.h"
namespace segmentation_platform {
// Various status for PredictionResult.
// GENERATED_JAVA_ENUM_PACKAGE: (
// org.chromium.components.segmentation_platform.prediction_status)
enum class PredictionStatus {
kNotReady = 0,
kFailed = 1,
kSucceeded = 2,
};
// ClassificationResult is returned when Predictor specified by the client in
// OutputConfig is one of BinaryClassifier, MultiClassClassifier or
// BinnedClassifier.
struct ClassificationResult {
explicit ClassificationResult(PredictionStatus status);
~ClassificationResult();
ClassificationResult(const ClassificationResult&);
ClassificationResult& operator=(const ClassificationResult&);
// Various error codes such as model failed or insufficient data collection.
PredictionStatus status;
// The list of labels arranged in descending order of result from model
// evaluation. For BinaryClassifier, it is either a `positive_label` or
// `negative_label`. For MultiClassClassifier, it is list of `top_k_outputs`
// labels based on the score for the label. For BinnedClassifier, it is a
// label from one of the bin depending on where the score from the model
// evaluation lies.
std::vector<std::string> ordered_labels;
// The request ID used for identifying a specific training data inputs. Can be
// null if training data was not uploaded for that execution.
TrainingRequestId request_id;
std::string ToDebugString() const;
};
// Result generated by evaluating the TFLite file or the default heuristic.
// Currently only supported when OutputConfig specifies a GenericPredictor.
struct AnnotatedNumericResult {
explicit AnnotatedNumericResult(PredictionStatus status);
~AnnotatedNumericResult();
AnnotatedNumericResult(const AnnotatedNumericResult&);
AnnotatedNumericResult& operator=(const AnnotatedNumericResult&);
// Returns the result for the given label. Null if the result failed to fetch
// or if the label is not available in the output config.
std::optional<float> GetResultForLabel(std::string_view label) const;
// Returns all the results, a float score for each output label.
base::flat_map<std::string, float> GetAllResults() const;
// Various error codes such as model failed or insufficient data collection.
PredictionStatus status;
// The result from the model.
proto::PredictionResult result;
// The request ID used for identifying a specific training data inputs. Can be
// null if training data was not uploaded for that execution.
TrainingRequestId request_id;
std::string ToDebugString() const;
};
using ClassificationResultCallback =
base::OnceCallback<void(const ClassificationResult&)>;
using AnnotatedNumericResultCallback =
base::OnceCallback<void(const AnnotatedNumericResult&)>;
using RawResult = AnnotatedNumericResult;
using RawResultCallback = base::OnceCallback<void(const RawResult&)>;
} // namespace segmentation_platform
#endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_RESULT_H_
|