1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
|
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#ifdef USE_XPU
#include <torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h>
#endif
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::inductor {
void initAOTIRunnerBindings(PyObject* module) {
auto rootModule = py::handle(module).cast<py::module>();
auto m = rootModule.def_submodule("_aoti");
py::class_<AOTIModelContainerRunnerCpu>(m, "AOTIModelContainerRunnerCpu")
.def(py::init<const std::string&, int>())
.def(
"run",
&AOTIModelContainerRunnerCpu::run,
py::arg("inputs"),
py::arg("stream_handle") = nullptr)
.def("get_call_spec", &AOTIModelContainerRunnerCpu::get_call_spec)
.def(
"get_constant_names_to_original_fqns",
&AOTIModelContainerRunnerCpu::getConstantNamesToOriginalFQNs)
.def(
"get_constant_names_to_dtypes",
&AOTIModelContainerRunnerCpu::getConstantNamesToDtypes)
.def(
"update_constant_buffer",
static_cast<void (AOTIModelContainerRunnerCpu::*)(
std::unordered_map<std::string, at::Tensor>&, bool, bool)>(
&AOTIModelContainerRunnerCpu::update_constant_buffer));
#ifdef USE_CUDA
py::class_<AOTIModelContainerRunnerCuda>(m, "AOTIModelContainerRunnerCuda")
.def(py::init<const std::string&, int>())
.def(py::init<const std::string&, int, const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&>())
.def(
"run",
&AOTIModelContainerRunnerCuda::run,
py::arg("inputs"),
py::arg("stream_handle") = nullptr)
.def("get_call_spec", &AOTIModelContainerRunnerCuda::get_call_spec)
.def(
"get_constant_names_to_original_fqns",
&AOTIModelContainerRunnerCuda::getConstantNamesToOriginalFQNs)
.def(
"get_constant_names_to_dtypes",
&AOTIModelContainerRunnerCuda::getConstantNamesToDtypes)
.def(
"update_constant_buffer",
static_cast<void (AOTIModelContainerRunnerCuda::*)(
std::unordered_map<std::string, at::Tensor>&, bool, bool)>(
&AOTIModelContainerRunnerCuda::update_constant_buffer));
#endif
#ifdef USE_XPU
py::class_<AOTIModelContainerRunnerXpu>(m, "AOTIModelContainerRunnerXpu")
.def(py::init<const std::string&, int>())
.def(py::init<const std::string&, int, const std::string&>())
.def(py::init<
const std::string&,
int,
const std::string&,
const std::string&>())
.def(
"run",
&AOTIModelContainerRunnerXpu::run,
py::arg("inputs"),
py::arg("stream_handle") = nullptr)
.def("get_call_spec", &AOTIModelContainerRunnerXpu::get_call_spec)
.def(
"get_constant_names_to_original_fqns",
&AOTIModelContainerRunnerXpu::getConstantNamesToOriginalFQNs)
.def(
"get_constant_names_to_dtypes",
&AOTIModelContainerRunnerXpu::getConstantNamesToDtypes)
.def(
"update_constant_buffer",
static_cast<void (AOTIModelContainerRunnerXpu::*)(
std::unordered_map<std::string, at::Tensor>&, bool, bool)>(
&AOTIModelContainerRunnerXpu::update_constant_buffer));
#endif
m.def(
"unsafe_alloc_void_ptrs_from_tensors",
[](const std::vector<at::Tensor>& tensors) {
std::vector<AtenTensorHandle> handles =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(tensors);
std::vector<void*> result(
reinterpret_cast<void**>(handles.data()),
reinterpret_cast<void**>(handles.data()) + handles.size());
return result;
});
m.def("unsafe_alloc_void_ptr_from_tensor", [](at::Tensor& tensor) {
return reinterpret_cast<void*>(
torch::aot_inductor::new_tensor_handle(std::move(tensor)));
});
m.def(
"alloc_tensors_by_stealing_from_void_ptrs",
[](std::vector<void*>& raw_handles) {
return torch::aot_inductor::alloc_tensors_by_stealing_from_handles(
reinterpret_cast<AtenTensorHandle*>(raw_handles.data()),
raw_handles.size());
});
m.def("alloc_tensor_by_stealing_from_void_ptr", [](void* raw_handle) {
return *torch::aot_inductor::tensor_handle_to_tensor_pointer(
reinterpret_cast<AtenTensorHandle>(raw_handle));
});
}
} // namespace torch::inductor
|