1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
#include "encode_jpeg.h"
#include "common_jpeg.h"
namespace vision {
namespace image {
#if !JPEG_FOUND
torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
TORCH_CHECK(
false, "encode_jpeg: torchvision not compiled with libjpeg support");
}
#else
// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is
// defined as unsigned long, whereas in later version, it is defined as size_t.
#if !defined(JPEG_LIB_VERSION_MAJOR) || JPEG_LIB_VERSION_MAJOR < 9 || \
(JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2)
using JpegSizeType = unsigned long;
#else
using JpegSizeType = size_t;
#endif
using namespace detail;
torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
C10_LOG_API_USAGE_ONCE(
"torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg");
// Define compression structures and error handling
struct jpeg_compress_struct cinfo{};
struct torch_jpeg_error_mgr jerr{};
// Define buffer to write JPEG information to and its size
JpegSizeType jpegSize = 0;
uint8_t* jpegBuf = nullptr;
cinfo.err = jpeg_std_error(&jerr.pub);
jerr.pub.error_exit = torch_jpeg_error_exit;
/* Establish the setjmp return context for my_error_exit to use. */
if (setjmp(jerr.setjmp_buffer)) {
/* If we get here, the JPEG code has signaled an error.
* We need to clean up the JPEG object and the buffer.
*/
jpeg_destroy_compress(&cinfo);
if (jpegBuf != nullptr) {
free(jpegBuf);
}
TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg);
}
// Check that the input tensor is on CPU
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
// Check that the input tensor is 3-dimensional
TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");
// Get image info
int channels = data.size(0);
int height = data.size(1);
int width = data.size(2);
auto input = data.permute({1, 2, 0}).contiguous();
TORCH_CHECK(
channels == 1 || channels == 3,
"The number of channels should be 1 or 3, got: ",
channels);
// Initialize JPEG structure
jpeg_create_compress(&cinfo);
// Set output image information
cinfo.image_width = width;
cinfo.image_height = height;
cinfo.input_components = channels;
cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB;
jpeg_set_defaults(&cinfo);
jpeg_set_quality(&cinfo, quality, TRUE);
// Save JPEG output to a buffer
jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize);
// Start JPEG compression
jpeg_start_compress(&cinfo, TRUE);
auto stride = width * channels;
auto ptr = input.data_ptr<uint8_t>();
// Encode JPEG file
while (cinfo.next_scanline < cinfo.image_height) {
jpeg_write_scanlines(&cinfo, &ptr, 1);
ptr += stride;
}
jpeg_finish_compress(&cinfo);
jpeg_destroy_compress(&cinfo);
torch::TensorOptions options = torch::TensorOptions{torch::kU8};
auto out_tensor =
torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options);
jpegBuf = nullptr;
return out_tensor;
}
#endif
} // namespace image
} // namespace vision
|