1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
|
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
# Test that finfo methods on FloatFormat agree with numpy/ml_dtypes
import ml_dtypes
import numpy as np
import pytest
from gfloat import decode_float, round_float
from gfloat.formats import *
@pytest.mark.parametrize(
"fmt,npfmt",
[
(format_info_ocp_e5m2, ml_dtypes.float8_e5m2),
(format_info_ocp_e4m3, ml_dtypes.float8_e4m3fn),
(format_info_binary16, np.float16),
(format_info_bfloat16, ml_dtypes.bfloat16),
],
)
def test_finfo(fmt: FormatInfo, npfmt: np.dtype) -> None:
assert fmt.eps == ml_dtypes.finfo(npfmt).eps
assert fmt.epsneg == ml_dtypes.finfo(npfmt).epsneg
assert fmt.max == ml_dtypes.finfo(npfmt).max
assert fmt.maxexp == ml_dtypes.finfo(npfmt).maxexp
def test_constants() -> None:
assert format_info_p3109(1).smallest_subnormal == 2.0**-62
assert format_info_p3109(4).smallest_subnormal == 2.0**-10
assert format_info_p3109(7).smallest_subnormal == 2.0**-6
|