1# ml_dtypes 2 3[](https://github.com/jax-ml/ml_dtypes/actions/workflows/test.yml) 4[](https://github.com/jax-ml/ml_dtypes/actions/workflows/wheels.yml) 5[](https://badge.fury.io/py/ml_dtypes) 6 7`ml_dtypes` is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including: 8 9- [`bfloat16`](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format): 10 an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format 11- 8-bit floating point representations, parameterized by number of exponent and 12 mantissa bits, as well as the bias (if any) and representability of infinity, 13 NaN, and signed zero. 14 * `float8_e3m4` 15 * `float8_e4m3` 16 * `float8_e4m3b11fnuz` 17 * `float8_e4m3fn` 18 * `float8_e4m3fnuz` 19 * `float8_e5m2` 20 * `float8_e5m2fnuz` 21 * `float8_e8m0fnu` 22- Microscaling (MX) sub-byte floating point representations: 23 * `float4_e2m1fn` 24 * `float6_e2m3fn` 25 * `float6_e3m2fn` 26- Narrow integer encodings: 27 * `int2` 28 * `int4` 29 * `uint2` 30 * `uint4` 31 32See below for specifications of these number formats. 33 34## Installation 35 36The `ml_dtypes` package is tested with Python versions 3.9-3.12, and can be installed 37with the following command: 38``` 39pip install ml_dtypes 40``` 41To test your installation, you can run the following: 42``` 43pip install absl-py pytest 44pytest --pyargs ml_dtypes 45``` 46To build from source, clone the repository and run: 47``` 48git submodule init 49git submodule update 50pip install . 51``` 52 53## Example Usage 54 55```python 56>>> from ml_dtypes import bfloat16 57>>> import numpy as np 58>>> np.zeros(4, dtype=bfloat16) 59array([0, 0, 0, 0], dtype=bfloat16) 60``` 61Importing `ml_dtypes` also registers the data types with numpy, so that they may 62be referred to by their string name: 63 64```python 65>>> np.dtype('bfloat16') 66dtype(bfloat16) 67>>> np.dtype('float8_e5m2') 68dtype(float8_e5m2) 69``` 70 71## Specifications of implemented floating point formats 72 73### `bfloat16` 74 75A `bfloat16` number is a single-precision float truncated at 16 bits. 76 77Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf. 78 79### `float4_e2m1fn` 80 81Exponent: 2, Mantissa: 1, bias: 1. 82 83Extended range: no inf, no NaN. 84 85Microscaling format, 4 bits (encoding: `0bSEEM`) using byte storage (higher 4 86bits are unused). NaN representation is undefined. 87 88Possible absolute values: [`0`, `0.5`, `1`, `1.5`, `2`, `3`, `4`, `6`] 89 90### `float6_e2m3fn` 91 92Exponent: 2, Mantissa: 3, bias: 1. 93 94Extended range: no inf, no NaN. 95 96Microscaling format, 6 bits (encoding: `0bSEEMMM`) using byte storage (higher 2 97bits are unused). NaN representation is undefined. 98 99Possible values range: [`-7.5`; `7.5`] 100 101### `float6_e3m2fn` 102 103Exponent: 3, Mantissa: 2, bias: 3. 104 105Extended range: no inf, no NaN. 106 107Microscaling format, 4 bits (encoding: `0bSEEEMM`) using byte storage (higher 2 108bits are unused). NaN representation is undefined. 109 110Possible values range: [`-28`; `28`] 111 112### `float8_e3m4` 113 114Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf. 115 116### `float8_e4m3` 117 118Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf. 119 120### `float8_e4m3b11fnuz` 121 122Exponent: 4, Mantissa: 3, bias: 11. 123 124Extended range: no inf, NaN represented by 0b1000'0000. 125 126### `float8_e4m3fn` 127 128Exponent: 4, Mantissa: 3, bias: 7. 129 130Extended range: no inf, NaN represented by 0bS111'1111. 131 132The `fn` suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The `f` indicates it is finite values only. The `n` indicates it includes NaNs, but only at the outer range. 133 134### `float8_e4m3fnuz` 135 1368-bit floating point with 3 bit mantissa. 137 138An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix `fnuz` is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. 139 140This type has the following characteristics: 141 * bit encoding: S1E4M3 - `0bSEEEEMMM` 142 * exponent bias: 8 143 * infinities: Not supported 144 * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000` 145 * denormals when exponent is 0 146 147### `float8_e5m2` 148 149Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf. 150 151### `float8_e5m2fnuz` 152 1538-bit floating point with 2 bit mantissa. 154 155An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix `fnuz` is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. `F` is for "finite" (no infinities), `N` for with special NaN encoding, `UZ` for unsigned zero. 156 157This type has the following characteristics: 158 * bit encoding: S1E5M2 - `0bSEEEEEMM` 159 * exponent bias: 16 160 * infinities: Not supported 161 * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000` 162 * denormals when exponent is 0 163 164### `float8_e8m0fnu` 165 166[OpenCompute MX](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) 167scale format E8M0, which has the following properties: 168 * Unsigned format 169 * 8 exponent bits 170 * Exponent range from -127 to 127 171 * No zero and infinity 172 * Single NaN value (0xFF). 173 174## `int2`, `int4`, `uint2` and `uint4` 175 1762 and 4-bit integer types, where each element is represented unpacked (i.e., 177padded up to a byte in memory). 178 179NumPy does not support types smaller than a single byte: for example, the 180distance between adjacent elements in an array (`.strides`) is expressed as 181an integer number of bytes. Relaxing this restriction would be a considerable 182engineering project. These types therefore use an unpacked representation, where 183each element of the array is padded up to a byte in memory. The lower two or four 184bits of each byte contain the representation of the number, whereas the remaining 185upper bits are ignored. 186 187## Quirks of low-precision Arithmetic 188 189If you're exploring the use of low-precision dtypes in your code, you should be 190careful to anticipate when the precision loss might lead to surprising results. 191One example is the behavior of aggregations like `sum`; consider this `bfloat16` 192summation in NumPy (run with version 1.24.2): 193 194```python 195>>> from ml_dtypes import bfloat16 196>>> import numpy as np 197>>> rng = np.random.default_rng(seed=0) 198>>> vals = rng.uniform(size=10000).astype(bfloat16) 199>>> vals.sum() 200256 201``` 202The true sum should be close to 5000, but numpy returns exactly 256: this is 203because `bfloat16` does not have the precision to increment `256` by values less than 204`1`: 205 206```python 207>>> bfloat16(256) + bfloat16(1) 208256 209``` 210After 256, the next representable value in bfloat16 is 258: 211 212```python 213>>> np.nextafter(bfloat16(256), bfloat16(np.inf)) 214258 215``` 216For better results you can specify that the accumulation should happen in a 217higher-precision type like `float32`: 218 219```python 220>>> vals.sum(dtype='float32').astype(bfloat16) 2214992 222``` 223In contrast to NumPy, projects like [JAX](http://jax.readthedocs.io/) which support 224low-precision arithmetic more natively will often do these kinds of higher-precision 225accumulations automatically: 226 227```python 228>>> import jax.numpy as jnp 229>>> jnp.array(vals).sum() 230Array(4992, dtype=bfloat16) 231``` 232 233## License 234 235*This is not an officially supported Google product.* 236 237The `ml_dtypes` source code is licensed under the Apache 2.0 license 238(see [LICENSE](LICENSE)). Pre-compiled wheels are built with the 239[EIGEN](https://eigen.tuxfamily.org/) project, which is released under the 240MPL 2.0 license (see [LICENSE.eigen](LICENSE.eigen)). 241