routing-transformer: Encoder-decoder fails at KMeans attention

I haven’t been able to dig into the root cause here yet, but I’m getting the following error when trying to run an encoder-decoder:

 File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/encoder_decoder.py", line 77, in generate
    return self.dec.generate(seq_out_start, max_seq_len, context = context, **{**dec_kwargs, **kwargs})
  File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/autoregressive_wrapper.py", line 71, in generate
    logits, _ = self.net(x, input_mask=input_mask, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/autopadder.py", line 33, in forward
    return self.net(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 614, in forward
    x, loss = self.routing_transformer(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 592, in forward
    x, loss = self.layers(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 200, in forward
    out, f_loss, g_loss =  _ReversibleFunction.apply(x, blocks, args)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 137, in forward
    x, f_loss, g_loss = block(x, **kwarg)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 80, in forward
    f_out, f_loss = cast_return(self.f(x2, record_rng=self.training, **f_args), requires_grad = False)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 53, in forward
    return self.net(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 121, in forward
    return self.fn(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 524, in forward
    global_out, loss = self.global_attn(q, k, v, query_mask = input_mask, key_mask = context_mask)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 390, in forward
    dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 339, in forward
    self.init(x)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 325, in init
    self.means.data.copy_(means)
RuntimeError: The size of tensor a (64) must match the size of tensor b (2) at non-singleton dimension 1

Here are my model params:

model = RoutingTransformerEncDec(
    enc_num_tokens=7000,
    dec_num_tokens=7000,
    dim=512,
    enc_ff_mult=4,
    dec_ff_mult=4,
    enc_depth=16,
    dec_depth=16,
    enc_heads=8,
    dec_heads=8,
    enc_max_seq_len=8192,
    dec_max_seq_len=8192,
    enc_window_size=128,
    dec_window_size=128,
    enc_causal=False,
    #dec_causal=True,  # decoder is always set to causal,
    enc_ff_dropout=0.05,
    dec_ff_dropout=0.05,
    enc_reversible=True,
    dec_reversible=True,
)

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 16 (16 by maintainers)

Most upvoted comments

@lucidrains Quick update: Running with the new version fixed my training loss problem! Unfortunately I’m seeing some weird results for predictions that I can’t quite explain yet, but it’s going to take me a bit longer to dig into why that is. I’m also going to play around with mixed attention head locality too, thanks for the tip!

Thanks! You’re right that it is silly to run the generate() method before fitting. I do it just as a last check to make sure I haven’t done anything weird like accidentally load a checkpoint when I shouldn’t have. Thanks for the fix!

After some more testing, it looks like this only happens if I run generate() before the first call of the model. Something seems to go wrong with initializing kmeans under those circumstances. I’d like to try this on your script as well to verify it isn’t my script but haven’t gotten to do that yet.