1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
|
from dataclasses import asdict, dataclass, is_dataclass
import torch
from torch_geometric.config_mixin import ConfigMixin
from torch_geometric.config_store import clear_config_store, register
def teardown_function() -> None:
clear_config_store()
@dataclass
class Dataclass:
x: int
y: int
class Base(torch.nn.Module, ConfigMixin):
pass
@register(with_target=True)
class Module(Base):
def __init__(self, x: int, data: Dataclass):
super().__init__()
self.x = x
self.data = data
def test_config_mixin() -> None:
x = 0
data = Dataclass(x=1, y=2)
model = Module(x, data)
cfg = model.config()
assert is_dataclass(cfg)
assert cfg.x == 0
assert isinstance(cfg.data, Dataclass)
assert cfg.data.x == 1
assert cfg.data.y == 2
assert cfg._target_ == 'test_config_mixin.Module'
model = Module.from_config(cfg)
assert isinstance(model, Module)
assert model.x == 0
assert isinstance(model.data, Dataclass)
assert model.data.x == 1
assert model.data.y == 2
model = Base.from_config(cfg)
assert isinstance(model, Module)
assert model.x == 0
assert isinstance(model.data, Dataclass)
assert model.data.x == 1
assert model.data.y == 2
model = Base.from_config(cfg, 3)
assert isinstance(model, Module)
assert model.x == 3
assert isinstance(model.data, Dataclass)
assert model.data.x == 1
assert model.data.y == 2
model = Base.from_config(cfg, data=Dataclass(x=2, y=3))
assert isinstance(model, Module)
assert model.x == 0
assert isinstance(model.data, Dataclass)
assert model.data.x == 2
assert model.data.y == 3
cfg = asdict(cfg)
model = Module.from_config(cfg)
assert isinstance(model, Module)
assert model.x == 0
assert isinstance(model.data, dict)
assert model.data['x'] == 1
assert model.data['y'] == 2
model = Base.from_config(cfg)
assert isinstance(model, Module)
assert model.x == 0
assert isinstance(model.data, dict)
assert model.data['x'] == 1
assert model.data['y'] == 2
|