File: gpu_decoder.cpp

package info (click to toggle)
pytorch-vision 0.21.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 20,228 kB
  • sloc: python: 65,904; cpp: 11,406; ansic: 2,459; java: 550; sh: 265; xml: 79; objc: 56; makefile: 33
file content (65 lines) | stat: -rw-r--r-- 2,068 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include "gpu_decoder.h"
#include <c10/cuda/CUDAGuard.h>

/* Set cuda device, create cuda context and initialise the demuxer and decoder.
 */
GPUDecoder::GPUDecoder(std::string src_file, torch::Device dev)
    : demuxer(src_file.c_str()) {
  at::cuda::CUDAGuard device_guard(dev);
  device = device_guard.current_device().index();
  check_for_cuda_errors(
      cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__);
  decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec()));
  initialised = true;
}

GPUDecoder::~GPUDecoder() {
  at::cuda::CUDAGuard device_guard(device);
  decoder.release();
  if (initialised) {
    check_for_cuda_errors(
        cuDevicePrimaryCtxRelease(device), __LINE__, __FILE__);
  }
}

/* Fetch a decoded frame tensor after demuxing and decoding.
 */
torch::Tensor GPUDecoder::decode() {
  torch::Tensor frameTensor;
  unsigned long videoBytes = 0;
  uint8_t* video = nullptr;
  at::cuda::CUDAGuard device_guard(device);
  torch::Tensor frame;
  do {
    demuxer.demux(&video, &videoBytes);
    decoder.decode(video, videoBytes);
    frame = decoder.fetch_frame();
  } while (frame.numel() == 0 && videoBytes > 0);
  return frame;
}

/* Seek to a passed timestamp. The second argument controls whether to seek to a
 * keyframe.
 */
void GPUDecoder::seek(double timestamp, bool keyframes_only) {
  int flag = keyframes_only ? 0 : AVSEEK_FLAG_ANY;
  demuxer.seek(timestamp, flag);
}

c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
    get_metadata() const {
  c10::Dict<std::string, c10::Dict<std::string, double>> metadata;
  c10::Dict<std::string, double> video_metadata;
  video_metadata.insert("duration", demuxer.get_duration());
  video_metadata.insert("fps", demuxer.get_fps());
  metadata.insert("video", video_metadata);
  return metadata;
}

TORCH_LIBRARY(torchvision, m) {
  m.class_<GPUDecoder>("GPUDecoder")
      .def(torch::init<std::string, torch::Device>())
      .def("seek", &GPUDecoder::seek)
      .def("get_metadata", &GPUDecoder::get_metadata)
      .def("next", &GPUDecoder::decode);
}