1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
#include <torch/csrc/jit/mobile/file_format.h>
#include <gtest/gtest.h>
#include <sstream>
// Tests go in torch::jit
namespace torch {
namespace jit {
TEST(FileFormatTest, IdentifiesFlatbufferStream) {
// Create data whose initial bytes look like a Flatbuffer stream.
std::stringstream data;
data << "abcd" // First four bytes don't matter.
<< "PTMF" // Magic string.
<< "efgh"; // Trailing bytes don't matter.
// The data should be identified as Flatbuffer.
EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
}
TEST(FileFormatTest, IdentifiesZipStream) {
// Create data whose initial bytes look like a ZIP stream.
std::stringstream data;
data << "PK\x03\x04" // Magic string.
<< "abcd" // Trailing bytes don't matter.
<< "efgh";
// The data should be identified as ZIP.
EXPECT_EQ(getFileFormat(data), FileFormat::ZipFileFormat);
}
TEST(FileFormatTest, FlatbufferTakesPrecedence) {
// Since the Flatbuffer and ZIP magic bytes are at different offsets,
// the same data could be identified as both. Demonstrate that Flatbuffer
// takes precedence. (See details in file_format.h)
std::stringstream data;
data << "PK\x03\x04" // ZIP magic string.
<< "PTMF" // Flatbuffer magic string.
<< "abcd"; // Trailing bytes don't matter.
// The data should be identified as Flatbuffer.
EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
}
TEST(FileFormatTest, HandlesUnknownStream) {
// Create data that doesn't look like any known format.
std::stringstream data;
data << "abcd"
<< "efgh"
<< "ijkl";
// The data should be classified as unknown.
EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
}
TEST(FileFormatTest, ShortStreamIsUnknown) {
// Create data with fewer than kFileFormatHeaderSize (8) bytes.
std::stringstream data;
data << "ABCD";
// The data should be classified as unknown.
EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
}
TEST(FileFormatTest, EmptyStreamIsUnknown) {
// Create an empty stream.
std::stringstream data;
// The data should be classified as unknown.
EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
}
TEST(FileFormatTest, BadStreamIsUnknown) {
// Create a stream with valid Flatbuffer data.
std::stringstream data;
data << "abcd"
<< "PTMF" // Flatbuffer magic string.
<< "efgh";
// Demonstrate that the data would normally be identified as Flatbuffer.
EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
// Mark the stream as bad, and demonstrate that it is in an error state.
data.setstate(std::stringstream::badbit);
// Demonstrate that the stream is in an error state.
EXPECT_FALSE(data.good());
// The data should now be classified as unknown.
EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
}
TEST(FileFormatTest, StreamOffsetIsObservedAndRestored) {
// Create data with a Flatbuffer header at a non-zero offset into the stream.
std::stringstream data;
// Add initial padding.
data << "PADDING";
size_t offset = data.str().size();
// Add a valid Flatbuffer header.
data << "abcd"
<< "PTMF" // Flatbuffer magic string.
<< "efgh";
// Seek just after the padding.
data.seekg(static_cast<std::stringstream::off_type>(offset), data.beg);
// Demonstrate that the stream points to the beginning of the Flatbuffer data,
// not to the padding.
EXPECT_EQ(data.peek(), 'a');
// The data should be identified as Flatbuffer.
EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
// The stream position should be where it was before identification.
EXPECT_EQ(offset, data.tellg());
}
TEST(FileFormatTest, HandlesMissingFile) {
// A missing file should be classified as unknown.
EXPECT_EQ(
getFileFormat("NON_EXISTENT_FILE_4965c363-44a7-443c-983a-8895eead0277"),
FileFormat::UnknownFileFormat);
}
} // namespace jit
} // namespace torch
|