File: test_encoding.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (22 lines) | stat: -rw-r--r-- 623 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch

from torch_geometric.nn import PositionalEncoding, TemporalEncoding
from torch_geometric.testing import withDevice


@withDevice
def test_positional_encoding(device):
    encoder = PositionalEncoding(64, device=device)
    assert str(encoder) == 'PositionalEncoding(64)'

    x = torch.tensor([1.0, 2.0, 3.0], device=device)
    assert encoder(x).size() == (3, 64)


@withDevice
def test_temporal_encoding(device):
    encoder = TemporalEncoding(64, device=device)
    assert str(encoder) == 'TemporalEncoding(64)'

    x = torch.tensor([1.0, 2.0, 3.0], device=device)
    assert encoder(x).size() == (3, 64)