jax: TPU deadlock
Hello,
I am trying to train reformer model using Trax and JAX. The training fails on Google Colab because of memory limitation. When I run it on google cloud server + TPU, it hangs on the “trax.supervised.Trainer”.
The warning is as follows:
2020-08-26 17:46:37.421334: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:601] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.
The code is very straight forward:
import requests
import os
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + "10.206.164.18"
print(config.FLAGS.jax_backend_target)
from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import jax
import trax
from trax.data import inputs
import numpy as np
import jax.numpy as jnp
from scipy.special import softmax
import sentencepiece as spm
from sentencepiece import SentencePieceProcessor
import random, glob, os
def fake_data():
with open("vocab.txt",'w') as f:
f.write("[MASK]\nL\nA\nG\nV\nE\nS\nI\nK\nR\nD\nT\nP\nN\nQ\nF\nY\nM\nH\nC\nW\nX\nU\nB\nZ\nO")
if not os.path.exists('dataset'):
os.makedirs('dataset')
with open("dataset/train_0.txt",'w') as f:
for i in range(50):
f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n")
with open("dataset/train_1.txt",'w') as f:
for i in range(50):
f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n")
fake_data()
spm.SentencePieceTrainer.train(input='vocab.txt',
model_prefix='protein',
vocab_size=30,
model_type="word",
#user_defined_symbols="<MASK>",
pad_id=0,
unk_id=1,
bos_id=2,
eos_id=3,
pad_piece="[PAD]",
unk_piece="[UNK]",
bos_piece="[BOS]",
eos_piece="[EOS]")
tokenizer = spm.SentencePieceProcessor(model_file='protein.model')
train_files = glob.glob("dataset/train*",recursive=True)
random.shuffle(train_files)
def mask_seq(seq,mask_prob=0.15):
seq = np.array(seq)
minValue = 1
maxValue = len(seq) - 2
max_mask_tokens = int(maxValue * 0.15 + 0.5)
randomlist = random.sample(range(minValue, maxValue), max_mask_tokens)
seq_masked = seq
seq_masked[randomlist] = tokenizer.encode("[MASK]")[0]
return seq_masked
def get_seq(train_files):
while True:
for file in train_files:
with open(file) as fp:
for line in fp:
yield line
def get_batch(seq_gen, batch_length):
batch = []
while True:
seq = next(seq_gen)
seq_ids = tokenizer.encode(seq,add_bos=True,add_eos=True)
new_batch_len = len(batch) + len(seq_ids)
if new_batch_len <= batch_length :
batch = batch + seq_ids
continue
next_batch = batch
batch = seq_ids
yield next_batch
# Set up the data pipeline.
def my_inputs(n_devices):
MAX_BATCH_LENGTH = 1024*4
seq_gen = get_seq(train_files)
batch_gen = get_batch(seq_gen,MAX_BATCH_LENGTH)
while True:
inputs = []
targets = []
mask = []
for i in range(n_devices):
batch_ids = next(batch_gen)
masked_seq_ids = mask_seq(batch_ids)
pad_amount = MAX_BATCH_LENGTH - len(batch_ids)
inputs.append(np.pad(masked_seq_ids, (0,pad_amount)))
targets.append(np.pad(batch_ids, (0,pad_amount)))
mask.append(np.pad(np.ones_like(batch_ids, dtype=np.float32),
(0,pad_amount),
mode='constant'))
inputs = np.stack(inputs)
targets = np.stack(targets)
mask = np.stack(mask)
yield (inputs, targets, mask)
inp_gen_test = my_inputs(trax.fastmath.device_count())
res = next(inp_gen_test)
print(tokenizer.decode(res[0][0].tolist()))
print(tokenizer.decode(res[1][0].tolist()))
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib
# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.Reformer
n_layers = 15
n_heads = 16
dropout = 0.1
n_tokens = 40000 # They have used very small n_tokens = 2048
vocab_size= 30
d_model = 1024
d_ff = 4096
# Done
# Parameters for MultifactorSchedule:
# ==============================================================================
multifactor.constant = 0.088
multifactor.decay_factor = 0.5
multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
multifactor.steps_per_cycle = 100000
multifactor.steps_per_decay = 20000
multifactor.warmup_steps = 8000
# Done
# Parameters for Adam:
# ==============================================================================
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-09
Adam.weight_decay_rate = 1e-05
# Done
# Parameters for SelfAttention:
# ==============================================================================
#trax.layers.SelfAttention.attention_dropout = 0.05
#trax.layers.SelfAttention.chunk_len = 64
#trax.layers.SelfAttention.n_chunks_before = 1
#trax.layers.SelfAttention.n_parallel_heads = 1
trax.layers.SelfAttention.causal = False
trax.layers.SelfAttention.chunk_len = None
trax.layers.SelfAttention.masked = False
trax.layers.SelfAttention.n_chunks_after = 0
trax.layers.SelfAttention.n_chunks_before = 0
trax.layers.SelfAttention.n_parallel_heads = None
trax.layers.SelfAttention.predict_drop_len = 64
trax.layers.SelfAttention.predict_mem_len = 192
trax.layers.SelfAttention.share_qk = False
trax.layers.SelfAttention.use_python_loop = False
trax.layers.SelfAttention.use_reference_code = False
# Done
# Parameters for EncDecAttention:
# ==============================================================================
trax.layers.EncDecAttention.masked = True
trax.layers.EncDecAttention.n_parallel_heads = None
trax.layers.EncDecAttention.use_python_loop = False
trax.layers.EncDecAttention.use_reference_code = False
# Done
# Parameters for LSHSelfAttention:
# ==============================================================================
#LSHSelfAttention.attention_dropout = 0.0
#LSHSelfAttention.chunk_len = 64
#LSHSelfAttention.n_buckets = [64, 128]
#LSHSelfAttention.n_chunks_after = 0
#LSHSelfAttention.n_chunks_before = 1
#LSHSelfAttention.n_hashes = 1
#LSHSelfAttention.n_parallel_heads = 1
#LSHSelfAttention.predict_drop_len = 128
#LSHSelfAttention.predict_mem_len = 1024
# Done
# Parameters for Reformer:
# ==============================================================================
Reformer.d_model = %d_model
Reformer.d_ff = %d_ff
Reformer.dropout = %dropout
Reformer.ff_activation = @trax.layers.Relu
Reformer.max_len = %n_tokens
Reformer.mode = 'train'
Reformer.n_heads = %n_heads
Reformer.n_encoder_layers = %n_layers
Reformer.n_decoder_layers = %n_layers
Reformer.input_vocab_size = %vocab_size
""")
# Set up a Trainer.
output_dir = os.path.expanduser('train_dir/')
trainer = trax.supervised.Trainer(
model=trax.models.Reformer,
loss_fn=trax.layers.CrossEntropyLoss(),
optimizer=trax.optimizers.Adam,
lr_schedule=trax.lr.multifactor(),
inputs=trax.data.inputs.Inputs(my_inputs),
output_dir=output_dir)
# Run one training step, to make sure the model fits in memory.
# The first time trainer.train_epoch is called, it will JIT the entire network
# architecture, which takes around 2 minutes. The JIT-compiled model is saved
# so subsequent runs will be much faster than the first.
trainer.train_epoch(n_steps=1, n_eval_steps=1)
Any idea how I can solve this issue ?
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 18 (2 by maintainers)
I am also getting the same error, where I am trying to train the standard TrnasformerLM, with the slight modification in the parameters(vocab_size=50000, max_len=1024). The same code works on the COLAB + TPU setup, but when I try it with the GCE-VM + CLOUD-TPU setup I am getting the same error.
I believe this trigger JAX compilation.
The issue of very long compilation doesn’t exist using the GPU for the same command mentioned above.
My main issue here is the jax compilation speed, which is very very slow. I killed the process after left it for 1 hour.