File: uses_torch.py

package info (click to toggle)
python-array-api-compat 1.11.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 708 kB
  • sloc: python: 3,954; sh: 16; makefile: 15
file content (30 lines) | stat: -rw-r--r-- 966 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Basic test that vendoring works

from .vendored._compat import (
    is_torch_array,
    is_torch_namespace,
    torch as torch_compat,
)

import torch

def _test_torch():
    a = torch_compat.asarray([1., 2., 3.])
    b = torch_compat.arange(3, dtype=torch_compat.float64)
    assert a.dtype == torch_compat.float32 == torch.float32
    assert b.dtype == torch_compat.float64 == torch.float64

    # torch.expand_dims does not exist. Update this to use something else if it is added
    res = torch_compat.expand_dims(a, axis=0)
    assert res.dtype == torch_compat.float32 == torch.float32
    assert res.shape == (1, 3)
    assert isinstance(res.shape, torch.Size)
    assert isinstance(a, torch.Tensor)
    assert isinstance(b, torch.Tensor)
    assert isinstance(res, torch.Tensor)

    torch.testing.assert_close(res, torch.as_tensor([[1., 2., 3.]]))

    assert is_torch_array(res)
    assert is_torch_namespace(torch) and is_torch_namespace(torch_compat)