x-transformers: XL-recurrence with RotaryEmbedding and mems not working correctly.
Note, this follows on from https://github.com/lucidrains/x-transformers/issues/216
I am trying to do XL-recurrence with:
RotaryEmbedding
attn_num_mem_kv > 0
mems
andreturn_mems
I’m doing a test which checks that the outputs when passing mems=None
and mems=torch.zeros(...)
are the same. They are not.
I’m using the code below:
lm = ContinuousTransformerWrapper(
dim_in = 2,
dim_out = 36,
max_seq_len = 0,
max_mem_len = 100,
attn_layers = Encoder(
dim = 512,
depth = 4,
heads = 4,
rotary_pos_emb = True,
attn_flash = True,
attn_num_mem_kv = 20
)
)
B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth
x = torch.randn(B, 1024, 2)
length = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems = [torch.zeros(x.shape[0], M, D) for _ in range(depth)]
out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, return_mems=True)
torch.testing.assert_close(out1, out2)
for m1, m2 in zip(mems1, mems2):
torch.testing.assert_close(m1, m2)
I also tried changing https://github.com/lucidrains/x-transformers/blob/583c19dc0eb80182b0fa8ed2bfd3b22bcecbc374/x_transformers/x_transformers.py#L882-L884
to
if exists(input_mask) and exists(mem):
attend = torch.any(mem)
input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)
but that doesn’t help. any ideas?
About this issue
- Original URL
- State: closed
- Created 6 months ago
- Comments: 34 (34 by maintainers)
Commits related to this issue
- address https://github.com/lucidrains/x-transformers/issues/223 with memory masking and offset for rotary positions — committed to lucidrains/x-transformers by lucidrains 6 months ago
I’m actually using a Decoder. I used Encoder for the repro to make things simpler
ok, i’m going to close this issue, i think it is good now
@pfeatherstone you let me know what you see when you rerun the sandwich norm experiments. thinking about removing it
@pfeatherstone i don’t think there would be any issue, just that a
mem_mask
would lead to more flexibility, and solve your problem with needing an initial zero mems, which i assume is onnx relatedSo I’ve fixed the issue of zero mems is the same as not attending to mems at all, and correct rotary embeddings. The second issue i’ve come across is that mems are recorded before the pre-norm layer normalization. Yet, on the next iteration, they are prepended after. I tested it, and i was getting gibberish. I’ve fixed the issue by recording new mems exactly where old mems are prepended. Now, i get sensible results. FYI, i’m using sandwich norm which uses pre-LN.
i’ll make the correction for rotary when i find some time
thanks for taking the initiative and working it out
@lucidrains Yes it worked!
So the total changes are:
at line 882 of x_transformers.py
and
at line 1257 of x_transformers.py
The absolute error is around 0.008 on average
I will give it a go
If I use
with the suggested change it works.
If I use:
it attempts to use
AbsolutePositionalEmbedding
which i don’t really want.