File: test_config_mixin.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (86 lines) | stat: -rw-r--r-- 2,113 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from dataclasses import asdict, dataclass, is_dataclass

import torch

from torch_geometric.config_mixin import ConfigMixin
from torch_geometric.config_store import clear_config_store, register


def teardown_function() -> None:
    clear_config_store()


@dataclass
class Dataclass:
    x: int
    y: int


class Base(torch.nn.Module, ConfigMixin):
    pass


@register(with_target=True)
class Module(Base):
    def __init__(self, x: int, data: Dataclass):
        super().__init__()
        self.x = x
        self.data = data


def test_config_mixin() -> None:
    x = 0
    data = Dataclass(x=1, y=2)

    model = Module(x, data)
    cfg = model.config()
    assert is_dataclass(cfg)
    assert cfg.x == 0
    assert isinstance(cfg.data, Dataclass)
    assert cfg.data.x == 1
    assert cfg.data.y == 2
    assert cfg._target_ == 'test_config_mixin.Module'

    model = Module.from_config(cfg)
    assert isinstance(model, Module)
    assert model.x == 0
    assert isinstance(model.data, Dataclass)
    assert model.data.x == 1
    assert model.data.y == 2

    model = Base.from_config(cfg)
    assert isinstance(model, Module)
    assert model.x == 0
    assert isinstance(model.data, Dataclass)
    assert model.data.x == 1
    assert model.data.y == 2

    model = Base.from_config(cfg, 3)
    assert isinstance(model, Module)
    assert model.x == 3
    assert isinstance(model.data, Dataclass)
    assert model.data.x == 1
    assert model.data.y == 2

    model = Base.from_config(cfg, data=Dataclass(x=2, y=3))
    assert isinstance(model, Module)
    assert model.x == 0
    assert isinstance(model.data, Dataclass)
    assert model.data.x == 2
    assert model.data.y == 3

    cfg = asdict(cfg)

    model = Module.from_config(cfg)
    assert isinstance(model, Module)
    assert model.x == 0
    assert isinstance(model.data, dict)
    assert model.data['x'] == 1
    assert model.data['y'] == 2

    model = Base.from_config(cfg)
    assert isinstance(model, Module)
    assert model.x == 0
    assert isinstance(model.data, dict)
    assert model.data['x'] == 1
    assert model.data['y'] == 2