x-transformers: XL-recurrence with RotaryEmbedding and mems not working correctly.

Note, this follows on from https://github.com/lucidrains/x-transformers/issues/216

I am trying to do XL-recurrence with:

  • RotaryEmbedding
  • attn_num_mem_kv > 0
  • mems and return_mems

I’m doing a test which checks that the outputs when passing mems=None and mems=torch.zeros(...) are the same. They are not. I’m using the code below:

lm = ContinuousTransformerWrapper(
    dim_in              = 2,
    dim_out             = 36,
    max_seq_len         = 0,
    max_mem_len         = 100,
    attn_layers = Encoder(
        dim             = 512,
        depth           = 4,
        heads           = 4,
        rotary_pos_emb  = True,
        attn_flash      = True,
        attn_num_mem_kv = 20
    )
)

B, M, D, depth = 1, lm.max_mem_len, lm.attn_layers.dim, lm.attn_layers.depth

x       = torch.randn(B, 1024, 2)
length  = torch.randint(100, x.shape[1], size=(x.shape[0],))
mask    = torch.arange(x.shape[1]).unsqueeze(0) < length.unsqueeze(-1)
mems    = [torch.zeros(x.shape[0], M, D) for _ in range(depth)]

out1, mems1 = lm(x, mask=mask, return_mems=True)
out2, mems2 = lm(x, mask=mask, mems=mems, return_mems=True)
torch.testing.assert_close(out1, out2)
for m1, m2 in zip(mems1, mems2):
    torch.testing.assert_close(m1, m2)

I also tried changing https://github.com/lucidrains/x-transformers/blob/583c19dc0eb80182b0fa8ed2bfd3b22bcecbc374/x_transformers/x_transformers.py#L882-L884

to

if exists(input_mask) and exists(mem):
    attend = torch.any(mem)
    input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)

but that doesn’t help. any ideas?

About this issue

  • Original URL
  • State: closed
  • Created 6 months ago
  • Comments: 34 (34 by maintainers)

Commits related to this issue

Most upvoted comments

@pfeatherstone noticed you are using an Encoder instead of a Decoder in your example code. you have a working model based on this idea?

I’m actually using a Decoder. I used Encoder for the repro to make things simpler

ok, i’m going to close this issue, i think it is good now

@pfeatherstone you let me know what you see when you rerun the sandwich norm experiments. thinking about removing it

@pfeatherstone i don’t think there would be any issue, just that a mem_mask would lead to more flexibility, and solve your problem with needing an initial zero mems, which i assume is onnx related

So I’ve fixed the issue of zero mems is the same as not attending to mems at all, and correct rotary embeddings. The second issue i’ve come across is that mems are recorded before the pre-norm layer normalization. Yet, on the next iteration, they are prepended after. I tested it, and i was getting gibberish. I’ve fixed the issue by recording new mems exactly where old mems are prepended. Now, i get sensible results. FYI, i’m using sandwich norm which uses pre-LN.

i’ll make the correction for rotary when i find some time

thanks for taking the initiative and working it out

@lucidrains Yes it worked!

So the total changes are:

if exists(input_mask) and exists(mem):
    attend = torch.any(mem)
    input_mask = pad_at_dim(input_mask, (mem.shape[-2], 0), dim = -1, value = attend)

at line 882 of x_transformers.py

and

if not exists(rotary_pos_emb) and exists(self.rotary_pos_emb):
    M = max(list(map(lambda m: m.shape[1] if exists(m) else 0, mems)))
    T = x.shape[1]
    t = torch.arange(-M, T)
    rotary_pos_emb = self.rotary_pos_emb.forward(t)

at line 1257 of x_transformers.py

The absolute error is around 0.008 on average

I will give it a go

If I use

use_abs_pos_emb=True,
rotary_pos_emb=False

with the suggested change it works.

If I use:

rotary_pos_emb=False

it attempts to use AbsolutePositionalEmbedding which i don’t really want.