1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
|
#pragma once
#include <cerrno>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <istream>
#include <mutex>
#include <ostream>
#include <unordered_set>
#include <c10/core/Allocator.h>
#include <c10/core/Backend.h>
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"
#include "caffe2/serialize/versions.h"
extern "C" {
typedef struct mz_zip_archive mz_zip_archive;
}
// PyTorch containers are a special zip archive with the following layout
// archive_name.zip contains:
// archive_name/
// version # a file with a single decimal number written in ascii,
// # used to establish the version of the archive format
// model.json # overall model description, this is a json output of
// # ModelDef from torch.proto
// # the following names are by convention only, model.json will
// # refer to these files by full names
// tensors/
// 0 # flat storage for tensor data, meta-data about shapes, etc. is
// # in model.json
// 1
// ...
// # code entries will only exist for modules that have methods attached
// code/
// archive_name.py # serialized torch script code (python syntax, using
// PythonPrint) archive_name_my_submodule.py # submodules have separate
// files
//
// The PyTorchStreamWriter also ensures additional useful properties for these
// files
// 1. All files are stored uncompressed.
// 2. All files in the archive are aligned to 64 byte boundaries such that
// it is possible to mmap the entire file and get an aligned pointer to
// tensor data.
// 3. We universally write in ZIP64 format for consistency.
// The PyTorchStreamReader also provides additional properties:
// 1. It can read zip files that are created with common
// zip tools. This means that even though our writer doesn't compress files,
// the reader can still read files that were compressed.
// 2. It provides a getRecordOffset function which returns the offset into the
// raw file where file data lives. If the file was written with
// PyTorchStreamWriter it is guaranteed to be 64 byte aligned.
// PyTorchReader/Writer handle checking the version number on the archive format
// and ensure that all files are written to a archive_name directory so they
// unzip cleanly.
// When developing this format we want to pay particular attention to the
// following use cases:
//
// -- Reading --
// 1) Reading with full random access
// a) Reading with file api's such as fread()
// b) mmaping the file and jumping around the mapped region
// 2) Reading with 1-pass sequential access
// -> A reader will need to build up a data structure of parsed structures
// as it reads
//
// -- Writing --
// 1) Writing with full random access
// 2) Writing with 1-pass sequential access
// -> We must take care not to require updating values that have already
// been written. We place the variable-length index at the end and do
// not put any indicies into the header to fulfill this constraint.
// The model.json, which contains all the metadata information,
// should be written as the last file. One reason is that the size of tensor
// data is usually stable. As long as the shape and type of the tensor do not
// change, the size of the data won't change. On the other sied, the size of the
// serialized model is likely to change, so we store it as the last record, and
// we don't need to move previous records when updating the model data.
// The zip format is sufficiently flexible to handle the above use-case.
// it puts its central directory at the end of the archive and we write
// model.json as the last file when writing after we have accumulated all
// other information.
namespace caffe2 {
namespace serialize {
static constexpr const char* kSerializationIdRecordName =
".data/serialization_id";
struct MzZipReaderIterWrapper;
class TORCH_API ChunkRecordIterator {
public:
~ChunkRecordIterator();
// Read at most `chunkSize` into `buf`. Return the number of actual bytes
// read.
size_t next(void* buf);
size_t recordSize() const {
return recordSize_;
}
private:
ChunkRecordIterator(
size_t recordSize,
size_t chunkSize,
std::unique_ptr<MzZipReaderIterWrapper> iter);
const size_t recordSize_;
const size_t chunkSize_;
size_t offset_;
std::unique_ptr<MzZipReaderIterWrapper> iter_;
friend class PyTorchStreamReader;
};
class TORCH_API PyTorchStreamReader final {
public:
explicit PyTorchStreamReader(const std::string& file_name);
explicit PyTorchStreamReader(std::istream* in);
explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in);
// return dataptr, size
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
// multi-thread getRecord
std::tuple<at::DataPtr, size_t> getRecord(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
// inplace memory writing
size_t getRecord(const std::string& name, void* dst, size_t n);
// inplace memory writing, multi-threads.
// When additionalReaders is empty, the default behavior is call
// getRecord(name, dst, n) with default reader This approach can be used for
// reading large tensors.
size_t getRecord(
const std::string& name,
void* dst,
size_t n,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
size_t getRecord(
const std::string& name,
void* dst,
size_t n,
size_t chunk_size,
void* buf,
const std::function<void(void*, const void*, size_t)>& memcpy_func =
nullptr);
// Concurrent reading records with multiple readers.
// additionalReaders are additional clients to access the underlying record at
// different offsets and write to different trunks of buffers. If the overall
// size of the tensor is 10, and size of additionalReader is 2. The default
// thread will read [0,4), the additional reader will read [4,8). The default
// reader will read [8,10). The default reader will write to buffer[0,4), the
// additional reader will write to buffer[4,8), the additional reader will
// write to buffer[8,10). When additionalReaders is empty, the default
// behavior is call getRecord(name) with default reader This approach can be
// used for reading large tensors.
size_t getRecordMultiReaders(
const std::string& name,
std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
void* dst,
size_t n);
size_t getRecordSize(const std::string& name);
size_t getRecordOffset(const std::string& name);
bool hasRecord(const std::string& name);
std::vector<std::string> getAllRecords();
ChunkRecordIterator createChunkReaderIter(
const std::string& name,
const size_t recordSize,
const size_t chunkSize);
~PyTorchStreamReader();
uint64_t version() const {
return version_;
}
const std::string& serializationId() {
return serialization_id_;
}
void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
load_debug_symbol_ = should_load_debug_symbol;
}
void setAdditionalReaderSizeThreshold(const size_t& size) {
additional_reader_size_threshold_ = size;
}
private:
void init();
size_t read(uint64_t pos, char* buf, size_t n);
void valid(const char* what, const char* info = "");
size_t getRecordID(const std::string& name);
friend size_t
istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
std::unique_ptr<mz_zip_archive> ar_;
std::string archive_name_;
std::string archive_name_plus_slash_;
std::shared_ptr<ReadAdapterInterface> in_;
int64_t version_;
std::mutex reader_lock_;
bool load_debug_symbol_ = true;
std::string serialization_id_;
size_t additional_reader_size_threshold_;
};
class TORCH_API PyTorchStreamWriter final {
public:
explicit PyTorchStreamWriter(
const std::string& archive_name,
bool compute_crc32 = true);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func,
bool compute_crc32 = true);
void setMinVersion(const uint64_t version);
void writeRecord(
const std::string& name,
const void* data,
size_t size,
bool compress = false);
void writeEndOfFile();
const std::unordered_set<std::string>& getAllWrittenRecords();
bool finalized() const {
return finalized_;
}
const std::string& archiveName() {
return archive_name_;
}
const std::string& serializationId() {
return serialization_id_;
}
~PyTorchStreamWriter();
private:
void setup(const std::string& file_name);
void valid(const char* what, const char* info = "");
void writeSerializationId();
size_t current_pos_ = 0;
std::unordered_set<std::string> files_written_;
std::unique_ptr<mz_zip_archive> ar_;
std::string archive_name_;
std::string archive_name_plus_slash_;
std::string padding_;
std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_;
uint64_t combined_uncomp_crc32_ = 0;
std::string serialization_id_;
bool compute_crc32_;
// This number will be updated when the model has operators
// that have valid upgraders.
uint64_t version_ = kMinProducedFileFormatVersion;
bool finalized_ = false;
bool err_seen_ = false;
friend size_t ostream_write_func(
void* pOpaque,
uint64_t file_ofs,
const void* pBuf,
size_t n);
};
namespace detail {
// Writer-specific constants
constexpr uint64_t kFieldAlignment = 64;
// Returns a record to be appended to the local user extra data entry in order
// to make data beginning aligned at kFieldAlignment bytes boundary.
size_t getPadding(
size_t cursor,
size_t filename_size,
size_t size,
std::string& padding_buf);
} // namespace detail
} // namespace serialize
} // namespace caffe2
|