1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
import warnings
from typing import Any, Dict, List, Optional, Type
import torch
from torch import Tensor
from torch_geometric.typing import NodeType
def get_embeddings(
model: torch.nn.Module,
*args: Any,
**kwargs: Any,
) -> List[Tensor]:
"""Returns the output embeddings of all
:class:`~torch_geometric.nn.conv.MessagePassing` layers in
:obj:`model`.
Internally, this method registers forward hooks on all
:class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`,
and runs the forward pass of the :obj:`model` by calling
:obj:`model(*args, **kwargs)`.
Args:
model (torch.nn.Module): The message passing model.
*args: Arguments passed to the model.
**kwargs (optional): Additional keyword arguments passed to the model.
"""
from torch_geometric.nn import MessagePassing
embeddings: List[Tensor] = []
def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
# Clone output in case it will be later modified in-place:
outputs = outputs[0] if isinstance(outputs, tuple) else outputs
assert isinstance(outputs, Tensor)
embeddings.append(outputs.clone())
hook_handles = []
for module in model.modules(): # Register forward hooks:
if isinstance(module, MessagePassing):
hook_handles.append(module.register_forward_hook(hook))
if len(hook_handles) == 0:
warnings.warn("The 'model' does not have any 'MessagePassing' layers",
stacklevel=2)
training = model.training
model.eval()
with torch.no_grad():
model(*args, **kwargs)
model.train(training)
for handle in hook_handles: # Remove hooks:
handle.remove()
return embeddings
def get_embeddings_hetero(
model: torch.nn.Module,
supported_models: Optional[List[Type[torch.nn.Module]]] = None,
*args: Any,
**kwargs: Any,
) -> Dict[NodeType, List[Tensor]]:
"""Returns the output embeddings of all
:class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous
:obj:`model`, organized by edge type.
Internally, this method registers forward hooks on all modules that process
heterogeneous graphs in the model and runs the forward pass of the model.
For heterogeneous models, the output is a dictionary where each key is a
node type and each value is a list of embeddings from different layers.
Args:
model (torch.nn.Module): The heterogeneous GNN model.
supported_models (List[Type[torch.nn.Module]], optional): A list of
supported model classes. If not provided, defaults to
[HGTConv, HANConv, HeteroConv].
*args: Arguments passed to the model.
**kwargs (optional): Additional keyword arguments passed to the model.
Returns:
Dict[NodeType, List[Tensor]]: A dictionary mapping each node type to
a list of embeddings from different layers.
"""
from torch_geometric.nn import HANConv, HeteroConv, HGTConv
if not supported_models:
supported_models = [HGTConv, HANConv, HeteroConv]
# Dictionary to store node embeddings by type
node_embeddings_dict: Dict[NodeType, List[Tensor]] = {}
# Hook function to capture node embeddings
def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
# Check if the outputs is a dictionary mapping node types to embeddings
if isinstance(outputs, dict) and outputs:
# Store embeddings for each node type
for node_type, embedding in outputs.items():
# Made sure that the outputs are a dictionary mapping node
# types to embeddings and remove the false positives.
if node_type not in node_embeddings_dict:
node_embeddings_dict[node_type] = []
node_embeddings_dict[node_type].append(embedding.clone())
# List to store hook handles
hook_handles = []
# Find ModuleDict objects in the model
for _, module in model.named_modules():
# Handle the native heterogenous models, e.g. HGTConv, HANConv
# and HeteroConv, etc.
if isinstance(module, tuple(supported_models)):
hook_handles.append(module.register_forward_hook(hook))
else:
# Handle the heterogenous models that are generated by calling
# to_hetero() on the homogeneous models.
submodules = list(module.children())
submodules_contains_module_dict = any([
isinstance(submodule, torch.nn.ModuleDict)
for submodule in submodules
])
if submodules_contains_module_dict:
hook_handles.append(module.register_forward_hook(hook))
if len(hook_handles) == 0:
warnings.warn(
"The 'model' does not have any heterogenous "
"'MessagePassing' layers", stacklevel=2)
# Run the model forward pass
training = model.training
model.eval()
with torch.no_grad():
model(*args, **kwargs)
model.train(training)
# Clean up hooks
for handle in hook_handles:
handle.remove()
return node_embeddings_dict
|