1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
|
#pragma once
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/python_print.h>
#include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
#include <torch/csrc/onnx/onnx.h>
#include <ostream>
namespace ONNX_NAMESPACE {
class ModelProto;
}
namespace torch {
namespace jit {
// This map is used to keep track of parameters that should be exported
// externally. When `defer_weight_export` is true, the returned map contains
// kv pairs that map {external reference name} -> {at::Tensor to be exported}.
// It is the responsibility of the caller to export these appropriately.
//
// For example, when exporting to a zip archive, the caller may write out files
// for each entry in the export map, with the filename being the key and the
// file contents being the raw tensor data.
using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
using NodeNameMap = std::unordered_map<const Node*, std::string>;
// Used for modularized export settling function and node attributes.
using NodeAttrNameMap = std::
unordered_map<const Node*, std::unordered_map<std::string, std::string>>;
TORCH_API std::tuple<
std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
RawDataExportMap,
SymbolDimMap,
bool,
NodeNameMap>
export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
const std::unordered_map<
std::string,
std::unordered_map<int64_t, std::string>>& dynamic_axes,
bool defer_weight_export = false,
::torch::onnx::OperatorExportTypes operator_export_type =
::torch::onnx::OperatorExportTypes::ONNX,
bool strip_doc_string = true,
bool keep_initializers_as_inputs = true,
const std::map<std::string, int>& custom_opsets = {},
bool add_node_names = true,
bool use_external_data_format = false,
const std::string& onnx_file_path = std::string(),
const NodeAttrNameMap& node_attr_to_name = {});
TORCH_API std::string serialize_model_proto_to_string(
const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto);
TORCH_API void check_onnx_proto(
const std::string& proto_string,
bool full_check = false);
// Serializer for both oldsyle and unified format TorchScript serialization
class TORCH_API ScriptModuleSerializer {
public:
explicit ScriptModuleSerializer(
caffe2::serialize::PyTorchStreamWriter& export_writer)
: writer_(export_writer), current_source_range_tag_(0) {}
void writeFiles(const std::string& code_dir);
void serialize(
const Module& module,
const ExtraFilesMap& extra_files,
bool bytecode_format,
bool save_mobile_debug_info);
void serialize_unified_format(Module& module, uint64_t script_module_id);
SerializationStorageContext& storage_context();
~ScriptModuleSerializer() = default;
private:
void convertNamedType(const c10::NamedTypePtr& class_type);
void convertTypes(const at::NamedTypePtr& root_type);
void writeExtraFiles(const Module& module, const ExtraFilesMap& extra_files);
void writeByteCode(const Module& module, bool save_mobile_debug_info);
void writeArchive(
const IValue& value,
const std::string& archive_name,
const std::string& archive_dir,
const std::string& tensor_dir,
bool use_storage_context = false);
void updateSourceRangeTags(const SourceRangeRecords& ranges);
caffe2::serialize::PyTorchStreamWriter& writer_;
std::vector<at::IValue> constant_table_;
std::unordered_set<c10::NamedTypePtr> converted_types_;
PrintDepsTable class_deps_;
TypeNameUniquer type_name_uniquer_;
// qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
// created
OrderedDict<std::string, PythonPrint> file_streams_;
// Used to keep references of storages around during serialization to solve
// for ABA memory reuse problem hit when storages are created/destroyed
// during serialization process. Also used to coordinate sharing of storages
// between Script and eager modules in torch.package.
SerializationStorageContext storage_context_;
// Uniquely identifies a SourceRange in a model.
// SourceRanges are associated with Nodes of Graphs.
// However for mobile deployment we dont intend to ship
// full JIT with capabilities of reading code and constructing
// graphs.
// Instead we serialize the Code generated from graph of the methods.
// Code is serialized in bytecode format that contains instructions
// corresponding to the nodes of the graph. Since original graph is gone, the
// question is how do we identify where the ops, in serialized bytecode, come
// from in original model code. We do this in two parts.
// 1. Associate a unique tag to SourceRange.
// 2. Serialize this unique_tag.
// 2.1 Meaning save <byte_offset, source_range_tag, source range> instead of
// <byte_offset, source range>
// 3. During serializing model for mobile, i.e. bytecode generation,
// save unique tag of SourceRange corresponding to the Node.
// 4. During deserialization, read all the debug_pkl, to construct a map
// of <unique_tag, SourceRange> and use tag saved with OPs in bytecode
// to lookup the source range.
// Strictly speaking we will serialize InlinedCallStack directly, which
// contains SourceRange. This way we have access to entire callstack and not
// just source information about where the node is, since bytecode inlines the
// graph before saving it.
SourceRangeTagMap source_range_tags_;
int64_t current_source_range_tag_;
};
// For testing purposes
TORCH_API std::string pretty_print_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
int64_t onnx_opset_version,
bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type =
::torch::onnx::OperatorExportTypes::ONNX,
bool google_printer = false,
bool keep_initializers_as_inputs = true,
const std::map<std::string, int>& custom_opsets = {},
bool add_node_names = true);
TORCH_API void ExportModule(
const Module& module,
std::ostream& out,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
bool save_mobile_debug_info = false,
bool use_flatbuffer = false);
TORCH_API void ExportModule(
const Module& module,
const std::string& filename,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
bool save_mobile_debug_info = false,
bool use_flatbuffer = false);
TORCH_API void ExportModule(
const Module& module,
const std::function<size_t(const void*, size_t)>& writer_func,
const ExtraFilesMap& metadata = ExtraFilesMap(),
bool bytecode_format = false,
bool save_mobile_debug_info = false,
bool use_flatbuffer = false);
// Write the bytes of a pickle archive and the tensors referenced inside that
// archive
TORCH_API void writeArchiveAndTensors(
const std::string& archive_name,
const char* pickle_bytes,
size_t size,
const std::vector<at::Tensor>& tensors,
caffe2::serialize::PyTorchStreamWriter& out);
// Surrounding system can install an additional hook to produce extra files
// with metadata based on environment every time a module is serialized.
using ExportModuleExtraFilesHook = std::function<ExtraFilesMap(const Module&)>;
TORCH_API void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook);
/**
* Generates new bytecode for a Script module and returns what the op list
* would be for a LiteScriptModule based off the current code base. If you
* have a LiteScriptModule and want to get the currently present
* list of ops call _export_operator_list instead.
*/
TORCH_API std::vector<std::string> export_opnames(const Module& m);
struct TORCH_API BytecodeEmitMode {
static bool is_default_value_for_unspecified_arg_enabled();
static void set_default_value_for_unspecified_arg_enabled(bool enabled);
static bool is_default_args_before_out_args_enabled();
static void set_default_args_before_out_args_enabled(bool enabled);
static bool is_emit_promoted_ops_enabled();
static void set_default_emit_promoted_ops_enabled(bool enabled);
};
// RAII guard to switch the way JIT emits the bytecode for inputs.
// default_value_for_unspecified_arg:
// true: instruction of default argument values (like LOADC) is emitted.
// false: instruction of default argument values are not emitted. Instead
// they are fetched from operator schema.
// default_args_before_out_args (to forward compatibile support
// operators allowing out arguments and default arguments):
// true: the number of specified arguments will deserialized to (#all_args -
// #default_args). false: the number of specified arguments will deserialized to
// (#all_args).
struct TORCH_API BytecodeEmitModeGuard {
BytecodeEmitModeGuard(
bool enable_default_value_for_unspecified_arg,
bool enable_default_args_before_out_args,
bool enable_emit_promoted_ops)
: prev_default_value_for_unspecified_arg_mode(
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()),
prev_default_args_before_out_args(
BytecodeEmitMode::is_default_args_before_out_args_enabled()),
prev_default_emit_promoted_ops(
BytecodeEmitMode::is_emit_promoted_ops_enabled()) {
BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
enable_default_value_for_unspecified_arg);
BytecodeEmitMode::set_default_args_before_out_args_enabled(
enable_default_args_before_out_args);
BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
enable_emit_promoted_ops);
}
~BytecodeEmitModeGuard() {
BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
prev_default_value_for_unspecified_arg_mode);
BytecodeEmitMode::set_default_args_before_out_args_enabled(
prev_default_args_before_out_args);
BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
prev_default_emit_promoted_ops);
}
bool prev_default_value_for_unspecified_arg_mode;
bool prev_default_args_before_out_args;
bool prev_default_emit_promoted_ops;
};
TORCH_API IValue to_tuple(std::vector<IValue> ivalues);
TORCH_API IValue
Table(const std::vector<std::pair<std::string, IValue>>& entries);
// TODO remove these switches once interface call is rolled out.
TORCH_API void enableMobileInterfaceCallExport();
bool getMobileInterfaceCallExport();
CompilationOptions getOptionsFromGlobal();
extern void (*_save_jit_module_to)(
const Module& module,
const ExtraFilesMap& extra_files,
bool save_mobile_debug_info,
const std::function<size_t(const void*, size_t)>& writer_func);
} // namespace jit
} // namespace torch
|