1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
|
#ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
#define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
#include <set>
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include <c10/util/irange.h>
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(MergeIdLists);
namespace caffe2 {
template <class Context>
class MergeIdListsOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(MergeIdListsOp);
template <typename T>
bool DoRunWithType() {
auto& first_lengths = Input(0);
CAFFE_ENFORCE_EQ(first_lengths.dim(), 1, "LENGTHS should be 1-D");
const auto batch_size = first_lengths.numel();
auto* out_lengths = Output(0, first_lengths.sizes(), at::dtype<int32_t>());
auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
/**
* Loop to figure out how much space to reserve for output
* and perform checks.
*/
auto M = 0;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
for (size_t i = 0; i < InputSize(); i += 2) {
auto& lengths = Input(i);
CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTHS should be 1-D");
CAFFE_ENFORCE_EQ(lengths.numel(), batch_size, "LENGTHS should be equal");
auto& values = Input(i + 1);
CAFFE_ENFORCE_EQ(values.dim(), 1, "VALUES should be 1-D");
M += values.numel();
}
auto* out_values = Output(1, {M}, at::dtype<T>());
T* out_values_data = out_values->template mutable_data<T>();
auto pos = 0;
// TODO(badri): Use unordered_set if performance is an issue
std::set<T> deduped;
std::vector<int> offsets(InputSize(), 0);
for (const auto sample : c10::irange(batch_size)) {
for (size_t i = 0; i < InputSize(); i += 2) {
auto& lengths = Input(i);
const auto* lengths_data = lengths.template data<int32_t>();
auto& values = Input(i + 1);
const T* values_data = values.template data<T>();
const auto length = lengths_data[sample];
for (auto j = offsets[i]; j < offsets[i] + length; j++) {
deduped.insert(values_data[j]);
}
offsets[i] += length;
}
for (auto val : deduped) {
out_values_data[pos++] = val;
}
out_lengths_data[sample] = deduped.size();
deduped.clear();
}
out_values->Resize(pos);
return true;
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(1));
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
|