sinkhorn-transformer: generation problem in a toy task

Here is the full script for my toy task (x -> xx like “abc” to “abcabc”)

from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 100
ENC_SEQ_LEN=16
DEC_SEQ_LEN=40
NUM_TOKENS = 256 + 2
BUCKET_SIZE = 8

# helpers

def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs


def cycle():
    while True:
        source = torch.randint(2, 258, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()

        target = torch.cat((source, source), 1)
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        target = torch.cat((prefix, target), axis=1)

        x_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
        y_mask = torch.ones(BATCH_SIZE, target.shape[1]).bool().cuda()


        yield (source, target, x_mask, y_mask)

# instantiate model

class MySinkhornTransformer(nn.Module):
    def __init__(self, num_tokens, dim, depth, heads, bucket_size, enc_max_seq_len, dec_max_seq_len):
        super().__init__()
        
        self.pad_token = 0
        self.sos_token = 1

        self.enc = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, bucket_size=bucket_size, max_seq_len=enc_max_seq_len,
                                         reversible=True, return_embeddings=True)
        self.dec = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, causal=True, bucket_size=bucket_size, max_seq_len=dec_max_seq_len, 
                                         receives_context=True, context_bucket_size=bucket_size, reversible=True)
        self.dec = AutoregressiveWrapper(self.dec, pad_value=num_tokens-2)
    
    @torch.no_grad()
    def generate(self, x, x_mask):
        context = self.enc(x, input_mask=x_mask)
        start_tokens = (torch.ones((x.shape[0],1)) * self.sos_token).long().cuda()

        return self.dec.generate(start_tokens, 32, context=context, context_mask=x_mask)

    def forward(self, x, y, x_mask, y_mask, return_loss):
        context = self.enc(x, input_mask=x_mask)
        return self.dec(y, context=context, input_mask=y_mask, context_mask=x_mask, return_loss=True)


model = MySinkhornTransformer(num_tokens=NUM_TOKENS, dim=512, depth=1, heads=1, bucket_size=BUCKET_SIZE, enc_max_seq_len=ENC_SEQ_LEN, dec_max_seq_len=DEC_SEQ_LEN)
model.cuda()
# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        source, target, x_mask, y_mask = next(cycle())
        loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
        loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            source, target, x_mask, y_mask = next(cycle())
            loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        
        source, target, x_mask, y_mask = next(cycle())
        
        sample = model.generate(x=source, x_mask=x_mask)
        print("input:  ", source[0])
        print("model output:  ", sample[0])

After a few steps the loss becomes practically zero. I checked the logits during the training and they seem to be OK. but during generation phase, the model outputs this pattern: “x,x,x,x,x,y,y,y,y,y” like “aaaabbbb” instead of “abcdabcd”. I was wondering what might be the underlying issue. Do you got any idea?

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 18 (13 by maintainers)

Most upvoted comments

thanks! it’s been a learning experience

You have very good repos (this and the reformer one) but I strongly suggest running toy tasks when implementing a paper. They catch bugs very well specially for seq2seq. The rule of thumb is that a seq2seq should be able to learn “x -> x” or “x -> xx” perfectly.

my hope is I get more eyeballs on the project and perhaps someone (maybe you) will think of something lol

@lucidrains Thanks. And generally I think there is still something wrong as the loss is still high for such a simple task and given the fact that input size is small