1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
#ifndef CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
#define CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/utils/math.h"
#include "c10/util/irange.h"
namespace caffe2 {
using T = float;
using TInd = int;
using Engine = DefaultEngine;
template <class Context>
class ConcatBatchMatMulBatchGatherOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ConcatBatchMatMulBatchGatherOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
bool RunOnDevice() override;
protected:
int axis_ = 1;
int add_axis_ = 1;
bool trans_a_ = 0;
bool trans_b_ = 1;
bool broadcast_ = 0;
};
template <class Context>
bool ConcatBatchMatMulBatchGatherOp<Context>::RunOnDevice() {
auto& indices = Input(0);
auto& input_zero = Input(1);
int adj_size = input_zero.dim() + 1;
int canonical_axis = 1;
CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
for (const auto i : c10::irange(2, InputSize())) {
CAFFE_ENFORCE(
Input(i).dtype() == input_zero.dtype(),
"All inputs must have the same type, expected: ",
input_zero.dtype().name(),
" but got: ",
Input(i).dtype().name(),
" for input: ",
i);
}
int before = 1, after = 1;
for (const auto i : c10::irange(input_zero.dim())) {
int dim = input_zero.dim32(i);
if (i < canonical_axis) {
before *= dim;
} else { // i > canonical_axis || i == canonical_axis && add_axis_
after *= dim;
}
// check the input dims are compatible.
for (const auto j : c10::irange(2, InputSize())) {
int dim_j = Input(j).dim32(i);
CAFFE_ENFORCE(
dim == dim_j,
"Expect dimension = ",
dim,
" got ",
dim_j,
" at axis = ",
i,
" for input: ",
j,
". The input tensors can only have different dimensions "
"when arg 'add_axis' = 0 and along the axis = ",
canonical_axis,
" <",
input_zero.sizes(),
"> vs <",
Input(j).sizes(),
">.");
}
}
auto ndata = InputSize() - 1;
auto batch_size = before;
auto embed_size = after;
auto gather_size = indices.sizes()[0];
vector<int64_t> output_dims;
output_dims.push_back(batch_size);
output_dims.insert(
output_dims.begin() + 1, indices.sizes().begin(), indices.sizes().end());
auto* output = Output(0, output_dims, at::dtype<T>());
// std::stringstream ss;
// ss << "[";
// for (const auto i : c10::irange(output_dims.size()))ss << output_dims[i];
// ss << "]";
// LOG(INFO) << "output size: " << ss.str();
auto* output_data = output->template mutable_data<T>();
auto* indices_data = indices.template data<TInd>();
#pragma omp parallel
{
std::vector<T> scratch_input(ndata * embed_size);
std::vector<T> scratch_output(ndata * ndata);
#pragma omp for
for (int b = 0; b < batch_size; ++b) {
// concat input to scratch
for (const auto i : c10::irange(1, InputSize())) {
auto* input_data = Input(i).template data<T>();
memcpy(
&scratch_input[(i - 1) * embed_size],
input_data + b * embed_size,
embed_size * Input(i).itemsize());
}
// call mkl gemm
math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasTrans,
ndata,
ndata,
embed_size,
1,
&scratch_input[0],
&scratch_input[0],
0,
&scratch_output[0],
&context_);
// do gather
int64_t output_offset = b * gather_size;
for (const auto i : c10::irange(gather_size)) {
output_data[output_offset + i] = scratch_output[indices_data[i]];
}
}
}
return true;
}
} // namespace caffe2
#endif // CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
|