1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
#include <torch/csrc/python_headers.h>
#include <libshm.h>
#include <cstdlib>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/csrc/utils/pybind.h>
#include <Python.h> // NOLINT
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/python_ivalue.h>
#include <torch/csrc/jit/python/python_sugared_value.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
namespace py = pybind11;
using torch::jit::kFlatbufferDataAlignmentBytes;
static std::shared_ptr<char> copyStr(const std::string& bytes) {
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
kFlatbufferDataAlignmentBytes;
#ifdef _WIN32
std::shared_ptr<char> bytes_copy(
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
_aligned_free);
#elif defined(__APPLE__)
void* p;
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
#else
std::shared_ptr<char> bytes_copy(
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
free);
#endif
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
return bytes_copy;
}
extern "C"
#ifdef _WIN32
__declspec(dllexport)
#endif
PyObject* initModuleFlatbuffer() {
using namespace torch::jit;
PyMethodDef m[] = {{nullptr, nullptr, 0, nullptr}}; // NOLINT
static struct PyModuleDef torchmodule = {
PyModuleDef_HEAD_INIT,
"torch._C_flatbuffer",
nullptr,
-1,
m,
}; // NOLINT
PyObject* module = PyModule_Create(&torchmodule);
auto pym = py::handle(module).cast<py::module>();
pym.def("_load_mobile_module_from_file", [](const std::string& filename) {
return torch::jit::load_mobile_module_from_file(filename);
});
pym.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
return torch::jit::parse_and_initialize_mobile_module(
bytes_copy, bytes.size());
});
pym.def("_load_jit_module_from_file", [](const std::string& filename) {
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::load_jit_module_from_file(filename, extra_files);
});
pym.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::parse_and_initialize_jit_module(
bytes_copy, bytes.size(), extra_files);
});
pym.def(
"_save_mobile_module",
[](const torch::jit::mobile::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_mobile_module(module, filename, _extra_files);
});
pym.def(
"_save_jit_module",
[](const torch::jit::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_jit_module(module, filename, _extra_files);
});
pym.def(
"_save_mobile_module_to_bytes",
[](const torch::jit::mobile::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_mobile_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer->data()),
detached_buffer->size());
});
pym.def(
"_save_jit_module_to_bytes",
[](const torch::jit::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_jit_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer->data()),
detached_buffer->size());
});
pym.def(
"_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
py::gil_scoped_acquire acquire;
py::dict result;
mobile::ModuleInfo minfo =
torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
result["bytecode_version"] = minfo.bytecode_version;
result["operator_version"] = minfo.operator_version;
result["function_names"] = minfo.function_names;
result["type_names"] = minfo.type_names;
result["opname_to_num_args"] = minfo.opname_to_num_args;
return result;
});
return module;
}
|