File: trt_utils.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (55 lines) | stat: -rw-r--r-- 1,267 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#pragma once

#include <iostream>
#include <NvInfer.h>

#include "caffe2/core/logging.h"

namespace caffe2 { namespace tensorrt {

  // Logger for GIE info/warning/errors
class TrtLogger : public nvinfer1::ILogger {
  using nvinfer1::ILogger::Severity;

 public:
  TrtLogger(Severity verbosity = Severity::kWARNING) : _verbosity(verbosity) {}
  void log(Severity severity, const char* msg) override {
    if (severity <= _verbosity) {
      if (severity == Severity::kINTERNAL_ERROR || severity == Severity::kERROR) {
        LOG(ERROR) << msg;
      } else if (severity == Severity::kWARNING) {
        LOG(WARNING)  << msg;
      } else if (severity == Severity::kINFO) {
        LOG(INFO) << msg;
      }
    }
  }

 private:
  Severity _verbosity;
};

struct TrtDeleter {
  template <typename T>
  void operator()(T* obj) const {
    if (obj) {
      obj->destroy();
    }
  }
};

template <typename T>
inline std::shared_ptr<T> TrtObject(T* obj) {
  CAFFE_ENFORCE(obj, "Failed to create TensorRt object");
  return std::shared_ptr<T>(obj, TrtDeleter());
}

std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine(
    const std::string& onnx_model_str,
    TrtLogger* logger,
    size_t max_batch_size,
    size_t max_workspace_size,
    bool debug_builder);
}
}