1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
#include <caffe2/ideep/ideep_utils.h>
using namespace caffe2;
namespace {
void momentum_sgd_update(
const int N,
const float* g,
const float* m,
float* ng,
float* nm,
const float* lr,
const float momentum,
const bool nesterov,
float* param) {
const float LR = lr[0];
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (auto i = 0; i < N; ++i) {
if (!nesterov) {
const float adjusted_gradient = LR * g[i] + momentum * m[i];
nm[i] = adjusted_gradient;
ng[i] = adjusted_gradient;
} else {
const float mi = m[i];
const float mi_new = momentum * mi + LR * g[i];
nm[i] = mi_new;
ng[i] = (1 + momentum) * mi_new - momentum * mi;
}
if (param) {
param[i] -= ng[i];
}
}
}
class IDEEPMomentumSGDOp final : public IDEEPOperator {
public:
USE_IDEEP_DEF_ALIASES();
USE_IDEEP_OPERATOR_FUNCTIONS();
IDEEPMomentumSGDOp(const OperatorDef& operator_def, Workspace* ws)
: IDEEPOperator(operator_def, ws),
momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.0)),
nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
bool RunOnDevice() override {
CAFFE_ENFORCE(Input(GRAD).get_nelems() == Input(MOMENTUM).get_nelems());
if (Input(GRAD) != *Output(OUTPUT_GRAD)) {
Output(OUTPUT_GRAD)->init(Input(GRAD).get_descriptor());
}
if (Input(MOMENTUM) != *Output(OUTPUT_MOMENTUM)) {
Output(OUTPUT_MOMENTUM)->init(Input(MOMENTUM).get_descriptor());
}
// TODO: Use itensor after 0-dim is supported. Now use CPU tensor.
const auto& lr = OperatorBase::Input<TensorCPU>(LR, CPU);
CAFFE_ENFORCE(lr.numel() == 1);
momentum_sgd_update(
Input(GRAD).get_nelems(),
static_cast<float*>(Input(GRAD).get_data_handle()),
static_cast<float*>(Input(MOMENTUM).get_data_handle()),
static_cast<float*>(Output(OUTPUT_GRAD)->get_data_handle()),
static_cast<float*>(Output(OUTPUT_MOMENTUM)->get_data_handle()),
lr.template data<float>(),
momentum_,
nesterov_,
nullptr);
return true;
}
protected:
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-non-private-member-variables-in-classes)
float momentum_ = 0.9f;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool nesterov_;
INPUT_TAGS(GRAD, MOMENTUM, LR);
OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM);
};
class IDEEPMomentumSGDUpdateOp final : public IDEEPOperator {
public:
USE_IDEEP_DEF_ALIASES();
USE_IDEEP_OPERATOR_FUNCTIONS();
IDEEPMomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
: IDEEPOperator(operator_def, ws),
momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.0)),
nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
bool RunOnDevice() override {
CAFFE_ENFORCE(Input(GRAD).get_nelems() == Input(MOMENTUM).get_nelems());
if (Input(GRAD) != *Output(OUTPUT_GRAD)) {
Output(OUTPUT_GRAD)->init(Input(GRAD).get_descriptor());
}
if (Input(MOMENTUM) != *Output(OUTPUT_MOMENTUM)) {
Output(OUTPUT_MOMENTUM)->init(Input(MOMENTUM).get_descriptor());
}
// TODO: Use itensor after 0-dim is supported. Now use CPU tensor.
const auto& lr = OperatorBase::Input<TensorCPU>(LR, CPU);
CAFFE_ENFORCE(lr.numel() == 1);
momentum_sgd_update(
Input(GRAD).get_nelems(),
static_cast<float*>(Input(GRAD).get_data_handle()),
static_cast<float*>(Input(MOMENTUM).get_data_handle()),
static_cast<float*>(Output(OUTPUT_GRAD)->get_data_handle()),
static_cast<float*>(Output(OUTPUT_MOMENTUM)->get_data_handle()),
lr.template data<float>(),
momentum_,
nesterov_,
static_cast<float*>(Output(OUTPUT_PARAM)->get_data_handle()));
return true;
}
protected:
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-non-private-member-variables-in-classes)
float momentum_ = 0.9f;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool nesterov_;
INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
};
REGISTER_IDEEP_OPERATOR(MomentumSGD, IDEEPMomentumSGDOp);
REGISTER_IDEEP_OPERATOR(MomentumSGDUpdate, IDEEPMomentumSGDUpdateOp);
} // namespace
|