1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
#pragma once
#include <torch/csrc/api/include/torch/imethod.h>
#include <torch/csrc/jit/runtime/static/impl.h>
namespace torch {
namespace jit {
class StaticMethod : public torch::IMethod {
public:
StaticMethod(
std::shared_ptr<StaticModule> static_module,
std::string method_name)
: static_module_(std::move(static_module)),
method_name_(std::move(method_name)) {
TORCH_CHECK(static_module_);
}
c10::IValue operator()(
std::vector<IValue> args,
const IValueMap& kwargs = IValueMap()) const override {
return (*static_module_)(std::move(args), kwargs);
}
const std::string& name() const override {
return method_name_;
}
protected:
void setArgumentNames(
std::vector<std::string>& argument_names_out) const override {
const auto& schema = static_module_->schema();
CAFFE_ENFORCE(schema.has_value());
const auto& arguments = schema->arguments();
argument_names_out.clear();
argument_names_out.reserve(arguments.size());
std::transform(
arguments.begin(),
arguments.end(),
std::back_inserter(argument_names_out),
[](const c10::Argument& arg) -> std::string { return arg.name(); });
}
private:
std::shared_ptr<StaticModule> static_module_;
std::string method_name_;
};
} // namespace jit
} // namespace torch
|