File: float8.h

package info (click to toggle)
ml-dtypes 0.5.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,768 kB
  • sloc: ansic: 48,160; cpp: 26,737; python: 2,344; pascal: 514; makefile: 15
file content (1849 lines) | stat: -rw-r--r-- 68,005 bytes parent folder | download | duplicates (7)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef ML_DTYPES_FLOAT8_H_
#define ML_DTYPES_FLOAT8_H_

// 8-bit Floating Point Interchange Format, as described by
//   https://arxiv.org/abs/2209.05433
//   https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
//   https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

#include <algorithm>
#include <climits>
#include <cmath>
#include <cstdint>
#include <limits>
#include <ostream>
#include <type_traits>
#include <utility>

#ifdef __has_include
#if __has_include(<version>)
#include <version>
#endif
#endif

#if (defined(__cpp_lib_bitops) && __cpp_lib_bitops >= 201907L)
#include <bit>
#endif

#include "Eigen/Core"

namespace ml_dtypes {
namespace float8_internal {

// Forward-declarations of classes.
class float8_e3m4;
class float8_e4m3;
class float8_e4m3fn;
class float8_e4m3fnuz;
class float8_e4m3b11fnuz;
class float8_e5m2;
class float8_e5m2fnuz;
class float8_e8m0fnu;

template <typename Derived>
class float8_base {
 protected:
  // Constructor tag to allow constexpr construction from bit representation.
  struct ConstructFromRepTag {};
  constexpr float8_base(uint8_t rep, ConstructFromRepTag) : rep_{rep} {}

 public:
  static constexpr int kBits = 8;
  constexpr float8_base() : rep_(0) {}

  template <typename T>
  explicit EIGEN_DEVICE_FUNC float8_base(
      T i, std::enable_if_t<std::is_integral_v<T>, int> = 0)
      : float8_base(ConvertFrom(static_cast<float>(i)).rep(),
                    ConstructFromRepTag{}) {}
  template <typename T>
  explicit EIGEN_DEVICE_FUNC float8_base(
      T f, std::enable_if_t<std::is_floating_point_v<T>, int> = 0)
      : float8_base(ConvertFrom(f).rep(), ConstructFromRepTag{}) {}
  explicit EIGEN_DEVICE_FUNC float8_base(Eigen::bfloat16 bf16)
      : float8_base(ConvertFrom(bf16).rep(), ConstructFromRepTag{}) {}
  explicit EIGEN_DEVICE_FUNC float8_base(Eigen::half f16)
      : float8_base(ConvertFrom(f16).rep(), ConstructFromRepTag{}) {}

  constexpr uint8_t rep() const { return rep_; }

  template <typename T,
            typename EnableIf = std::enable_if<std::is_arithmetic_v<T>>>
  explicit EIGEN_DEVICE_FUNC operator T() const {
    return static_cast<T>(static_cast<float>(derived()));
  }
  explicit EIGEN_DEVICE_FUNC operator double() const {
    return ConvertTo<double>(derived());
  }
  EIGEN_DEVICE_FUNC operator float() const {
    return ConvertTo<float>(derived());
  }
  EIGEN_DEVICE_FUNC operator Eigen::bfloat16() const {
    return ConvertTo<Eigen::bfloat16>(derived());
  }
  EIGEN_DEVICE_FUNC operator Eigen::half() const {
    return ConvertTo<Eigen::half>(derived());
  }
  explicit EIGEN_DEVICE_FUNC operator bool() const {
    return (rep() & 0x7F) != 0;
  }

  constexpr Derived operator-() const {
    return Derived(static_cast<uint8_t>(rep() ^ 0x80), ConstructFromRepTag{});
  }

  constexpr const Derived& derived() const {
    return *static_cast<const Derived*>(this);
  }

  constexpr Derived& derived() { return *static_cast<Derived*>(this); }

  static constexpr Derived FromRep(uint8_t rep) {
    return Derived(rep, ConstructFromRepTag{});
  }

  // Conversions allowing saturation and truncation.
  template <bool kSaturate = false, bool kTruncate = false, typename From>
  static inline EIGEN_DEVICE_FUNC Derived ConvertFrom(From from);

  template <typename To, bool kSaturate = false, bool kTruncate = false>
  static inline EIGEN_DEVICE_FUNC To ConvertTo(Derived from);

