File: trt_utils.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (55 lines) | stat: -rw-r--r-- 1,267 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#pragma once

#include <iostream>
#include <NvInfer.h>

#include "caffe2/core/logging.h"

namespace caffe2 { namespace tensorrt {

  // Logger for GIE info/warning/errors
class TrtLogger : public nvinfer1::ILogger {
  using nvinfer1::ILogger::Severity;

 public:
  TrtLogger(Severity verbosity = Severity::kWARNING) : _verbosity(verbosity) {}
  void log(Severity severity, const char* msg) override {
    if (severity <= _verbosity) {
      if (severity == Severity::kINTERNAL_ERROR || severity == Severity::kERROR) {
        LOG(ERROR) << msg;
      } else if (severity == Severity::kWARNING) {
        LOG(WARNING)  << msg;
      } else if (severity == Severity::kINFO) {
        LOG(INFO) << msg;
      }
    }
  }

 private:
  Severity _verbosity;
};

struct TrtDeleter {
  template <typename T>
  void operator()(T* obj) const {
    if (obj) {
      obj->destroy();
    }
  }
};

template <typename T>
inline std::shared_ptr<T> TrtObject(T* obj) {
  CAFFE_ENFORCE(obj, "Failed to create TensorRt object");
  return std::shared_ptr<T>(obj, TrtDeleter());
}

std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine(
    const std::string& onnx_model_str,
    TrtLogger* logger,
    size_t max_batch_size,
    size_t max_workspace_size,
    bool debug_builder);
}
}