File: elementwise_ops_utils.cc

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (124 lines) | stat: -rw-r--r-- 3,388 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include "caffe2/operators/elementwise_ops_utils.h"

namespace caffe2 {
namespace elementwise_ops_utils {

std::tuple<size_t, size_t, size_t>
ComputeLegacyBroadcastSizes(const Tensor& A, const Tensor& B, int axis) {
  CAFFE_ENFORCE_GE(
      A.dim(),
      B.dim(),
      "If you are doing broadcasting, input1 should have "
      "a smaller or equal number of dimensions.");
  if (axis == -1) {
    axis = A.dim() - B.dim();
  }
  CAFFE_ENFORCE(
      axis >= 0 && axis <= A.dim() - B.dim(),
      "Broadcast axis should be in the range of"
      "[0, A.ndim() - B.ndim()], but axis = ",
      axis);

  int b_dim_start = 0;
  while (b_dim_start < B.dim() && B.size(b_dim_start) == 1) {
    ++b_dim_start;
  }
  int b_dim_end = B.dim() - 1;
  while (b_dim_end >= b_dim_start && B.size(b_dim_end) == 1) {
    --b_dim_end;
  }
  size_t pre = 1, n = 1, post = 1;
  for (int i = 0; i < axis + b_dim_start; ++i) {
    pre *= A.size(i);
  }
  for (int i = b_dim_start; i <= b_dim_end; ++i) {
    CAFFE_ENFORCE_EQ(
        A.size(i + axis), B.size(i), "Broadcast dimension mismatch.");
    n *= B.size(i);
  }
  for (int i = axis + b_dim_end + 1; i < A.dim(); ++i) {
    post *= A.size(i);
  }
  return std::make_tuple(pre, n, post);
}

std::vector<int> ComputeBinaryBroadcastForwardDims(
    const std::vector<int>& A_dims,
    const std::vector<int>& B_dims) {
  const int ndim = std::max(A_dims.size(), B_dims.size());
  std::vector<int> C_dims(ndim);
  int i = A_dims.size() - 1;
  int j = B_dims.size() - 1;
  int k = ndim - 1;
  for (; i >= 0 && j >= 0; --k) {
    const int A_dim = A_dims[i];
    const int B_dim = B_dims[j];
    CAFFE_ENFORCE(A_dim == B_dim || A_dim == 1 || B_dim == 1);
    if (A_dim == 0 || B_dim == 0) {
      C_dims[k] = 0;
    } else {
      C_dims[k] = std::max(A_dims[i], B_dims[j]);
    }
    --i;
    --j;
  }
  for (; i >= 0; --i) {
    C_dims[k--] = A_dims[i];
  }
  for (; j >= 0; --j) {
    C_dims[k--] = B_dims[j];
  }
  return C_dims;
}

void ComputeBinaryBroadcastBackwardAxes(
    const std::vector<int>& A_dims,
    const std::vector<int>& B_dims,
    std::vector<int>* A_axes,
    std::vector<int>* B_axes) {
  A_axes->clear();
  B_axes->clear();
  const int ndim = std::max(A_dims.size(), B_dims.size());
  int i = A_dims.size() - 1;
  int j = B_dims.size() - 1;
  int k = ndim - 1;
  for (; i >= 0 && j >= 0; --k) {
    CAFFE_ENFORCE(A_dims[i] == B_dims[j] || A_dims[i] == 1 || B_dims[j] == 1);
    if (A_dims[i] != B_dims[j]) {
      if (A_dims[i] == 1) {
        A_axes->push_back(k);
      }
      if (B_dims[j] == 1) {
        B_axes->push_back(k);
      }
    }
    --i;
    --j;
  }
  if (i < 0) {
    for (; k >= 0; --k) {
      A_axes->push_back(k);
    }
  } else {
    for (; k >= 0; --k) {
      B_axes->push_back(k);
    }
  }
  std::reverse(A_axes->begin(), A_axes->end());
  std::reverse(B_axes->begin(), B_axes->end());
}

void ComputeBinaryBroadcastBackwardDims(
    const std::vector<int>& A_dims,
    const std::vector<int>& B_dims,
    std::vector<int>* A_back_dims,
    std::vector<int>* B_back_dims) {
  const int ndim = std::max(A_dims.size(), B_dims.size());
  A_back_dims->assign(ndim, 1);
  B_back_dims->assign(ndim, 1);
  std::copy(A_dims.crbegin(), A_dims.crend(), A_back_dims->rbegin());
  std::copy(B_dims.crbegin(), B_dims.crend(), B_back_dims->rbegin());
}

} // namespace elementwise_ops_utils
} // namespace caffe2