1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
|
#pragma once
#include <vector>
#include <fbgemm/FbgemmConvert.h>
#include "caffe2/operators/elementwise_ops.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
C10_DECLARE_bool(caffe2_fbgemm_fake_fp16_clamp);
namespace caffe2 {
using namespace std;
template <class Context>
struct ReluFakeFp16Functor {
template <typename T>
bool operator()(const int N, const T* X, T* Y, Context* /* unused */) const {
std::vector<float> X_fp16(N);
fbgemm::RoundToFloat16(
X, X_fp16.data(), N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
EigenVectorMap<T>(Y, N) =
ConstEigenVectorMap<float>(X_fp16.data(), N).cwiseMax(T(0));
return true;
}
};
template <class Context>
struct SqrFakeFp16Functor {
template <typename T>
bool operator()(const int N, const T* X, T* Y, Context* context) const {
std::vector<float> X_fp16(N);
fbgemm::RoundToFloat16(
X, X_fp16.data(), N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
math::Sqr(N, X_fp16.data(), Y, context);
fbgemm::RoundToFloat16(Y, Y, N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
return true;
}
};
struct SigmoidFakeIdealFp16Functor {
template <typename T>
bool operator()(const int N, const T* X, T* Y, CPUContext* /* unused */)
const {
std::vector<float> X_fp16(N);
fbgemm::RoundToFloat16(X, X_fp16.data(), N);
EigenVectorArrayMap<T>(Y, N) =
T(1) / (T(1) + (-ConstEigenVectorArrayMap<T>(X_fp16.data(), N)).exp());
fbgemm::RoundToFloat16(Y, Y, N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
return true;
}
};
struct TanhFakeIdealFp16Functor {
template <typename T>
bool operator()(const int N, const T* X, T* Y, CPUContext* context) const {
std::vector<float> X_fp16(N);
fbgemm::RoundToFloat16(
X, X_fp16.data(), N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
math::Tanh<T, CPUContext>(N, X_fp16.data(), Y, context);
fbgemm::RoundToFloat16(Y, Y, N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
return true;
}
};
} // namespace caffe2
namespace fake_fp16 {
at::Half CalcSigmoidByLUT(at::Half x);
at::Half CalcSwishByLUT(at::Half x);
at::Half CalcSwishByLUTCubic(at::Half x);
at::Half CalcTanhByLUT(at::Half input);
} // namespace fake_fp16
|