File: encode_jpeg.cpp

package info (click to toggle)
pytorch-vision 0.21.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 20,228 kB
  • sloc: python: 65,904; cpp: 11,406; ansic: 2,459; java: 550; sh: 265; xml: 79; objc: 56; makefile: 33
file content (113 lines) | stat: -rw-r--r-- 3,256 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include "encode_jpeg.h"

#include "common_jpeg.h"

namespace vision {
namespace image {

#if !JPEG_FOUND

torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
  TORCH_CHECK(
      false, "encode_jpeg: torchvision not compiled with libjpeg support");
}

#else
// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is
// defined as unsigned long, whereas in later version, it is defined as size_t.
#if !defined(JPEG_LIB_VERSION_MAJOR) || JPEG_LIB_VERSION_MAJOR < 9 || \
    (JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2)
using JpegSizeType = unsigned long;
#else
using JpegSizeType = size_t;
#endif

using namespace detail;

torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
  C10_LOG_API_USAGE_ONCE(
      "torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg");
  // Define compression structures and error handling
  struct jpeg_compress_struct cinfo{};
  struct torch_jpeg_error_mgr jerr{};

  // Define buffer to write JPEG information to and its size
  JpegSizeType jpegSize = 0;
  uint8_t* jpegBuf = nullptr;

  cinfo.err = jpeg_std_error(&jerr.pub);
  jerr.pub.error_exit = torch_jpeg_error_exit;

  /* Establish the setjmp return context for my_error_exit to use. */
  if (setjmp(jerr.setjmp_buffer)) {
    /* If we get here, the JPEG code has signaled an error.
     * We need to clean up the JPEG object and the buffer.
     */
    jpeg_destroy_compress(&cinfo);
    if (jpegBuf != nullptr) {
      free(jpegBuf);
    }

    TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg);
  }

  // Check that the input tensor is on CPU
  TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");

  // Check that the input tensor dtype is uint8
  TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");

  // Check that the input tensor is 3-dimensional
  TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");

  // Get image info
  int channels = data.size(0);
  int height = data.size(1);
  int width = data.size(2);
  auto input = data.permute({1, 2, 0}).contiguous();

  TORCH_CHECK(
      channels == 1 || channels == 3,
      "The number of channels should be 1 or 3, got: ",
      channels);

  // Initialize JPEG structure
  jpeg_create_compress(&cinfo);

  // Set output image information
  cinfo.image_width = width;
  cinfo.image_height = height;
  cinfo.input_components = channels;
  cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB;

  jpeg_set_defaults(&cinfo);
  jpeg_set_quality(&cinfo, quality, TRUE);

  // Save JPEG output to a buffer
  jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize);

  // Start JPEG compression
  jpeg_start_compress(&cinfo, TRUE);

  auto stride = width * channels;
  auto ptr = input.data_ptr<uint8_t>();

  // Encode JPEG file
  while (cinfo.next_scanline < cinfo.image_height) {
    jpeg_write_scanlines(&cinfo, &ptr, 1);
    ptr += stride;
  }

  jpeg_finish_compress(&cinfo);
  jpeg_destroy_compress(&cinfo);

  torch::TensorOptions options = torch::TensorOptions{torch::kU8};
  auto out_tensor =
      torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options);
  jpegBuf = nullptr;
  return out_tensor;
}
#endif

} // namespace image
} // namespace vision