• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# ml_dtypes
2
3[![Unittests](https://github.com/jax-ml/ml_dtypes/actions/workflows/test.yml/badge.svg)](https://github.com/jax-ml/ml_dtypes/actions/workflows/test.yml)
4[![Wheel Build](https://github.com/jax-ml/ml_dtypes/actions/workflows/wheels.yml/badge.svg)](https://github.com/jax-ml/ml_dtypes/actions/workflows/wheels.yml)
5[![PyPI version](https://badge.fury.io/py/ml_dtypes.svg)](https://badge.fury.io/py/ml_dtypes)
6
7`ml_dtypes` is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:
8
9- [`bfloat16`](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format):
10  an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format
11- 8-bit floating point representations, parameterized by number of exponent and
12  mantissa bits, as well as the bias (if any) and representability of infinity,
13  NaN, and signed zero.
14  * `float8_e3m4`
15  * `float8_e4m3`
16  * `float8_e4m3b11fnuz`
17  * `float8_e4m3fn`
18  * `float8_e4m3fnuz`
19  * `float8_e5m2`
20  * `float8_e5m2fnuz`
21  * `float8_e8m0fnu`
22- Microscaling (MX) sub-byte floating point representations:
23  * `float4_e2m1fn`
24  * `float6_e2m3fn`
25  * `float6_e3m2fn`
26- Narrow integer encodings:
27  * `int2`
28  * `int4`
29  * `uint2`
30  * `uint4`
31
32See below for specifications of these number formats.
33
34## Installation
35
36The `ml_dtypes` package is tested with Python versions 3.9-3.12, and can be installed
37with the following command:
38```
39pip install ml_dtypes
40```
41To test your installation, you can run the following:
42```
43pip install absl-py pytest
44pytest --pyargs ml_dtypes
45```
46To build from source, clone the repository and run:
47```
48git submodule init
49git submodule update
50pip install .
51```
52
53## Example Usage
54
55```python
56>>> from ml_dtypes import bfloat16
57>>> import numpy as np
58>>> np.zeros(4, dtype=bfloat16)
59array([0, 0, 0, 0], dtype=bfloat16)
60```
61Importing `ml_dtypes` also registers the data types with numpy, so that they may
62be referred to by their string name:
63
64```python
65>>> np.dtype('bfloat16')
66dtype(bfloat16)
67>>> np.dtype('float8_e5m2')
68dtype(float8_e5m2)
69```
70
71## Specifications of implemented floating point formats
72
73### `bfloat16`
74
75A `bfloat16` number is a single-precision float truncated at 16 bits.
76
77Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
78
79### `float4_e2m1fn`
80
81Exponent: 2, Mantissa: 1, bias: 1.
82
83Extended range: no inf, no NaN.
84
85Microscaling format, 4 bits (encoding: `0bSEEM`) using byte storage (higher 4
86bits are unused). NaN representation is undefined.
87
88Possible absolute values: [`0`, `0.5`, `1`, `1.5`, `2`, `3`, `4`, `6`]
89
90### `float6_e2m3fn`
91
92Exponent: 2, Mantissa: 3, bias: 1.
93
94Extended range: no inf, no NaN.
95
96Microscaling format, 6 bits (encoding: `0bSEEMMM`) using byte storage (higher 2
97bits are unused). NaN representation is undefined.
98
99Possible values range: [`-7.5`; `7.5`]
100
101### `float6_e3m2fn`
102
103Exponent: 3, Mantissa: 2, bias: 3.
104
105Extended range: no inf, no NaN.
106
107Microscaling format, 4 bits (encoding: `0bSEEEMM`) using byte storage (higher 2
108bits are unused). NaN representation is undefined.
109
110Possible values range: [`-28`; `28`]
111
112### `float8_e3m4`
113
114Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf.
115
116### `float8_e4m3`
117
118Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf.
119
120### `float8_e4m3b11fnuz`
121
122Exponent: 4, Mantissa: 3, bias: 11.
123
124Extended range: no inf, NaN represented by 0b1000'0000.
125
126### `float8_e4m3fn`
127
128Exponent: 4, Mantissa: 3, bias: 7.
129
130Extended range: no inf, NaN represented by 0bS111'1111.
131
132The `fn` suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754.  The `f` indicates it is finite values only. The `n` indicates it includes NaNs, but only at the outer range.
133
134### `float8_e4m3fnuz`
135
1368-bit floating point with 3 bit mantissa.
137
138An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix `fnuz` is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero.
139
140This type has the following characteristics:
141 * bit encoding: S1E4M3 - `0bSEEEEMMM`
142 * exponent bias: 8
143 * infinities: Not supported
144 * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000`
145 * denormals when exponent is 0
146
147### `float8_e5m2`
148
149Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.
150
151### `float8_e5m2fnuz`
152
1538-bit floating point with 2 bit mantissa.
154
155An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix `fnuz` is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero.
156
157This type has the following characteristics:
158 * bit encoding: S1E5M2 - `0bSEEEEEMM`
159 * exponent bias: 16
160 * infinities: Not supported
161 * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000`
162 * denormals when exponent is 0
163
164### `float8_e8m0fnu`
165
166[OpenCompute MX](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
167scale format E8M0, which has the following properties:
168  * Unsigned format
169  * 8 exponent bits
170  * Exponent range from -127 to 127
171  * No zero and infinity
172  * Single NaN value (0xFF).
173
174## `int2`, `int4`, `uint2` and `uint4`
175
1762 and 4-bit integer types, where each element is represented unpacked (i.e.,
177padded up to a byte in memory).
178
179NumPy does not support types smaller than a single byte: for example, the
180distance between adjacent elements in an array (`.strides`) is expressed as
181an integer number of bytes. Relaxing this restriction would be a considerable
182engineering project. These types therefore use an unpacked representation, where
183each element of the array is padded up to a byte in memory. The lower two or four
184bits of each byte contain the representation of the number, whereas the remaining
185upper bits are ignored.
186
187## Quirks of low-precision Arithmetic
188
189If you're exploring the use of low-precision dtypes in your code, you should be
190careful to anticipate when the precision loss might lead to surprising results.
191One example is the behavior of aggregations like `sum`; consider this `bfloat16`
192summation in NumPy (run with version 1.24.2):
193
194```python
195>>> from ml_dtypes import bfloat16
196>>> import numpy as np
197>>> rng = np.random.default_rng(seed=0)
198>>> vals = rng.uniform(size=10000).astype(bfloat16)
199>>> vals.sum()
200256
201```
202The true sum should be close to 5000, but numpy returns exactly 256: this is
203because `bfloat16` does not have the precision to increment `256` by values less than
204`1`:
205
206```python
207>>> bfloat16(256) + bfloat16(1)
208256
209```
210After 256, the next representable value in bfloat16 is 258:
211
212```python
213>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
214258
215```
216For better results you can specify that the accumulation should happen in a
217higher-precision type like `float32`:
218
219```python
220>>> vals.sum(dtype='float32').astype(bfloat16)
2214992
222```
223In contrast to NumPy, projects like [JAX](http://jax.readthedocs.io/) which support
224low-precision arithmetic more natively will often do these kinds of higher-precision
225accumulations automatically:
226
227```python
228>>> import jax.numpy as jnp
229>>> jnp.array(vals).sum()
230Array(4992, dtype=bfloat16)
231```
232
233## License
234
235*This is not an officially supported Google product.*
236
237The `ml_dtypes` source code is licensed under the Apache 2.0 license
238(see [LICENSE](LICENSE)). Pre-compiled wheels are built with the
239[EIGEN](https://eigen.tuxfamily.org/) project, which is released under the
240MPL 2.0 license (see [LICENSE.eigen](LICENSE.eigen)).
241