jax: TPU deadlock

Hello,

I am trying to train reformer model using Trax and JAX. The training fails on Google Colab because of memory limitation. When I run it on google cloud server + TPU, it hangs on the “trax.supervised.Trainer”.

The warning is as follows:

2020-08-26 17:46:37.421334: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:601] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.

The code is very straight forward:

import requests
import os

from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + "10.206.164.18"
print(config.FLAGS.jax_backend_target)

from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import jax
import trax
from trax.data import inputs

import numpy as np
import jax.numpy as jnp

from scipy.special import softmax

import sentencepiece as spm
from sentencepiece import SentencePieceProcessor

import random, glob, os

def fake_data():
  with open("vocab.txt",'w') as f:
    f.write("[MASK]\nL\nA\nG\nV\nE\nS\nI\nK\nR\nD\nT\nP\nN\nQ\nF\nY\nM\nH\nC\nW\nX\nU\nB\nZ\nO")

  if not os.path.exists('dataset'):
    os.makedirs('dataset')

  with open("dataset/train_0.txt",'w') as f:
    for i in range(50):
      f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
      f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
      f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
      f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
      f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
      f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
      f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
      f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
      f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n") 

    with open("dataset/train_1.txt",'w') as f:
      for i in range(50):
        f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
        f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
        f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
        f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
        f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
        f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
        f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
        f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
        f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n") 


fake_data()

spm.SentencePieceTrainer.train(input='vocab.txt', 
                               model_prefix='protein', 
                               vocab_size=30, 
                               model_type="word", 
                               #user_defined_symbols="<MASK>",
                               pad_id=0,
                               unk_id=1,
                               bos_id=2,
                               eos_id=3,
                               pad_piece="[PAD]",
                               unk_piece="[UNK]",
                               bos_piece="[BOS]",
                               eos_piece="[EOS]")


tokenizer = spm.SentencePieceProcessor(model_file='protein.model')

train_files = glob.glob("dataset/train*",recursive=True)
random.shuffle(train_files)

def mask_seq(seq,mask_prob=0.15):
  seq = np.array(seq)
  minValue = 1
  maxValue = len(seq) - 2
  max_mask_tokens = int(maxValue * 0.15 + 0.5)
  randomlist = random.sample(range(minValue, maxValue), max_mask_tokens)
  seq_masked = seq
  seq_masked[randomlist] = tokenizer.encode("[MASK]")[0]
  return seq_masked

def get_seq(train_files):
  while True:
    for file in train_files:
      with open(file) as fp: 
        for line in fp:
          yield line

def get_batch(seq_gen, batch_length):

  batch = []
  while True:
    seq = next(seq_gen)
    seq_ids = tokenizer.encode(seq,add_bos=True,add_eos=True)
    
    new_batch_len = len(batch) + len(seq_ids)

    if new_batch_len <= batch_length :
      batch = batch + seq_ids
      continue

    next_batch = batch
    batch = seq_ids

    yield next_batch

# Set up the data pipeline.
def my_inputs(n_devices):

  MAX_BATCH_LENGTH = 1024*4

  seq_gen = get_seq(train_files)
  batch_gen = get_batch(seq_gen,MAX_BATCH_LENGTH)

  while True:

    inputs = []
    targets = []
    mask = []
    for i in range(n_devices):
      batch_ids = next(batch_gen)
      masked_seq_ids = mask_seq(batch_ids)
      pad_amount = MAX_BATCH_LENGTH - len(batch_ids)

      inputs.append(np.pad(masked_seq_ids, (0,pad_amount)))
      targets.append(np.pad(batch_ids, (0,pad_amount)))
      mask.append(np.pad(np.ones_like(batch_ids, dtype=np.float32),
                          (0,pad_amount),
                          mode='constant'))
    inputs = np.stack(inputs)
    targets = np.stack(targets)
    mask = np.stack(mask)
    yield (inputs, targets, mask)

inp_gen_test = my_inputs(trax.fastmath.device_count())

res = next(inp_gen_test)

print(tokenizer.decode(res[0][0].tolist()))

print(tokenizer.decode(res[1][0].tolist()))


# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.Reformer
n_layers = 15
n_heads = 16
dropout = 0.1
n_tokens = 40000 # They have used very small n_tokens = 2048
vocab_size= 30
d_model = 1024
d_ff = 4096