  // Operators via float32.
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
  operator+(const Derived& other) const {
    return Derived{float{derived()} + float{other}};
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
  operator-(const Derived& other) const {
    return Derived{float{derived()} - float{other}};
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
  operator*(const Derived& other) const {
    return Derived{float{derived()} * float{other}};
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
  operator/(const Derived& other) const {
    return Derived{float{derived()} / float{other}};
  }

  constexpr bool operator==(const Derived& other) const {
    return Compare(derived(), other) == Ordering::kEquivalent;
  }

  constexpr bool operator!=(const Derived& other) const {
    return Compare(derived(), other) != Ordering::kEquivalent;
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(
      const Derived& other) const {
    return Compare(derived(), other) == Ordering::kLess;
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(
      const Derived& other) const {
    return Compare(derived(), other) <= Ordering::kEquivalent;
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(
      const Derived& other) const {
    return Compare(derived(), other) == Ordering::kGreater;
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(
      const Derived& other) const {
    Ordering ordering = Compare(derived(), other);
    return ordering == Ordering::kGreater || ordering == Ordering::kEquivalent;
  }

  // Compound assignment.
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator+=(
      const Derived& other) {
    derived() = derived() + other;
    return derived();
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator-=(
      const Derived& other) {
    derived() = derived() - other;
    return derived();
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator*=(
      const Derived& other) {
    derived() = derived() * other;
    return derived();
  }

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived& operator/=(
      const Derived& other) {
    derived() = derived() / other;
    return derived();
  }

 private:
  static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC std::pair<uint8_t, uint8_t>
  SignAndMagnitude(Derived x) {
    const uint8_t x_abs_bits =
        Eigen::numext::bit_cast<uint8_t>(Eigen::numext::abs(x));
    const uint8_t x_bits = Eigen::numext::bit_cast<uint8_t>(x);
    const uint8_t x_sign = (x_bits ^ x_abs_bits) << (CHAR_BIT - Derived::kBits);
    return {x_sign, x_abs_bits};
  }
  static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int8_t
  SignAndMagnitudeToTwosComplement(uint8_t sign, uint8_t magnitude) {
    return magnitude ^ (static_cast<int8_t>(sign) < 0 ? -1 : 0);
  }

  enum Ordering : int8_t {
    kLess = -1,
    kEquivalent = 0,
    kGreater = 1,
    kUnordered = 2,
  };

  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC friend Ordering Compare(
      const Derived& lhs, const Derived& rhs) {
    if (Eigen::numext::isnan(lhs) || Eigen::numext::isnan(rhs)) {
      return Ordering::kUnordered;
    }
    auto [lhs_sign, lhs_mag] = SignAndMagnitude(lhs);
    auto [rhs_sign, rhs_mag] = SignAndMagnitude(rhs);
    if (lhs_mag == 0 && rhs_mag == 0) {
      return Ordering::kEquivalent;
    }
    int8_t lhs_twos_complement =
        SignAndMagnitudeToTwosComplement(lhs_sign, lhs_mag);
    int8_t rhs_twos_complement =
        SignAndMagnitudeToTwosComplement(rhs_sign, rhs_mag);
    if (lhs_twos_complement < rhs_twos_complement) {
      return Ordering::kLess;
    }
    if (lhs_twos_complement > rhs_twos_complement) {
      return Ordering::kGreater;
    }
    return Ordering::kEquivalent;
  }

  uint8_t rep_;
};

template <typename T>
using RequiresIsDerivedFromFloat8Base =
    std::enable_if_t<std::is_base_of_v<float8_base<T>, T>, int>;

class float8_e3m4 : public float8_base<float8_e3m4> {
  // Exponent: 3, Mantissa: 4, bias: 3.
  // IEEE 754.
 private:
  using Base = float8_base<float8_e3m4>;
  friend class float8_base<float8_e3m4>;
  using Base::Base;

 public:
  template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
  explicit EIGEN_DEVICE_FUNC float8_e3m4(T f8) : float8_e3m4(ConvertFrom(f8)) {}
};

class float8_e4m3 : public float8_base<float8_e4m3> {
  // Exponent: 4, Mantissa: 3, bias: 7.
  // IEEE 754.
 private:
  using Base = float8_base<float8_e4m3>;
  friend class float8_base<float8_e4m3>;
  using Base::Base;

 public:
  template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
  explicit EIGEN_DEVICE_FUNC float8_e4m3(T f8) : float8_e4m3(ConvertFrom(f8)) {}
};

class float8_e4m3fn : public float8_base<float8_e4m3fn> {
  // Exponent: 4, Mantissa: 3, bias: 7.
  // Extended range: no inf, NaN represented by 0bS111'1111.
  // The "fn" suffix is for consistency with the corresponding LLVM/MLIR type,
  // signaling this type is not consistent with IEEE-754.  The "f" indicates
  // it is finite values only. The "n" indicates it includes NaNs, but only
  // at the outer range.
 private:
  using Base = float8_base<float8_e4m3fn>;
  friend class float8_base<float8_e4m3fn>;
  using Base::Base;

 public:
  template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
  explicit EIGEN_DEVICE_FUNC float8_e4m3fn(T f8)
      : float8_e4m3fn(ConvertFrom(f8)) {}
};

class float8_e4m3b11fnuz : public float8_base<float8_e4m3b11fnuz> {
  // Exponent: 4, Mantissa: 3, bias: 11.
  // Extended range: no inf, NaN represented by 0b1000'0000.
 private:
  using Base = float8_base<float8_e4m3b11fnuz>;
  friend class float8_base<float8_e4m3b11fnuz>;
  using Base::Base;

 public:
  template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
  explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(T f8)
      : float8_e4m3b11fnuz(ConvertFrom(f8)) {}

  constexpr float8_e4m3b11fnuz operator-() const {
    if ((rep() & 0x7f) == 0x00) {
      return *this;
    }
    return Base::operator-();
  }

  float8_e4m3b11fnuz operator-(const float8_e4m3b11fnuz& other) const {
    return Base::operator-(other);
  }

  explicit EIGEN_DEVICE_FUNC operator bool() const { return rep() != 0; }
};

// Legacy name used in XLA (TODO(jewillco): remove).
using float8_e4m3b11 = float8_e4m3b11fnuz;

class float8_e4m3fnuz : public float8_base<float8_e4m3fnuz> {
  // 8-bit floating point with 3 bit mantissa.
  //
  // An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
  // mantissa. The suffix "fnuz" is consistent with LLVM/MLIR naming and is
  // derived from the differences to IEEE floating point conventions. `F` is
  // for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for
  // unsigned zero.
  //
  // This type has the following characteristics:
  // * bit encoding: S1E4M3 - `0bSEEEEMMM`
  // * exponent bias: 8
  // * infinities: Not supported
  // * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits
  // set to all 0s - `0b10000000`
  // * denormals when exponent is 0
 private:
  using Base = float8_base<float8_e4m3fnuz>;
  friend class float8_base<float8_e4m3fnuz>;
  using Base::Base;

 public:
  template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
  explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(T f8)
      : float8_e4m3fnuz(ConvertFrom(f8)) {}

  constexpr float8_e4m3fnuz operator-() const {
    if ((rep() & 0x7f) == 0x00) {
      return *this;
    }
    return Base::operator-();
  }

  float8_e4m3fnuz operator-(const float8_e4m3fnuz& other) const {
    return Base::operator-(other);
  }

  explicit EIGEN_DEVICE_FUNC operator bool() const { return rep() != 0; }
};

class float8_e5m2 : public float8_base<float8_e5m2> {
  // Exponent: 5, Mantissa: 2, bias: 15.
  // IEEE 754.
 private:
  using Base = float8_base<float8_e5m2>;
  friend class float8_base<float8_e5m2>;
  using Base::Base;

 public:
  template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
  explicit EIGEN_DEVICE_FUNC float8_e5m2(T f8) : float8_e5m2(ConvertFrom(f8)) {}
};

class float8_e5m2fnuz : public float8_base<float8_e5m2fnuz> {
  // 8-bit floating point with 2 bit mantissa.
  //
  // An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
  // mantissa. The suffix "fnuz" is consistent with LLVM/MLIR naming and is
  // derived from the differences to IEEE floating point conventions. `F` is
  // for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for
  // unsigned zero.
  //
  // This type has the following characteristics:
  // * bit encoding: S1E5M2 - `0bSEEEEEMM`
  // * exponent bias: 16
  // * infinities: Not supported
  // * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits
  // set to all 0s - `0b10000000`
  // * denormals when exponent is 0
 private:
  using Base = float8_base<float8_e5m2fnuz>;
  friend class float8_base<float8_e5m2fnuz>;
  using Base::Base;

 public:
  template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
  explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(T f8)
      : float8_e5m2fnuz(ConvertFrom(f8)) {}

  constexpr float8_e5m2fnuz operator-() const {
    if ((rep() & 0x7f) == 0x00) {
      return *this;
    }
    return Base::operator-();
  }

  float8_e5m2fnuz operator-(const float8_e5m2fnuz& other) const {
    return Base::operator-(other);
  }

  explicit EIGEN_DEVICE_FUNC operator bool() const { return rep() != 0; }
};

class float8_e8m0fnu : public float8_base<float8_e8m0fnu> {
  // 8-bit floating point with 8 bit exponent, no sign and zero mantissa.
  //
  // See:
  // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
  //
  // An 8-bit floating point type with no sign bit, 8 bits exponent and 0 bits
  // mantissa. The suffix "fnuz" is consistent with LLVM/MLIR naming and is
  // derived from the differences to IEEE floating point conventions. `F` is
  // for "finite" (no infinities), `N` for with special NaN encoding, `U` for
  // unsigned.
  //
  // This type has the following characteristics:
  // * bit encoding: S0E8M0 - `0bEEEEEEEE`
  // * exponent bias: 127
  // * infinities: Not supported
  // * NaNs: Supported with exponent bits set to 1s - `0b11111111`
 private:
  using Base = float8_base<float8_e8m0fnu>;
  friend class float8_base<float8_e8m0fnu>;
  using Base::Base;

 public:
  template <typename T, RequiresIsDerivedFromFloat8Base<T> = 0>
  explicit EIGEN_DEVICE_FUNC float8_e8m0fnu(T f8)
      : float8_e8m0fnu(ConvertFrom(f8)) {}

  constexpr float8_e8m0fnu operator-() const {
    // No negative numbers supported in E8M0 => NaN
    return float8_e8m0fnu::FromRep(0xFF);
  }

  float8_e8m0fnu operator-(const float8_e8m0fnu& other) const {
    return Base::operator-(other);
  }

  explicit EIGEN_DEVICE_FUNC operator bool() const {
    // No zero supported in E8M0 format.
    return true;
  }

  // Comparison simplified to uint8_t compare.
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(
      const float8_e8m0fnu& other) const {
    if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) {
      return false;
    }
    return rep() < other.rep();
  }
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(
      const float8_e8m0fnu& other) const {
    if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) {
      return false;
    }
    return rep() <= other.rep();
  }
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(
      const float8_e8m0fnu& other) const {
    if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) {
      return false;
    }
    return rep() > other.rep();
  }
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(
      const float8_e8m0fnu& other) const {
    if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) {
      return false;
    }
    return rep() >= other.rep();
  }
};

constexpr double ConstexprAbs(double x) { return x < 0.0 ? -x : x; }

constexpr double ConstexprCeil(double x) {
  constexpr double kIntegerThreshold =
      uint64_t{1} << (std::numeric_limits<double>::digits - 1);
  // Too big or NaN inputs get returned unchanged.
  if (!(ConstexprAbs(x) < kIntegerThreshold)) {
    return x;
  }
  const double x_trunc = static_cast<double>(static_cast<int64_t>(x));
  return x_trunc < x ? x_trunc + 1.0 : x_trunc;
}

constexpr double ConstexprFloor(double x) { return -ConstexprCeil(-x); }

constexpr double kLog10Of2 = 0.3010299956639812;
// C17 5.2.4.2.2p11:
// "number of decimal digits, q, such that any floating-point number with q
// decimal digits can be rounded into a floating-point number with p radix b
// digits and back again without change to the q decimal digits"
// floor((p - 1) * log10(2));
constexpr int Digits10FromDigits(int digits) {
  return static_cast<int>(ConstexprFloor((digits - 1) * kLog10Of2));
}

// C17 5.2.4.2.2p11:
// "number of decimal digits, n, such that any floating-point number with p
// radix b digits can be rounded to a floating-point number with n decimal
// digits and back again without change to the value"
// ceil(1 + p * log10(2));
constexpr int MaxDigits10FromDigits(int digits) {
  return static_cast<int>(ConstexprCeil(1.0 + (digits * kLog10Of2)));
}

// C17 5.2.4.2.2p11:
// "minimum negative integer such that 10 raised to that power is in the range
// of normalized floating-point numbers"
// ceil(log10(2**(emin - 1))) == ceil((emin - 1) * log10(2));
constexpr int MinExponent10FromMinExponent(int min_exponent) {
  return static_cast<int>(ConstexprCeil((min_exponent - 1) * kLog10Of2));
}

// C17 5.2.4.2.2p11:
// "maximum integer such that 10 raised to that power is in the range of
// representable finite floating-point numbers"
// floor(log10((1 - 2**-p) * 2**emax)) == floor(log10(1 - 2**-p) +
// emax * log10(2))
constexpr int MaxExponent10FromMaxExponentAndDigits(int max_exponent,
                                                    int digits) {
  // We only support digits in {1,2,3,4,5}. This table would grow if we wanted
  // to handle more values.
  constexpr double kLog10OfOnePredecessor[] = {
      // log10(1 - 2**-1)
      -0.3010299956639812,
      // log10(1 - 2**-2)
      -0.12493873660829993,
      // log10(1 - 2**-3)
      -0.057991946977686754,
      // log10(1 - 2**-4)
      -0.028028723600243537,
      // log10(1 - 2**-5)
      -0.013788284485633295,
  };
  return static_cast<int>(ConstexprFloor(kLog10OfOnePredecessor[digits - 1] +
                                         max_exponent * kLog10Of2));
}

// Structures for use in specializing std::numeric_limits.
struct numeric_limits_float8_base {
  // NOLINTBEGIN: these names must match std::numeric_limits.
  static inline constexpr const bool is_specialized = true;
  static inline constexpr const bool is_signed = true;
  static inline constexpr const bool is_integer = false;
  static inline constexpr const bool is_exact = false;
  static inline constexpr const bool has_quiet_NaN = true;
// has_denorm and has_denorm_loss are deprecated in C++23.
#if !defined(__cplusplus) || __cplusplus < 202302L
  static inline constexpr const std::float_denorm_style has_denorm =
      std::denorm_present;
  static inline constexpr const bool has_denorm_loss = false;
#endif
  static inline constexpr const std::float_round_style round_style =
      std::round_to_nearest;
  static inline constexpr const bool is_bounded = true;
  static inline constexpr const bool is_modulo = false;
  static inline constexpr const int radix = std::numeric_limits<float>::radix;
  static inline constexpr const bool traps = std::numeric_limits<float>::traps;
  static inline constexpr const bool tinyness_before =
      std::numeric_limits<float>::tinyness_before;
  // NOLINTEND
};

struct numeric_limits_float8_e3m4 : public numeric_limits_float8_base {
 private:
  static inline constexpr const int kExponentBias = 3;
  static inline constexpr const int kMantissaBits = 4;

 public:
  // NOLINTBEGIN: these names must match std::numeric_limits.
  static inline constexpr const int digits = kMantissaBits + 1;
  static inline constexpr const int digits10 = Digits10FromDigits(digits);
  static inline constexpr const int max_digits10 =
      MaxDigits10FromDigits(digits);
  static inline constexpr const int min_exponent = (1 - kExponentBias) + 1;
  static inline constexpr const int min_exponent10 =
      MinExponent10FromMinExponent(min_exponent);
  static inline constexpr const int max_exponent = 0b111 - kExponentBias;
  static inline constexpr const int max_exponent10 =
      MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
  static inline constexpr const bool is_iec559 = true;
  static inline constexpr const bool has_infinity = true;
  static inline constexpr const bool has_signaling_NaN = true;
  // NOLINTEND

  // 1.0 * 2^(0b001 - 3) = 1.0 * 2^-2 = 1/4 (min normal)
  static constexpr float8_e3m4 min() {
    return float8_e3m4::FromRep(1 << kMantissaBits);
  }
  // -(1 + 0b1111 * 2^-2) * 2^(0b110 - 3) = -(1 + 15/16) * 2^3 = -15.5
  static constexpr float8_e3m4 lowest() {
    return float8_e3m4::FromRep(0b1'110'1111);
  }
  // (1 + 0b1111 * 2^-2) * 2^(0b110 - 3) = (1 + 15/16) * 2^3 = 15.5
  static constexpr float8_e3m4 max() {
    return float8_e3m4::FromRep(0b0'110'1111);
  }
  // (1 + 1/16) * 2^0 - 1.0 = 1.0 + 1/16 - 1.0 = 1/16
  // Encoded as denormal number 2^-2 * 1/4
  static constexpr float8_e3m4 epsilon() {
    return float8_e3m4::FromRep(0b0'000'0100);
  }
  // 1.0 * 2^-1 = 0.5
  static constexpr float8_e3m4 round_error() {
    return float8_e3m4::FromRep((-1 + kExponentBias) << kMantissaBits);
  }
  static constexpr float8_e3m4 infinity() {
    return float8_e3m4::FromRep(0b0'111'0000);
  }
  static constexpr float8_e3m4 quiet_NaN() {
    // IEEE 754-2019 6.2.1: "All binary NaN bit strings have the sign bit S set
    // to 0 or 1 and all the bits of the biased exponent field E set to 1
    // (see 3.4). A quiet NaN bit string should be encoded with the first bit
    // (d1) of the trailing significand field T being 1."
    return float8_e3m4::FromRep(0b0'111'1000);
  }
  static constexpr float8_e3m4 signaling_NaN() {
    // IEEE 754-2019 6.2.1: "A signaling NaN bit string should be encoded with
    // the first bit of the trailing significand field being 0."
    return float8_e3m4::FromRep(0b0'111'0100);
  }
  // 2^(-2) * 2^(-4) = 2^-6 = 1/64 (min denormal)
  static constexpr float8_e3m4 denorm_min() {
    return float8_e3m4::FromRep(0b0'000'0001);
  }
};

struct numeric_limits_float8_e4m3 : public numeric_limits_float8_base {
 private:
  static inline constexpr const int kExponentBias = 7;
  static inline constexpr const int kMantissaBits = 3;

 public:
  // NOLINTBEGIN: these names must match std::numeric_limits.
  static inline constexpr const int digits = kMantissaBits + 1;
  static inline constexpr const int digits10 = Digits10FromDigits(digits);
  static inline constexpr const int max_digits10 =
      MaxDigits10FromDigits(digits);
  static inline constexpr const int min_exponent = (1 - kExponentBias) + 1;
  static inline constexpr const int min_exponent10 =
      MinExponent10FromMinExponent(min_exponent);
  static inline constexpr const int max_exponent = 0b1111 - kExponentBias;
  static inline constexpr const int max_exponent10 =
      MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
  static inline constexpr const bool is_iec559 = true;
  static inline constexpr const bool has_infinity = true;
  static inline constexpr const bool has_signaling_NaN = true;
  // NOLINTEND

  // 1.0 * 2^(0b0001 - 7) = 1.0 * 2^-6 = 1/64 (min normal)
  static constexpr float8_e4m3 min() {
    return float8_e4m3::FromRep(1 << kMantissaBits);
  }
  // -(1 + 0b111 * 2^-2) * 2^(0b1110 - 7) = -(1 + 7/8) * 2^7 = -240
  static constexpr float8_e4m3 lowest() {
    return float8_e4m3::FromRep(0b1'1110'111);
  }
  // (1 + 0b111 * 2^-2) * 2^(0b1110 - 7) = (1 + 7/8) * 2^7 = 240
  static constexpr float8_e4m3 max() {
    return float8_e4m3::FromRep(0b0'1110'111);
  }
  // 1.0 * 2^-3 = 0.125
  static constexpr float8_e4m3 epsilon() {
    return float8_e4m3::FromRep((-kMantissaBits + kExponentBias)
                                << kMantissaBits);
  }
  // 1.0 * 2^-1 = 0.5
  static constexpr float8_e4m3 round_error() {
    return float8_e4m3::FromRep((-1 + kExponentBias) << kMantissaBits);
  }
  static constexpr float8_e4m3 infinity() {
    return float8_e4m3::FromRep(0b0'1111'000);
  }
  static constexpr float8_e4m3 quiet_NaN() {
    // IEEE 754-2019 6.2.1: "All binary NaN bit strings have the sign bit S set
    // to 0 or 1 and all the bits of the biased exponent field E set to 1
    // (see 3.4). A quiet NaN bit string should be encoded with the first bit
    // (d1) of the trailing significand field T being 1."
    return float8_e4m3::FromRep(0b0'1111'100);
  }
  static constexpr float8_e4m3 signaling_NaN() {
    // IEEE 754-2019 6.2.1: "A signaling NaN bit string should be encoded with
    // the first bit of the trailing significand field being 0."
    return float8_e4m3::FromRep(0b0'1111'001);
  }
  // 2^(-6) * 2^(-3) = 2^-9 = 1/512 (min denormal)
  static constexpr float8_e4m3 denorm_min() {
    return float8_e4m3::FromRep(0b0'0000'001);
  }
};

struct numeric_limits_float8_e4m3fn : public numeric_limits_float8_base {
 private:
  static inline constexpr const int kExponentBias = 7;
  static inline constexpr const int kMantissaBits = 3;

 public:
  // NOLINTBEGIN: these names must match std::numeric_limits.
  static inline constexpr const int digits = kMantissaBits + 1;
  static inline constexpr const int digits10 = Digits10FromDigits(digits);
  static inline constexpr const int max_digits10 =
      MaxDigits10FromDigits(digits);
  static inline constexpr const int min_exponent = (1 - kExponentBias) + 1;
  static inline constexpr const int min_exponent10 =
      MinExponent10FromMinExponent(min_exponent);
  static inline constexpr const int max_exponent =
      (0b1111 - kExponentBias) + 1;  // Extended format.
  static inline constexpr const int max_exponent10 =
      MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
  static inline constexpr const bool is_iec559 = false;
  static inline constexpr const bool has_infinity = false;
  static inline constexpr const bool has_signaling_NaN = false;
  // NOLINTEND

  // 1.0 * 2^(0b0001 - 7) = 1.0 * 2^-6 = 0.015625
  static constexpr float8_e4m3fn min() {
    return float8_e4m3fn::FromRep(0b0'0001 << kMantissaBits);
  }
  // -(1 + 0b110 * 2^-3) * 2^(0b1111 - 7) = -1.75 * 2^8 = -448
  static constexpr float8_e4m3fn lowest() {
    return float8_e4m3fn::FromRep(0b1'1111'110);
  }
  // (1 + 0b110 * 2^-3) * 2**(0b1111 - 7) = 1.75 * 2^8 = 448
  static constexpr float8_e4m3fn max() {
    return float8_e4m3fn::FromRep(0b0'1111'110);
  }
  // 1.0 * 2^-3 = 0.125
  static constexpr float8_e4m3fn epsilon() {
    return float8_e4m3fn::FromRep((-kMantissaBits + kExponentBias)
                                  << kMantissaBits);
  }
  // 1.0 * 2^-1 = 0.5
  static constexpr float8_e4m3fn round_error() {
    return float8_e4m3fn::FromRep((-1 + kExponentBias) << kMantissaBits);
  }
  static constexpr float8_e4m3fn infinity() {
    return float8_e4m3fn::FromRep(0b0'1111'111);
  }
  // NaN.
  static constexpr float8_e4m3fn quiet_NaN() {
    return float8_e4m3fn::FromRep(0b0'1111'111);
  }
  static constexpr float8_e4m3fn signaling_NaN() {
    return float8_e4m3fn::FromRep(0b0'1111'111);
  }
  // 1.0 * 2^(-7 - 3 + 1) = 1.0 * 2^-9 = 0.001953125
  static constexpr float8_e4m3fn denorm_min() {
    return float8_e4m3fn::FromRep(0b0'0000'001);
  }
};

struct numeric_limits_float8_e4m3b11fnuz : public numeric_limits_float8_base {
 private:
  static inline constexpr const int kExponentBias = 11;
  static inline constexpr const int kMantissaBits = 3;

 public:
  // NOLINTBEGIN: these names must match std::numeric_limits.
  static inline constexpr const int digits = kMantissaBits + 1;
  static inline constexpr const int digits10 = Digits10FromDigits(digits);
  static inline constexpr const int max_digits10 =
      MaxDigits10FromDigits(digits);
  static inline constexpr const int min_exponent = (1 - kExponentBias) + 1;
  static inline constexpr const int min_exponent10 =
      MinExponent10FromMinExponent(min_exponent);
  static inline constexpr const int max_exponent =
      (0b1111 - kExponentBias) + 1;  // Extended format.
  static inline constexpr const int max_exponent10 =
      MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
  static inline constexpr const bool is_iec559 = false;
  static inline constexpr const bool has_infinity = false;
  static inline constexpr const bool has_signaling_NaN = false;
  // NOLINTEND

  // 1.0 * 2^(0b0001 - 11) = 1.0 * 2^-10 = 0.0009765625
  static constexpr float8_e4m3b11fnuz min() {
    return float8_e4m3b11fnuz::FromRep(1 << kMantissaBits);
  }
  // -(1 + 0b111 * 2^-3) * 2^(0b1111 - 11) = -1.875 * 2^4 = -30
  static constexpr float8_e4m3b11fnuz lowest() {
    return float8_e4m3b11fnuz::FromRep(0b1'1111'111);
  }
  // (1 + 0b111 * 2^-3) * 2^(0b1111 - 11) = 1.875 * 2^4 = 30
  static constexpr float8_e4m3b11fnuz max() {
    return float8_e4m3b11fnuz::FromRep(0b0'1111'111);
  }
  // 1.0 * 2^-3 = 0.125
  static constexpr float8_e4m3b11fnuz epsilon() {
    return float8_e4m3b11fnuz::FromRep((-kMantissaBits + kExponentBias)
                                       << kMantissaBits);
  }
  // 1.0 * 2^-1 = 0.5
  static constexpr float8_e4m3b11fnuz round_error() {
    return float8_e4m3b11fnuz::FromRep((-1 + kExponentBias) << kMantissaBits);
  }
  static constexpr float8_e4m3b11fnuz infinity() {
    return float8_e4m3b11fnuz::FromRep(0b1'0000'000);
  }
  // NaN.
  static constexpr float8_e4m3b11fnuz quiet_NaN() {
    return float8_e4m3b11fnuz::FromRep(0b1'0000'000);
  }
  static constexpr float8_e4m3b11fnuz signaling_NaN() {
    return float8_e4m3b11fnuz::FromRep(0b1'0000'000);
  }
  // 1.0 * 2^(-11 - 3 + 1) = 1.0 * 2^-13 = 0.0001220703125
  static constexpr float8_e4m3b11fnuz denorm_min() {
    return float8_e4m3b11fnuz::FromRep(0b0'0000'001);
  }
};

struct numeric_limits_float8_e4m3fnuz : public numeric_limits_float8_base {
 private:
  static inline constexpr const int kExponentBias = 8;
  static inline constexpr const int kMantissaBits = 3;

 public:
  // NOLINTBEGIN: these names must match std::numeric_limits.
  static inline constexpr const int digits = kMantissaBits + 1;
  static inline constexpr const int digits10 = Digits10FromDigits(digits);
  static inline constexpr const int max_digits10 =
      MaxDigits10FromDigits(digits);
  static inline constexpr const int min_exponent = (1 - kExponentBias) + 1;
  static inline constexpr const int min_exponent10 =
      MinExponent10FromMinExponent(min_exponent);
  static inline constexpr const int max_exponent =
      (0b1111 - kExponentBias) + 1;  // Extended format.
  static inline constexpr const int max_exponent10 =
      MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
  static inline constexpr const bool is_iec559 = false;
  static inline constexpr const bool has_infinity = false;
  static inline constexpr const bool has_signaling_NaN = false;
  // NOLINTEND

  static constexpr float8_e4m3fnuz min() {
    return float8_e4m3fnuz::FromRep(0x08);
  }
  static constexpr float8_e4m3fnuz lowest() {
    return float8_e4m3fnuz::FromRep(0xFF);
  }
  static constexpr float8_e4m3fnuz max() {
    return float8_e4m3fnuz::FromRep(0x7F);
  }
  static constexpr float8_e4m3fnuz epsilon() {
    return float8_e4m3fnuz::FromRep((-kMantissaBits + kExponentBias)
                                    << kMantissaBits);
  }
  static constexpr float8_e4m3fnuz round_error() {
    return float8_e4m3fnuz::FromRep((-1 + kExponentBias) << kMantissaBits);
  }
  static constexpr float8_e4m3fnuz infinity() {
    return float8_e4m3fnuz::FromRep(0x80);
  }
  // NaN.
  static constexpr float8_e4m3fnuz quiet_NaN() {
    return float8_e4m3fnuz::FromRep(0x80);
  }
  static constexpr float8_e4m3fnuz signaling_NaN() {
    return float8_e4m3fnuz::FromRep(0x80);
  }
  static constexpr float8_e4m3fnuz denorm_min() {
    return float8_e4m3fnuz::FromRep(0x01);
  }
};

struct numeric_limits_float8_e5m2 : public numeric_limits_float8_base {
 private:
  static inline constexpr const int kExponentBias = 15;
  static inline constexpr const int kMantissaBits = 2;

 public:
  // NOLINTBEGIN: these names must match std::numeric_limits.
  static inline constexpr const int digits = kMantissaBits + 1;
  static inline constexpr const int digits10 = Digits10FromDigits(digits);
  static inline constexpr const int max_digits10 =
      MaxDigits10FromDigits(digits);
  static inline constexpr const int min_exponent = (1 - kExponentBias) + 1;
  static inline constexpr const int min_exponent10 =
      MinExponent10FromMinExponent(min_exponent);
  static inline constexpr const int max_exponent = 0b11111 - kExponentBias;
  static inline constexpr const int max_exponent10 =
      MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
  static inline constexpr const bool is_iec559 = true;
  static inline constexpr const bool has_infinity = true;
  static inline constexpr const bool has_signaling_NaN = true;
  // NOLINTEND

  // 1.0 * 2^(0b00001 - 15) = 1.0 * 2^-14 = 0.00006103515625
  static constexpr float8_e5m2 min() {
    return float8_e5m2::FromRep(1 << kMantissaBits);
  }
  // -(1 + 0b11 * 2^-2) * 2^(0b11110 - 15) = -1.75 * 2^15 = -57344
  static constexpr float8_e5m2 lowest() {
    return float8_e5m2::FromRep(0b1'11110'11);
  }
  // (1 + 0b11 * 2^-2) * 2^(0b11110 - 15) = 1.75 * 2^15 = 57344
  static constexpr float8_e5m2 max() {
    return float8_e5m2::FromRep(0b0'11110'11);
  }
  // 1.0 * 2^-2 = 0.25
  static constexpr float8_e5m2 epsilon() {
    return float8_e5m2::FromRep((-kMantissaBits + kExponentBias)
                                << kMantissaBits);
  }
  // 1.0 * 2^-1 = 0.5
  static constexpr float8_e5m2 round_error() {
    return float8_e5m2::FromRep((-1 + kExponentBias) << kMantissaBits);
  }
  static constexpr float8_e5m2 infinity() {
    return float8_e5m2::FromRep(0b0'11111'00);
  }
  static constexpr float8_e5m2 quiet_NaN() {
    // IEEE 754-2019 6.2.1: "All binary NaN bit strings have the sign bit S set
    // to 0 or 1 and all the bits of the biased exponent field E set to 1
    // (see 3.4). A quiet NaN bit string should be encoded with the first bit
    // (d1) of the trailing significand field T being 1."
    return float8_e5m2::FromRep(0b0'11111'10);
  }
  static constexpr float8_e5m2 signaling_NaN() {
    // IEEE 754-2019 6.2.1: "A signaling NaN bit string should be encoded with
    // the first bit of the trailing significand field being 0."
    return float8_e5m2::FromRep(0b0'11111'01);
  }
  // 1.0 * 2^(-15 - 2 + 1) = 1.0 * 2^-16 = 0.0000152587890625
  static constexpr float8_e5m2 denorm_min() {
    return float8_e5m2::FromRep(0b0'00000'01);
  }
};

struct numeric_limits_float8_e5m2fnuz : public numeric_limits_float8_base {
 private:
  static inline constexpr const int kExponentBias = 16;
  static inline constexpr const int kMantissaBits = 2;

 public:
  // NOLINTBEGIN: these names must match std::numeric_limits.
  static inline constexpr const int digits = kMantissaBits + 1;
  static inline constexpr const int digits10 = Digits10FromDigits(digits);
  static inline constexpr const int max_digits10 =
      MaxDigits10FromDigits(digits);
  static inline constexpr const int min_exponent = (1 - kExponentBias) + 1;
  static inline constexpr const int min_exponent10 =
      MinExponent10FromMinExponent(min_exponent);
  static inline constexpr const int max_exponent =
      (0b11111 - kExponentBias) + 1;
  static inline constexpr const int max_exponent10 =
      MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
  static inline constexpr const bool is_iec559 = false;
  static inline constexpr const bool has_infinity = false;
  static inline constexpr const bool has_signaling_NaN = false;
  // NOLINTEND

  static constexpr float8_e5m2fnuz min() {
    return float8_e5m2fnuz::FromRep(0x04);
  }
  static constexpr float8_e5m2fnuz lowest() {
    return float8_e5m2fnuz::FromRep(0xFF);
  }
  static constexpr float8_e5m2fnuz max() {
    return float8_e5m2fnuz::FromRep(0x7F);
  }
  static constexpr float8_e5m2fnuz epsilon() {
    return float8_e5m2fnuz::FromRep((-kMantissaBits + kExponentBias)
                                    << kMantissaBits);
  }
  static constexpr float8_e5m2fnuz round_error() {
    return float8_e5m2fnuz::FromRep((-1 + kExponentBias) << kMantissaBits);
  }
  static constexpr float8_e5m2fnuz infinity() {
    return float8_e5m2fnuz::FromRep(0x80);
  }  // NaN.
  static constexpr float8_e5m2fnuz quiet_NaN() {
    return float8_e5m2fnuz::FromRep(0x80);
  }
  static constexpr float8_e5m2fnuz signaling_NaN() {
    return float8_e5m2fnuz::FromRep(0x80);
  }
  static constexpr float8_e5m2fnuz denorm_min() {
    return float8_e5m2fnuz::FromRep(0x01);
  }
};

struct numeric_limits_float8_e8m0fnu : public numeric_limits_float8_base {
 private:
  static inline constexpr const int kExponentBias = 127;
  static inline constexpr const int kMantissaBits = 0;

 public:
  // NOLINTBEGIN: these names must match std::numeric_limits.
  static inline constexpr const bool is_signed = false;
// has_denorm and has_denorm_loss are deprecated in C++23.
#if !defined(__cplusplus) || __cplusplus < 202302L
  static inline constexpr const std::float_denorm_style has_denorm =
      std::denorm_absent;
#endif
  static inline constexpr const int digits = kMantissaBits + 1;
  static inline constexpr const int digits10 = Digits10FromDigits(digits);
  static inline constexpr const int max_digits10 =
      MaxDigits10FromDigits(digits);
  // 2**-127 smallest valid normalized value..
  static inline constexpr const int min_exponent = -kExponentBias + 1;
  static inline constexpr const int min_exponent10 =
      MinExponent10FromMinExponent(min_exponent);
  // 128 encoding using for NaN
  static inline constexpr const int max_exponent = kExponentBias + 1;
  static inline constexpr const int max_exponent10 =
      MaxExponent10FromMaxExponentAndDigits(max_exponent, digits);
  static inline constexpr const bool is_iec559 = false;
  static inline constexpr const bool has_infinity = false;
  static inline constexpr const bool has_signaling_NaN = false;
  // NOLINTEND

  static constexpr float8_e8m0fnu min() {
    return float8_e8m0fnu::FromRep(0x00);
  }
  static constexpr float8_e8m0fnu lowest() {
    return float8_e8m0fnu::FromRep(0x00);
  }
  static constexpr float8_e8m0fnu max() {
    return float8_e8m0fnu::FromRep(0xfe);
  }
  static constexpr float8_e8m0fnu epsilon() {
    return float8_e8m0fnu::FromRep((-kMantissaBits + kExponentBias)
                                   << kMantissaBits);
  }
  static constexpr float8_e8m0fnu round_error() {
    return float8_e8m0fnu::FromRep((-1 + kExponentBias) << kMantissaBits);
  }
  static constexpr float8_e8m0fnu infinity() {
    return float8_e8m0fnu::FromRep(0xFF);
  }  // NaN.
  static constexpr float8_e8m0fnu quiet_NaN() {
    return float8_e8m0fnu::FromRep(0xFF);
  }
  static constexpr float8_e8m0fnu signaling_NaN() {
    return float8_e8m0fnu::FromRep(0xFF);
  }
  static constexpr float8_e8m0fnu denorm_min() {
    // No denorm => smallest value.
    return float8_e8m0fnu::FromRep(0x00);
  }
};

}  // namespace float8_internal
}  // namespace ml_dtypes

namespace std {
// Standard-library overrides.  Note that these are picked up by Eigen as well.
template <>
struct numeric_limits<ml_dtypes::float8_internal::float8_e3m4>
    : public ml_dtypes::float8_internal::numeric_limits_float8_e3m4 {};

template <>
struct numeric_limits<ml_dtypes::float8_internal::float8_e4m3>
    : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3 {};

template <>
struct numeric_limits<ml_dtypes::float8_internal::float8_e4m3fn>
    : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3fn {};

template <>
struct numeric_limits<ml_dtypes::float8_internal::float8_e4m3b11fnuz>
    : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3b11fnuz {};

template <>
struct numeric_limits<ml_dtypes::float8_internal::float8_e4m3fnuz>
    : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3fnuz {};

template <>
struct numeric_limits<ml_dtypes::float8_internal::float8_e5m2>
    : public ml_dtypes::float8_internal::numeric_limits_float8_e5m2 {};

template <>
struct numeric_limits<ml_dtypes::float8_internal::float8_e5m2fnuz>
    : public ml_dtypes::float8_internal::numeric_limits_float8_e5m2fnuz {};

template <>
struct numeric_limits<ml_dtypes::float8_internal::float8_e8m0fnu>
    : public ml_dtypes::float8_internal::numeric_limits_float8_e8m0fnu {};
}  // namespace std

namespace ml_dtypes {
namespace float8_internal {

constexpr inline float8_e3m4 abs(const float8_e3m4& a) {
  return float8_e3m4::FromRep(a.rep() & 0b0'111'1111);
}

constexpr inline bool(isnan)(const float8_e3m4& a) {
  return abs(a).rep() > std::numeric_limits<float8_e3m4>::infinity().rep();
}

constexpr inline float8_e4m3 abs(const float8_e4m3& a) {
  return float8_e4m3::FromRep(a.rep() & 0b0'1111'111);
}

constexpr inline bool(isnan)(const float8_e4m3& a) {
  return abs(a).rep() > std::numeric_limits<float8_e4m3>::infinity().rep();
}

// Free-functions for use with ADL and in Eigen.
constexpr inline float8_e4m3fn abs(const float8_e4m3fn& a) {
  return float8_e4m3fn::FromRep(a.rep() & 0b0'1111'111);
}

constexpr inline bool(isnan)(const float8_e4m3fn& a) {
  return abs(a).rep() == std::numeric_limits<float8_e4m3fn>::quiet_NaN().rep();
}

constexpr inline float8_e4m3b11fnuz abs(const float8_e4m3b11fnuz& a) {
  return (a.rep() & 0b0'1111'111) == 0
             ? float8_e4m3b11fnuz::FromRep(a.rep())
             : float8_e4m3b11fnuz::FromRep(a.rep() & 0b0'1111'111);
}

constexpr inline bool(isnan)(const float8_e4m3b11fnuz& a) {
  return a.rep() == std::numeric_limits<float8_e4m3b11fnuz>::quiet_NaN().rep();
}

constexpr inline float8_e4m3fnuz abs(const float8_e4m3fnuz& a) {
  return (a.rep() & 0x7F) == 0 ? float8_e4m3fnuz::FromRep(a.rep())
                               : float8_e4m3fnuz::FromRep(a.rep() & 0x7F);
}

constexpr inline bool(isnan)(const float8_e4m3fnuz& a) {
  return abs(a).rep() ==
         std::numeric_limits<float8_e4m3fnuz>::quiet_NaN().rep();
}

constexpr inline float8_e5m2 abs(const float8_e5m2& a) {
  return float8_e5m2::FromRep(a.rep() & 0b0'11111'11);
}

constexpr inline bool(isnan)(const float8_e5m2& a) {
  return abs(a).rep() > std::numeric_limits<float8_e5m2>::infinity().rep();
}

constexpr inline float8_e5m2fnuz abs(const float8_e5m2fnuz& a) {
  return (a.rep() & 0x7F) == 0 ? float8_e5m2fnuz::FromRep(a.rep())
                               : float8_e5m2fnuz::FromRep(a.rep() & 0x7F);
}

constexpr inline bool(isnan)(const float8_e5m2fnuz& a) {
  return a.rep() == 0x80;
}

constexpr inline float8_e8m0fnu abs(const float8_e8m0fnu& a) { return a; }

constexpr inline bool(isnan)(const float8_e8m0fnu& a) {
  return a.rep() == 0xff;
}

template <typename Float8>
constexpr inline bool(isinf)(const float8_base<Float8>& a) {
  if constexpr (std::numeric_limits<Float8>::has_infinity) {
    return abs(a.derived()).rep() ==
           std::numeric_limits<Float8>::infinity().rep();
  } else {
    // No inf representation.
    return false;
  }
}

template <typename Float8>
constexpr inline bool(isfinite)(const float8_base<Float8>& a) {
  return !isnan(a.derived()) && !isinf(a.derived());
}

template <typename Float8>
std::ostream& operator<<(std::ostream& os, const float8_base<Float8>& f8) {
  os << static_cast<float>(f8.derived());
  return os;
}

//==============================================================================
// Inline conversion routines between float8 and other types.
//==============================================================================

template <typename T>
bool constexpr IsPowerOfTwo(T x) {
  return (x != 0) && ((x & (x - 1)) == 0);
}
// Helper for getting a bytes size which is a power of two.
template <int Size>
struct NextPowerOfTwo {
  static constexpr int value = Size;
};
template <>
struct NextPowerOfTwo<3> {
  static constexpr int value = 4;
};
template <>
struct NextPowerOfTwo<5> {
  static constexpr int value = 8;
};
template <>
struct NextPowerOfTwo<6> {
  static constexpr int value = 8;
};
template <>
struct NextPowerOfTwo<7> {
  static constexpr int value = 8;
};

// Helper for getting a bit representation provided a byte size.
template <int kNumBytes>
using GetUnsignedInteger =
    typename Eigen::numext::get_integer_by_size<kNumBytes>::unsigned_type;

// Converts between two floating-point types.
template <typename From, typename To, bool kSaturate, bool kTruncate,
          typename EnableIf = void>
struct ConvertImpl;

// Convert to same type.  We need explicit specializations for all combinations
// of template parameters to avoid ambiguities.
template <typename Scalar>
struct IdentityConversion {
  static EIGEN_DEVICE_FUNC inline Scalar run(Scalar from) { return from; }
};

template <typename Scalar, bool kSaturate, bool kTruncate>
struct ConvertImpl<Scalar, Scalar, /*kSaturate=*/kSaturate,
                   /*kTruncate=*/kTruncate>
    : public IdentityConversion<Scalar> {};

template <typename Float>
struct TraitsBase {
  using BitsType = GetUnsignedInteger<sizeof(Float)>;
  static constexpr bool kIsSigned = std::numeric_limits<Float>::is_signed;
  static constexpr bool kHasZero = true;

  static constexpr int kBits = sizeof(Float) * CHAR_BIT;
  static constexpr int kMantissaBits = Eigen::NumTraits<Float>::digits() - 1;
  // Extra bit used in exponent for unsigned float.
  static constexpr int kExponentBits =
      kBits - kMantissaBits - static_cast<int>(kIsSigned);
  static constexpr BitsType kExponentMask = ((BitsType{1} << kExponentBits) - 1)
                                            << kMantissaBits;
  static constexpr BitsType kMantissaMask = (BitsType{1} << kMantissaBits) - 1;
  static constexpr int kExponentBias = (1 << (kExponentBits - 1)) - 1;
};

template <typename Float>
struct Traits : public TraitsBase<Float> {};

template <>
struct Traits<float8_e4m3b11fnuz> : public TraitsBase<float8_e4m3b11fnuz> {
  static constexpr int kExponentBias = 11;
};

template <>
struct Traits<float8_e4m3fnuz> : public TraitsBase<float8_e4m3fnuz> {
  using Base = TraitsBase<float8_e4m3fnuz>;
  static constexpr int kExponentBias = Base::kExponentBias + 1;
};

template <>
struct Traits<float8_e5m2fnuz> : public TraitsBase<float8_e5m2fnuz> {
  using Base = TraitsBase<float8_e5m2fnuz>;
  static constexpr int kExponentBias = Base::kExponentBias + 1;
};

template <>
struct Traits<float8_e8m0fnu> : public TraitsBase<float8_e8m0fnu> {
  using Base = TraitsBase<float8_e8m0fnu>;
  // No zero in E8MO OCP MX format description.
  static constexpr bool kHasZero = false;
};

template <typename Bits>
constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff,
                                             bool use_implicit_bit) {
  // Round to nearest even by adding a bias term.
  // Consider a bit pattern
  //   FFF...FLRTT...T,
  // where bits RTT...T need to be rounded-off.  We add a bias term to the
  // bit pattern s.t. a carry is introduced to round up only if
  // - L is 1, R is 1, OR
  // - L is 0, R is 1, any T is one.
  // We do this by adding L to a bit pattern consisting of all T = 1.
  //
  // When rounding to zero mantissa (E8M0 type), the L bit is implicitly 1 (do
  // not use the exponent bits for rounding). Add only the R bit in this case.
  Bits bias = !use_implicit_bit
                  ? ((bits >> roundoff) & 1) + (Bits{1} << (roundoff - 1)) - 1
                  : Bits{1} << (roundoff - 1);
  return bits + bias;
}

#if (defined(__cpp_lib_bitops) && __cpp_lib_bitops >= 201907L)
using std::countl_zero;
#else
static constexpr inline int countl_zero(uint64_t x) {
  int zeroes = 60;
  if (x >> 32) {
    zeroes -= 32;
    x >>= 32;
  }
  if (x >> 16) {
    zeroes -= 16;
    x >>= 16;
  }
  if (x >> 8) {
    zeroes -= 8;
    x >>= 8;
  }
  if (x >> 4) {
    zeroes -= 4;
    x >>= 4;
  }
  return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[x] + zeroes;
}
static constexpr inline int countl_zero(uint32_t x) {
  int zeroes = 28;
  if (x >> 16) {
    zeroes -= 16;
    x >>= 16;
  }
  if (x >> 8) {
    zeroes -= 8;
    x >>= 8;
  }
  if (x >> 4) {
    zeroes -= 4;
    x >>= 4;
  }
  return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[x] + zeroes;
}
static constexpr inline int countl_zero(uint16_t x) {
  int zeroes = 12;
  if (x >> 8) {
    zeroes -= 8;
    x >>= 8;
  }
  if (x >> 4) {
    zeroes -= 4;
    x >>= 4;
  }
  return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[x] + zeroes;
}
static constexpr inline int countl_zero(uint8_t x) {
  int zeroes = 4;
  if (x >> 4) {
    zeroes -= 4;
    x >>= 4;
  }
  return "\4\3\2\2\1\1\1\1\0\0\0\0\0\0\0"[x] + zeroes;
}
#endif

template <typename From, typename To, bool kSaturate, bool kTruncate>
struct ConvertImpl<From, To, kSaturate, kTruncate,
                   std::enable_if_t<!std::is_same_v<From, To>>> {
  using FromTraits = Traits<From>;
  using FromBits = typename FromTraits::BitsType;
  static constexpr bool kFromIsSigned = FromTraits::kIsSigned;
  static constexpr bool kFromHasZero = FromTraits::kHasZero;
  static constexpr int kFromBits = FromTraits::kBits;
  static constexpr int kFromMantissaBits = FromTraits::kMantissaBits;
  static constexpr int kFromExponentBits = FromTraits::kExponentBits;
  static constexpr int kFromExponentBias = FromTraits::kExponentBias;
  static constexpr FromBits kFromExponentMask = FromTraits::kExponentMask;

  using ToTraits = Traits<To>;
  using ToBits = typename ToTraits::BitsType;
  static constexpr bool kToIsSigned = ToTraits::kIsSigned;
  static constexpr bool kToHasZero = ToTraits::kHasZero;
  static constexpr int kToBits = ToTraits::kBits;
  static constexpr int kToMantissaBits = ToTraits::kMantissaBits;
  static constexpr int kToExponentBits = ToTraits::kExponentBits;
  static constexpr int kToExponentBias = ToTraits::kExponentBias;
  static constexpr ToBits kToExponentMask = ToTraits::kExponentMask;

  // `WideBits` is wide enough to accommodate the largest exponent and mantissa
  // in either `From` or `To`.
  static constexpr int kWideBits =
      (std::max(kToMantissaBits, kFromMantissaBits)) +  // Max significand.
      (std::max(kToExponentBits, kFromExponentBits));   // Max exponent.
  static constexpr int kWideBytesRaw = (kWideBits + (CHAR_BIT - 1)) / CHAR_BIT;
  // Need a power of two (i.e. not 3 bytes).
  static constexpr int kWideBytes = NextPowerOfTwo<kWideBytesRaw>::value;

  using WideBits = GetUnsignedInteger<kWideBytes>;
  static_assert(!std::is_void_v<WideBits>,
                "`WideBits` type can not be void type.");

  static constexpr int kExponentOffset = kToExponentBias - kFromExponentBias;
  static constexpr int kDigitShift = kToMantissaBits - kFromMantissaBits;

  static EIGEN_DEVICE_FUNC inline To run(From from) {
    // Shift bits to destination type, without sign bit.
    const bool from_sign_bit =
        Eigen::numext::bit_cast<FromBits>(from) >> (kFromBits - 1) &&
        kFromIsSigned;
    const FromBits from_bits =
        Eigen::numext::bit_cast<FromBits>(Eigen::numext::abs(from));

    // Special values, preserving sign.
    if (Eigen::numext::isinf(from)) {
      return from_sign_bit ? -Eigen::NumTraits<To>::infinity()
                           : Eigen::NumTraits<To>::infinity();
    }
    if (Eigen::numext::isnan(from)) {
      return from_sign_bit ? -Eigen::NumTraits<To>::quiet_NaN()
                           : Eigen::NumTraits<To>::quiet_NaN();
    }
    // Dealing with zero, when `From` has one.
    if (from_bits == 0 && kFromHasZero) {
      if constexpr (kToHasZero) {
        // Keep the sign, if `To` supports it.
        return from_sign_bit && kToIsSigned ? -To{} : To{};
      } else {
        return kSaturate ? std::numeric_limits<To>::denorm_min()
                         : Eigen::NumTraits<To>::quiet_NaN();
      }
    }
    // `To` unsigned floating format: NaN or saturate.
    if constexpr (!kToIsSigned && kFromIsSigned) {
      if (from_sign_bit) {
        return kSaturate ? std::numeric_limits<To>::lowest()
                         : Eigen::NumTraits<To>::quiet_NaN();
      }
    }

    const int biased_from_exponent = from_bits >> kFromMantissaBits;
    const bool to_zero_mantissa = kToMantissaBits == 0;

    // `To` supports more exponents near zero which means that some subnormal
    // values in `From` may become normal.
    if constexpr (std::numeric_limits<To>::min_exponent <
                  std::numeric_limits<From>::min_exponent) {
      if (biased_from_exponent == 0) {
        // Subnormals.
        WideBits bits = from_bits;

        // Determine exponent in target type.
        const int msb =
            sizeof(from_bits) * CHAR_BIT - countl_zero(from_bits) - 1;
        const int normalization_factor = kFromMantissaBits - msb;
        const int biased_exponent = kExponentOffset - normalization_factor + 1;
        if (biased_exponent <= 0) {
          // Result is subnormal.  Adjust the subnormal bits to account for
          // the difference in exponent bias.
          if constexpr (kExponentOffset < sizeof(WideBits) * CHAR_BIT) {
            bits <<= kExponentOffset;
          }
        } else {
          // Result is normal. Shift the mantissa to account for the number of
          // leading zero digits, and clear the hidden bit.
          bits <<= normalization_factor;
          bits &= ~(WideBits{1} << kFromMantissaBits);
          // Insert the exponent bits.
          bits |= static_cast<WideBits>(biased_exponent) << kFromMantissaBits;
        }

        // Truncate/round mantissa if necessary.
        if constexpr (kDigitShift >= 0) {
          bits <<= kDigitShift;
        } else {
          if constexpr (!kTruncate) {
            // When converting float to e8m0, the bits represent a denormal,
            // so don't use the implicit mantissa bit for rounding.
            bits = RoundBitsToNearestEven(
                bits, -kDigitShift, to_zero_mantissa && kExponentOffset != 0);
          }
          bits >>= -kDigitShift;
        }
        To to = Eigen::numext::bit_cast<To>(static_cast<ToBits>(bits));
        return from_sign_bit ? -to : to;
      }
    }
    // `To` supports fewer exponents near zero which means that some values in
    // `From` may become subnormal.
    if constexpr (std::numeric_limits<To>::min_exponent >
                  std::numeric_limits<From>::min_exponent) {
      const int unbiased_exponent = biased_from_exponent - kFromExponentBias;
      const int biased_to_exponent = unbiased_exponent + kToExponentBias;
      // Subnormals and zero.
      if (biased_to_exponent <= 0) {
        // Round and shift mantissa down.
        // Zero exponent valid if From has no zero representation.
        FromBits from_has_leading_one =
            (biased_from_exponent > 0 || !kFromHasZero ? 1 : 0);
        int exponent_shift =
            -kDigitShift - biased_to_exponent + from_has_leading_one;
        // Insert the implicit leading 1 bit on the mantissa for normalized
        // inputs.
        FromBits rounded_from_bits =
            (from_bits & FromTraits::kMantissaMask) |
            (from_has_leading_one << kFromMantissaBits);
        ToBits bits = 0;
        if (exponent_shift > 0) {
          // To avoid UB, limit rounding and shifting to the full mantissa plus
          // leading 1.
          if (exponent_shift <= kFromMantissaBits + 1) {
            if constexpr (!kTruncate) {
              // NOTE: we need to round again from the original from_bits,
              // otherwise the lower precision bits may already be lost.  There
              // is an edge-case where rounding to a normalized value would
              // normally round down, but for a subnormal, we need to round up.
              rounded_from_bits = RoundBitsToNearestEven(rounded_from_bits,
                                                         exponent_shift, false);
            }
            bits = rounded_from_bits >> exponent_shift;
          }
        } else {
          bits = rounded_from_bits << -exponent_shift;
        }
        // Insert sign and return.
        To to = Eigen::numext::bit_cast<To>(bits);
        return from_sign_bit ? -to : to;
      }
    }

    // Round the mantissa if it is shrinking.
    WideBits rounded_from_bits = from_bits;
    if constexpr (kDigitShift < 0) {
      if constexpr (!kTruncate) {
        rounded_from_bits =
            RoundBitsToNearestEven(from_bits, -kDigitShift, to_zero_mantissa);
      }
      // Zero-out tail bits.
      rounded_from_bits &= ~((WideBits{1} << (-kDigitShift)) - 1);
    }

    // Re-bias the exponent.
    rounded_from_bits += static_cast<WideBits>(kExponentOffset)
                         << kFromMantissaBits;

    ToBits bits;
    // Check for overflows by aligning the significands. We always align the
    // narrower significand to the wider significand.
    const WideBits kToHighestRep =
        Eigen::numext::bit_cast<ToBits>(Eigen::NumTraits<To>::highest());
    WideBits aligned_highest{kToHighestRep};
    if constexpr (kDigitShift < 0) {
      aligned_highest <<= -kDigitShift;
      // Shift down, all dropped bits should already be zero.
      bits = static_cast<ToBits>(rounded_from_bits >> -kDigitShift);
    } else if constexpr (kDigitShift >= 0) {
      // Shift up, inserting zeros in the newly created digits.
      rounded_from_bits <<= kDigitShift;
      bits = static_cast<ToBits>(rounded_from_bits);
    }

    To to = Eigen::numext::bit_cast<To>(bits);
    // `From` supports larger values than `To`, we may overflow.
    if constexpr (std::make_pair(std::numeric_limits<To>::max_exponent,
                                 std::numeric_limits<To>::digits) <
                  std::make_pair(std::numeric_limits<From>::max_exponent,
                                 std::numeric_limits<From>::digits)) {
      if (rounded_from_bits > aligned_highest) {
        // Overflowed values map to highest or infinity depending on kSaturate.
        to = kSaturate ? Eigen::NumTraits<To>::highest()
                       : Eigen::NumTraits<To>::infinity();
      }
    }
    // Insert sign bit.
    return from_sign_bit ? -to : to;
  }
};

// Saturation has no impact when casting e4m3fn to e5m2.
template <bool kTruncate>
struct ConvertImpl<float8_e4m3fn, float8_e5m2, true, kTruncate> {
  static EIGEN_DEVICE_FUNC inline float8_e5m2 run(float8_e4m3fn from) {
    return ConvertImpl<float8_e4m3fn, float8_e5m2, false, kTruncate>::run(from);
  }
};

template <bool kSaturate, bool kTruncate>
struct ConvertImpl<Eigen::half, float8_e5m2, kSaturate, kTruncate> {
  static EIGEN_DEVICE_FUNC inline float8_e5m2 run(Eigen::half from) {
    uint16_t from_bits = Eigen::numext::bit_cast<uint16_t>(from);

    // Special values (Inf or NaN).
    uint16_t abs_bits = from_bits & 0x7FFF;
    if (abs_bits == 0x7C00) {
      return float8_e5m2::FromRep(from_bits >> 8);
    } else if (abs_bits > 0x7C00) {
      // IEEE 754-2019 6.2.1: "A quiet NaN bit string should be encoded with the
      // first bit (d1) of the trailing significand field T being 1."
      // IEEE 754-2019 6.2.3: "Conversion of a quiet NaN to a floating-point
      // format of the same or a different radix that does not allow the payload
      // to be preserved, shall return a quiet NaN [...]"
      return float8_e5m2::FromRep((from_bits >> 8) | 0b0'00000'10);
    }

    if constexpr (!kTruncate) {
      from_bits = RoundBitsToNearestEven(from_bits, 8, false);
      // Rounding can cause an overflow to infinity. Clamp to the largest finite
      // value if saturation is requested.
      if constexpr (kSaturate) {
        const float8_e5m2 kHighest = Eigen::NumTraits<float8_e5m2>::highest();
        if ((from_bits & 0x7F00) > static_cast<uint16_t>(kHighest.rep()) << 8) {
          const bool from_sign_bit = from_bits >> 15;
          return from_sign_bit ? -kHighest : kHighest;
        }
      }
    }
    return float8_e5m2::FromRep(from_bits >> 8);
  }
};

// Direct casts of e5m2 to Eigen::half simply shifts bits over.
template <bool kSaturate, bool kTruncate>
struct ConvertImpl<float8_e5m2, Eigen::half, kSaturate, kTruncate> {
  static EIGEN_DEVICE_FUNC inline Eigen::half run(float8_e5m2 from) {
    return Eigen::numext::bit_cast<Eigen::half>(
        static_cast<uint16_t>(static_cast<uint16_t>(from.rep()) << 8));
  }
};

template <typename Derived>
template <bool kSaturate, bool kTruncate, typename From>
EIGEN_DEVICE_FUNC Derived float8_base<Derived>::ConvertFrom(const From from) {
  // We are rounding long double -> float -> float8. This can induce
  // double-rounding which may alter the results. We can correct for this using
  // a trick explained in: Boldo, Sylvie, and Guillaume Melquiond. "When double
  // rounding is odd." 17th IMACS World Congress. 2005.
  if constexpr (std::is_floating_point_v<From> &&
                sizeof(From) > sizeof(double)) {
    // float80, binary128, etc. end up here.
    static_assert(std::numeric_limits<From>::digits >=
                  std::numeric_limits<float>::digits + 2);
    static_assert(std::numeric_limits<float>::min_exponent >=
                  std::numeric_limits<From>::min_exponent + 2);
    static_assert(std::numeric_limits<float>::is_iec559);
    static_assert(std::numeric_limits<float>::radix == 2);
    const bool is_negative = std::signbit(from);
    const From abs_wide = std::fabs(from);
    float abs_narrow = static_cast<float>(abs_wide);
    const From abs_narrow_as_wide = static_cast<From>(abs_narrow);

    uint32_t narrow_bits = Eigen::numext::bit_cast<uint32_t>(abs_narrow);
    // We can keep the narrow value as-is if narrowing was exact (no rounding
    // error), the wide value was NaN (the narrow value is also NaN and should
    // be preserved) or if we rounded to the odd value.
    const bool keep_narrow = (abs_wide == abs_narrow_as_wide) ||
                             std::isnan(abs_narrow) || (narrow_bits & 1);
    // We morally performed a round-down if `abs_narrow` is smaller than
    // `abs_wide`.
    const bool narrow_is_rd = abs_wide > abs_narrow_as_wide;
    // If the narrow value is odd or exact, pick it.
    // Otherwise, narrow is even and corresponds to either the rounded-up or
    // rounded-down value. If narrow is the rounded-down value, we want the
    // rounded-up value as it will be odd.
    narrow_bits += keep_narrow ? 0 : narrow_is_rd ? 1 : -1;
    abs_narrow = Eigen::numext::bit_cast<float>(narrow_bits);
    return ConvertImpl<float, Derived, kSaturate, kTruncate>::run(
        is_negative ? -abs_narrow : abs_narrow);
  } else {
    return ConvertImpl<From, Derived, kSaturate, kTruncate>::run(from);
  }
}

template <typename Derived>
template <typename To, bool kSaturate, bool kTruncate>
EIGEN_DEVICE_FUNC To float8_base<Derived>::ConvertTo(Derived from) {
  return ConvertImpl<Derived, To, kSaturate, kTruncate>::run(from);
}

}  // namespace float8_internal

// Exported types.
using float8_e3m4 = float8_internal::float8_e3m4;
using float8_e4m3 = float8_internal::float8_e4m3;
using float8_e4m3fn = float8_internal::float8_e4m3fn;
using float8_e4m3fnuz = float8_internal::float8_e4m3fnuz;
using float8_e4m3b11fnuz = float8_internal::float8_e4m3b11fnuz;
using float8_e5m2 = float8_internal::float8_e5m2;
using float8_e5m2fnuz = float8_internal::float8_e5m2fnuz;
using float8_e8m0fnu = float8_internal::float8_e8m0fnu;

}  // namespace ml_dtypes

// Work-around for isinf/isnan/isfinite issue on aarch64.
namespace Eigen {
namespace internal {

template <>
EIGEN_DEVICE_FUNC inline bool isinf_impl<ml_dtypes::float8_e3m4>(
    const ml_dtypes::float8_e3m4& x) {
  return ml_dtypes::float8_internal::isinf(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isinf_impl<ml_dtypes::float8_e4m3>(
    const ml_dtypes::float8_e4m3& x) {
  return ml_dtypes::float8_internal::isinf(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isinf_impl<ml_dtypes::float8_e4m3fn>(
    const ml_dtypes::float8_e4m3fn& x) {
  return ml_dtypes::float8_internal::isinf(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isinf_impl<ml_dtypes::float8_e4m3b11fnuz>(
    const ml_dtypes::float8_e4m3b11fnuz& x) {
  return ml_dtypes::float8_internal::isinf(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isinf_impl<ml_dtypes::float8_e4m3fnuz>(
    const ml_dtypes::float8_e4m3fnuz& x) {
  return ml_dtypes::float8_internal::isinf(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isinf_impl<ml_dtypes::float8_e5m2>(
    const ml_dtypes::float8_e5m2& x) {
  return ml_dtypes::float8_internal::isinf(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isinf_impl<ml_dtypes::float8_e5m2fnuz>(
    const ml_dtypes::float8_e5m2fnuz& x) {
  return ml_dtypes::float8_internal::isinf(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isinf_impl<ml_dtypes::float8_e8m0fnu>(
    const ml_dtypes::float8_e8m0fnu& x) {
  return ml_dtypes::float8_internal::isinf(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isnan_impl<ml_dtypes::float8_e3m4>(
    const ml_dtypes::float8_e3m4& x) {
  return ml_dtypes::float8_internal::isnan(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isnan_impl<ml_dtypes::float8_e4m3>(
    const ml_dtypes::float8_e4m3& x) {
  return ml_dtypes::float8_internal::isnan(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isnan_impl<ml_dtypes::float8_e4m3fn>(
    const ml_dtypes::float8_e4m3fn& x) {
  return ml_dtypes::float8_internal::isnan(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isnan_impl<ml_dtypes::float8_e4m3b11fnuz>(
    const ml_dtypes::float8_e4m3b11fnuz& x) {
  return ml_dtypes::float8_internal::isnan(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isnan_impl<ml_dtypes::float8_e4m3fnuz>(
    const ml_dtypes::float8_e4m3fnuz& x) {
  return ml_dtypes::float8_internal::isnan(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isnan_impl<ml_dtypes::float8_e5m2>(
    const ml_dtypes::float8_e5m2& x) {
  return ml_dtypes::float8_internal::isnan(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isnan_impl<ml_dtypes::float8_e5m2fnuz>(
    const ml_dtypes::float8_e5m2fnuz& x) {
  return ml_dtypes::float8_internal::isnan(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isnan_impl<ml_dtypes::float8_e8m0fnu>(
    const ml_dtypes::float8_e8m0fnu& x) {
  return ml_dtypes::float8_internal::isnan(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isfinite_impl<ml_dtypes::float8_e3m4>(
    const ml_dtypes::float8_e3m4& x) {
  return ml_dtypes::float8_internal::isfinite(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isfinite_impl<ml_dtypes::float8_e4m3>(
    const ml_dtypes::float8_e4m3& x) {
  return ml_dtypes::float8_internal::isfinite(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isfinite_impl<ml_dtypes::float8_e4m3fn>(
    const ml_dtypes::float8_e4m3fn& x) {
  return ml_dtypes::float8_internal::isfinite(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isfinite_impl<ml_dtypes::float8_e4m3b11fnuz>(
    const ml_dtypes::float8_e4m3b11fnuz& x) {
  return ml_dtypes::float8_internal::isfinite(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isfinite_impl<ml_dtypes::float8_e4m3fnuz>(
    const ml_dtypes::float8_e4m3fnuz& x) {
  return ml_dtypes::float8_internal::isfinite(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isfinite_impl<ml_dtypes::float8_e5m2>(
    const ml_dtypes::float8_e5m2& x) {
  return ml_dtypes::float8_internal::isfinite(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isfinite_impl<ml_dtypes::float8_e5m2fnuz>(
    const ml_dtypes::float8_e5m2fnuz& x) {
  return ml_dtypes::float8_internal::isfinite(x);
}

template <>
EIGEN_DEVICE_FUNC inline bool isfinite_impl<ml_dtypes::float8_e8m0fnu>(
    const ml_dtypes::float8_e8m0fnu& x) {
  return ml_dtypes::float8_internal::isfinite(x);
}

}  // namespace internal
}  // namespace Eigen

#endif  // ML_DTYPES_FLOAT8_H_