1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_BASE_MODEL_EXECUTOR_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_BASE_MODEL_EXECUTOR_H_
#include "base/task/sequenced_task_runner.h"
#include "base/types/expected.h"
#include "build/build_config.h"
#include "components/optimization_guide/core/base_model_executor_helpers.h"
#include "components/optimization_guide/core/execution_status.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/tflite_model_executor.h"
#include "components/optimization_guide/core/tflite_op_resolver.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"
namespace optimization_guide {
// An ModelExecutor that executes models with arbitrary input and output types.
// Note that callers will need to give an implementation of this class to a
// |ModelHandler|, whereas the handle is the actual class that calling code
// would own and call into.
template <class OutputType, class InputType>
class BaseModelExecutor : public TFLiteModelExecutor<OutputType, InputType>,
public InferenceDelegate<OutputType, InputType> {
public:
using ModelExecutionTask =
tflite::task::core::BaseTaskApi<OutputType, InputType>;
BaseModelExecutor() = default;
~BaseModelExecutor() override = default;
BaseModelExecutor(const BaseModelExecutor&) = delete;
BaseModelExecutor& operator=(const BaseModelExecutor&) = delete;
public:
// TFLiteModelExecutor:
void InitializeAndMoveToExecutionThread(
std::optional<base::TimeDelta> model_inference_timeout,
proto::OptimizationTarget optimization_target,
scoped_refptr<base::SequencedTaskRunner> execution_task_runner,
scoped_refptr<base::SequencedTaskRunner> reply_task_runner) override {
num_threads_ = features::OverrideNumThreadsForOptTarget(optimization_target)
.value_or(-1);
TFLiteModelExecutor<OutputType, InputType>::
InitializeAndMoveToExecutionThread(
model_inference_timeout, optimization_target, execution_task_runner,
reply_task_runner);
}
protected:
std::optional<OutputType> Execute(ModelExecutionTask* execution_task,
ExecutionStatus* out_status,
InputType input) override {
return static_cast<GenericModelExecutionTask<OutputType, InputType>*>(
execution_task)
->Execute(out_status, input);
}
base::expected<std::unique_ptr<ModelExecutionTask>, ExecutionStatus>
BuildModelExecutionTask(base::File& model_file) override {
std::unique_ptr<tflite::task::core::TfLiteEngine> tflite_engine =
std::make_unique<tflite::task::core::TfLiteEngine>(
std::make_unique<TFLiteOpResolver>());
#if BUILDFLAG(IS_WIN)
absl::Status model_load_status =
tflite_engine->BuildModelFromFileHandle(model_file.GetPlatformFile());
#else
absl::Status model_load_status =
tflite_engine->BuildModelFromFileDescriptor(
model_file.GetPlatformFile());
#endif
if (!model_load_status.ok()) {
DLOG(ERROR) << "Failed to load model: " << model_load_status.ToString();
return base::unexpected(ExecutionStatus::kErrorModelFileNotValid);
}
auto compute_settings = tflite::proto::ComputeSettings();
compute_settings.mutable_tflite_settings()
->mutable_cpu_settings()
->set_num_threads(num_threads_);
absl::Status interpreter_status =
tflite_engine->InitInterpreter(compute_settings);
if (!interpreter_status.ok()) {
DLOG(ERROR) << "Failed to initialize model interpreter: "
<< interpreter_status.ToString();
return base::unexpected(ExecutionStatus::kErrorUnknown);
}
return std::make_unique<GenericModelExecutionTask<OutputType, InputType>>(
std::move(tflite_engine), this);
}
// InferenceDelegate:
bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
InputType input) override = 0;
std::optional<OutputType> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) override = 0;
private:
// -1 tells TFLite to use its own default number of threads.
int num_threads_ = -1;
};
} // namespace optimization_guide
#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_BASE_MODEL_EXECUTOR_H_
|