1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
|
#include "decode_jpeg_cuda.h"
#include <ATen/ATen.h>
#if NVJPEG_FOUND
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <nvjpeg.h>
#endif
#include <string>
namespace vision {
namespace image {
#if !NVJPEG_FOUND
torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
TORCH_CHECK(
false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support");
}
#else
namespace {
static nvjpegHandle_t nvjpeg_handle = nullptr;
}
torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
C10_LOG_API_USAGE_ONCE(
"torchvision.csrc.io.image.cuda.decode_jpeg_cuda.decode_jpeg_cuda");
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
TORCH_CHECK(
!data.is_cuda(),
"The input tensor must be on CPU when decoding with nvjpeg")
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(device.is_cuda(), "Expected a cuda device")
int major_version;
int minor_version;
nvjpegStatus_t get_major_property_status =
nvjpegGetProperty(MAJOR_VERSION, &major_version);
nvjpegStatus_t get_minor_property_status =
nvjpegGetProperty(MINOR_VERSION, &minor_version);
TORCH_CHECK(
get_major_property_status == NVJPEG_STATUS_SUCCESS,
"nvjpegGetProperty failed: ",
get_major_property_status);
TORCH_CHECK(
get_minor_property_status == NVJPEG_STATUS_SUCCESS,
"nvjpegGetProperty failed: ",
get_minor_property_status);
if ((major_version < 11) || ((major_version == 11) && (minor_version < 6))) {
TORCH_WARN_ONCE(
"There is a memory leak issue in the nvjpeg library for CUDA versions < 11.6. "
"Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda').");
}
at::cuda::CUDAGuard device_guard(device);
// Create global nvJPEG handle
static std::once_flag nvjpeg_handle_creation_flag;
std::call_once(nvjpeg_handle_creation_flag, []() {
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);
if (create_status != NVJPEG_STATUS_SUCCESS) {
// Reset handle so that one can still call the function again in the
// same process if there was a failure
free(nvjpeg_handle);
nvjpeg_handle = nullptr;
}
TORCH_CHECK(
create_status == NVJPEG_STATUS_SUCCESS,
"nvjpegCreateSimple failed: ",
create_status);
}
});
// Create the jpeg state
nvjpegJpegState_t jpeg_state;
nvjpegStatus_t state_status =
nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state);
TORCH_CHECK(
state_status == NVJPEG_STATUS_SUCCESS,
"nvjpegJpegStateCreate failed: ",
state_status);
auto datap = data.data_ptr<uint8_t>();
// Get the image information
int num_channels;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];
nvjpegStatus_t info_status = nvjpegGetImageInfo(
nvjpeg_handle,
datap,
data.numel(),
&num_channels,
&subsampling,
widths,
heights);
if (info_status != NVJPEG_STATUS_SUCCESS) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status);
}
if (subsampling == NVJPEG_CSS_UNKNOWN) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling");
}
int width = widths[0];
int height = heights[0];
nvjpegOutputFormat_t ouput_format;
int num_channels_output;
switch (mode) {
case IMAGE_READ_MODE_UNCHANGED:
num_channels_output = num_channels;
// For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will
// not properly decode RGB images (it's fine for grayscale), so we set
// output_format manually here
if (num_channels == 1) {
ouput_format = NVJPEG_OUTPUT_Y;
} else if (num_channels == 3) {
ouput_format = NVJPEG_OUTPUT_RGB;
} else {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false,
"When mode is UNCHANGED, only 1 or 3 input channels are allowed.");
}
break;
case IMAGE_READ_MODE_GRAY:
ouput_format = NVJPEG_OUTPUT_Y;
num_channels_output = 1;
break;
case IMAGE_READ_MODE_RGB:
ouput_format = NVJPEG_OUTPUT_RGB;
num_channels_output = 3;
break;
default:
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false, "The provided mode is not supported for JPEG decoding on GPU");
}
auto out_tensor = torch::empty(
{int64_t(num_channels_output), int64_t(height), int64_t(width)},
torch::dtype(torch::kU8).device(device));
// nvjpegImage_t is a struct with
// - an array of pointers to each channel
// - the pitch for each channel
// which must be filled in manually
nvjpegImage_t out_image;
for (int c = 0; c < num_channels_output; c++) {
out_image.channel[c] = out_tensor[c].data_ptr<uint8_t>();
out_image.pitch[c] = width;
}
for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) {
out_image.channel[c] = nullptr;
out_image.pitch[c] = 0;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index());
nvjpegStatus_t decode_status = nvjpegDecode(
nvjpeg_handle,
jpeg_state,
datap,
data.numel(),
ouput_format,
&out_image,
stream);
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
decode_status == NVJPEG_STATUS_SUCCESS,
"nvjpegDecode failed: ",
decode_status);
return out_tensor;
}
#endif // NVJPEG_FOUND
} // namespace image
} // namespace vision
|