File: README.md

package info (click to toggle)
ml-dtypes 0.5.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,768 kB
  • sloc: ansic: 48,160; cpp: 26,737; python: 2,344; pascal: 514; makefile: 15
file content (240 lines) | stat: -rw-r--r-- 7,779 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# ml_dtypes

[![Unittests](https://github.com/jax-ml/ml_dtypes/actions/workflows/test.yml/badge.svg)](https://github.com/jax-ml/ml_dtypes/actions/workflows/test.yml)
[![Wheel Build](https://github.com/jax-ml/ml_dtypes/actions/workflows/wheels.yml/badge.svg)](https://github.com/jax-ml/ml_dtypes/actions/workflows/wheels.yml)
[![PyPI version](https://badge.fury.io/py/ml_dtypes.svg)](https://badge.fury.io/py/ml_dtypes)

`ml_dtypes` is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:

- [`bfloat16`](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format):
  an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format
- 8-bit floating point representations, parameterized by number of exponent and
  mantissa bits, as well as the bias (if any) and representability of infinity,
  NaN, and signed zero.
  * `float8_e3m4`
  * `float8_e4m3`
  * `float8_e4m3b11fnuz`
  * `float8_e4m3fn`
  * `float8_e4m3fnuz`
  * `float8_e5m2`
  * `float8_e5m2fnuz`
  * `float8_e8m0fnu`
- Microscaling (MX) sub-byte floating point representations:
  * `float4_e2m1fn`
  * `float6_e2m3fn`
  * `float6_e3m2fn`
- Narrow integer encodings:
  * `int2`
  * `int4`
  * `uint2`
  * `uint4`

See below for specifications of these number formats.

## Installation

The `ml_dtypes` package is tested with Python versions 3.9-3.12, and can be installed
with the following command:
```
pip install ml_dtypes
```
To test your installation, you can run the following:
```
pip install absl-py pytest
pytest --pyargs ml_dtypes
```
To build from source, clone the repository and run:
```
git submodule init
git submodule update
pip install .
```

## Example Usage

```python
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)
```
Importing `ml_dtypes` also registers the data types with numpy, so that they may
be referred to by their string name:

```python
>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)
```

## Specifications of implemented floating point formats

### `bfloat16`

A `bfloat16` number is a single-precision float truncated at 16 bits.

Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.

### `float4_e2m1fn`

Exponent: 2, Mantissa: 1, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: `0bSEEM`) using byte storage (higher 4
bits are unused). NaN representation is undefined.

Possible absolute values: [`0`, `0.5`, `1`, `1.5`, `2`, `3`, `4`, `6`]

### `float6_e2m3fn`

Exponent: 2, Mantissa: 3, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 6 bits (encoding: `0bSEEMMM`) using byte storage (higher 2
bits are unused). NaN representation is undefined.

Possible values range: [`-7.5`; `7.5`]

### `float6_e3m2fn`

Exponent: 3, Mantissa: 2, bias: 3.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: `0bSEEEMM`) using byte storage (higher 2
bits are unused). NaN representation is undefined.

Possible values range: [`-28`; `28`]

### `float8_e3m4`

Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf.

### `float8_e4m3`

Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf.

### `float8_e4m3b11fnuz`

Exponent: 4, Mantissa: 3, bias: 11.

Extended range: no inf, NaN represented by 0b1000'0000.

### `float8_e4m3fn`

Exponent: 4, Mantissa: 3, bias: 7.

Extended range: no inf, NaN represented by 0bS111'1111.

The `fn` suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754.  The `f` indicates it is finite values only. The `n` indicates it includes NaNs, but only at the outer range.

### `float8_e4m3fnuz`

8-bit floating point with 3 bit mantissa.

An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix `fnuz` is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero.

This type has the following characteristics:
 * bit encoding: S1E4M3 - `0bSEEEEMMM`
 * exponent bias: 8
 * infinities: Not supported
 * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000`
 * denormals when exponent is 0

### `float8_e5m2`

Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.

### `float8_e5m2fnuz`

8-bit floating point with 2 bit mantissa.

An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix `fnuz` is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero.

This type has the following characteristics:
 * bit encoding: S1E5M2 - `0bSEEEEEMM`
 * exponent bias: 16
 * infinities: Not supported
 * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000`
 * denormals when exponent is 0

### `float8_e8m0fnu`

[OpenCompute MX](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
scale format E8M0, which has the following properties:
  * Unsigned format
  * 8 exponent bits
  * Exponent range from -127 to 127
  * No zero and infinity
  * Single NaN value (0xFF).

## `int2`, `int4`, `uint2` and `uint4`

2 and 4-bit integer types, where each element is represented unpacked (i.e.,
padded up to a byte in memory).

NumPy does not support types smaller than a single byte: for example, the
distance between adjacent elements in an array (`.strides`) is expressed as
an integer number of bytes. Relaxing this restriction would be a considerable
engineering project. These types therefore use an unpacked representation, where
each element of the array is padded up to a byte in memory. The lower two or four
bits of each byte contain the representation of the number, whereas the remaining
upper bits are ignored.

## Quirks of low-precision Arithmetic

If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like `sum`; consider this `bfloat16`
summation in NumPy (run with version 1.24.2):

```python
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256
```
The true sum should be close to 5000, but numpy returns exactly 256: this is
because `bfloat16` does not have the precision to increment `256` by values less than
`1`:

```python
>>> bfloat16(256) + bfloat16(1)
256
```
After 256, the next representable value in bfloat16 is 258:

```python
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258
```
For better results you can specify that the accumulation should happen in a
higher-precision type like `float32`:

```python
>>> vals.sum(dtype='float32').astype(bfloat16)
4992
```
In contrast to NumPy, projects like [JAX](http://jax.readthedocs.io/) which support
low-precision arithmetic more natively will often do these kinds of higher-precision
accumulations automatically:

```python
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)
```

## License

*This is not an officially supported Google product.*

The `ml_dtypes` source code is licensed under the Apache 2.0 license
(see [LICENSE](LICENSE)). Pre-compiled wheels are built with the
[EIGEN](https://eigen.tuxfamily.org/) project, which is released under the
MPL 2.0 license (see [LICENSE.eigen](LICENSE.eigen)).