File: test_benchmark.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (22 lines) | stat: -rw-r--r-- 491 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch

from torch_geometric.profile import benchmark
from torch_geometric.testing import withPackage


@withPackage('tabulate')
def test_benchmark(capfd):
    def add(x, y):
        return x + y

    benchmark(
        funcs=[add],
        args=(torch.randn(10), torch.randn(10)),
        num_steps=1,
        num_warmups=1,
        backward=True,
    )

    out, _ = capfd.readouterr()
    assert '| Name   | Forward   | Backward   | Total   |' in out
    assert '| add    |' in out