1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
|
# Basic test that vendoring works
from .vendored._compat import (
is_torch_array,
is_torch_namespace,
torch as torch_compat,
)
import torch
def _test_torch():
a = torch_compat.asarray([1., 2., 3.])
b = torch_compat.arange(3, dtype=torch_compat.float64)
assert a.dtype == torch_compat.float32 == torch.float32
assert b.dtype == torch_compat.float64 == torch.float64
# torch.expand_dims does not exist. Update this to use something else if it is added
res = torch_compat.expand_dims(a, axis=0)
assert res.dtype == torch_compat.float32 == torch.float32
assert res.shape == (1, 3)
assert isinstance(res.shape, torch.Size)
assert isinstance(a, torch.Tensor)
assert isinstance(b, torch.Tensor)
assert isinstance(res, torch.Tensor)
torch.testing.assert_close(res, torch.as_tensor([[1., 2., 3.]]))
assert is_torch_array(res)
assert is_torch_namespace(torch) and is_torch_namespace(torch_compat)
|