1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
|
#pragma once
#include "caffe2/core/common_omp.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <typename Context>
void rmsprop_update(
int N,
const float* g,
const float* ms,
const float* mom,
float* ng,
float* nms,
float* nmom,
float decay,
float momentum,
float epsilon,
const float* lr,
Context* context);
template <typename T, class Context>
class RmsPropOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
RmsPropOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
decay_(this->template GetSingleArgument<float>("decay", 0.9f)),
momentum_(this->template GetSingleArgument<float>("momentum", 0.0f)),
epsilon_(this->template GetSingleArgument<float>("epsilon", 1e-5f)) {}
bool RunOnDevice() override {
CAFFE_ENFORCE(Input(LR).numel() == 1);
CAFFE_ENFORCE(Input(GRAD).numel() == Input(MEAN_SQUARES).numel());
CAFFE_ENFORCE(Input(GRAD).numel() == Input(OUTPUT_MOMENTUM).numel());
Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
Output(OUTPUT_MEAN_SQUARES)->ResizeLike(Input(MEAN_SQUARES));
Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
rmsprop_update<Context>(
Input(GRAD).numel(),
Input(GRAD).template data<T>(),
Input(MEAN_SQUARES).template data<T>(),
Input(MOMENTUM).template data<T>(),
Output(OUTPUT_GRAD)->template mutable_data<T>(),
Output(OUTPUT_MEAN_SQUARES)->template mutable_data<T>(),
Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
decay_,
momentum_,
epsilon_,
Input(LR).template data<T>(),
&context_);
return true;
}
protected:
T decay_{0.9};
T momentum_{0.0};
T epsilon_{1e-8};
INPUT_TAGS(GRAD, MEAN_SQUARES, MOMENTUM, LR);
OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MEAN_SQUARES, OUTPUT_MOMENTUM);
};
}
|