pytorch_geometric: NotImplementedError occurs when doing "to_hetero"

🐛 Describe the bug

I’ve been learning how to apply torch_geometric to heterogeneous graphs, and this is the video I followed on YouTube: https://www.youtube.com/watch?v=qL09oshDKww

Basically I did exactly what Giovanni did in the tutorial, and here is the code:

import torch
from torch_geometric.data import HeteroData
import numpy as np

authors = torch.rand((10, 8))
papers = torch.rand((20, 4))

authors_y = torch.rand(10).round()
write_from = torch.tensor(np.random.choice(10, 50, replace=True))
write_to = torch.tensor(np.random.choice(20, 50, replace=True))
write = torch.concat((write_from, write_to)).reshape(-1, 50).long()

cite_from = torch.tensor(np.random.choice(20, 15, replace=True))
cite_to = torch.tensor(np.random.choice(20, 15, replace=True))
cite = torch.concat((cite_from, cite_to)).reshape(-1, 15).long()

data = HeteroData({"author": {"x": authors, "y": authors_y}, "paper": {"x": papers}},
                      author__write__paper={"edge_index": write}, paper__cite__paper={"edge_index": cite})

homogeneous_data = data.to_homogeneous()

import torch_geometric.transforms as T
from torch_geometric.nn import Sequential, Linear
from torch.nn import ReLU

transform = T.RandomNodeSplit()
data = transform(data)

from torch_geometric.nn import SAGEConv, to_hetero

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1,-1), hidden_channels)
        self.conv2 = SAGEConv((-1,-1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(hidden_channels=64, out_channels=2)
model= to_hetero(model, data_backup.metadata(), aggr='sum')

And this is what I got:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Input In [44], in <cell line: 15>()
     12         return x
     14 model = GNN(hidden_channels=64, out_channels=2)
---> 15 model= to_hetero(model, data_backup.metadata(), aggr='sum')

File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:120, in to_hetero(module, metadata, aggr, input_map, debug)
     30 r"""Converts a homogeneous GNN model into its heterogeneous equivalent in
     31 which node representations are learned for each node type in
     32 :obj:`metadata[0]`, and messages are exchanged between each edge type in
   (...)
    117         transformation in debug mode. (default: :obj:`False`)
    118 """
    119 transformer = ToHeteroTransformer(module, metadata, aggr, input_map, debug)
--> 120 return transformer.transform()

File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/fx.py:157, in Transformer.transform(self)
    155     elif is_global_pooling_op(self.module, op, node.target):
    156         op = 'call_global_pooling_module'
--> 157     getattr(self, op)(node, node.target, node.name)
    159 # Remove all unused nodes in the computation graph, i.e., all nodes
    160 # which have been replaced by node type-wise or edge type-wise variants
    161 # but which are still present in the computation graph.
    162 # We do this by iterating over the computation graph in reversed order,
    163 # and try to remove every node. This does only succeed in case there
    164 # are no users of that node left in the computation graph.
    165 for node in reversed(list(self.graph.nodes)):

File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:308, in ToHeteroTransformer.call_method(self, node, target, name)
    306 self.graph.inserting_after(node)
    307 for key in self.metadata[int(self.is_edge_level(node))]:
--> 308     args, kwargs = self.map_args_kwargs(node, key)
    309     out = self.graph.create_node('call_method', target=target,
    310                                  args=args, kwargs=kwargs,
    311                                  name=f'{name}__{key2str(key)}')
    312     self.graph.inserting_after(out)

File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:414, in ToHeteroTransformer.map_args_kwargs(self, node, key)
    411     else:
    412         return value
--> 414 args = tuple(_recurse(v) for v in node.args)
    415 kwargs = {k: _recurse(v) for k, v in node.kwargs.items()}
    416 return args, kwargs

File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:414, in <genexpr>(.0)
    411     else:
    412         return value
--> 414 args = tuple(_recurse(v) for v in node.args)
    415 kwargs = {k: _recurse(v) for k, v in node.kwargs.items()}
    416 return args, kwargs

File ~/anaconda3/envs/play/lib/python3.10/site-packages/torch_geometric/nn/to_hetero_transformer.py:404, in ToHeteroTransformer.map_args_kwargs.<locals>._recurse(value)
    399         return (
    400             self.find_by_name(f'{value.name}__{key2str(key[0])}'),
    401             self.find_by_name(f'{value.name}__{key2str(key[-1])}'),
    402         )
    403     else:
--> 404         raise NotImplementedError
    405 elif isinstance(value, dict):
    406     return {k: _recurse(v) for k, v in value.items()}

NotImplementedError:

Environment

  • PyG version: 2.3.0
  • PyTorch version: 2.0.0+cu117
  • OS: Linux
  • Python version: 3.10.4
  • CUDA/cuDNN version: no CUDA/cuDNN
  • How you installed PyTorch and PyG (conda, pip, source): pip
  • Any other relevant information (e.g., version of torch-scatter):

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 22 (11 by maintainers)

Most upvoted comments

This is fixable by adding reverse edge types:

data = T.ToUndirected()(data)

I will look into the reason for this error though.

I see. Currently you have to register the global pooling operator as a module, see here for an example.