1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
#include <torch/csrc/jit/mobile/train/optim/sgd.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <ATen/ATen.h>
namespace torch::jit::mobile {
bool SGDParamGroup::has_options() const {
return options_ != nullptr;
}
SGDOptions& SGDParamGroup::options() {
TORCH_CHECK(has_options());
return *options_;
}
const SGDOptions& SGDParamGroup::options() const {
TORCH_CHECK(has_options());
return *options_;
}
void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
options_ = std::move(options);
}
std::vector<Tensor>& SGDParamGroup::params() {
return params_;
}
const std::vector<Tensor>& SGDParamGroup::params() const {
return params_;
}
SGDOptions::SGDOptions(double lr) : lr_(lr) {}
bool operator==(const SGDOptions& lhs, const SGDOptions& rhs) {
return (lhs.lr() == rhs.lr()) && (lhs.momentum() == rhs.momentum()) &&
(lhs.dampening() == rhs.dampening()) &&
(lhs.weight_decay() == rhs.weight_decay()) &&
(lhs.nesterov() == rhs.nesterov());
}
bool operator==(const SGDParamState& lhs, const SGDParamState& rhs) {
return torch::equal(lhs.momentum_buffer(), rhs.momentum_buffer());
}
void SGD::add_param_group(const SGDParamGroup& param_group) {
for (const auto& param : param_group.params()) {
TORCH_CHECK(param.is_leaf(), "can't optimize a non-leaf Tensor");
}
TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
SGDParamGroup param_group_(param_group.params());
if (!param_group.has_options()) {
param_group_.set_options(defaults_->clone());
} else {
param_group_.set_options(param_group.options().clone());
}
for (const auto& p : param_group_.params()) {
TORCH_CHECK(
state_.count(p.unsafeGetTensorImpl()) == 0,
"some parameters appear in more than one parameter group");
}
param_groups_.emplace_back(std::move(param_group_));
}
void SGD::zero_grad() {
for (auto& group : param_groups_) {
for (auto& p : group.params()) {
if (p.grad().defined()) {
p.grad().detach_();
p.grad().zero_();
}
}
}
}
Tensor SGD::step(const LossClosure& closure) {
NoGradGuard no_grad;
Tensor loss = {};
if (closure != nullptr) {
at::AutoGradMode enable_grad(true);
loss = closure();
}
for (auto& group : param_groups_) {
auto& options = static_cast<SGDOptions&>(group.options());
auto weight_decay = options.weight_decay();
auto momentum = options.momentum();
auto dampening = options.dampening();
auto nesterov = options.nesterov();
for (auto& p : group.params()) {
if (!p.grad().defined()) {
continue;
}
auto d_p = p.grad().data();
if (weight_decay != 0) {
d_p = d_p.add(p.data(), weight_decay);
}
if (momentum != 0) {
Tensor buf;
auto param_state = state_.find(p.unsafeGetTensorImpl());
if (param_state == state_.end()) {
buf = torch::clone(d_p).detach();
auto state = std::make_unique<SGDParamState>();
state->momentum_buffer(buf);
state_[p.unsafeGetTensorImpl()] = std::move(state);
} else {
buf = static_cast<SGDParamState&>(*param_state->second)
.momentum_buffer();
buf.mul_(momentum).add_(d_p, 1 - dampening);
}
if (nesterov) {
d_p = d_p.add(buf, momentum);
} else {
d_p = buf;
}
}
p.data().add_(d_p, -1 * options.lr());
}
}
return loss;
}
} // namespace torch::jit::mobile
|