File: test_3259_to_torch_from_torch.py

package info (click to toggle)
python-awkward 2.8.9-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 24,932 kB
  • sloc: python: 178,875; cpp: 33,828; sh: 432; makefile: 21; javascript: 8
file content (72 lines) | stat: -rw-r--r-- 2,296 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np
import pytest

import awkward as ak

to_torch = ak.operations.to_torch
from_torch = ak.operations.from_torch

torch = pytest.importorskip("torch")

a = np.arange(2 * 2 * 2, dtype=np.float64).reshape(2, 2, 2)
b = np.arange(2 * 2 * 2).reshape(2, 2, 2)

array = np.arange(2 * 3 * 5).reshape(2, 3, 5)
content2 = ak.contents.NumpyArray(array.reshape(-1))
inneroffsets = ak.index.Index64(np.array([0, 5, 10, 15, 20, 25, 30]))
outeroffsets = ak.index.Index64(np.array([0, 3, 6]))


def test_to_torch():
    # a basic test for a 4 dimensional array
    array1 = ak.Array([a, b])
    i = 0
    for sub_array in [
        [[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]],
        [[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]],
    ]:
        assert to_torch(array1)[i].tolist() == sub_array
        i += 1

    # test that the data types are remaining the same (float64 in this case)
    assert array1.layout.to_backend_array().dtype.name in str(to_torch(array1).dtype)

    # try a listoffset array inside a listoffset array
    array2 = ak.contents.ListOffsetArray(
        outeroffsets, ak.contents.ListOffsetArray(inneroffsets, content2)
    )
    assert to_torch(array2)[0].tolist() == [
        [0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9],
        [10, 11, 12, 13, 14],
    ]
    assert to_torch(array2)[1].tolist() == [
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
    ]

    # try just a python list
    array3 = [3, 1, 4, 1, 9, 2, 6]
    assert to_torch(array3).tolist() == [3, 1, 4, 1, 9, 2, 6]


array1 = torch.tensor([[1.0, -1.0], [1.0, -1.0]], dtype=torch.float32)
array2 = torch.tensor(np.array([[1, 2, 3], [4, 5, 6]]))


def test_from_torch():
    # Awkward.to_list() == Tensor.tolist()
    assert from_torch(array1).to_list() == array1.tolist()

    assert from_torch(array2).to_list() == array2.tolist()

    # test that the data types are remaining the same (int64 in this case)
    assert from_torch(array1).layout.dtype.name in str(array1.dtype)

    # test that the data types are remaining the same (float32 in this case)
    assert from_torch(array2).layout.dtype.name in str(array2.dtype)