File: common.cpp

package info (click to toggle)
vart 2.5-5
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 4,404 kB
  • sloc: cpp: 30,188; python: 7,493; sh: 969; makefile: 37; ansic: 36
file content (91 lines) | stat: -rwxr-xr-x 3,620 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

/*
 * Copyright 2019 Xilinx Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "common.h"

#include <cassert>
#include <numeric>
int getTensorShape(vart::Runner* runner, GraphInfo* shapes, int cntin,
                   int cntout) {
  auto outputTensors = runner->get_output_tensors();
  auto inputTensors = runner->get_input_tensors();
  if (shapes->output_mapping.empty()) {
    shapes->output_mapping.resize((unsigned)cntout);
    std::iota(shapes->output_mapping.begin(), shapes->output_mapping.end(), 0);
  }
  for (int i = 0; i < cntin; i++) {
    auto dim_num = inputTensors[i]->get_shape().size();
    if (dim_num == 4) {
      shapes->inTensorList[i].channel = inputTensors[i]->get_shape().at(3);
      shapes->inTensorList[i].width = inputTensors[i]->get_shape().at(2);
      shapes->inTensorList[i].height = inputTensors[i]->get_shape().at(1);
      shapes->inTensorList[i].size =
          inputTensors[i]->get_element_num() / inputTensors[0]->get_shape().at(0);
    } else if (dim_num == 2) {
      shapes->inTensorList[i].channel = inputTensors[i]->get_shape().at(1);
      shapes->inTensorList[i].width = 1;
      shapes->inTensorList[i].height = 1;
      shapes->inTensorList[i].size =
          inputTensors[i]->get_element_num() / inputTensors[0]->get_shape().at(0);
    }
  }
  for (int i = 0; i < cntout; i++) {
    auto dim_num = outputTensors[shapes->output_mapping[i]]->get_shape().size();
    if (dim_num == 4) {
      shapes->outTensorList[i].channel =
          outputTensors[shapes->output_mapping[i]]->get_shape().at(3);
      shapes->outTensorList[i].width =
          outputTensors[shapes->output_mapping[i]]->get_shape().at(2);
      shapes->outTensorList[i].height =
          outputTensors[shapes->output_mapping[i]]->get_shape().at(1);
      shapes->outTensorList[i].size =
          outputTensors[shapes->output_mapping[i]]->get_element_num() /
          outputTensors[shapes->output_mapping[0]]->get_shape().at(0);
    } else if (dim_num == 2) {
      shapes->outTensorList[i].channel =
          outputTensors[shapes->output_mapping[i]]->get_shape().at(1);
      shapes->outTensorList[i].width = 1;
      shapes->outTensorList[i].height = 1;
      shapes->outTensorList[i].size =
          outputTensors[shapes->output_mapping[i]]->get_element_num() /
          outputTensors[shapes->output_mapping[0]]->get_shape().at(0);
    }
  }
  return 0;
}

static int find_tensor(std::vector<const xir::Tensor*> tensors,
                       const std::string& name) {
  int ret = -1;
  for (auto i = 0u; i < tensors.size(); ++i) {
    if (tensors[i]->get_name().find(name) != std::string::npos) {
      ret = (int)i;
      break;
    }
  }
  assert(ret != -1);
  return ret;
}
int getTensorShape(vart::Runner* runner, GraphInfo* shapes, int cntin,
                   std::vector<std::string> output_names) {
  for (auto i = 0u; i < output_names.size(); ++i) {
    auto idx = find_tensor(runner->get_output_tensors(), output_names[i]);
    shapes->output_mapping.push_back(idx);
  }
  getTensorShape(runner, shapes, cntin, (int)output_names.size());
  return 0;
}