1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
#ifndef CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
#define CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
// Reuse helper logic from GatherOp since BatchGather is the same with axis=1.
#include "caffe2/operators/gather_op.h"
namespace caffe2 {
template <class Context>
class BatchGatherOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit BatchGatherOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(bool, "match_outer", match_outer_, false) {}
// virtual ~BatchGatherOp() noexcept {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, this->template Input<Tensor>(INDICES, CPU));
}
template <typename TInd>
bool DoRunWithType() {
// BatchGather is a special-case of Gather with Axis = 1.
return gather_helper::gather_impl<TInd, Context>(
this, DATA, INDICES, 0, 1, false, match_outer_);
}
INPUT_TAGS(DATA, INDICES);
protected:
bool match_outer_;
};
template <class Context>
class BatchGatherGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
// Constructor to receive axis in case it was passed for GatherOp gradient,
// use default of 1 for batch gather otherwise.
template <class... Args>
explicit BatchGatherGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "axis", axis_, 1),
OP_SINGLE_ARG(bool, "match_outer", match_outer_, false) {}
virtual ~BatchGatherGradientOp() noexcept {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, this->template Input<Tensor>(INDICES, CPU));
}
template <typename TInd>
bool DoRunWithType() {
return DispatchHelper<
TensorTypes2<float, GenericTensorImplementation>,
TInd>::call(this, Input(DATA));
}
template <typename TInd, typename TData>
bool DoRunWithType2() {
auto& data = Input(DATA);
auto& indices = Input(INDICES);
auto& grad = Input(GRAD);
// ONNX allows negative axis to index from the back, valid range: [-r, r].
int axis = axis_;
bool match_outer = match_outer_;
if (axis < 0) {
axis = data.dim() + axis;
}
CAFFE_ENFORCE_GE(data.dim(), 2, "DATA should be at least 2-D");
// Outer dimensions of input data and gradient should be the same
// because they are preserved for gathers with axis > 0.
for (const auto acheck : c10::irange(axis)) {
CAFFE_ENFORCE_EQ(
data.size(acheck),
grad.size(acheck),
"batch gather outer dimensions should match");
}
auto* output = Output(0, data.sizes(), at::dtype<TData>());
TData* out_data = output->template mutable_data<TData>();
if (data.numel() <= 0) {
return true;
}
memset(out_data, 0, output->nbytes());
const TData* grad_data = grad.template data<TData>();
const TInd* idxs = indices.template data<TInd>();
auto outer_dims_product = data.size_to_dim(axis);
auto batch_size = data.size_from_dim(axis);
auto block_size = data.size_from_dim(axis + 1);
auto N = indices.numel();
auto idx_inner_dims_product = indices.size_from_dim(axis);
if (match_outer) {
CAFFE_ENFORCE_GE(axis, 1, "Axis should be at least 1");
for (const auto i : c10::irange(axis)) {
CAFFE_ENFORCE_EQ(
data.size(i),
indices.size(i),
"INDICES must have the same outer dims as DATA (before dim AXIS)");
}
N = idx_inner_dims_product;
}
auto gathered_grad_batch_size = N * block_size;
// Check indexing bounds.
auto src_indexing_axis_dim = data.dim(axis);
gather_helper::check_indexarray_range<TInd>(
idxs, N, src_indexing_axis_dim, false);
for (const auto batch : c10::irange(outer_dims_product)) {
auto grad_batch_base = grad_data + batch * gathered_grad_batch_size;
auto out_batch_base = out_data + batch * batch_size;
for (const auto i : c10::irange(N)) {
auto idx = idxs[i];
if (match_outer) {
idx = idxs[batch * idx_inner_dims_product + i];
}
if (idx < 0) {
idx = idx + src_indexing_axis_dim;
}
if (block_size == 1) {
out_batch_base[idx] += grad_batch_base[i];
} else {
math::Add(
block_size,
out_batch_base + idx * block_size,
grad_batch_base + i * block_size,
out_batch_base + idx * block_size,
&context_);
}
}
}
return true;
}
template <typename TInd>
bool DoRunWithOtherType2() {
CAFFE_THROW(
"BatchGatherGradient is not implemented on tensor of type ",
Input(DATA).meta().name(),
"consider adding it as a type in the DispatchHelper list or "
"implementing a generic version (which won't work for "
"duplicated indices though)");
}
INPUT_TAGS(DATA, INDICES, GRAD);
protected:
int axis_;
bool match_outer_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_
|