File: test_noise_scheduler.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (34 lines) | stat: -rw-r--r-- 830 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import pytest
import torch

from torch_geometric.utils.noise_scheduler import (
    get_diffusion_beta_schedule,
    get_smld_sigma_schedule,
)


def test_get_smld_sigma_schedule():
    expected = torch.tensor([
        1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
        0.04641589, 0.02782559, 0.01668101, 0.01
    ])
    out = get_smld_sigma_schedule(
        sigma_min=0.01,
        sigma_max=1.0,
        num_scales=10,
    )
    assert torch.allclose(out, expected)


@pytest.mark.parametrize(
    'schedule_type',
    ['linear', 'quadratic', 'constant', 'sigmoid'],
)
def test_get_diffusion_beta_schedule(schedule_type):
    out = get_diffusion_beta_schedule(
        schedule_type,
        beta_start=0.1,
        beta_end=0.2,
        num_diffusion_timesteps=10,
    )
    assert out.size() == (10, )