1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
|
#pragma once
#include "caffe2/core/operator.h"
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
template <typename Context>
void decay_adagrad_compute(
int N,
const float* w,
const float* g,
const float* m,
const float* v,
float* nw,
float* nm,
float* nv,
float beta1,
float beta2,
float eps_hat,
float weight_decay,
float c,
const float* lr,
Context* /*context*/) {
ConstEigenVectorArrayMap<float> w_arr(w, N);
ConstEigenVectorArrayMap<float> g_arr(g, N);
ConstEigenVectorArrayMap<float> m_arr(m, N);
ConstEigenVectorArrayMap<float> v_arr(v, N);
EigenVectorArrayMap<float> nw_arr(nw, N);
EigenVectorArrayMap<float> nm_arr(nm, N);
EigenVectorArrayMap<float> nv_arr(nv, N);
nm_arr = m_arr * beta1 + g_arr * (1.0f - beta1);
nv_arr = v_arr + g_arr.square();
nw_arr = w_arr + *lr * (nm_arr / c / (nv_arr.sqrt() + eps_hat) + weight_decay * w_arr);
}
template <typename T, class Context>
class DecayAdagradOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
DecayAdagradOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
beta1_(this->template GetSingleArgument<float>("beta1", 0.9f)),
beta2_(this->template GetSingleArgument<float>("beta2", 0.999f)),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)),
weight_decay_(this->template GetSingleArgument<float>("weight_decay", 0.0f)),
bias_correction_first_(this->template GetSingleArgument<bool>("bias_correction_first", true)) {}
bool RunOnDevice() override {
// Iter live on the CPU
CAFFE_ENFORCE(OperatorBase::InputIsTensorType(ITER, CPU));
CAFFE_ENFORCE(Input(LR).numel() == 1);
CAFFE_ENFORCE(Input(GRAD).numel() == Input(PARAM).numel());
CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_1).numel());
CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_2).numel());
Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
Output(OUTPUT_MOMENT_1)->ResizeLike(Input(MOMENT_1));
Output(OUTPUT_MOMENT_2)->ResizeLike(Input(MOMENT_2));
const auto iter =
OperatorBase::Input<Tensor>(ITER, CPU).template data<int64_t>()[0];
const auto t = iter + 1;
const auto c = (bias_correction_first_)? (T(1.) - std::pow(beta1_, t)) : 1.0;
decay_adagrad_compute<Context>(
Input(GRAD).numel(),
Input(PARAM).template data<T>(),
Input(GRAD).template data<T>(),
Input(MOMENT_1).template data<T>(),
Input(MOMENT_2).template data<T>(),
Output(OUTPUT_PARAM)->template mutable_data<T>(),
Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
beta1_,
beta2_,
epsilon_,
weight_decay_,
c,
Input(LR).template data<T>(),
&context_);
return true;
}
protected:
T beta1_{0.9};
T beta2_{0.999};
T epsilon_{1e-8};
T weight_decay_{0.0};
bool bias_correction_first_{true};
INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER);
OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
};
} // namespace caffe2
|