File: trt_utils.cc

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (62 lines) | stat: -rw-r--r-- 2,187 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include "caffe2/contrib/tensorrt/trt_utils.h"

#include <NvOnnxParser.h>

namespace caffe2 {
namespace tensorrt {
std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine(
    const std::string& onnx_model_str,
    TrtLogger* logger,
    size_t max_batch_size,
    size_t max_workspace_size,
    bool debug_builder) {
  auto trt_builder = TrtObject(nvinfer1::createInferBuilder(*logger));
#if defined(TENSORRT_VERSION_MAJOR) && (TENSORRT_VERSION_MAJOR >= 6)
  auto trt_builder_cfg = TrtObject(trt_builder->createBuilderConfig());
  // TensorRTOp doesn't support dynamic shapes yet
  auto trt_network = TrtObject(trt_builder->createNetworkV2(
      1U << static_cast<uint32_t>(nvinfer1::
      NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
#else
  auto trt_network = TrtObject(trt_builder->createNetwork());
#endif
  auto trt_parser =
      TrtObject(nvonnxparser::createParser(*trt_network, *logger));
  auto status = trt_parser->parse(onnx_model_str.data(), onnx_model_str.size());
  if (!status) {
    const auto num_errors = trt_parser->getNbErrors();
    if (num_errors > 0) {
      const auto* error = trt_parser->getError(num_errors - 1);
      CAFFE_THROW(
          "TensorRTTransformer ERROR: ",
          error->file(),
          ":",
          error->line(),
          " In function ",
          error->func(),
          ":\n",
          "[",
          static_cast<int>(error->code()),
          "] ",
          error->desc());
    } else {
      CAFFE_THROW("TensorRTTransformer Unknown Error");
    }
  }
  trt_builder->setMaxBatchSize(max_batch_size);
#if defined(TENSORRT_VERSION_MAJOR) && (TENSORRT_VERSION_MAJOR >= 6)
  trt_builder_cfg->setMaxWorkspaceSize(max_workspace_size);
  if (debug_builder) {
    trt_builder_cfg->setFlag(nvinfer1::BuilderFlag::kDEBUG);
  }
  trt_builder_cfg->setDefaultDeviceType(nvinfer1::DeviceType::kGPU);
  return TrtObject(trt_builder->
      buildEngineWithConfig(*trt_network.get(), *trt_builder_cfg));
#else
  trt_builder->setMaxWorkspaceSize(max_workspace_size);
  trt_builder->setDebugSync(debug_builder);
  return TrtObject(trt_builder->buildCudaEngine(*trt_network.get()));
#endif
}
} // namespace tensorrt
} // namespace caffe2