1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
import inspect
from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union
import torch
from torch import Tensor
from torch_geometric.inspector import Inspector, Parameter, Signature
from torch_geometric.nn import GATConv, SAGEConv
from torch_geometric.typing import OptPairTensor
def test_eval_type() -> None:
inspector = Inspector(SAGEConv)
assert inspector.eval_type('Tensor') == Tensor
assert inspector.eval_type('List[Tensor]') == List[Tensor]
assert inspector.eval_type('Tuple[Tensor, int]') == Tuple[Tensor, int]
assert inspector.eval_type('Tuple[int, ...]') == Tuple[int, ...]
def test_type_repr() -> None:
inspector = Inspector(SAGEConv)
assert inspector.type_repr(Any) == 'typing.Any'
assert inspector.type_repr(Final) == 'typing.Final'
assert inspector.type_repr(OptPairTensor) == (
'Tuple[Tensor, Optional[Tensor]]')
assert inspector.type_repr(
Final[Optional[Tensor]]) == ('typing.Final[Optional[Tensor]]')
assert inspector.type_repr(Union[None, Tensor]) == 'Optional[Tensor]'
assert inspector.type_repr(Optional[Tensor]) == 'Optional[Tensor]'
assert inspector.type_repr(Set[Tensor]) == 'typing.Set[Tensor]'
assert inspector.type_repr(List) == 'List'
assert inspector.type_repr(Tuple) == 'Tuple'
assert inspector.type_repr(Set) == 'typing.Set'
assert inspector.type_repr(Dict) == 'typing.Dict'
assert inspector.type_repr(Dict[str, Tuple[Tensor, Tensor]]) == ( #
'typing.Dict[str, Tuple[Tensor, Tensor]]')
assert inspector.type_repr(Tuple[int, ...]) == 'Tuple[int, ...]'
assert inspector.type_repr(Union[int, str, None]) == ( #
'Union[int, str, None]')
def test_inspector_sage_conv() -> None:
inspector = Inspector(SAGEConv)
assert str(inspector) == 'Inspector(SAGEConv)'
assert inspector.implements('message')
assert inspector.implements('message_and_aggregate')
out = inspector.inspect_signature(SAGEConv.message)
assert isinstance(out, Signature)
assert out.param_dict == {
'x_j': Parameter('x_j', Tensor, 'Tensor', inspect._empty)
}
assert out.return_type == Tensor
assert inspector.get_flat_params(['message', 'message']) == [
Parameter('x_j', Tensor, 'Tensor', inspect._empty),
]
assert inspector.get_flat_param_names(['message']) == ['x_j']
kwargs = {'x_j': torch.randn(5), 'x_i': torch.randn(5)}
data = inspector.collect_param_data('message', kwargs)
assert len(data) == 1
assert torch.allclose(data['x_j'], kwargs['x_j'])
assert inspector.get_params_from_method_call(SAGEConv.propagate) == {
'x': Parameter('x', OptPairTensor, 'OptPairTensor', inspect._empty),
}
def test_inspector_gat_conv() -> None:
inspector = Inspector(GATConv)
assert str(inspector) == 'Inspector(GATConv)'
assert inspector.implements('message')
assert not inspector.implements('message_and_aggregate')
out = inspector.inspect_signature(GATConv.message)
assert isinstance(out, Signature)
assert out.param_dict == {
'x_j': Parameter('x_j', Tensor, 'Tensor', inspect._empty),
'alpha': Parameter('alpha', Tensor, 'Tensor', inspect._empty),
}
assert out.return_type == Tensor
assert inspector.get_flat_params(['message', 'message']) == [
Parameter('x_j', Tensor, 'Tensor', inspect._empty),
Parameter('alpha', Tensor, 'Tensor', inspect._empty),
]
assert inspector.get_flat_param_names(['message']) == ['x_j', 'alpha']
kwargs = {'x_j': torch.randn(5), 'alpha': torch.randn(5)}
data = inspector.collect_param_data('message', kwargs)
assert len(data) == 2
assert torch.allclose(data['x_j'], kwargs['x_j'])
assert torch.allclose(data['alpha'], kwargs['alpha'])
assert inspector.get_params_from_method_call(SAGEConv.propagate) == {
'x': Parameter('x', OptPairTensor, 'OptPairTensor', inspect._empty),
'alpha': Parameter('alpha', Tensor, 'Tensor', inspect._empty),
}
def test_get_params_from_method_call() -> None:
class FromMethodCall1:
propagate_type = {'x': Tensor}
inspector = Inspector(FromMethodCall1)
assert inspector.get_params_from_method_call('propagate') == {
'x': Parameter('x', Tensor, 'Tensor', inspect._empty),
}
class FromMethodCall2:
# propagate_type: (x: Tensor)
pass
inspector = Inspector(FromMethodCall2)
assert inspector.get_params_from_method_call('propagate') == {
'x': Parameter('x', Tensor, 'Tensor', inspect._empty),
}
class FromMethodCall3:
def forward(self) -> None:
self.propagate( # type: ignore
torch.randn(5, 5),
x=None,
size=None,
)
inspector = Inspector(FromMethodCall3)
exclude = [0, 'size']
assert inspector.get_params_from_method_call('propagate', exclude) == {
'x': Parameter('x', Tensor, 'Tensor', inspect._empty),
}
class FromMethodCall4:
pass
inspector = Inspector(FromMethodCall4)
assert inspector.get_params_from_method_call('propagate') == {}
|