wandb: wandb.watch doesn't work with modules that return dicts
wandb --version && python --version && uname
- Weights and Biases version: 0.9.1
- Python version: 3.7.7
- Operating System: Manjaro Linux 20.0.3
Description
When I use wandb.watch(model) in a model of mine, I get a StopIteration message and my program crashes. When I comment that line out, everything works fine.
Two things my model does that could be considered non-standard are the use of some neural network layers from PyTorch Geometric and the output of a dictionary instead of a tensor. In the traceback I provide below, the error occurs exactly in a layer that returns a dict, so I believe the dict is the culprit.
What I Did
Traceback (most recent call last):
File "/home/dodo/Code/dodonet/dodonet/training/run.py", line 434, in <module>
OffPolicySMACRunner(trainer).run()
File "/home/dodo/Code/dodonet/dodonet/training/run.py", line 109, in run
output_dict = self.model.policy_net(current_state)
File "/home/dodo/.anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/dodo/Code/dodonet/dodonet/nn/nets.py", line 215, in forward
xdict = self.action_layer(obs_by_class, enc_node_types)
File "/home/dodo/.anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 559, in __call__
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
StopIteration
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 18 (5 by maintainers)
@cola-nicolas Thanks so much for the repro! I was able to implement a fix in the PR listed above. If you want to try it yourself you can install wandb with:
This should make it into the next release likely early next week.