File: sgd.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (125 lines) | stat: -rw-r--r-- 3,500 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <torch/csrc/jit/mobile/train/optim/sgd.h>

#include <torch/types.h>
#include <torch/utils.h>

#include <ATen/ATen.h>

namespace torch::jit::mobile {

bool SGDParamGroup::has_options() const {
  return options_ != nullptr;
}

SGDOptions& SGDParamGroup::options() {
  TORCH_CHECK(has_options());
  return *options_;
}

const SGDOptions& SGDParamGroup::options() const {
  TORCH_CHECK(has_options());
  return *options_;
}

void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
  options_ = std::move(options);
}

std::vector<Tensor>& SGDParamGroup::params() {
  return params_;
}

const std::vector<Tensor>& SGDParamGroup::params() const {
  return params_;
}

SGDOptions::SGDOptions(double lr) : lr_(lr) {}

bool operator==(const SGDOptions& lhs, const SGDOptions& rhs) {
  return (lhs.lr() == rhs.lr()) && (lhs.momentum() == rhs.momentum()) &&
      (lhs.dampening() == rhs.dampening()) &&
      (lhs.weight_decay() == rhs.weight_decay()) &&
      (lhs.nesterov() == rhs.nesterov());
}

bool operator==(const SGDParamState& lhs, const SGDParamState& rhs) {
  return torch::equal(lhs.momentum_buffer(), rhs.momentum_buffer());
}

void SGD::add_param_group(const SGDParamGroup& param_group) {
  for (const auto& param : param_group.params()) {
    TORCH_CHECK(param.is_leaf(), "can't optimize a non-leaf Tensor");
  }
  TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
  SGDParamGroup param_group_(param_group.params());
  if (!param_group.has_options()) {
    param_group_.set_options(defaults_->clone());
  } else {
    param_group_.set_options(param_group.options().clone());
  }
  for (const auto& p : param_group_.params()) {
    TORCH_CHECK(
        state_.count(p.unsafeGetTensorImpl()) == 0,
        "some parameters appear in more than one parameter group");
  }
  param_groups_.emplace_back(std::move(param_group_));
}

void SGD::zero_grad() {
  for (auto& group : param_groups_) {
    for (auto& p : group.params()) {
      if (p.grad().defined()) {
        p.grad().detach_();
        p.grad().zero_();
      }
    }
  }
}

Tensor SGD::step(const LossClosure& closure) {
  NoGradGuard no_grad;
  Tensor loss = {};
  if (closure != nullptr) {
    at::AutoGradMode enable_grad(true);
    loss = closure();
  }
  for (auto& group : param_groups_) {
    auto& options = static_cast<SGDOptions&>(group.options());
    auto weight_decay = options.weight_decay();
    auto momentum = options.momentum();
    auto dampening = options.dampening();
    auto nesterov = options.nesterov();

    for (auto& p : group.params()) {
      if (!p.grad().defined()) {
        continue;
      }
      auto d_p = p.grad().data();
      if (weight_decay != 0) {
        d_p = d_p.add(p.data(), weight_decay);
      }
      if (momentum != 0) {
        Tensor buf;
        auto param_state = state_.find(p.unsafeGetTensorImpl());
        if (param_state == state_.end()) {
          buf = torch::clone(d_p).detach();
          auto state = std::make_unique<SGDParamState>();
          state->momentum_buffer(buf);
          state_[p.unsafeGetTensorImpl()] = std::move(state);
        } else {
          buf = static_cast<SGDParamState&>(*param_state->second)
                    .momentum_buffer();
          buf.mul_(momentum).add_(d_p, 1 - dampening);
        }
        if (nesterov) {
          d_p = d_p.add(buf, momentum);
        } else {
          d_p = buf;
        }
      }
      p.data().add_(d_p, -1 * options.lr());
    }
  }
  return loss;
}
} // namespace torch::jit::mobile