1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
|
#include "wngrad_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(Wngrad, WngradOp<float, CPUContext>);
OPERATOR_SCHEMA(Wngrad)
.NumInputs(4)
.NumOutputs(2, 4)
.AllowInplace({{0, 0}, {1, 1}})
.SetDoc(R"DOC(
Computes the WnGrad update for an input gradient and accumulated
history. This operator implement the optimization algorithm
in https://arxiv.org/abs/1803.02865 by Wu, Ward and Bottou.
Concretely, given inputs (param, grad, seq_b, learning_rate),
computes
new_seq_b = seq_b + 1 / seq_b * norm(grad)^2
effective_lr = learning_rate / (new_seq_b + epsilon)
update = learning_rate * grad / (new_seq_b + epsilon)
new_param = param + update
and returns (new_param, new_seq_b).
Optionally returns effective_lr and update as well.
)DOC")
.Input(0, "param", "Parameters to be updated")
.Input(1, "seq_b", "Seq_b history")
.Input(2, "grad", "Gradient computed")
.Input(3, "lr", "learning rate")
.Output(0, "output_param", "Updated parameters")
.Output(1, "output_seq_b", "Updated seq_b")
.Output(2, "output_effective_lr", "(optional) Effective learning rate")
.Output(3, "output_update", "(optional) Actual update that is applied.")
.Arg("epsilon", "Default 1e-5");
REGISTER_CPU_OPERATOR(SparseWngrad, SparseWngradOp<float, CPUContext>);
OPERATOR_SCHEMA(SparseWngrad)
.NumInputs(5)
.NumOutputs(2)
.EnforceOneToOneInplace()
.SetDoc(R"DOC(
This operator implement the optimization algorithm
in https://arxiv.org/abs/1803.02865 by Wu, Ward and Bottou.
Given inputs (param, seq_b, indices, grad, lr), runs the dense WnGrad
update on (param, grad, seq_b, lr), and returns (new_param,
new_seq_b) as in the dense case.
)DOC")
.Input(0, "param", "Parameters to be updated")
.Input(1, "seq_b", "seq_b history")
.Input(2, "indices", "Sparse indices")
.Input(3, "grad", "Gradient computed")
.Input(4, "lr", "learning rate")
.Output(0, "output_param", "Updated parameters")
.Output(1, "output_seq_b", "Updated seq_b")
.Arg("epsilon", "Default 1e-5");
SHOULD_NOT_DO_GRADIENT(Wngrad);
SHOULD_NOT_DO_GRADIENT(SparseWngrad);
} // namespace caffe2
|