File: init_flatbuffer_module.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (130 lines) | stat: -rw-r--r-- 4,773 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <torch/csrc/python_headers.h>

#include <libshm.h>
#include <cstdlib>

#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/csrc/utils/pybind.h>

#include <Python.h> // NOLINT
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/python_ivalue.h>
#include <torch/csrc/jit/python/python_sugared_value.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>

namespace py = pybind11;

using torch::jit::kFlatbufferDataAlignmentBytes;

static std::shared_ptr<char> copyStr(const std::string& bytes) {
  size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
      kFlatbufferDataAlignmentBytes;
#ifdef _WIN32
  std::shared_ptr<char> bytes_copy(
      static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
      _aligned_free);
#elif defined(__APPLE__)
  void* p;
  ::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
  TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
  std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
#else
  std::shared_ptr<char> bytes_copy(
      static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
      free);
#endif
  memcpy(bytes_copy.get(), bytes.data(), bytes.size());
  return bytes_copy;
}

extern "C"
#ifdef _WIN32
    __declspec(dllexport)
#endif
        PyObject* initModuleFlatbuffer() {
  using namespace torch::jit;
  PyMethodDef m[] = {{nullptr, nullptr, 0, nullptr}}; // NOLINT
  static struct PyModuleDef torchmodule = {
      PyModuleDef_HEAD_INIT,
      "torch._C_flatbuffer",
      nullptr,
      -1,
      m,
  }; // NOLINT
  PyObject* module = PyModule_Create(&torchmodule);
  auto pym = py::handle(module).cast<py::module>();
  pym.def("_load_mobile_module_from_file", [](const std::string& filename) {
    return torch::jit::load_mobile_module_from_file(filename);
  });
  pym.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
    auto bytes_copy = copyStr(bytes);
    return torch::jit::parse_and_initialize_mobile_module(
        bytes_copy, bytes.size());
  });
  pym.def("_load_jit_module_from_file", [](const std::string& filename) {
    ExtraFilesMap extra_files = ExtraFilesMap();
    return torch::jit::load_jit_module_from_file(filename, extra_files);
  });
  pym.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
    auto bytes_copy = copyStr(bytes);
    ExtraFilesMap extra_files = ExtraFilesMap();
    return torch::jit::parse_and_initialize_jit_module(
        bytes_copy, bytes.size(), extra_files);
  });
  pym.def(
      "_save_mobile_module",
      [](const torch::jit::mobile::Module& module,
         const std::string& filename,
         const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
        return torch::jit::save_mobile_module(module, filename, _extra_files);
      });
  pym.def(
      "_save_jit_module",
      [](const torch::jit::Module& module,
         const std::string& filename,
         const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
        return torch::jit::save_jit_module(module, filename, _extra_files);
      });
  pym.def(
      "_save_mobile_module_to_bytes",
      [](const torch::jit::mobile::Module& module,
         const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
        auto detached_buffer =
            torch::jit::save_mobile_module_to_bytes(module, _extra_files);
        return py::bytes(
            reinterpret_cast<char*>(detached_buffer->data()),
            detached_buffer->size());
      });
  pym.def(
      "_save_jit_module_to_bytes",
      [](const torch::jit::Module& module,
         const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
        auto detached_buffer =
            torch::jit::save_jit_module_to_bytes(module, _extra_files);
        return py::bytes(
            reinterpret_cast<char*>(detached_buffer->data()),
            detached_buffer->size());
      });
  pym.def(
      "_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
        py::gil_scoped_acquire acquire;
        py::dict result;
        mobile::ModuleInfo minfo =
            torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
        result["bytecode_version"] = minfo.bytecode_version;
        result["operator_version"] = minfo.operator_version;
        result["function_names"] = minfo.function_names;
        result["type_names"] = minfo.type_names;
        result["opname_to_num_args"] = minfo.opname_to_num_args;
        return result;
      });

  return module;
}