File: mxfp.py

package info (click to toggle)
python-bitstring 4.3.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,312 kB
  • sloc: python: 11,397; makefile: 8; sh: 7
file content (206 lines) | stat: -rw-r--r-- 9,058 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import array
import math
import struct
import bitarray
from bitstring.luts import mxfp_luts_compressed
import zlib
from typing import Optional


def round_to_nearest_ties_to_even(lut_int_to_float, lower: int, f: float) -> Optional[int]:
    upper = lower + 1
    # Special case for LUTs without a negative zero.
    lower_float = 0.0 if lower == 128 else lut_int_to_float[lower]
    upper_float = lut_int_to_float[upper]
    if upper_float < lower_float:
        lower, upper = upper, lower
        lower_float, upper_float = upper_float, lower_float
    if f == lower_float:
        return lower
    if f == upper_float:
        return upper
    if lower_float < f < upper_float:
        d1 = f - lower_float
        d2 = upper_float - f
        if d1 < d2:
            return lower
        if d2 < d1:
            return upper
        return lower if lower % 2 == 0 else upper
    return None


class MXFPFormat:
    """Defining an MXFP micro-scaling floating point format"""

    def __init__(self, exp_bits: int, mantissa_bits: int, bias: int, mxfp_overflow: str):
        self.exp_bits = exp_bits
        self.mantissa_bits = mantissa_bits
        self.bias = bias
        self.mxfp_overflow = mxfp_overflow

        self.pos_clamp_value = (1 << (self.exp_bits + self.mantissa_bits)) - 1
        self.neg_clamp_value = (1 << (1 + self.exp_bits + self.mantissa_bits)) - 1

        # Special cases for e4m3 and e5m2
        if self.exp_bits == 4 and self.mantissa_bits == 3:
            if self.mxfp_overflow == 'saturate':
                self.pos_clamp_value = 0b01111110  # 448
                self.neg_clamp_value = 0b11111110  # -448
            else:
                self.pos_clamp_value = self.neg_clamp_value = 0b11111111  # NaN
        if self.exp_bits == 5 and self.mantissa_bits == 2:
            if self.mxfp_overflow == 'saturate':
                self.pos_clamp_value = 0b01111011  # 57344
                self.neg_clamp_value = 0b11111011  # -57344
            else:
                self.pos_clamp_value = 0b01111100  # +inf
                self.neg_clamp_value = 0b11111100  # -inf

        # If we calculate these LUTs now it creates a bootstrap problem in generate_luts.py.
        self.lut_float16_to_mxfp = None
        self.lut_int_to_float = None

    def __str__(self):
        return f"MXFPFormat(exp_bits={self.exp_bits}, mantissa_bits={self.mantissa_bits}, bias={self.bias}, mxfp_overflow='{self.mxfp_overflow}')"

    def decompress_luts(self):
        int_to_float_compressed, float16_to_mxfp_compressed = mxfp_luts_compressed[(self.exp_bits, self.mantissa_bits, self.bias, self.mxfp_overflow)]
        self.lut_float16_to_mxfp = zlib.decompress(float16_to_mxfp_compressed)
        dec = zlib.decompress(int_to_float_compressed)
        self.lut_int_to_float = struct.unpack(f'<{len(dec) // 4}f', dec)

    def create_luts(self):
        self.lut_int_to_float = self.createLUT_for_int_to_float()
        self.lut_float16_to_mxfp = self.createLUT_for_float16_to_mxfp()

    def float_to_int(self, f: float) -> int:
        """Given a Python float convert to the best mxfp float (expressed as an int) that represents it."""
        # First convert the float to a float16, then a 16 bit uint
        try:
            b = struct.pack('>e', f)
        except (OverflowError, struct.error):
            # Return the largest representable positive or negative value
            return self.pos_clamp_value if f > 0 else self.neg_clamp_value

        f16_int = int.from_bytes(b, byteorder='big')
        # Then use this as an index to our large LUT
        return self.lut_float16_to_mxfp[f16_int]

    def slow_float_to_int(self, f: float) -> int:
        # Slow, but easier to follow than the faster version.
        # The output int has the binary sequence needed for the float.
        length = 1 + self.exp_bits + self.mantissa_bits
        values = 1 << length
        # First get the NaN case out of the way
        if math.isnan(f):
            if length == 8:
                return 0xff  # Works for both e5m2 and e4m3
            # For smaller lengths, NaN isn't supported so we instead return an invalid value to detect later
            return 0xff
        # This is so we can distinguish between 0.0 and -0.0
        is_positive = math.copysign(1.0, f) == 1.0
        if is_positive:
            # Positive, so top bit is not set
            for i in range(values // 2 - 1):
                upper = self.lut_int_to_float[i + 1]
                if upper == float('inf'):
                    break
                x = round_to_nearest_ties_to_even(self.lut_int_to_float, i, f)
                if x is not None:
                    return x
            return self.pos_clamp_value
        else:
            # Negative, so top bit is set
            for i in range(values // 2, values - 1):
                lower = self.lut_int_to_float[i + 1]
                if lower == float('-inf'):
                    break
                x = round_to_nearest_ties_to_even(self.lut_int_to_float, i, f)
                if x is not None:
                    return x
            # Clip to negative max
            return self.neg_clamp_value

    def createLUT_for_int_to_float(self) -> array.array:
        """Create a LUT to convert an int in representing a MXFP float into a Python float"""
        i2f = []
        length = 1 + self.exp_bits + self.mantissa_bits
        for i in range(1 << length):
            b = bitarray.util.int2ba(i, length=length, endian='big', signed=False)
            sign = b[0]
            exponent = bitarray.util.ba2int(b[1:1 + self.exp_bits])
            significand = b[1 + self.exp_bits:]
            if exponent == 0:
                significand = bitarray.bitarray('0') + significand
                exponent = -self.bias + 1
            else:
                significand = bitarray.bitarray('1') + significand
                exponent -= self.bias
            f = float(bitarray.util.ba2int(significand)) / (2.0 ** self.mantissa_bits)
            f *= 2 ** exponent
            if length == 8:
                # Some special cases
                if self.exp_bits == 5:
                    if i in [0b01111100, 0b11111100]:
                        f = float('inf')
                    if i in [0b01111101, 0b11111101, 0b01111110, 0b11111110, 0b01111111, 0b11111111]:
                        f = float('nan')
                if self.exp_bits == 4:
                    if i in [0b01111111, 0b11111111]:
                        f = float('nan')
            i2f.append(f if not sign else -f)
        return array.array('f', i2f)

    def createLUT_for_float16_to_mxfp(self) -> bytes:
        """Create a LUT to convert a float16 into a MXFP format"""
        # Used to create the LUT that was compressed and stored for the fp8 code
        length = 1 + self.exp_bits + self.mantissa_bits
        if length == 8:
            import gfloat
            from gfloat.formats import format_info_ocp_e5m2, format_info_ocp_e4m3
            fi = format_info_ocp_e5m2 if self.exp_bits == 5 else format_info_ocp_e4m3

            fp16_to_fp8 = bytearray(1 << 16)
            for i in range(1 << 16):
                b = struct.pack('>H', i)
                f, = struct.unpack('>e', b)
                fp = gfloat.round_float(fi, f, sat=self.mxfp_overflow == 'saturate')
                if math.isnan(fp):
                    fp8_i = 0b11111111
                else:
                    # Special case for negative zero
                    if fp == 0.0 and math.copysign(1.0, fp) == -1.0:
                        fp8_i = 0b10000000
                    else:
                        fp8_i = self.lut_int_to_float.index(fp)
                fp16_to_fp8[i] = fp8_i
            return bytes(fp16_to_fp8)
        else:
            assert length in [4, 6]
            fp16_to_fp8 = bytearray(1 << 16)
            for i in range(1 << 16):
                b = struct.pack('>H', i)
                f, = struct.unpack('>e', b)
                fp8_i = self.slow_float_to_int(f)
                fp16_to_fp8[i] = fp8_i
            return bytes(fp16_to_fp8)


e2m1mxfp_fmt = MXFPFormat(exp_bits=2, mantissa_bits=1, bias=1, mxfp_overflow='saturate')
e2m3mxfp_fmt = MXFPFormat(exp_bits=2, mantissa_bits=3, bias=1, mxfp_overflow='saturate')
e3m2mxfp_fmt = MXFPFormat(exp_bits=3, mantissa_bits=2, bias=3, mxfp_overflow='saturate')
e4m3mxfp_saturate_fmt = MXFPFormat(exp_bits=4, mantissa_bits=3, bias=7, mxfp_overflow='saturate')
e5m2mxfp_saturate_fmt = MXFPFormat(exp_bits=5, mantissa_bits=2, bias=15, mxfp_overflow='saturate')
e4m3mxfp_overflow_fmt = MXFPFormat(exp_bits=4, mantissa_bits=3, bias=7, mxfp_overflow='overflow')
e5m2mxfp_overflow_fmt = MXFPFormat(exp_bits=5, mantissa_bits=2, bias=15, mxfp_overflow='overflow')


def decompress_luts():
    e2m1mxfp_fmt.decompress_luts()
    e2m3mxfp_fmt.decompress_luts()
    e3m2mxfp_fmt.decompress_luts()
    e4m3mxfp_saturate_fmt.decompress_luts()
    e5m2mxfp_saturate_fmt.decompress_luts()
    e4m3mxfp_overflow_fmt.decompress_luts()
    e5m2mxfp_overflow_fmt.decompress_luts()