1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
|
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/util/ArrayRef.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/serialization/pickler.h>
namespace torch {
namespace jit {
using TypeResolver =
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
using ObjLoader = std::function<
c10::intrusive_ptr<c10::ivalue::Object>(at::StrongTypePtr, IValue)>;
class DeserializationStorageContext;
// [unpickler refactor] there is some cruft around PickleOpCode::BUILD,
// PickleOpCode::NEWOBJ, and the last_opcode_ member below that should be
// deleted at some point, the Pickler doesn't produce it and it's only around to
// support models saved before 1.1
class TORCH_API Unpickler {
TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
using TypeParserT = c10::TypePtr (*)(const std::string&);
public:
// tensors inside the pickle are references to the tensor_table.
// class_resolver is to resolve strong class type, type_resolver_ is
// to resolve any JIT type. class_resolver and type_resolver are not merged
// here because some use cases need to get strong class type that
// type_resolver_ can not return.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Unpickler(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
c10::ArrayRef<at::Tensor> tensor_table,
TypeParserT type_parser = defaultTypeParser)
: reader_(std::move(reader)),
tensor_table_(tensor_table),
type_resolver_(std::move(type_resolver)),
use_storage_device_(false),
type_parser_(type_parser),
version_(caffe2::serialize::kProducedFileFormatVersion) {}
// tensors inside the pickle contain meta-data, the raw tensor
// dead is retrieved by calling `read_record`.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Unpickler(
std::function<size_t(char*, size_t)> reader,
TypeResolver type_resolver,
ObjLoader obj_loader,
std::function<at::DataPtr(const std::string&)> read_record,
c10::optional<at::Device> device,
bool use_storage_device = false,
TypeParserT type_parser = defaultTypeParser,
std::shared_ptr<DeserializationStorageContext> storage_context = nullptr)
: reader_(std::move(reader)),
tensor_table_(),
type_resolver_(std::move(type_resolver)),
obj_loader_(std::move(obj_loader)),
read_record_(std::move(read_record)),
// NOLINTNEXTLINE(performance-move-const-arg)
device_(std::move(device)),
use_storage_device_(use_storage_device),
type_parser_(type_parser),
storage_context_(std::move(storage_context)),
version_(caffe2::serialize::kProducedFileFormatVersion) {}
// consume the pickle stream, producing an IValue from the contents.
// Type Tags: the pickler will restore the type tags on
// List and Dict objects when possible IValue is an Object.
// Otherwise, Dict and List objects will end up with Any as their tag.
// If you know the type of the ivalue, tags can be restored with
// restoreAccurateTypeTags
IValue parse_ivalue();
// [type tag serialization]
// This is used to determine whether to restore type tags be recursively
// descending into the returned stack object (if version_number <= 2), or
// if version_number >= 3, to use the type strings included in the pickle
// archive for container types. By default this is set to
// `kProducedFileFormatVersion` so unless you're loading a pickle file
// from alongside a corresponding `version` file, you don't need to set
// the version manually.
void set_version(uint64_t version_number) {
version_ = version_number;
}
static c10::TypePtr defaultTypeParser(const std::string& str) {
ScriptTypeParser parser;
return parser.parseType(str);
}
private:
// No arguments ensures that a template argument must be specified
// so that the number of bytes read / type read is explicit
template <typename T>
T read() {
T item;
if (sizeof(T) <= buffer_remaining_) {
// Fast path: entirely from buffer.
memcpy(&item, buffer_.data() + buffer_pos_, sizeof(T));
buffer_remaining_ -= sizeof(T);
buffer_pos_ += sizeof(T);
} else {
// Don't over-template the slow path, to avoid code size bloat.
readSlowWithBuffer(reinterpret_cast<char*>(&item), sizeof(T));
}
return item;
}
void readSlowWithBuffer(char* dest, size_t sz);
std::string readBytes(size_t num_bytes);
double readFloat();
void readGlobal(
const std::string& module_name,
const std::string& class_name);
void rebuildTensor(bool quantized);
void rebuildSparseTensor();
#ifdef USE_DISTRIBUTED
void rebuildRRef();
#endif
PickleOpCode readInstruction();
PickleOpCode readOpCode() {
return static_cast<PickleOpCode>(read<uint8_t>());
}
std::string readString();
void readList(IValue list_ivalue);
void setInput(size_t memo_id);
void run();
// Returns the number of bytes read. This should statefully
// remember the position. Don't call reader_ directly.
std::function<size_t(char*, size_t)> reader_;
// Small buffer to avoid calling reader_ on a per-byte basis.
std::array<char, 256> buffer_;
size_t buffer_pos_{0};
size_t buffer_remaining_{0};
std::vector<IValue> stack_;
// globals are represented on the stack as IValue integer indices
// into this list
std::vector<std::function<void(void)>> globals_;
std::vector<IValue> memo_table_;
std::vector<size_t> marks_;
c10::ArrayRef<at::Tensor> tensor_table_;
// When deserializing types on lists and dicts, cache the type here
// so we don't have to parse the same type multiple times. Strings
// are already de-duplicated and replaced with BINGETs in the
// pickler, so we can just use the actual data pointer of each string.
std::unordered_map<std::string, c10::TypePtr> type_cache_;
// optionally nullptr, needs to be present for creating classes
TypeResolver type_resolver_;
ObjLoader obj_loader_;
IValue empty_tuple_;
std::function<at::DataPtr(const std::string&)> read_record_;
c10::optional<at::Device> device_;
// When set to true, Unpickler will ignore the pickled device and use the
// device of the DataPtr returned by the read_record_ function. The default
// value of this flag is false.
const bool use_storage_device_;
TypeParserT type_parser_{defaultTypeParser};
// Used for torch.package to enable sharing of storages across
// ScriptModules and eager modules
std::shared_ptr<DeserializationStorageContext> storage_context_;
// See [type tag serialization]
uint64_t version_;
};
void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag);
} // namespace jit
} // namespace torch
|