1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
#include <torch/csrc/jit/mobile/train/export_data.h>
#include <torch/csrc/jit/mobile/import_export_common.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
#include <caffe2/serialize/inline_container.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <string>
#include <vector>
namespace torch {
namespace jit {
namespace mobile {
char const* toString(OpCode op);
namespace {
/**
* Serializes an IValue using Pickle, and puts it in a file named "data.pkl"
* in a ZIP wrapper.
*/
class IValuePickler final {
public:
explicit IValuePickler(const std::string& filename) : writer_(filename) {}
explicit IValuePickler(
const std::function<size_t(const void*, size_t)>& writer_func)
: writer_(writer_func) {}
void serialize(const IValue& object) {
// Serialize just the data
writeArchive("data", object);
}
private:
void writeArchive(const std::string& archive_name, const IValue& value) {
std::vector<char> data;
// Vector to capture the run-time class types during pickling the IValues
std::vector<c10::ClassTypePtr> memoizedClassTypes;
Pickler data_pickle(
[&](const char* buf, size_t size) {
data.insert(data.end(), buf, buf + size);
},
nullptr,
[&](const c10::ClassTypePtr& t) {
return type_name_uniquer_.getUniqueName(t);
},
&memoizedClassTypes);
data_pickle.protocol();
data_pickle.pushIValue(value);
data_pickle.stop();
size_t i = 0;
std::string prefix = archive_name + "/";
for (const auto& td : data_pickle.tensorData()) {
WriteableTensorData writable_td = getWriteableTensorData(td);
std::string fname = prefix + c10::to_string(i++);
writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
}
std::string fname = archive_name + ".pkl";
writer_.writeRecord(fname, data.data(), data.size());
}
caffe2::serialize::PyTorchStreamWriter writer_;
TypeNameUniquer type_name_uniquer_;
};
} // namespace
/**
* Converts a map of named tensors to a c10::Dict.
*/
c10::Dict<std::string, at::Tensor> tensor_map_to_dict(
const std::map<std::string, at::Tensor>& map) {
c10::Dict<std::string, at::Tensor> dict;
for (const auto& e : map) {
dict.insert(e.first, e.second);
}
return dict;
}
/**
* Returns a Module with a single attribute, with the attribute name specified
* by #internal::kSavedParametersAttributeName, whose value is the provided
* dict.
*/
mobile::Module tensor_dict_to_mobile(
const c10::Dict<std::string, at::Tensor>& dict) {
// Create an Object to back the Module, with an attribute to hold the dict.
auto cu = std::make_shared<torch::jit::CompilationUnit>();
// Note that the name doesn't really matter, but it must begin with
// "__torch__." to be treated as a valid class when being imported.
auto cls = c10::ClassType::create(
"__torch__.SavedParameters", cu, /*is_module=*/true);
cls->addAttribute(
internal::kSavedParametersAttributeName,
c10::DictType::create(dict.keyType(), dict.valueType()));
auto object = c10::ivalue::Object::create(
c10::StrongTypePtr(std::move(cu), std::move(cls)), /*numSlots=*/1);
// Add the dict as an attribute.
object->setAttr(internal::kSavedParametersAttributeName, dict);
// Wrap the Object in a Module.
auto mcu = std::make_shared<mobile::CompilationUnit>();
return mobile::Module(object, mcu);
}
} // namespace mobile
void (*_save_mobile_module_to)(
const mobile::Module& module,
const std::function<size_t(const void*, size_t)>& writer_func) = nullptr;
void _save_parameters(
const std::map<std::string, at::Tensor>& map,
std::ostream& out,
bool use_flatbuffer) {
auto dict = mobile::tensor_map_to_dict(map);
auto write_func = [&out](const void* buf, size_t nbytes) -> size_t {
out.write(
static_cast<const char*>(buf), static_cast<std::streamsize>(nbytes));
return !out ? 0 : nbytes;
};
if (use_flatbuffer) {
if (_save_mobile_module_to != nullptr) {
_save_mobile_module_to(mobile::tensor_dict_to_mobile(dict), write_func);
} else {
TORCH_CHECK(
false,
"Trying to export as flatbuffer file but "
"the build hasn't enabled flatbuffer");
}
} else {
// For Pickle, we only serialize the dict itself.
mobile::IValuePickler pickler(write_func);
pickler.serialize(dict);
}
}
void _save_parameters(
const std::map<std::string, at::Tensor>& map,
const std::string& filename,
bool use_flatbuffer) {
auto dict = mobile::tensor_map_to_dict(map);
std::ofstream ifile(filename);
_save_parameters(map, ifile, use_flatbuffer);
}
} // namespace jit
} // namespace torch
|