1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
|
#include <torch/csrc/jit/runtime/static/init.h>
#include <torch/csrc/jit/runtime/static/impl.h>
namespace torch {
namespace jit {
void initStaticRuntimeBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<StaticRuntime>(m, "StaticRuntime").def("run", &StaticRuntime::run);
m.def(
"_jit_to_static_runtime",
[](const std::shared_ptr<torch::jit::Graph>& g) {
return StaticRuntime(PrepareForStaticRuntime(g));
})
.def("_jit_to_static_runtime", [](const torch::jit::Module& m) {
return StaticRuntime(PrepareForStaticRuntime(m));
});
}
} // namespace jit
} // namespace torch
|