pytorch_geometric: Assertion Error raised in metapath2vec without an explicit error message

🐛 Describe the bug

I am encountering inconsistent error messages when using the MetaPath2Vec module from the torch_geometric library. The issue arises when I try to set the values for walk_length and context_size. Depending on the values I choose, I receive different error messages that make it challenging to resolve the problem.

I have a heterogeneous graph as shown in the provided code snippet. I’m using the MetaPath2Vec module to perform node embedding with the following configuration:

HeteroData(
  word={ x=[32390, 8] },
  root={ x=[2252, 8] },
  lemma={ x=[14944, 8] },
  tweet={
    x=[9682, 8],
    y=[9682],
  },
  (lemma, links, tweet)={ edge_index=[2, 88662] },
  (root, links, tweet)={ edge_index=[2, 80569] },
  (word, in, tweet)={ edge_index=[2, 93776] }
)

The configuration:

from torch_geometric.nn import MetaPath2Vec
metapath = [
    ('word', 'in', 'tweet'),
    ('lemma', 'links', 'tweet'),
    ('root', 'links', 'tweet'),
]

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MetaPath2Vec(data.edge_index_dict, embedding_dim=64,
                     metapath=metapath, walk_length=2, context_size=6,
                     walks_per_node=5, num_negative_samples=5,
                     sparse=True).to(device)

loader = model.loader(batch_size=128, shuffle=True, num_workers=6)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

When walk_length + 1 is less than context_size, I receive an AssertionError indicating the problem:

AssertionError                            Traceback (most recent call last)

[<ipython-input-58-bbd9357b905c>](https://localhost:8080/#) in <cell line: 9>()
      7 
      8 device = 'cuda' if torch.cuda.is_available() else 'cpu'
----> 9 model = MetaPath2Vec(data.edge_index_dict, embedding_dim=64,
     10                      metapath=metapath, walk_length=2, context_size=6,
     11                      walks_per_node=5, num_negative_samples=5,

[/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/models/metapath2vec.py](https://localhost:8080/#) in __init__(self, edge_index_dict, embedding_dim, metapath, walk_length, context_size, walks_per_node, num_negative_samples, num_nodes_dict, sparse)
     83             self.rowcount_dict[keys] = rowptr[1:] - rowptr[:-1]
     84 
---> 85         assert walk_length + 1 >= context_size
     86         if walk_length > len(metapath) and metapath[0][0] != metapath[-1][-1]:
     87             raise AttributeError(

AssertionError:

When I increase walk_length, I receive a different AttributeError indicating the problem:

AttributeError                            Traceback (most recent call last)

[<ipython-input-59-2d4888523ad8>](https://localhost:8080/#) in <cell line: 9>()
      7 
      8 device = 'cuda' if torch.cuda.is_available() else 'cpu'
----> 9 model = MetaPath2Vec(data.edge_index_dict, embedding_dim=64,
     10                      metapath=metapath, walk_length=5, context_size=6,
     11                      walks_per_node=5, num_negative_samples=5,

[/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/models/metapath2vec.py](https://localhost:8080/#) in __init__(self, edge_index_dict, embedding_dim, metapath, walk_length, context_size, walks_per_node, num_negative_samples, num_nodes_dict, sparse)
     85         assert walk_length + 1 >= context_size
     86         if walk_length > len(metapath) and metapath[0][0] != metapath[-1][-1]:
---> 87             raise AttributeError(
     88                 "The 'walk_length' is longer than the given 'metapath', but "
     89                 "the 'metapath' does not denote a cycle")

AttributeError: The `walk_length` is longer than the given `metapath`, but the `metapath` does not denote a cycle

I’ve tried various values for walk_length and context_size, but I can’t find a value that avoids getting trapped between these two error messages.

Environment

  • PyG version:
  • PyTorch version:
  • OS:
  • Python version:
  • CUDA/cuDNN version:
  • How you installed PyTorch and PyG (conda, pip, source):
  • Any other relevant information (e.g., version of torch-scatter):

About this issue

  • Original URL
  • State: closed
  • Created 8 months ago
  • Comments: 17 (8 by maintainers)

Commits related to this issue

Most upvoted comments

how can I extract only the tweet embedding since that’s what would be useful in my use case.

model('tweet') can return only the embedding of tweet nodes

Related docs: https://pytorch-geometric.readthedocs.io/en/latest/notes/heterogeneous.html

You don’t need the y_index because all the nodes already have associated labels., so

z = model('tweet')