reformer-pytorch: Possible bug in end-dec attention?
In the encoder-decoder architecture, encoder output is passed to decoder as keys to be used in attention. Here (https://github.com/lucidrains/reformer-pytorch/blob/5f5bbf4fd5806f45d2cb3b7373021786b3b34e5b/reformer_pytorch/reformer_pytorch.py#L598) you are concating keys with x (where x is the decoder input) and then apply self-attention. Does it make sense to do self attention on decoder-input and encoder outputs? Because even in the trax codes these two are handled separately: (https://github.com/google/trax/blob/c7c47a14ef8ea5b260ac78c22cbadd6dc1fb605b/trax/models/reformer/reformer.py#L968) at first self attention is applied on the decoder input, and then a seperate encoder-decoder attention is applied between the new representation for decoder and the keys.
I don’t if this is the reason or not but I have this simple copy-reverse task where the loss stops at 2.08. However in the trax code the loss becomes close to 0 after a few steps.
def cycle():
while True:
source = torch.randint(2, 10, (32, 768)).long().cuda()
target_np = np.flip(source.cpu().numpy(),axis=1).copy() #Reverse of copy of numpy array of given tensor
target = torch.from_numpy(target_np).long().cuda()
mask = torch.ones(32, 768).bool().cuda()
yield (source, target, mask)
# First example: Copy Reverse: 768 tokens - vocab size: 256
model = ReformerEncDec(
dim = 512,
enc_num_tokens = 256,
enc_depth = 1,
enc_max_seq_len = 768,
enc_heads=1,
dec_num_tokens = 256,
dec_depth = 1,
dec_max_seq_len = 768,
dec_heads=1,
).cuda()
#model = TrainingWrapper(model)
model.cuda()
# optimizer
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
source, target, mask = next(cycle())
loss = model(seq_in=source, seq_out=target, return_loss = True, enc_input_mask=mask)
loss.backward()
print(f'training loss: {loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Reactions: 2
- Comments: 18 (13 by maintainers)
Ok. Thank you 😃
If you run the encoder / decoder from my Sinkhorn project against the Reformer on full attention, you will see they take about the same number of iterations for the task you have, even though Sinkhorn has self-attention and contextual attention separate
@py4 In a lot of projects with full attention, I have mixed self-attention and enc-dec attention together still with great results. I think the gradients will allow the encoder to adjust the space of the contextual keys to align with the decoder.
No worries, I will build it the way you describe eventually, since I was able to get argument routing working with reversible networks in my other project. How about I put this as a final todo item on the Projects tab?
edit - https://github.com/lucidrains/reformer-pytorch/projects/2
I think the concept of sharing Q and K is different what I’m talking about. What I’m saying is that when you want to calculate qk and v using x, you multiply x by a project matrix w_qk and w_v. right? Your x is including both encoder output and decoder input. It is not intuitive to me to use a single w for both encoder and decoder. In the “self attention” all tensors are homogenous. They are all encoder representation or decoder representation. But here you have mixed them and you are applying self-attention on the mixture of encoder and decoder representations.
And yeah I know that in trax implementations, they have not used LSH for enc-dec attention and are using full attention. But even in this full attention, at first they calculate self-attention on decoder (which can be easily replaced by LSH as i did and nothing bad happens) and then apply a Enc-Dec attention.
What I’m saying is that maybe we should seperate w_qk and w_v for encoder and decoder when you want to apply self-attention on mixture of encoder and decoder representations.
Here https://github.com/lucidrains/reformer-pytorch/blob/7a2a6ab9c53eeb5ed76e927234ff26ac4b7ff263/reformer_pytorch/reformer_pytorch.py#L599 in the x you have merged both encoder representations and decoder representation. I’m not sure if you can extract q,k,v from a shared w_q and w_k for both encoder and decoder representations.