1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
|
#include <torch/csrc/jit/backends/backend_detail.h>
#include <ATen/core/builtin_function.h>
namespace torch {
namespace jit {
namespace detail {
c10::FunctionSchema getPreprocessSchema() {
c10::Argument self("self", c10::AnyType::get());
c10::Argument mod("mod", c10::AnyType::get());
c10::Argument method_compile_spec(
"method_compile_spec",
c10::DictType::create(c10::StringType::get(), c10::AnyType::get()));
c10::FunctionSchema preprocessor_schema(
"preprocess",
/*overload_name=*/"",
/*arguments=*/{self, mod, method_compile_spec},
/*returns=*/{mod});
return preprocessor_schema;
}
c10::FunctionSchema getCompileSchema() {
c10::Argument self("self", c10::AnyType::get());
c10::Argument mod("processed", c10::AnyType::get());
auto any_dict_ty =
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
c10::Argument method_compile_spec("method_compile_spec", any_dict_ty);
c10::Argument handles("handles", any_dict_ty);
c10::FunctionSchema compile_schema(
"compile",
/*overload_name=*/"",
/*arguments=*/{self, mod, method_compile_spec},
/*returns=*/{handles});
return compile_schema;
}
c10::FunctionSchema getExecuteSchema() {
auto any_list_ty = c10::ListType::create(c10::AnyType::get());
c10::Argument self("self", c10::AnyType::get());
c10::Argument handle("handle", c10::AnyType::get());
c10::Argument input("input", any_list_ty);
c10::Argument output("output", any_list_ty);
return c10::FunctionSchema(
"execute",
/*overload_name=*/"",
/*arguments=*/{self, handle, input},
/*returns=*/{output});
}
} // namespace detail
} // namespace jit
} // namespace torch
|