1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
|
#pragma once
#include <c10/core/SymFloatNodeImpl.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <limits>
#include <memory>
namespace c10 {
// NB: this is actually double precision; we're using the Python naming here
class C10_API SymFloat {
public:
/*implicit*/ SymFloat(double d) : data_(d){};
SymFloat(SymFloatNode ptr)
: data_(std::numeric_limits<double>::quiet_NaN()), ptr_(std::move(ptr)){};
SymFloat() : data_(0.0) {}
SymFloatNodeImpl* toSymFloatNodeImplUnowned() const {
return ptr_.get();
}
SymFloatNodeImpl* release() && {
return std::move(ptr_).release();
}
SymFloatNode toSymFloatNodeImpl() const;
static c10::SymFloat toSymFloat(SymFloatNode sin);
double expect_float() const {
TORCH_CHECK(!is_symbolic());
return data_;
}
SymFloat operator+(SymFloat) const;
SymFloat operator-(SymFloat) const;
SymFloat operator*(SymFloat) const;
SymFloat operator/(SymFloat) const;
// N.B. It's important to keep this definition in the header
// as we expect if checks to be folded for mobile builds
// where `is_symbolic` is always false
C10_ALWAYS_INLINE bool is_symbolic() const {
return ptr_;
}
double as_float_unchecked() const {
return data_;
}
private:
// TODO: optimize to union
double data_;
SymFloatNode ptr_;
};
C10_API std::ostream& operator<<(std::ostream& os, SymFloat s);
} // namespace c10
|