File: ForgettingMassCalculator.hpp

package info (click to toggle)
salmon 1.10.3%2Bds1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 35,088 kB
  • sloc: cpp: 200,707; ansic: 171,082; sh: 859; python: 792; makefile: 238
file content (145 lines) | stat: -rw-r--r-- 5,267 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#ifndef __FORGETTING_MASS_CALCULATOR__
#define __FORGETTING_MASS_CALCULATOR__

#include "SalmonMath.hpp"
#include "SalmonSpinLock.hpp"
#include "spdlog/spdlog.h"

class ForgettingMassCalculator {
public:
  ForgettingMassCalculator(double forgettingFactor = 0.65)
      : batchNum_(0), forgettingFactor_(forgettingFactor),
        logForgettingMass_(salmon::math::LOG_1), logForgettingMasses_({}),
        cumulativeLogForgettingMasses_({}) {}

  ForgettingMassCalculator(const ForgettingMassCalculator&) = delete;
  ForgettingMassCalculator(ForgettingMassCalculator&&) = default;
  ForgettingMassCalculator& operator=(ForgettingMassCalculator&&) = default;
  ForgettingMassCalculator& operator=(const ForgettingMassCalculator&) = delete;

  /** Precompute the log(forgetting mass) and cumulative log(forgetting mass)
   * for the first numMiniBatches batches / timesteps.
   */
  bool prefill(uint64_t numMiniBatches) {
    logForgettingMasses_.reserve(numMiniBatches);
    cumulativeLogForgettingMasses_.reserve(numMiniBatches);

    double fm = salmon::math::LOG_1;
    logForgettingMasses_.push_back(fm);
    cumulativeLogForgettingMasses_.push_back(fm);

    for (size_t i = 2; i < numMiniBatches - 1; ++i) {
      fm += forgettingFactor_ * std::log(static_cast<double>(i - 1)) -
            std::log(std::pow(static_cast<double>(i), forgettingFactor_) - 1);
      logForgettingMasses_.push_back(fm);
      // fill in cumulative mass
      cumulativeLogForgettingMasses_.push_back(
          salmon::math::logAdd(cumulativeLogForgettingMasses_.back(), fm));
    }
    return true;
  }

  double operator()() {
#if defined __APPLE__
    spin_lock::scoped_lock sl(ffMutex_);
#else
    std::lock_guard<std::mutex> lock(ffMutex_);
#endif
    ++batchNum_;
    if (batchNum_ > 1) {
      logForgettingMass_ +=
          forgettingFactor_ * std::log(static_cast<double>(batchNum_ - 1)) -
          std::log(std::pow(static_cast<double>(batchNum_), forgettingFactor_) -
                   1);
    }
    return logForgettingMass_;
  }

  /**
   *  Return the log(forgetting mass) and current timestep in the output
   *  variables logForgettingMass and currentMinibatchTimestep. If we
   *  haven't pre-computed the forgetting mass for the next timestep yet,
   *  then do it now.
   */
  void getLogMassAndTimestep(double& logForgettingMass,
                             uint64_t& currentMinibatchTimestep) {
#if defined __APPLE__
    spin_lock::scoped_lock sl(ffMutex_);
#else
    std::lock_guard<std::mutex> lock(ffMutex_);
#endif
    if (batchNum_ < logForgettingMasses_.size()) {
      currentMinibatchTimestep = batchNum_;
      logForgettingMass = logForgettingMasses_[batchNum_];
      ++batchNum_;
    } else {
      double fm =
          logForgettingMasses_.back() +
          forgettingFactor_ * std::log(static_cast<double>(batchNum_ - 1)) -
          std::log(std::pow(static_cast<double>(batchNum_), forgettingFactor_) -
                   1);

      logForgettingMasses_.push_back(fm);
      cumulativeLogForgettingMasses_.push_back(
          salmon::math::logAdd(cumulativeLogForgettingMasses_.back(), fm));

      currentMinibatchTimestep = batchNum_;
      logForgettingMass = logForgettingMasses_[batchNum_];
      ++batchNum_;
    }
  }

  // Retrieve the log(forgetting mass) at a particular timestep.  This
  // function assumes that the forgetting mass has already been computed
  // for this timestep --- otherwise, this will result in a fatal error.
  double logMassAt(uint64_t timestep) {
    if (timestep < logForgettingMasses_.size()) {
      return logForgettingMasses_[timestep];
    } else {
      spdlog::get("jointLog")
          ->error("Requested forgetting mass for timestep {} "
                  "where it has not yet been computed.  This "
                  "likely means that the ForgettingMassCalculator "
                  "class was being used incorrectly!  Please "
                  "report this crash on GitHub!\n",
                  timestep);
      std::exit(1);
      return salmon::math::LOG_0;
    }
  }

  // Retrieve the cumulative log(forgetting mass) at a particular timestep.
  // This function assumes that the forgetting mass has already been computed
  // for this timestep --- otherwise, this will result in a fatal error.
  double cumulativeLogMassAt(uint64_t timestep) {
    if (timestep < cumulativeLogForgettingMasses_.size()) {
      return cumulativeLogForgettingMasses_[timestep];
    } else {
      spdlog::get("jointLog")
          ->error("Requested cumulative forgetting mass for timestep {} "
                  "where it has not yet been computed.  This "
                  "likely means that the ForgettingMassCalculator "
                  "class was being used incorrectly!  Please "
                  "report this crash on GitHub!\n",
                  timestep);
      std::exit(1);
      return salmon::math::LOG_0;
    }
  }

  uint64_t getCurrentTimestep() { return batchNum_; }

private:
  uint64_t batchNum_;
  double forgettingFactor_;
  double logForgettingMass_;
  std::vector<double> logForgettingMasses_;
  std::vector<double> cumulativeLogForgettingMasses_;
#if defined __APPLE__
  spin_lock ffMutex_;
#else
  std::mutex ffMutex_;
#endif
};

#endif //__FORGETTING_MASS_CALCULATOR__