File: model_provider.h

package info (click to toggle)
chromium 139.0.7258.127-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,122,156 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (120 lines) | stat: -rw-r--r-- 4,605 bytes parent folder | download | duplicates (7)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_MODEL_PROVIDER_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_MODEL_PROVIDER_H_

#include <memory>
#include <optional>

#include "base/functional/callback.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"

namespace segmentation_platform {
namespace proto {
class SegmentationModelMetadata;
}  // namespace proto

// Interface used by the segmentation platform to get model and metadata for a
// single optimization target.
class ModelProvider {
 public:
  using Request = std::vector<float>;
  using Response = std::vector<float>;

  using ModelUpdatedCallback = base::RepeatingCallback<void(
      proto::SegmentId,
      std::optional<proto::SegmentationModelMetadata>,
      int64_t)>;
  using ExecutionCallback =
      base::OnceCallback<void(const std::optional<Response>&)>;

  explicit ModelProvider(proto::SegmentId segment_id);
  virtual ~ModelProvider();

  ModelProvider(const ModelProvider&) = delete;
  ModelProvider& operator=(const ModelProvider&) = delete;

  // Implementation should return metadata that will be used to execute model.
  // The metadata provided should define the number of features needed by the
  // ExecuteModelWithInput() method. Starts a fetch request for the model for
  // optimization target. The `model_updated_callback` can be called multiple
  // times when new models are available for the optimization target.
  virtual void InitAndFetchModel(
      const ModelUpdatedCallback& model_updated_callback) = 0;

  // Executes the latest model available, with the given inputs and returns
  // result via `callback`. Should be called only after InitAndFetchModel()
  // otherwise returns std::nullopt. Implementation could be a heuristic or
  // model execution to return a result. The inputs to this method are the
  // computed tensors based on the features provided in the latest call to
  // `model_updated_callback`. The result is a float score with the probability
  // of positive result. Also see `discrete_mapping` field in the
  // `SegmentationModelMetadata` for how the score will be used to determine the
  // segment.
  virtual void ExecuteModelWithInput(const Request& inputs,
                                     ExecutionCallback callback) = 0;

  // Returns true if a model is available.
  virtual bool ModelAvailable() = 0;

 protected:
  const proto::SegmentId segment_id_;
};

// ModelProvider wrapper for implementing default models in c++.
class DefaultModelProvider : public ModelProvider {
 public:
  explicit DefaultModelProvider(proto::SegmentId segment_id);
  ~DefaultModelProvider() override;

  DefaultModelProvider(const DefaultModelProvider&) = delete;
  DefaultModelProvider& operator=(const DefaultModelProvider&) = delete;

  // Config needed for the model.
  struct ModelConfig {
    // Model metadata that contains inputs, outputs, and other configuration
    // fields.
    proto::SegmentationModelMetadata metadata;
    // Model version. Should be incremented for any changes to the model.
    int64_t model_version;

    ModelConfig(proto::SegmentationModelMetadata metadata,
                int64_t model_version);
    ~ModelConfig();

    ModelConfig(const ModelConfig&) = delete;
    ModelConfig& operator=(const ModelConfig&) = delete;
  };
  virtual std::unique_ptr<ModelConfig> GetModelConfig() = 0;

  // Returns true by default. Can be overridden to disable the model if needed.
  bool ModelAvailable() override;

 private:
  void InitAndFetchModel(
      const ModelUpdatedCallback& model_updated_callback) final;
};

// Interface used by segmentation platform to create ModelProvider(s).
class ModelProviderFactory {
 public:
  virtual ~ModelProviderFactory();

  // Creates a model provider for the given `segment_id`.
  virtual std::unique_ptr<ModelProvider> CreateProvider(proto::SegmentId) = 0;

  // Creates a default model provider to be used when the original provider did
  // not provide a model. Returns `nullptr` when a default provider is not
  // available.
  // TODO(crbug.com/40232484): This method should be moved to Config after
  // migrating all the tests that use this.
  virtual std::unique_ptr<DefaultModelProvider> CreateDefaultProvider(
      proto::SegmentId) = 0;
};

}  // namespace segmentation_platform

#endif  // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_MODEL_PROVIDER_H_