File: encode_png.cpp

package info (click to toggle)
pytorch-vision 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 15,188 kB
  • sloc: python: 49,008; cpp: 10,019; sh: 610; java: 550; xml: 79; objc: 56; makefile: 32
file content (180 lines) | stat: -rw-r--r-- 4,844 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#include "encode_jpeg.h"

#include "common_png.h"

namespace vision {
namespace image {

#if !PNG_FOUND

torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
  TORCH_CHECK(
      false, "encode_png: torchvision not compiled with libpng support");
}

#else

namespace {

struct torch_mem_encode {
  char* buffer;
  size_t size;
};

struct torch_png_error_mgr {
  const char* pngLastErrorMsg; /* error messages */
  jmp_buf setjmp_buffer; /* for return to caller */
};

using torch_png_error_mgr_ptr = torch_png_error_mgr*;

void torch_png_error(png_structp png_ptr, png_const_charp error_msg) {
  /* png_ptr->err really points to a torch_png_error_mgr struct, so coerce
   * pointer */
  auto error_ptr = (torch_png_error_mgr_ptr)png_get_error_ptr(png_ptr);
  /* Replace the error message on the error structure */
  error_ptr->pngLastErrorMsg = error_msg;
  /* Return control to the setjmp point */
  longjmp(error_ptr->setjmp_buffer, 1);
}

void torch_png_write_data(
    png_structp png_ptr,
    png_bytep data,
    png_size_t length) {
  struct torch_mem_encode* p =
      (struct torch_mem_encode*)png_get_io_ptr(png_ptr);
  size_t nsize = p->size + length;

  /* allocate or grow buffer */
  if (p->buffer)
    p->buffer = (char*)realloc(p->buffer, nsize);
  else
    p->buffer = (char*)malloc(nsize);

  if (!p->buffer)
    png_error(png_ptr, "Write Error");

  /* copy new bytes to end of buffer */
  memcpy(p->buffer + p->size, data, length);
  p->size += length;
}

} // namespace

torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
  C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png");
  // Define compression structures and error handling
  png_structp png_write;
  png_infop info_ptr;
  struct torch_png_error_mgr err_ptr;

  // Define output buffer
  struct torch_mem_encode buf_info;
  buf_info.buffer = NULL;
  buf_info.size = 0;

  /* Establish the setjmp return context for my_error_exit to use. */
  if (setjmp(err_ptr.setjmp_buffer)) {
    /* If we get here, the PNG code has signaled an error.
     * We need to clean up the PNG object and the buffer.
     */
    if (info_ptr != NULL) {
      png_destroy_info_struct(png_write, &info_ptr);
    }

    if (png_write != NULL) {
      png_destroy_write_struct(&png_write, NULL);
    }

    if (buf_info.buffer != NULL) {
      free(buf_info.buffer);
    }

    TORCH_CHECK(false, err_ptr.pngLastErrorMsg);
  }

  // Check that the compression level is between 0 and 9
  TORCH_CHECK(
      compression_level >= 0 && compression_level <= 9,
      "Compression level should be between 0 and 9");

  // Check that the input tensor is on CPU
  TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");

  // Check that the input tensor dtype is uint8
  TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");

  // Check that the input tensor is 3-dimensional
  TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");

  // Get image info
  int channels = data.size(0);
  int height = data.size(1);
  int width = data.size(2);
  auto input = data.permute({1, 2, 0}).contiguous();

  TORCH_CHECK(
      channels == 1 || channels == 3,
      "The number of channels should be 1 or 3, got: ",
      channels);

  // Initialize PNG structures
  png_write = png_create_write_struct(
      PNG_LIBPNG_VER_STRING, &err_ptr, torch_png_error, NULL);

  info_ptr = png_create_info_struct(png_write);

  // Define custom buffer output
  png_set_write_fn(png_write, &buf_info, torch_png_write_data, NULL);

  // Set output image information
  auto color_type = channels == 1 ? PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB;
  png_set_IHDR(
      png_write,
      info_ptr,
      width,
      height,
      8,
      color_type,
      PNG_INTERLACE_NONE,
      PNG_COMPRESSION_TYPE_DEFAULT,
      PNG_FILTER_TYPE_DEFAULT);

  // Set image compression level
  png_set_compression_level(png_write, compression_level);

  // Write file header
  png_write_info(png_write, info_ptr);

  auto stride = width * channels;
  auto ptr = input.data_ptr<uint8_t>();

  // Encode PNG file
  for (int y = 0; y < height; ++y) {
    png_write_row(png_write, ptr);
    ptr += stride;
  }

  // Write EOF
  png_write_end(png_write, info_ptr);

  // Destroy structures
  png_destroy_write_struct(&png_write, &info_ptr);

  torch::TensorOptions options = torch::TensorOptions{torch::kU8};
  auto outTensor = torch::empty({(long)buf_info.size}, options);

  // Copy memory from png buffer, since torch cannot get ownership of it via
  // `from_blob`
  auto outPtr = outTensor.data_ptr<uint8_t>();
  std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel());
  free(buf_info.buffer);

  return outTensor;
}

#endif

} // namespace image
} // namespace vision