tensorboardX: Variable slice/index assignment graph breaking

I have been facing an issue when trying to create a graph of a Module in which some Variables have slice assignment operations in them. I have reduced the problem to the following example, ignore the commented out V = x for now.

import torch
from torch.autograd import Variable
from tensorboardX import SummaryWriter


class DummyModule(torch.nn.Module):
    def forward(self, x):
        V = Variable(torch.Tensor(2, 2))
        V[0, 0] = x
        # V = x
        return torch.sum(V * 3)


x = Variable(torch.Tensor([1]), requires_grad=True)
r = DummyModule()(x)
r.backward()
print(x.grad)

w = SummaryWriter()
x = Variable(torch.Tensor([1]), requires_grad=True)
w.add_graph(DummyModule(), x, verbose=True)

The output from this is below, showing that the gradients are flowing all right, but the graph is not being connected. If I insert another input Variable and other operations in the Module, add_graph() works fine without throwing an error, but the graph show a disconnected input for x, so I suppose the nature of this error is that the only input Variable available is being interpreted as disconnected.

Variable containing:
 3
[torch.FloatTensor of size (1,)]

Traceback (most recent call last):
  File "test_grad.py", line 21, in <module>
    w.add_graph(DummyModule(), x, verbose=True)
  File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/tensorboardX/writer.py", line 400, in add_graph
    self.file_writer.add_graph(graph(model, input_to_model, verbose))
  File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/tensorboardX/graph.py", line 44, in graph
    trace, _ = torch.jit.trace(model, args)
  File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/torch/jit/__init__.py", line 251, in trace
    return TracedModule(f, nderivs=nderivs)(*args, **kwargs)
  File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/torch/jit/__init__.py", line 287, in forward
    torch._C._tracer_exit(out_vars)
RuntimeError: /Users/filiped/pytorch/torch/csrc/jit/tracer.h:117: getTracingState: Assertion `state` failed.

Moreover, if you uncomment the line V = x and comment the line above it, so that no slice/index assign operation is performed, you get, as expected:

Variable containing:
 3
[torch.FloatTensor of size (1,)]

graph(%0 : Float(1)) {
  %1 : UNKNOWN_TYPE = Constant[value={3}](), scope: DummyModule
  %2 : Float(1) = Mul[broadcast=1](%0, %1), scope: DummyModule
  %3 : Float() = Sum(%2), scope: DummyModule
  return (%3);
}

This was all executed in Pytorch 0.4

(Edits: Did a couple rounds of re-simplifying the example.)

About this issue

  • Original URL
  • State: open
  • Created 6 years ago
  • Comments: 18 (10 by maintainers)

Most upvoted comments

update: still not working in pytorch 0.4 release + tensorboardX master. output of tensorboardX:

Error occurs, No graph saved
Checking if it's onnx problem...
Your model fails onnx too, please report to onnx team