File: test_rpc.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (122 lines) | stat: -rw-r--r-- 3,848 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import socket

import torch

import torch_geometric.distributed.rpc as rpc
from torch_geometric.distributed import LocalFeatureStore
from torch_geometric.distributed.dist_context import DistContext
from torch_geometric.distributed.rpc import RPCRouter
from torch_geometric.testing import onlyDistributedTest


def run_rpc_feature_test(
    world_size: int,
    rank: int,
    feature: LocalFeatureStore,
    partition_book: torch.Tensor,
    master_port: int,
):
    # 1) Initialize the context info:
    current_ctx = DistContext(
        rank=rank,
        global_rank=rank,
        world_size=world_size,
        global_world_size=world_size,
        group_name='dist-feature-test',
    )

    rpc.init_rpc(
        current_ctx=current_ctx,
        master_addr='localhost',
        master_port=master_port,
    )

    # 2) Collect all workers:
    partition_to_workers = rpc.rpc_partition_to_workers(
        current_ctx, world_size, rank)

    assert partition_to_workers == [
        ['dist-feature-test-0'],
        ['dist-feature-test-1'],
    ]

    # 3) Find the mapping between worker and partition ID:
    rpc_router = RPCRouter(partition_to_workers)

    assert rpc_router.get_to_worker(partition_idx=0) == 'dist-feature-test-0'
    assert rpc_router.get_to_worker(partition_idx=1) == 'dist-feature-test-1'

    meta = {
        'edge_types': None,
        'is_hetero': False,
        'node_types': None,
        'num_parts': 2,
    }

    feature.num_partitions = world_size
    feature.partition_idx = rank
    feature.node_feat_pb = partition_book
    feature.meta = meta
    feature.local_only = False
    feature.set_rpc_router(rpc_router)

    # Global node IDs:
    global_id0 = torch.arange(128 * 2)
    global_id1 = torch.arange(128 * 2) + 128 * 2

    # Lookup the features from stores including locally and remotely:
    tensor0 = feature.lookup_features(global_id0)
    tensor1 = feature.lookup_features(global_id1)

    # Expected searched results:
    cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2])
    cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)])

    # Verify..
    assert torch.allclose(cpu_tensor0, tensor0.wait())
    assert torch.allclose(cpu_tensor1, tensor1.wait())

    rpc.shutdown_rpc()
    assert rpc.rpc_is_initialized() is False


@onlyDistributedTest
def test_dist_feature_lookup():
    cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2])
    cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)])

    # Global node IDs:
    global_id0 = torch.arange(128 * 2)
    global_id1 = torch.arange(128 * 2) + 128 * 2

    # Set the partition book for two features (partition 0 and 1):
    partition_book = torch.cat([
        torch.zeros(128 * 2, dtype=torch.long),
        torch.ones(128 * 2, dtype=torch.long),
    ])

    # Put the test tensor into the different feature stores with IDs:
    feature0 = LocalFeatureStore()
    feature0.put_global_id(global_id0, group_name=None)
    feature0.put_tensor(cpu_tensor0, group_name=None, attr_name='x')

    feature1 = LocalFeatureStore()
    feature1.put_global_id(global_id1, group_name=None)
    feature1.put_tensor(cpu_tensor1, group_name=None, attr_name='x')

    mp_context = torch.multiprocessing.get_context('spawn')
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        s.settimeout(1)
        s.bind(('127.0.0.1', 0))
        port = s.getsockname()[1]

    w0 = mp_context.Process(target=run_rpc_feature_test,
                            args=(2, 0, feature0, partition_book, port))
    w1 = mp_context.Process(target=run_rpc_feature_test,
                            args=(2, 1, feature1, partition_book, port))

    w0.start()
    w1.start()
    w0.join()
    w1.join()