pytorch_geometric: NotImplementedError occurs when doing "to_hetero"
🐛 Describe the bug
I’ve been learning how to apply torch_geometric to heterogeneous graphs, and this is the video I followed on YouTube: https://www.youtube.com/watch?v=qL09oshDKww
Basically I did exactly what Giovanni did in the tutorial, and here is the code:
import torch
from torch_geometric.data import HeteroData
import numpy as np
authors = torch.rand((10, 8))
papers = torch.rand((20, 4))
authors_y = torch.rand(10).round()
write_from = torch.tensor(np.random.choice(10, 50, replace=True))
write_to = torch.tensor(np.random.choice(20, 50, replace=True))
write = torch.concat((write_from, write_to)).reshape(-1, 50).long()
cite_from = torch.tensor(np.random.choice(20, 15, replace=True))
cite_to = torch.tensor(np.random.choice(20, 15, replace=True))
cite = torch.concat((cite_from, cite_to)).reshape(-1, 15).long()
data = HeteroData({"author": {"x": authors, "y": authors_y}, "paper": {"x": papers}},
author__write__paper={"edge_index": write}, paper__cite__paper={"edge_index": cite})
homogeneous_data = data.to_homogeneous()
import torch_geometric.transforms as T
from torch_geometric.nn import Sequential, Linear
from torch.nn import ReLU
transform = T.RandomNodeSplit()
data = transform(data)
from torch_geometric.nn import SAGEConv, to_hetero
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1,-1), hidden_channels)
self.conv2 = SAGEConv((-1,-1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(hidden_channels=64, out_channels=2)
model= to_hetero(model, data_backup.metadata(), aggr='sum')
And this is what I got:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Input In [44], in <cell line: 15>()
12 return x
14 model = GNN(hidden_channels=64, out_channels=2)
---> 15 model= to_hetero(model, data_backup.metadata(), aggr='sum')
File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:120, in to_hetero(module, metadata, aggr, input_map, debug)
30 r"""Converts a homogeneous GNN model into its heterogeneous equivalent in
31 which node representations are learned for each node type in
32 :obj:`metadata[0]`, and messages are exchanged between each edge type in
(...)
117 transformation in debug mode. (default: :obj:`False`)
118 """
119 transformer = ToHeteroTransformer(module, metadata, aggr, input_map, debug)
--> 120 return transformer.transform()
File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/fx.py:157, in Transformer.transform(self)
155 elif is_global_pooling_op(self.module, op, node.target):
156 op = 'call_global_pooling_module'
--> 157 getattr(self, op)(node, node.target, node.name)
159 # Remove all unused nodes in the computation graph, i.e., all nodes
160 # which have been replaced by node type-wise or edge type-wise variants
161 # but which are still present in the computation graph.
162 # We do this by iterating over the computation graph in reversed order,
163 # and try to remove every node. This does only succeed in case there
164 # are no users of that node left in the computation graph.
165 for node in reversed(list(self.graph.nodes)):
File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:308, in ToHeteroTransformer.call_method(self, node, target, name)
306 self.graph.inserting_after(node)
307 for key in self.metadata[int(self.is_edge_level(node))]:
--> 308 args, kwargs = self.map_args_kwargs(node, key)
309 out = self.graph.create_node('call_method', target=target,
310 args=args, kwargs=kwargs,
311 name=f'{name}__{key2str(key)}')
312 self.graph.inserting_after(out)
File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:414, in ToHeteroTransformer.map_args_kwargs(self, node, key)
411 else:
412 return value
--> 414 args = tuple(_recurse(v) for v in node.args)
415 kwargs = {k: _recurse(v) for k, v in node.kwargs.items()}
416 return args, kwargs
File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:414, in <genexpr>(.0)
411 else:
412 return value
--> 414 args = tuple(_recurse(v) for v in node.args)
415 kwargs = {k: _recurse(v) for k, v in node.kwargs.items()}
416 return args, kwargs
File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:404, in ToHeteroTransformer.map_args_kwargs.<locals>._recurse(value)
399 return (
400 self.find_by_name(f'{value.name}__{key2str(key[0])}'),
401 self.find_by_name(f'{value.name}__{key2str(key[-1])}'),
402 )
403 else:
--> 404 raise NotImplementedError
405 elif isinstance(value, dict):
406 return {k: _recurse(v) for k, v in value.items()}
NotImplementedError:
Environment
- PyG version: 2.3.0
- PyTorch version: 2.0.0+cu117
- OS: Linux
- Python version: 3.10.4
- CUDA/cuDNN version: no CUDA/cuDNN
- How you installed PyTorch and PyG (
conda
,pip
, source): pip - Any other relevant information (e.g., version of
torch-scatter
):
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 22 (11 by maintainers)
This is fixable by adding reverse edge types:
I will look into the reason for this error though.
I see. Currently you have to register the global pooling operator as a module, see here for an example.