1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
|
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef ML_DTYPES_MXFLOAT_H_
#define ML_DTYPES_MXFLOAT_H_
// Microscaling (MX) floating point formats, as described in
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
//
// Note: this implements the underlying raw data types (e.g. E2M1FN), not the
// composite types (e.g. MXFP4).
#include <cstdint>
#include <limits>
#include "ml_dtypes/include/float8.h"
#include "Eigen/Core"
namespace ml_dtypes {
namespace mxfloat_internal {
// Use 8-bit storage for 6-bit and 4-bit types.
template <typename Derived>
class mxfloat6_base : public float8_internal::float8_base<Derived> {
using Base = float8_internal::float8_base<Derived>;
friend class float8_internal::float8_base<Derived>;
using Base::Base;
public:
static constexpr int kBits = 6;
explicit EIGEN_DEVICE_FUNC operator bool() const {
return (Base::rep() & 0x1F) != 0;
}
constexpr Derived operator-() const {
return Derived::FromRep(Base::rep() ^ 0x20);
}
Derived operator-(const Derived& other) const {
return Base::operator-(other);
}
};
template <typename Derived>
class mxfloat4_base : public float8_internal::float8_base<Derived> {
using Base = float8_internal::float8_base<Derived>;
friend class float8_internal::float8_base<Derived>;
using Base::Base;
public:
static constexpr int kBits = 4;
explicit EIGEN_DEVICE_FUNC operator bool() const {
return (Base::rep() & 0x07) != 0;
}
constexpr Derived operator-() const {
return Derived::FromRep(Base::rep() ^ 0x08);
}
Derived operator-(const Derived& other) const {
return Base::operator-(other);
}
};
class float6_e2m3fn : public mxfloat6_base<float6_e2m3fn> {
// Exponent: 2, Mantissa: 3, bias: 1.
// Extended range: no inf, no NaN.
using Base = mxfloat6_base<float6_e2m3fn>;
friend class float8_internal::float8_base<float6_e2m3fn>;
using Base::Base;
public:
template <typename T, float8_internal::RequiresIsDerivedFromFloat8Base<T> = 0>
explicit EIGEN_DEVICE_FUNC float6_e2m3fn(T f8)
: float6_e2m3fn(ConvertFrom(f8)) {}
};
class float6_e3m2fn : public mxfloat6_base<float6_e3m2fn> {
// Exponent: 3, Mantissa: 2, bias: 3.
// Extended range: no inf, no NaN.
using Base = mxfloat6_base<float6_e3m2fn>;
friend class float8_internal::float8_base<float6_e3m2fn>;
using Base::Base;
public:
template <typename T, float8_internal::RequiresIsDerivedFromFloat8Base<T> = 0>
explicit EIGEN_DEVICE_FUNC float6_e3m2fn(T f8)
: float6_e3m2fn(ConvertFrom(f8)) {}
};
class float4_e2m1fn : public mxfloat4_base<float4_e2m1fn> {
// Exponent: 2, Mantissa: 1, bias: 1.
// Extended range: no inf, no NaN.
using Base = mxfloat4_base<float4_e2m1fn>;
friend class float8_internal::float8_base<float4_e2m1fn>;
using Base::Base;
public:
template <typename T, float8_internal::RequiresIsDerivedFromFloat8Base<T> = 0>
explicit EIGEN_DEVICE_FUNC float4_e2m1fn(T f8)
: float4_e2m1fn(ConvertFrom(f8)) {}
};
// Common properties for specializing std::numeric_limits.
template <int E, int M>
struct numeric_limits_mxfloat_tpl {
protected:
static constexpr int kExponentBias = (1 << (E - 1)) - 1;
static constexpr int kMantissaBits = M;
public:
// NOLINTBEGIN: these names must match std::numeric_limits.
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = false;
static constexpr bool has_signaling_NaN = false;
#if !defined(__cplusplus) || __cplusplus < 202302L
static constexpr std::float_denorm_style has_denorm = std::denorm_present;
static constexpr bool has_denorm_loss = false;
#endif
static constexpr std::float_round_style round_style = std::round_to_nearest;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = kMantissaBits + 1;
static constexpr int digits10 = float8_internal::Digits10FromDigits(digits);
static constexpr int max_digits10 =
float8_internal::MaxDigits10FromDigits(digits);
static constexpr int radix = std::numeric_limits<float>::radix;
static constexpr int min_exponent = (1 - kExponentBias) + 1;
static constexpr int min_exponent10 =
float8_internal::MinExponent10FromMinExponent(min_exponent);
static constexpr int max_exponent = kExponentBias + 2;
static constexpr int max_exponent10 =
float8_internal::MaxExponent10FromMaxExponentAndDigits(max_exponent,
digits);
static constexpr bool traps = std::numeric_limits<float>::traps;
static constexpr bool tinyness_before =
std::numeric_limits<float>::tinyness_before;
// NOLINTEND
};
struct numeric_limits_float6_e2m3fn : public numeric_limits_mxfloat_tpl<2, 3> {
// 1.0 * 2^(0) = 1
static constexpr float6_e2m3fn min() {
return float6_e2m3fn::FromRep(0b0'01'000);
}
// -1.875 * 2^(2) = -7.5
static constexpr float6_e2m3fn lowest() {
return float6_e2m3fn::FromRep(0b1'11'111);
}
// 1.875 * 2^(2) = 7.5
static constexpr float6_e2m3fn max() {
return float6_e2m3fn::FromRep(0b0'11'111);
}
// 0.125 * 2^(0) = 0.125
static constexpr float6_e2m3fn epsilon() {
return float6_e2m3fn::FromRep(0b0'00'001);
}
// 0.25 * 2^(0) = 0.25
static constexpr float6_e2m3fn round_error() {
return float6_e2m3fn::FromRep(0b0'00'010);
}
// 0.25 * 2^(0) = 0.125
static constexpr float6_e2m3fn denorm_min() {
return float6_e2m3fn::FromRep(0b0'00'001);
}
// Conversion from NaNs is implementation-defined (by MX specification).
static constexpr float6_e2m3fn quiet_NaN() {
return float6_e2m3fn::FromRep(0b1'00'000);
}
static constexpr float6_e2m3fn signaling_NaN() {
return float6_e2m3fn::FromRep(0b1'00'000);
}
static constexpr float6_e2m3fn infinity() {
return float6_e2m3fn::FromRep(0b0'11'111);
}
};
struct numeric_limits_float6_e3m2fn : public numeric_limits_mxfloat_tpl<3, 2> {
// 1.0 * 2^(-2) = 0.25
static constexpr float6_e3m2fn min() {
return float6_e3m2fn::FromRep(0b0'001'00);
}
// -1.75 * 2^(4) = -28
static constexpr float6_e3m2fn lowest() {
return float6_e3m2fn::FromRep(0b1'111'11);
}
// 1.75 * 2^(4) = 28
static constexpr float6_e3m2fn max() {
return float6_e3m2fn::FromRep(0b0'111'11);
}
// 1.0 * 2^(-2) = 0.25
static constexpr float6_e3m2fn epsilon() {
return float6_e3m2fn::FromRep(0b0'001'00);
}
// 1.0 * 2^(0) = 1
static constexpr float6_e3m2fn round_error() {
return float6_e3m2fn::FromRep(0b0'011'00);
}
// 0.25 * 2^(-2) = 0.0625
static constexpr float6_e3m2fn denorm_min() {
return float6_e3m2fn::FromRep(0b0'000'01);
}
// Conversion from NaNs is implementation-defined (by MX specification).
static constexpr float6_e3m2fn quiet_NaN() {
return float6_e3m2fn::FromRep(0b1'000'00);
}
static constexpr float6_e3m2fn signaling_NaN() {
return float6_e3m2fn::FromRep(0b1'000'00);
}
static constexpr float6_e3m2fn infinity() {
return float6_e3m2fn::FromRep(0b0'111'11);
}
};
struct numeric_limits_float4_e2m1fn : public numeric_limits_mxfloat_tpl<2, 1> {
// 1.0 * 2^(0) = 1
static constexpr float4_e2m1fn min() {
return float4_e2m1fn::FromRep(0b0'01'0);
}
// -1.5 * 2^(2) = -6
static constexpr float4_e2m1fn lowest() {
return float4_e2m1fn::FromRep(0b1'11'1);
}
// 1.5 * 2^(2) = 6
static constexpr float4_e2m1fn max() {
return float4_e2m1fn::FromRep(0b0'11'1);
}
// 0.5 * 2^(0) = 0.5
static constexpr float4_e2m1fn epsilon() {
return float4_e2m1fn::FromRep(0b0'00'1);
}
// 1.0 * 2^(0) = 1
static constexpr float4_e2m1fn round_error() {
return float4_e2m1fn::FromRep(0b0'01'0);
}
// 0.5 * 2^(0) = 0.5
static constexpr float4_e2m1fn denorm_min() {
return float4_e2m1fn::FromRep(0b0'00'1);
}
// Conversion from NaNs is implementation-defined (by MX specification).
static constexpr float4_e2m1fn quiet_NaN() {
return float4_e2m1fn::FromRep(0b1'00'0);
}
static constexpr float4_e2m1fn signaling_NaN() {
return float4_e2m1fn::FromRep(0b1'00'0);
}
static constexpr float4_e2m1fn infinity() {
return float4_e2m1fn::FromRep(0b0'11'1);
}
};
// Free-functions for use with ADL and in Eigen.
constexpr inline float6_e2m3fn abs(const float6_e2m3fn& a) {
return float6_e2m3fn::FromRep(a.rep() & 0b0'11'111);
}
constexpr inline bool(isnan)(const float6_e2m3fn& a) { return false; }
constexpr inline float6_e3m2fn abs(const float6_e3m2fn& a) {
return float6_e3m2fn::FromRep(a.rep() & 0b0'111'11);
}
constexpr inline bool(isnan)(const float6_e3m2fn& a) { return false; }
constexpr inline float4_e2m1fn abs(const float4_e2m1fn& a) {
return float4_e2m1fn::FromRep(a.rep() & 0b0'11'1);
}
constexpr inline bool(isnan)(const float4_e2m1fn& a) { return false; }
// Define traits required for floating point conversion.
template <typename T, int E, int M>
struct TraitsBase : public float8_internal::TraitsBase<T> {
static constexpr int kBits = E + M + 1;
static constexpr int kMantissaBits = M;
static constexpr int kExponentBits = E;
static constexpr int kExponentBias = (1 << (E - 1)) - 1;
static constexpr uint8_t kExponentMask = ((1 << E) - 1) << M;
};
} // namespace mxfloat_internal
// Exported types.
using float6_e2m3fn = mxfloat_internal::float6_e2m3fn;
using float6_e3m2fn = mxfloat_internal::float6_e3m2fn;
using float4_e2m1fn = mxfloat_internal::float4_e2m1fn;
} // namespace ml_dtypes
// Standard library overrides.
namespace std {
template <>
struct numeric_limits<ml_dtypes::mxfloat_internal::float6_e2m3fn>
: public ml_dtypes::mxfloat_internal::numeric_limits_float6_e2m3fn {};
template <>
struct numeric_limits<ml_dtypes::mxfloat_internal::float6_e3m2fn>
: public ml_dtypes::mxfloat_internal::numeric_limits_float6_e3m2fn {};
template <>
struct numeric_limits<ml_dtypes::mxfloat_internal::float4_e2m1fn>
: public ml_dtypes::mxfloat_internal::numeric_limits_float4_e2m1fn {};
} // namespace std
// Conversion traits.
namespace ml_dtypes {
namespace float8_internal {
template <>
struct Traits<float6_e2m3fn>
: public mxfloat_internal::TraitsBase<float6_e2m3fn, 2, 3> {};
template <>
struct Traits<float6_e3m2fn>
: public mxfloat_internal::TraitsBase<float6_e3m2fn, 3, 2> {};
template <>
struct Traits<float4_e2m1fn>
: public mxfloat_internal::TraitsBase<float4_e2m1fn, 2, 1> {};
} // namespace float8_internal
} // namespace ml_dtypes
// Eigen library overrides.
namespace Eigen {
namespace numext {
#define MXFLOAT_EIGEN_SIGNBIT_IMPL(Type) \
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Type signbit(const Type& x) { \
int8_t t = bit_cast<int8_t, Type>(x) << (8 - Type::kBits); \
return bit_cast<Type, int8_t>(t >> 7); \
}
MXFLOAT_EIGEN_SIGNBIT_IMPL(ml_dtypes::float6_e2m3fn)
MXFLOAT_EIGEN_SIGNBIT_IMPL(ml_dtypes::float6_e3m2fn)
MXFLOAT_EIGEN_SIGNBIT_IMPL(ml_dtypes::float4_e2m1fn)
#undef MXFLOAT_EIGEN_SIGNBIT_IMPL
} // namespace numext
// Work-around for isinf/isnan/isfinite issue on aarch64.
namespace internal {
#define MXFLOAT_EIGEN_ISFINITE_IMPL(Type) \
template <> \
EIGEN_DEVICE_FUNC inline bool isinf_impl<Type>(const Type&) { \
return false; \
} \
template <> \
EIGEN_DEVICE_FUNC inline bool isnan_impl<Type>(const Type&) { \
return false; \
} \
template <> \
EIGEN_DEVICE_FUNC inline bool isfinite_impl<Type>(const Type&) { \
return true; \
}
MXFLOAT_EIGEN_ISFINITE_IMPL(ml_dtypes::float6_e2m3fn)
MXFLOAT_EIGEN_ISFINITE_IMPL(ml_dtypes::float6_e3m2fn)
MXFLOAT_EIGEN_ISFINITE_IMPL(ml_dtypes::float4_e2m1fn)
#undef MXFLOAT_EIGEN_ISFINITE_IMPL
} // namespace internal
} // namespace Eigen
#endif // ML_DTYPES_MXFLOAT_H_
|