1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
import pytest
import numpy as np
from numpy.typing import NDArray
import torch
from mx.mx_ops import quantize_mx_op
from mx.formats import ElemFormat
from gfloat import (
BlockFormatInfo,
RoundMode,
quantize_block,
compute_scale_amax,
encode_block,
)
from gfloat.formats import *
@pytest.mark.parametrize(
("mx_round,gf_round"),
[("even", RoundMode.TiesToEven), ("nearest", RoundMode.TiesToAway)],
)
@pytest.mark.parametrize(
("mx_etype,gf_etype"),
[
(ElemFormat.int8, format_info_ocp_int8),
(ElemFormat.fp6_e3m2, format_info_ocp_e3m2),
(ElemFormat.fp4_e2m1, format_info_ocp_e2m1),
],
)
@pytest.mark.parametrize(
"A",
[
np.arange(32) / 2 - 5,
np.zeros(32),
],
ids=[
"tennish",
"zeros",
],
)
def test_mx(
mx_etype: ElemFormat,
gf_etype: FormatInfo,
mx_round: str,
gf_round: RoundMode,
A: NDArray[np.float64],
) -> None:
# MX: Declare block format
mx_specs = dict(
block_size=32,
scale_bits=8,
shared_exp_method="max",
mx_flush_fp32_subnorms=False,
custom_cuda=False,
)
# MX: Quantize
mx_dq = quantize_mx_op(torch.tensor(A), mx_specs, mx_etype, axes=0, round=mx_round)
# GFloat: Declare block format
fi = BlockFormatInfo("test", gf_etype, 32, format_info_ocp_e8m0)
# GFloat: Quantize
gf_dq = quantize_block(fi, A, compute_scale_amax, gf_round)
# Compare
np.testing.assert_allclose(gf_dq, mx_dq)
def test_mx_exceptions() -> None:
fi = BlockFormatInfo("test", format_info_ocp_e2m1, 32, format_info_ocp_e8m0)
A = np.ones(32) * 2.0**-139
s = compute_scale_amax(fi.etype.emax, A)
assert s == 2.0**-127
with pytest.raises(ValueError, match="out of range"):
list(encode_block(fi, fi.stype.max * 2, A))
assert not fi.stype.is_signed
scale = fi.stype.min / 2
assert scale != 0
with pytest.raises(ValueError, match="out of range"):
list(encode_block(fi, scale, A))
|