1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
#pragma once
#include <ATen/core/stack.h>
namespace torch {
namespace jit {
namespace detail {
constexpr static auto kBackendsNamespace = "__backends__";
c10::FunctionSchema TORCH_API getPreprocessSchema();
c10::FunctionSchema TORCH_API getCompileSchema();
c10::FunctionSchema TORCH_API getExecuteSchema();
template <typename TBackendInterface>
std::function<void(Stack&)> getPreprocessFunc() {
return [](Stack& stack) {
auto method_compile_spec = pop(stack).toGenericDict();
auto mod = pop(stack);
auto self = pop(stack).toCustomClass<TBackendInterface>();
auto ret = self->preprocess(mod, method_compile_spec);
push(stack, ret);
};
}
template <typename TBackendInterface>
std::function<void(Stack&)> getCompileFunc() {
return [](Stack& stack) {
auto method_compile_spec = pop(stack).toGenericDict();
auto processed = pop(stack);
auto self = pop(stack).toCustomClass<TBackendInterface>();
auto ret = self->compile(processed, method_compile_spec);
push(stack, ret);
};
}
template <typename TBackendInterface>
std::function<void(Stack&)> getExecuteFunc() {
return [](Stack& stack) {
auto args = pop(stack);
auto handle = pop(stack);
auto self = pop(stack);
auto backend = self.toCustomClass<TBackendInterface>();
auto res = backend->execute(handle, args.toList());
push(stack, res);
};
}
} // namespace detail
} // namespace jit
} // namespace torch
|