1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
|
#pragma once
#include <array>
#include <cerrno>
#include <cstddef>
#include <cstring>
#include <fstream>
#include <istream>
#include <memory>
#include <c10/core/CPUAllocator.h>
#include <c10/core/impl/alloc_cpu.h>
#include <caffe2/serialize/read_adapter_interface.h>
#if defined(HAVE_MMAP)
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#endif
/**
* @file
*
* Helpers for identifying file formats when reading serialized data.
*
* Note that these functions are declared inline because they will typically
* only be called from one or two locations per binary.
*/
namespace torch {
namespace jit {
/**
* The format of a file or data stream.
*/
enum class FileFormat {
UnknownFileFormat = 0,
FlatbufferFileFormat,
ZipFileFormat,
};
/// The size of the buffer to pass to #getFileFormat(), in bytes.
constexpr size_t kFileFormatHeaderSize = 8;
constexpr size_t kMaxAlignment = 16;
/**
* Returns the likely file format based on the magic header bytes in @p header,
* which should contain the first bytes of a file or data stream.
*/
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline FileFormat getFileFormat(const char* data) {
// The size of magic strings to look for in the buffer.
static constexpr size_t kMagicSize = 4;
// Bytes 4..7 of a Flatbuffer-encoded file produced by
// `flatbuffer_serializer.h`. (The first four bytes contain an offset to the
// actual Flatbuffer data.)
static constexpr std::array<char, kMagicSize> kFlatbufferMagicString = {
'P', 'T', 'M', 'F'};
static constexpr size_t kFlatbufferMagicOffset = 4;
// The first four bytes of a ZIP file.
static constexpr std::array<char, kMagicSize> kZipMagicString = {
'P', 'K', '\x03', '\x04'};
// Note that we check for Flatbuffer magic first. Since the first four bytes
// of flatbuffer data contain an offset to the root struct, it's theoretically
// possible to construct a file whose offset looks like the ZIP magic. On the
// other hand, bytes 4-7 of ZIP files are constrained to a small set of values
// that do not typically cross into the printable ASCII range, so a ZIP file
// should never have a header that looks like a Flatbuffer file.
if (std::memcmp(
data + kFlatbufferMagicOffset,
kFlatbufferMagicString.data(),
kMagicSize) == 0) {
// Magic header for a binary file containing a Flatbuffer-serialized mobile
// Module.
return FileFormat::FlatbufferFileFormat;
} else if (std::memcmp(data, kZipMagicString.data(), kMagicSize) == 0) {
// Magic header for a zip file, which we use to store pickled sub-files.
return FileFormat::ZipFileFormat;
}
return FileFormat::UnknownFileFormat;
}
/**
* Returns the likely file format based on the magic header bytes of @p data.
* If the stream position changes while inspecting the data, this function will
* restore the stream position to its original offset before returning.
*/
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline FileFormat getFileFormat(std::istream& data) {
FileFormat format = FileFormat::UnknownFileFormat;
std::streampos orig_pos = data.tellg();
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<char, kFileFormatHeaderSize> header;
data.read(header.data(), header.size());
if (data.good()) {
format = getFileFormat(header.data());
}
data.seekg(orig_pos, data.beg);
return format;
}
/**
* Returns the likely file format based on the magic header bytes of the file
* named @p filename.
*/
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline FileFormat getFileFormat(const std::string& filename) {
std::ifstream data(filename, std::ifstream::binary);
return getFileFormat(data);
}
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static void file_not_found_error() {
std::stringstream message;
message << "Error while opening file: ";
if (errno == ENOENT) {
message << "no such file or directory" << std::endl;
} else {
message << "error no is: " << errno << std::endl;
}
TORCH_CHECK(false, message.str());
}
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline std::tuple<std::shared_ptr<char>, size_t> get_file_content(
const char* filename) {
#if defined(HAVE_MMAP)
int fd = open(filename, O_RDONLY);
if (fd < 0) {
// failed to open file, chances are it's no such file or directory.
file_not_found_error();
}
struct stat statbuf {};
fstat(fd, &statbuf);
size_t size = statbuf.st_size;
void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
close(fd);
auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
#else
FILE* f = fopen(filename, "rb");
if (f == nullptr) {
file_not_found_error();
}
fseek(f, 0, SEEK_END);
size_t size = ftell(f);
fseek(f, 0, SEEK_SET);
// make sure buffer size is multiple of alignment
size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
std::shared_ptr<char> data(
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
fread(data.get(), size, 1, f);
fclose(f);
#endif
return std::make_tuple(data, size);
}
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
std::istream& in) {
// get size of the stream and reset to orig
std::streampos orig_pos = in.tellg();
in.seekg(orig_pos, std::ios::end);
const long size = in.tellg();
in.seekg(orig_pos, in.beg);
// read stream
// NOLINT make sure buffer size is multiple of alignment
size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
std::shared_ptr<char> data(
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
in.read(data.get(), size);
// reset stream to original position
in.seekg(orig_pos, in.beg);
return std::make_tuple(data, size);
}
// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
static inline std::tuple<std::shared_ptr<char>, size_t> get_rai_content(
caffe2::serialize::ReadAdapterInterface* rai) {
size_t buffer_size = (rai->size() / kMaxAlignment + 1) * kMaxAlignment;
std::shared_ptr<char> data(
static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
rai->read(
0, data.get(), rai->size(), "Loading ReadAdapterInterface to bytes");
return std::make_tuple(data, buffer_size);
}
} // namespace jit
} // namespace torch
|