File: test_permute.py

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (17 lines) | stat: -rw-r--r-- 580 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import pytest
import torch

from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, tensor


@pytest.mark.parametrize('device', devices)
def test_permute(device):
    row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], torch.long, device)
    value = tensor([1, 2, 3, 4, 5], torch.float, device)
    adj = SparseTensor(row=row, col=col, value=value)

    row, col, value = adj.permute(torch.tensor([1, 0, 2])).coo()
    assert row.tolist() == [0, 1, 1, 2, 2]
    assert col.tolist() == [1, 0, 1, 0, 2]
    assert value.tolist() == [3, 2, 1, 4, 5]