File: test_isdtype.py

package info (click to toggle)
python-array-api-compat 1.11.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 708 kB
  • sloc: python: 3,954; sh: 16; makefile: 15
file content (115 lines) | stat: -rw-r--r-- 3,583 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
isdtype is not yet tested in the test suite, and it should extend properly to
non-spec dtypes
"""

import pytest

from ._helpers import import_, wrapped_libraries

# Check the known dtypes by their string names

def _spec_dtypes(library):
    if library == 'torch':
        # torch does not have unsigned integer dtypes
        return {
            'bool',
            'complex64',
            'complex128',
            'uint8',
            'int8',
            'int16',
            'int32',
            'int64',
            'float32',
            'float64',
        }
    else:
        return {
            'bool',
            'complex64',
            'complex128',
            'float32',
            'float64',
            'int16',
            'int32',
            'int64',
            'int8',
            'uint16',
            'uint32',
            'uint64',
            'uint8',
        }

dtype_categories = {
    'bool': lambda d: d == 'bool',
    'signed integer': lambda d: d.startswith('int'),
    'unsigned integer': lambda d: d.startswith('uint'),
    'integral': lambda d: dtype_categories['signed integer'](d) or
                          dtype_categories['unsigned integer'](d),
    'real floating': lambda d: 'float' in d,
    'complex floating': lambda d: d.startswith('complex'),
    'numeric': lambda d: dtype_categories['integral'](d) or
                         dtype_categories['real floating'](d) or
                         dtype_categories['complex floating'](d),
}

def isdtype_(dtype_, kind):
    # Check a dtype_ string against kind. Note that 'bool' technically has two
    # meanings here but they are both the same.
    if kind in dtype_categories:
        res = dtype_categories[kind](dtype_)
    else:
        res = dtype_ == kind
    assert type(res) is bool #  noqa: E721
    return res

@pytest.mark.parametrize("library", wrapped_libraries)
def test_isdtype_spec_dtypes(library):
    xp = import_(library, wrapper=True)

    isdtype = xp.isdtype

    for dtype_ in _spec_dtypes(library):
        for dtype2_ in _spec_dtypes(library):
            dtype = getattr(xp, dtype_)
            dtype2 = getattr(xp, dtype2_)
            res = isdtype_(dtype_, dtype2_)
            assert isdtype(dtype, dtype2) is res, (dtype_, dtype2_)

        for cat in dtype_categories:
            res = isdtype_(dtype_, cat)
            assert isdtype(dtype, cat) == res, (dtype_, cat)

        # Basic tuple testing (the array-api testsuite will be more complete here)
        for kind1_ in [*_spec_dtypes(library), *dtype_categories]:
            for kind2_ in [*_spec_dtypes(library), *dtype_categories]:
                kind1 = kind1_ if kind1_ in dtype_categories else getattr(xp, kind1_)
                kind2 = kind2_ if kind2_ in dtype_categories else getattr(xp, kind2_)
                kind = (kind1, kind2)

                res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_)
                assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_))

additional_dtypes = [
    'float16',
    'float128',
    'complex256',
    'bfloat16',
]

@pytest.mark.parametrize("library", wrapped_libraries)
@pytest.mark.parametrize("dtype_", additional_dtypes)
def test_isdtype_additional_dtypes(library, dtype_):
    xp = import_(library, wrapper=True)

    isdtype = xp.isdtype

    if not hasattr(xp, dtype_):
        return
        # pytest.skip(f"{library} doesn't have dtype {dtype_}")

    dtype = getattr(xp, dtype_)
    for cat in dtype_categories:
        res = isdtype_(dtype_, cat)
        assert isdtype(dtype, cat) == res, (dtype_, cat)