# Done
# Parameters for MultifactorSchedule:
# ==============================================================================
multifactor.constant = 0.088
multifactor.decay_factor = 0.5
multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
multifactor.steps_per_cycle = 100000
multifactor.steps_per_decay = 20000
multifactor.warmup_steps = 8000


# Done
# Parameters for Adam:
# ==============================================================================
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-09
Adam.weight_decay_rate = 1e-05

# Done
# Parameters for SelfAttention:
# ==============================================================================
#trax.layers.SelfAttention.attention_dropout = 0.05
#trax.layers.SelfAttention.chunk_len = 64
#trax.layers.SelfAttention.n_chunks_before = 1
#trax.layers.SelfAttention.n_parallel_heads = 1
trax.layers.SelfAttention.causal = False
trax.layers.SelfAttention.chunk_len = None
trax.layers.SelfAttention.masked = False
trax.layers.SelfAttention.n_chunks_after = 0
trax.layers.SelfAttention.n_chunks_before = 0
trax.layers.SelfAttention.n_parallel_heads = None
trax.layers.SelfAttention.predict_drop_len = 64
trax.layers.SelfAttention.predict_mem_len = 192
trax.layers.SelfAttention.share_qk = False
trax.layers.SelfAttention.use_python_loop = False
trax.layers.SelfAttention.use_reference_code = False

# Done
# Parameters for EncDecAttention:
# ==============================================================================
trax.layers.EncDecAttention.masked = True
trax.layers.EncDecAttention.n_parallel_heads = None
trax.layers.EncDecAttention.use_python_loop = False
trax.layers.EncDecAttention.use_reference_code = False

# Done
# Parameters for LSHSelfAttention:
# ==============================================================================
#LSHSelfAttention.attention_dropout = 0.0
#LSHSelfAttention.chunk_len = 64
#LSHSelfAttention.n_buckets = [64, 128]
#LSHSelfAttention.n_chunks_after = 0
#LSHSelfAttention.n_chunks_before = 1
#LSHSelfAttention.n_hashes = 1
#LSHSelfAttention.n_parallel_heads = 1
#LSHSelfAttention.predict_drop_len = 128
#LSHSelfAttention.predict_mem_len = 1024

# Done
# Parameters for Reformer:
# ==============================================================================
Reformer.d_model = %d_model
Reformer.d_ff = %d_ff
Reformer.dropout = %dropout
Reformer.ff_activation = @trax.layers.Relu
Reformer.max_len = %n_tokens
Reformer.mode = 'train'
Reformer.n_heads = %n_heads
Reformer.n_encoder_layers = %n_layers
Reformer.n_decoder_layers = %n_layers
Reformer.input_vocab_size = %vocab_size


""")

# Set up a Trainer.
output_dir = os.path.expanduser('train_dir/')

trainer = trax.supervised.Trainer(
    model=trax.models.Reformer,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.multifactor(),
    inputs=trax.data.inputs.Inputs(my_inputs),
    output_dir=output_dir)


# Run one training step, to make sure the model fits in memory.
# The first time trainer.train_epoch is called, it will JIT the entire network
# architecture, which takes around 2 minutes. The JIT-compiled model is saved
# so subsequent runs will be much faster than the first.
trainer.train_epoch(n_steps=1, n_eval_steps=1)

Any idea how I can solve this issue ?

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 18 (2 by maintainers)

Most upvoted comments

I am also getting the same error, where I am trying to train the standard TrnasformerLM, with the slight modification in the parameters(vocab_size=50000, max_len=1024). The same code works on the COLAB + TPU setup, but when I try it with the GCE-VM + CLOUD-TPU setup I am getting the same error.

  1. Can you tell which line of code is causing the long compile? The trainer:
trainer = trax.supervised.Trainer(
    model=trax.models.Reformer,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.multifactor(),
    inputs=trax.data.inputs.Inputs(my_inputs),
    output_dir=output_dir)

I believe this trigger JAX compilation.

  1. For the GPU issue, I figured out the problem. I was using : pip install --upgrade jax jaxlib Which doesn’t support CUDA, and I followed the readme and Jax now can see the GPUs. I think the default Jax installation should support both TPUs and GPUs. This will be easier for users.

The issue of very long compilation doesn’t exist using the GPU for the same command mentioned above.

My main issue here is the jax compilation speed, which is very very slow. I killed the process after left it for 1 hour.