wandb: wandb.watch doesn't work with modules that return dicts

wandb --version && python --version && uname

  • Weights and Biases version: 0.9.1
  • Python version: 3.7.7
  • Operating System: Manjaro Linux 20.0.3

Description

When I use wandb.watch(model) in a model of mine, I get a StopIteration message and my program crashes. When I comment that line out, everything works fine.

Two things my model does that could be considered non-standard are the use of some neural network layers from PyTorch Geometric and the output of a dictionary instead of a tensor. In the traceback I provide below, the error occurs exactly in a layer that returns a dict, so I believe the dict is the culprit.

What I Did

Traceback (most recent call last):
  File "/home/dodo/Code/dodonet/dodonet/training/run.py", line 434, in <module>
    OffPolicySMACRunner(trainer).run()
  File "/home/dodo/Code/dodonet/dodonet/training/run.py", line 109, in run
    output_dict = self.model.policy_net(current_state)
  File "/home/dodo/.anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dodo/Code/dodonet/dodonet/nn/nets.py", line 215, in forward
    xdict = self.action_layer(obs_by_class, enc_node_types)
  File "/home/dodo/.anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 559, in __call__
    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
StopIteration

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 18 (5 by maintainers)

Most upvoted comments

@cola-nicolas Thanks so much for the repro! I was able to implement a fix in the PR listed above. If you want to try it yourself you can install wandb with:

pip install --upgrade git+git://github.com/wandb/client.git@bug/issue-1122#egg=wandb

This should make it into the next release likely early next week.