1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
#ifndef __FORGETTING_MASS_CALCULATOR__
#define __FORGETTING_MASS_CALCULATOR__
#include "SalmonMath.hpp"
#include "SalmonSpinLock.hpp"
#include "spdlog/spdlog.h"
class ForgettingMassCalculator {
public:
ForgettingMassCalculator(double forgettingFactor = 0.65)
: batchNum_(0), forgettingFactor_(forgettingFactor),
logForgettingMass_(salmon::math::LOG_1), logForgettingMasses_({}),
cumulativeLogForgettingMasses_({}) {}
ForgettingMassCalculator(const ForgettingMassCalculator&) = delete;
ForgettingMassCalculator(ForgettingMassCalculator&&) = default;
ForgettingMassCalculator& operator=(ForgettingMassCalculator&&) = default;
ForgettingMassCalculator& operator=(const ForgettingMassCalculator&) = delete;
/** Precompute the log(forgetting mass) and cumulative log(forgetting mass)
* for the first numMiniBatches batches / timesteps.
*/
bool prefill(uint64_t numMiniBatches) {
logForgettingMasses_.reserve(numMiniBatches);
cumulativeLogForgettingMasses_.reserve(numMiniBatches);
double fm = salmon::math::LOG_1;
logForgettingMasses_.push_back(fm);
cumulativeLogForgettingMasses_.push_back(fm);
for (size_t i = 2; i < numMiniBatches - 1; ++i) {
fm += forgettingFactor_ * std::log(static_cast<double>(i - 1)) -
std::log(std::pow(static_cast<double>(i), forgettingFactor_) - 1);
logForgettingMasses_.push_back(fm);
// fill in cumulative mass
cumulativeLogForgettingMasses_.push_back(
salmon::math::logAdd(cumulativeLogForgettingMasses_.back(), fm));
}
return true;
}
double operator()() {
#if defined __APPLE__
spin_lock::scoped_lock sl(ffMutex_);
#else
std::lock_guard<std::mutex> lock(ffMutex_);
#endif
++batchNum_;
if (batchNum_ > 1) {
logForgettingMass_ +=
forgettingFactor_ * std::log(static_cast<double>(batchNum_ - 1)) -
std::log(std::pow(static_cast<double>(batchNum_), forgettingFactor_) -
1);
}
return logForgettingMass_;
}
/**
* Return the log(forgetting mass) and current timestep in the output
* variables logForgettingMass and currentMinibatchTimestep. If we
* haven't pre-computed the forgetting mass for the next timestep yet,
* then do it now.
*/
void getLogMassAndTimestep(double& logForgettingMass,
uint64_t& currentMinibatchTimestep) {
#if defined __APPLE__
spin_lock::scoped_lock sl(ffMutex_);
#else
std::lock_guard<std::mutex> lock(ffMutex_);
#endif
if (batchNum_ < logForgettingMasses_.size()) {
currentMinibatchTimestep = batchNum_;
logForgettingMass = logForgettingMasses_[batchNum_];
++batchNum_;
} else {
double fm =
logForgettingMasses_.back() +
forgettingFactor_ * std::log(static_cast<double>(batchNum_ - 1)) -
std::log(std::pow(static_cast<double>(batchNum_), forgettingFactor_) -
1);
logForgettingMasses_.push_back(fm);
cumulativeLogForgettingMasses_.push_back(
salmon::math::logAdd(cumulativeLogForgettingMasses_.back(), fm));
currentMinibatchTimestep = batchNum_;
logForgettingMass = logForgettingMasses_[batchNum_];
++batchNum_;
}
}
// Retrieve the log(forgetting mass) at a particular timestep. This
// function assumes that the forgetting mass has already been computed
// for this timestep --- otherwise, this will result in a fatal error.
double logMassAt(uint64_t timestep) {
if (timestep < logForgettingMasses_.size()) {
return logForgettingMasses_[timestep];
} else {
spdlog::get("jointLog")
->error("Requested forgetting mass for timestep {} "
"where it has not yet been computed. This "
"likely means that the ForgettingMassCalculator "
"class was being used incorrectly! Please "
"report this crash on GitHub!\n",
timestep);
std::exit(1);
return salmon::math::LOG_0;
}
}
// Retrieve the cumulative log(forgetting mass) at a particular timestep.
// This function assumes that the forgetting mass has already been computed
// for this timestep --- otherwise, this will result in a fatal error.
double cumulativeLogMassAt(uint64_t timestep) {
if (timestep < cumulativeLogForgettingMasses_.size()) {
return cumulativeLogForgettingMasses_[timestep];
} else {
spdlog::get("jointLog")
->error("Requested cumulative forgetting mass for timestep {} "
"where it has not yet been computed. This "
"likely means that the ForgettingMassCalculator "
"class was being used incorrectly! Please "
"report this crash on GitHub!\n",
timestep);
std::exit(1);
return salmon::math::LOG_0;
}
}
uint64_t getCurrentTimestep() { return batchNum_; }
private:
uint64_t batchNum_;
double forgettingFactor_;
double logForgettingMass_;
std::vector<double> logForgettingMasses_;
std::vector<double> cumulativeLogForgettingMasses_;
#if defined __APPLE__
spin_lock ffMutex_;
#else
std::mutex ffMutex_;
#endif
};
#endif //__FORGETTING_MASS_CALCULATOR__
|