1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
#ifndef CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
#define CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
using T = float;
using TInd = int;
using Engine = DefaultEngine;
template <class Context>
class ConcatBatchMatMulBatchGatherOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ConcatBatchMatMulBatchGatherOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
bool RunOnDevice() override;
protected:
int axis_ = 1;
int add_axis_ = 1;
bool trans_a_ = 0;
bool trans_b_ = 1;
bool broadcast_ = 0;
};
template <class Context>
bool ConcatBatchMatMulBatchGatherOp<Context>::RunOnDevice() {
auto& indices = Input(0);
auto& input_zero = Input(1);
int adj_size = input_zero.dim() + 1;
int canonical_axis = 1;
CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
for (int i = 2; i < InputSize(); ++i) {
CAFFE_ENFORCE(
Input(i).dtype() == input_zero.dtype(),
"All inputs must have the same type, expected: ",
input_zero.dtype().name(),
" but got: ",
Input(i).dtype().name(),
" for input: ",
i);
}
int before = 1, after = 1;
for (int i = 0; i < input_zero.dim(); ++i) {
int dim = input_zero.dim32(i);
if (i < canonical_axis) {
before *= dim;
} else { // i > canonical_axis || i == canonical_axis && add_axis_
after *= dim;
}
// check the input dims are compatible.
for (int j = 2; j < InputSize(); ++j) {
int dim_j = Input(j).dim32(i);
CAFFE_ENFORCE(
dim == dim_j,
"Expect dimension = ",
dim,
" got ",
dim_j,
" at axis = ",
i,
" for input: ",
j,
". The input tensors can only have different dimensions "
"when arg 'add_axis' = 0 and along the axis = ",
canonical_axis,
" <",
input_zero.sizes(),
"> vs <",
Input(j).sizes(),
">.");
}
}
auto ndata = InputSize() - 1;
auto batch_size = before;
auto embed_size = after;
auto gather_size = indices.sizes()[0];
vector<int64_t> output_dims;
output_dims.push_back(batch_size);
output_dims.insert(
output_dims.begin() + 1, indices.sizes().begin(), indices.sizes().end());
auto* output = Output(0, output_dims, at::dtype<T>());
// std::stringstream ss;
// ss << "[";
// for(int i = 0; i < output_dims.size(); i++) ss << output_dims[i];
// ss << "]";
// LOG(INFO) << "output size: " << ss.str();
auto* output_data = output->template mutable_data<T>();
auto* indices_data = indices.template data<TInd>();
#pragma omp parallel
{
std::vector<T> scratch_input(ndata * embed_size);
std::vector<T> scratch_output(ndata * ndata);
#pragma omp for
for (int b = 0; b < batch_size; ++b) {
// concat input to scratch
for (int i = 1; i < InputSize(); ++i) {
auto* input_data = Input(i).template data<T>();
memcpy(
&scratch_input[(i - 1) * embed_size],
input_data + b * embed_size,
embed_size * Input(i).itemsize());
}
// call mkl gemm
math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasTrans,
ndata,
ndata,
embed_size,
1,
&scratch_input[0],
&scratch_input[0],
0,
&scratch_output[0],
&context_);
// do gather
int64_t output_offset = b * gather_size;
for (int i = 0; i < gather_size; i++) {
output_data[output_offset + i] = scratch_output[indices_data[i]];
}
}
}
return true;
}
} // namespace caffe2
#endif // CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
|