File: separate.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (155 lines) | stat: -rw-r--r-- 5,587 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from collections.abc import Mapping, Sequence
from typing import Any, Type, TypeVar

from torch import Tensor

from torch_geometric import EdgeIndex, Index
from torch_geometric.data.data import BaseData
from torch_geometric.data.storage import BaseStorage
from torch_geometric.typing import SparseTensor, TensorFrame
from torch_geometric.utils import narrow

T = TypeVar('T')


def separate(
    cls: Type[T],
    batch: Any,
    idx: int,
    slice_dict: Any,
    inc_dict: Any = None,
    decrement: bool = True,
) -> T:
    # Separates the individual element from a `batch` at index `idx`.
    # `separate` can handle both homogeneous and heterogeneous data objects by
    # individually separating all their stores.
    # In addition, `separate` can handle nested data structures such as
    # dictionaries and lists.

    data = cls().stores_as(batch)

    # Iterate over each storage object and recursively separate its attributes:
    for batch_store, data_store in zip(batch.stores, data.stores):
        key = batch_store._key
        if key is not None:  # Heterogeneous:
            attrs = slice_dict[key].keys()
        else:  # Homogeneous:
            attrs = set(batch_store.keys())
            attrs = [attr for attr in slice_dict.keys() if attr in attrs]

        for attr in attrs:
            if key is not None:
                slices = slice_dict[key][attr]
                incs = inc_dict[key][attr] if decrement else None
            else:
                slices = slice_dict[attr]
                incs = inc_dict[attr] if decrement else None

            data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
                                         incs, batch, batch_store, decrement)

        # The `num_nodes` attribute needs special treatment, as we cannot infer
        # the real number of nodes from the total number of nodes alone:
        if hasattr(batch_store, '_num_nodes'):
            data_store.num_nodes = batch_store._num_nodes[idx]

    return data


def _separate(
    key: str,
    values: Any,
    idx: int,
    slices: Any,
    incs: Any,
    batch: BaseData,
    store: BaseStorage,
    decrement: bool,
) -> Any:

    if isinstance(values, Tensor):
        # Narrow a `torch.Tensor` based on `slices`.
        # NOTE: We need to take care of decrementing elements appropriately.
        key = str(key)
        cat_dim = batch.__cat_dim__(key, values, store)
        start, end = int(slices[idx]), int(slices[idx + 1])
        value = narrow(values, cat_dim or 0, start, end - start)
        value = value.squeeze(0) if cat_dim is None else value

        if isinstance(values, Index) and values._cat_metadata is not None:
            # Reconstruct original `Index` metadata:
            value._dim_size = values._cat_metadata.dim_size[idx]
            value._is_sorted = values._cat_metadata.is_sorted[idx]

        if isinstance(values, EdgeIndex) and values._cat_metadata is not None:
            # Reconstruct original `EdgeIndex` metadata:
            value._sparse_size = values._cat_metadata.sparse_size[idx]
            value._sort_order = values._cat_metadata.sort_order[idx]
            value._is_undirected = values._cat_metadata.is_undirected[idx]

        if (decrement and incs is not None
                and (incs.dim() > 1 or int(incs[idx]) != 0)):
            value = value - incs[idx].to(value.device)

        return value

    elif isinstance(values, SparseTensor) and decrement:
        # Narrow a `SparseTensor` based on `slices`.
        # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking.
        key = str(key)
        cat_dim = batch.__cat_dim__(key, values, store)
        cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim
        for i, dim in enumerate(cat_dims):
            start, end = int(slices[idx][i]), int(slices[idx + 1][i])
            values = values.narrow(dim, start, end - start)
        return values

    elif isinstance(values, TensorFrame):
        key = str(key)
        start, end = int(slices[idx]), int(slices[idx + 1])
        value = values[start:end]
        return value

    elif isinstance(values, Mapping):
        # Recursively separate elements of dictionaries.
        return {
            key:
            _separate(
                key,
                value,
                idx,
                slices=slices[key],
                incs=incs[key] if decrement else None,
                batch=batch,
                store=store,
                decrement=decrement,
            )
            for key, value in values.items()
        }

    elif (isinstance(values, Sequence) and isinstance(values[0], Sequence)
          and not isinstance(values[0], str) and len(values[0]) > 0
          and isinstance(values[0][0], (Tensor, SparseTensor))
          and isinstance(slices, Sequence)):
        # Recursively separate elements of lists of lists.
        return [value[idx] for value in values]

    elif (isinstance(values, Sequence) and not isinstance(values, str)
          and isinstance(values[0], (Tensor, SparseTensor))
          and isinstance(slices, Sequence)):
        # Recursively separate elements of lists of Tensors/SparseTensors.
        return [
            _separate(
                key,
                value,
                idx,
                slices=slices[i],
                incs=incs[i] if decrement else None,
                batch=batch,
                store=store,
                decrement=decrement,
            ) for i, value in enumerate(values)
        ]

    else:
        return values[idx]