1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
#pragma once
#include <ctc.h>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
#include "caffe2/core/common_cudnn.h"
#define CTC_CHECK(condition) \
do { \
ctcStatus_t status = condition; \
CAFFE_ENFORCE_EQ( \
status, \
CTC_STATUS_SUCCESS, \
" Error at: ", \
__FILE__, \
":", \
__LINE__, \
": ", \
::ctcGetStatusString(status)); \
} while (0)
namespace caffe2 {
namespace detail {
template <typename Context>
ctcComputeInfo workspaceInfo(const Context& context);
}
template <typename T, typename Context>
class CTCOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
CTCOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
is_test_(
OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
CAFFE_ENFORCE(
(is_test_ && OutputSize() == 2) || (!is_test_ && OutputSize() == 3));
}
bool RunOnDevice() override {
// inputs
const auto& inputs = Input(INPUTS);
const auto maxTimeSteps = inputs.size(0);
const auto minibatchSize = inputs.size(1);
const auto alphabetSize = inputs.size(2);
const auto& labels = OperatorBase::template Input<Tensor>(LABELS, CPU);
const auto& labelLengths =
OperatorBase::template Input<Tensor>(LABEL_LENGTHS, CPU);
const int* inputLengthsData = nullptr;
if (InputSize() == 4) {
const auto& inputLengths =
OperatorBase::template Input<Tensor>(INPUT_LENGTHS, CPU);
inputLengthsData = inputLengths.template data<int>();
} else {
// Input lengths not passed in. Default to max timesteps for
// each item in minibatch.
default_input_lengths_.resize(minibatchSize, maxTimeSteps);
inputLengthsData = default_input_lengths_.data();
}
// outputs
Tensor* gradients = nullptr;
TensorCPU* costs;
Tensor* workspace;
if (!is_test_) {
// [grads, costs, workspace] to maintain backward compatibility
gradients = Output(0);
gradients->ResizeLike(inputs);
costs = OperatorBase::template Output<Tensor>(1, CPU);
costs->ResizeLike(labelLengths);
workspace = Output(2);
} else {
// [costs, workspace]
costs = OperatorBase::template Output<Tensor>(0, CPU);
costs->ResizeLike(labelLengths);
workspace = Output(1);
}
size_t workspaceSizeBytes;
CTC_CHECK(get_workspace_size(
labelLengths.template data<int>(),
inputLengthsData,
alphabetSize,
minibatchSize,
detail::workspaceInfo(context_),
&workspaceSizeBytes));
workspace->Resize(workspaceSizeBytes);
auto* workspaceData = workspace->template mutable_data<uint8_t>();
if (is_test_ && labels.size(0) == 0) {
// compute_ctc_loss doesn't handle empty labels well
T* costsData = costs->template mutable_data<T>();
for (int i = 0; i < costs->numel(); ++i) {
costsData[i] = 0;
}
return true;
}
CTC_CHECK(compute_ctc_loss(
inputs.template data<T>(),
gradients ? gradients->template mutable_data<T>() : nullptr,
labels.template data<int>(),
labelLengths.template data<int>(),
inputLengthsData,
alphabetSize,
minibatchSize,
costs->template mutable_data<T>(),
workspaceData,
detail::workspaceInfo(context_)));
return true;
}
private:
bool is_test_;
std::vector<int> default_input_lengths_;
INPUT_TAGS(INPUTS, LABELS, LABEL_LENGTHS, INPUT_LENGTHS);
};
}
#undef CTC_CHECK
|