File: test_frcnn_tracing.cpp

package info (click to toggle)
pytorch-vision 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 15,188 kB
  • sloc: python: 49,008; cpp: 10,019; sh: 610; java: 550; xml: 79; objc: 56; makefile: 32
file content (58 lines) | stat: -rw-r--r-- 1,590 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <torch/script.h>
#include <torch/torch.h>
#include <torchvision/vision.h>
#include <torchvision/ops/nms.h>


int main() {
  torch::DeviceType device_type;
  device_type = torch::kCPU;

  torch::jit::script::Module module;
  try {
    std::cout << "Loading model\n";
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load("fasterrcnn_resnet50_fpn.pt");
    std::cout << "Model loaded\n";
  } catch (const torch::Error& e) {
    std::cout << "error loading the model\n";
    return -1;
  } catch (const std::exception& e) {
    std::cout << "Other error: " << e.what() << "\n";
    return -1;
  }

  // TorchScript models require a List[IValue] as input
  std::vector<torch::jit::IValue> inputs;

  // Faster RCNN accepts a List[Tensor] as main input
  std::vector<torch::Tensor> images;
  images.push_back(torch::rand({3, 256, 275}));
  images.push_back(torch::rand({3, 256, 275}));

  inputs.push_back(images);
  auto output = module.forward(inputs);

  std::cout << "ok\n";
  std::cout << "output" << output << "\n";

  if (torch::cuda::is_available()) {
    // Move traced model to GPU
    module.to(torch::kCUDA);

    // Add GPU inputs
    images.clear();
    inputs.clear();

    torch::TensorOptions options = torch::TensorOptions{torch::kCUDA};
    images.push_back(torch::rand({3, 256, 275}, options));
    images.push_back(torch::rand({3, 256, 275}, options));

    inputs.push_back(images);
    auto output = module.forward(inputs);

    std::cout << "ok\n";
    std::cout << "output" << output << "\n";
  }
  return 0;
}