sinkhorn-transformer: generation problem in a toy task
Here is the full script for my toy task (x -> xx like “abc” to “abcabc”)
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# constants
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100
GENERATE_EVERY = 100
ENC_SEQ_LEN=16
DEC_SEQ_LEN=40
NUM_TOKENS = 256 + 2
BUCKET_SIZE = 8
# helpers
def top_k(logits, thres = 0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
def cycle():
while True:
source = torch.randint(2, 258, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
target = torch.cat((source, source), 1)
prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
target = torch.cat((prefix, target), axis=1)
x_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
y_mask = torch.ones(BATCH_SIZE, target.shape[1]).bool().cuda()
yield (source, target, x_mask, y_mask)
# instantiate model
class MySinkhornTransformer(nn.Module):
def __init__(self, num_tokens, dim, depth, heads, bucket_size, enc_max_seq_len, dec_max_seq_len):
super().__init__()
self.pad_token = 0
self.sos_token = 1
self.enc = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, bucket_size=bucket_size, max_seq_len=enc_max_seq_len,
reversible=True, return_embeddings=True)
self.dec = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, causal=True, bucket_size=bucket_size, max_seq_len=dec_max_seq_len,
receives_context=True, context_bucket_size=bucket_size, reversible=True)
self.dec = AutoregressiveWrapper(self.dec, pad_value=num_tokens-2)
@torch.no_grad()
def generate(self, x, x_mask):
context = self.enc(x, input_mask=x_mask)
start_tokens = (torch.ones((x.shape[0],1)) * self.sos_token).long().cuda()
return self.dec.generate(start_tokens, 32, context=context, context_mask=x_mask)
def forward(self, x, y, x_mask, y_mask, return_loss):
context = self.enc(x, input_mask=x_mask)
return self.dec(y, context=context, input_mask=y_mask, context_mask=x_mask, return_loss=True)
model = MySinkhornTransformer(num_tokens=NUM_TOKENS, dim=512, depth=1, heads=1, bucket_size=BUCKET_SIZE, enc_max_seq_len=ENC_SEQ_LEN, dec_max_seq_len=DEC_SEQ_LEN)
model.cuda()
# optimizer
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY):
source, target, x_mask, y_mask = next(cycle())
loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
source, target, x_mask, y_mask = next(cycle())
loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
print(f'validation loss: {loss.item()}')
if i % GENERATE_EVERY == 0:
model.eval()
source, target, x_mask, y_mask = next(cycle())
sample = model.generate(x=source, x_mask=x_mask)
print("input: ", source[0])
print("model output: ", sample[0])
After a few steps the loss becomes practically zero. I checked the logits during the training and they seem to be OK. but during generation phase, the model outputs this pattern: “x,x,x,x,x,y,y,y,y,y” like “aaaabbbb” instead of “abcdabcd”. I was wondering what might be the underlying issue. Do you got any idea?
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 18 (13 by maintainers)
thanks! it’s been a learning experience
You have very good repos (this and the reformer one) but I strongly suggest running toy tasks when implementing a paper. They catch bugs very well specially for seq2seq. The rule of thumb is that a seq2seq should be able to learn “x -> x” or “x -> xx” perfectly.
my hope is I get more eyeballs on the project and perhaps someone (maybe you) will think of something lol
@lucidrains Thanks. And generally I think there is still something wrong as the loss is still high for such a simple task and given the fact that input size is small