File: fp8.py

package info (click to toggle)
python-bitstring 4.3.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,312 kB
  • sloc: python: 11,397; makefile: 8; sh: 7
file content (97 lines) | stat: -rw-r--r-- 3,768 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
The 8-bit float formats used here are from a proposal supported by Graphcore, AMD and Qualcomm.
See https://arxiv.org/abs/2206.02915

"""

import struct
import zlib
import array
import bitarray
from bitstring.luts import binary8_luts_compressed
import math


class Binary8Format:
    """8-bit floating point formats based on draft IEEE binary8"""

    def __init__(self, exp_bits: int, bias: int):
        self.exp_bits = exp_bits
        self.bias = bias
        self.pos_clamp_value = 0b01111111
        self.neg_clamp_value = 0b11111111

    def __str__(self):
        return f"Binary8Format(exp_bits={self.exp_bits}, bias={self.bias})"

    def decompress_luts(self):
        binary8_to_float_compressed, float16_to_binary8_compressed = binary8_luts_compressed[(self.exp_bits, self.bias)]
        self.lut_float16_to_binary8 = zlib.decompress(float16_to_binary8_compressed)
        dec = zlib.decompress(binary8_to_float_compressed)
        self.lut_binary8_to_float = struct.unpack(f'<{len(dec) // 4}f', dec)

    def create_luts(self):
        self.lut_binary8_to_float = self.createLUT_for_binary8_to_float()
        self.lut_float16_to_binary8 = self.createLUT_for_float16_to_binary8()

    def float_to_int8(self, f: float) -> int:
        """Given a Python float convert to the best float8 (expressed as an integer in 0-255 range)."""
        # First convert the float to a float16, then a 16 bit uint
        try:
            b = struct.pack('>e', f)
        except (OverflowError, struct.error):
            # Return the largest representable positive or negative value
            return self.pos_clamp_value if f > 0 else self.neg_clamp_value
        f16_int = int.from_bytes(b, byteorder='big')
        # Then use this as an index to our large LUT
        return self.lut_float16_to_binary8[f16_int]

    def createLUT_for_float16_to_binary8(self) -> bytes:
        # Used to create the LUT that was compressed and stored for the fp8 code
        import gfloat
        fi = gfloat.formats.format_info_p3109(8 - self.exp_bits)
        fp16_to_fp8 = bytearray(1 << 16)
        for i in range(1 << 16):
            b = struct.pack('>H', i)
            f, = struct.unpack('>e', b)
            fp = gfloat.round_float(fi, f)
            if math.isnan(fp):
                fp8_i = 0b10000000
            else:
                fp8_i = self.lut_binary8_to_float.index(fp)
            fp16_to_fp8[i] = fp8_i
        return bytes(fp16_to_fp8)

    def createLUT_for_binary8_to_float(self):
        """Create a LUT to convert an int in range 0-255 representing a float8 into a Python float"""
        i2f = []
        for i in range(256):
            b = bitarray.util.int2ba(i, length=8, endian='big', signed=False)
            sign = b[0]
            exponent = bitarray.util.ba2int(b[1:1 + self.exp_bits])
            significand = b[1 + self.exp_bits:]
            if exponent == 0:
                significand = bitarray.bitarray('0') + significand
                exponent = -self.bias + 1
            else:
                significand = bitarray.bitarray('1') + significand
                exponent -= self.bias
            f = float(bitarray.util.ba2int(significand)) / (2.0 ** (7 - self.exp_bits))
            f *= 2 ** exponent
            i2f.append(f if not sign else -f)
        # One special case for minus zero
        i2f[0b10000000] = float('nan')
        # and for plus and minus infinity
        i2f[0b01111111] = float('inf')
        i2f[0b11111111] = float('-inf')
        return array.array('f', i2f)


# We create the 1.5.2 and 1.4.3 formats.
p4binary_fmt = Binary8Format(exp_bits=4, bias=8)
p3binary_fmt = Binary8Format(exp_bits=5, bias=16)


def decompress_luts():
    p4binary_fmt.decompress_luts()
    p3binary_fmt.decompress_luts()