1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
|
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_MODEL_PROVIDER_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_MODEL_PROVIDER_H_
#include <memory>
#include <optional>
#include "base/functional/callback.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
namespace segmentation_platform {
namespace proto {
class SegmentationModelMetadata;
} // namespace proto
// Interface used by the segmentation platform to get model and metadata for a
// single optimization target.
class ModelProvider {
public:
using Request = std::vector<float>;
using Response = std::vector<float>;
using ModelUpdatedCallback = base::RepeatingCallback<void(
proto::SegmentId,
std::optional<proto::SegmentationModelMetadata>,
int64_t)>;
using ExecutionCallback =
base::OnceCallback<void(const std::optional<Response>&)>;
explicit ModelProvider(proto::SegmentId segment_id);
virtual ~ModelProvider();
ModelProvider(const ModelProvider&) = delete;
ModelProvider& operator=(const ModelProvider&) = delete;
// Implementation should return metadata that will be used to execute model.
// The metadata provided should define the number of features needed by the
// ExecuteModelWithInput() method. Starts a fetch request for the model for
// optimization target. The `model_updated_callback` can be called multiple
// times when new models are available for the optimization target.
virtual void InitAndFetchModel(
const ModelUpdatedCallback& model_updated_callback) = 0;
// Executes the latest model available, with the given inputs and returns
// result via `callback`. Should be called only after InitAndFetchModel()
// otherwise returns std::nullopt. Implementation could be a heuristic or
// model execution to return a result. The inputs to this method are the
// computed tensors based on the features provided in the latest call to
// `model_updated_callback`. The result is a float score with the probability
// of positive result. Also see `discrete_mapping` field in the
// `SegmentationModelMetadata` for how the score will be used to determine the
// segment.
virtual void ExecuteModelWithInput(const Request& inputs,
ExecutionCallback callback) = 0;
// Returns true if a model is available.
virtual bool ModelAvailable() = 0;
protected:
const proto::SegmentId segment_id_;
};
// ModelProvider wrapper for implementing default models in c++.
class DefaultModelProvider : public ModelProvider {
public:
explicit DefaultModelProvider(proto::SegmentId segment_id);
~DefaultModelProvider() override;
DefaultModelProvider(const DefaultModelProvider&) = delete;
DefaultModelProvider& operator=(const DefaultModelProvider&) = delete;
// Config needed for the model.
struct ModelConfig {
// Model metadata that contains inputs, outputs, and other configuration
// fields.
proto::SegmentationModelMetadata metadata;
// Model version. Should be incremented for any changes to the model.
int64_t model_version;
ModelConfig(proto::SegmentationModelMetadata metadata,
int64_t model_version);
~ModelConfig();
ModelConfig(const ModelConfig&) = delete;
ModelConfig& operator=(const ModelConfig&) = delete;
};
virtual std::unique_ptr<ModelConfig> GetModelConfig() = 0;
// Returns true by default. Can be overridden to disable the model if needed.
bool ModelAvailable() override;
private:
void InitAndFetchModel(
const ModelUpdatedCallback& model_updated_callback) final;
};
// Interface used by segmentation platform to create ModelProvider(s).
class ModelProviderFactory {
public:
virtual ~ModelProviderFactory();
// Creates a model provider for the given `segment_id`.
virtual std::unique_ptr<ModelProvider> CreateProvider(proto::SegmentId) = 0;
// Creates a default model provider to be used when the original provider did
// not provide a model. Returns `nullptr` when a default provider is not
// available.
// TODO(crbug.com/40232484): This method should be moved to Config after
// migrating all the tests that use this.
virtual std::unique_ptr<DefaultModelProvider> CreateDefaultProvider(
proto::SegmentId) = 0;
};
} // namespace segmentation_platform
#endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_MODEL_PROVIDER_H_
|