tensorflow: Tensorflow 2.3.0 is much slower than PyTorch 1.4.0 in backward propagation. (20 times slower)
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Win10
- CUDA/cuDNN version: CUDA 10.1/cuDNN 7
- TensorFlow version (use command below): 2.3.0
- Python version: 3.7.4
- GPU model and memory: Geforce GTX 1660Ti (6GB)
Describe the current behavior
-
I tried to use both TF2 and PyTorch to train SkipGram model with negative sampling. But notice a huge training speed difference.
-
TF2 takes ~0.5s to finish one step while PyTorch takes ~0.01s. One step refers to use 1 batch data to do forward and backward and update the weights.
-
The main difference is in backward propagation.
-
More specifically, ‘apply_gradients()’ in TF2 is much slower than counterpart ‘step()’ in PyTorch. Does anyone know why?
-
Notes:
- Both versions use GPU.
- tf.function is used.
- Training batch size and data is exactly the same.
- Model size (vocab_size, emb_size) is the same.
- TF2:
@tf.function
def train_step_tf(pos_w, pos_v, neg_v):
print('-'*50)
with tf.GradientTape() as tape:
loss = skip_gram_model_tf(pos_w, pos_v, neg_v)
start_time = time.time()
variables = skip_gram_model_tf.trainable_variables
print('getting variables time = {}'.format(time.time() - start_time))
start_time = time.time()
gradients = tape.gradient(loss, variables)
print('init gradient time = {}'.format(time.time() - start_time))
start_time = time.time()
optimizer.apply_gradients(zip(gradients, variables)) # !!!!! this is where the major speed difference happens!!!!!
print('apply gradient time = {}'.format(time.time() - start_time))
print('-'*50)
return loss
- PyTorch:
def train_step_torch(pos_w, pos_v, neg_v):
print('-'*50)
start_time = time.time()
optimizer_torch.zero_grad()
loss = skip_gram_model_torch.forward(pos_w, pos_v, neg_v)
print('getting variables time = {}'.format(time.time() - start_time))
start_time = time.time()
loss.backward()
print('init gradient time = {}'.format(time.time() - start_time))
start_time = time.time()
optimizer_torch.step() # !!!!this is fast!!!!
print('apply gradient time = {}'.format(time.time() - start_time))
print('-'*50)
return loss
Describe the expected behavior I expect the performance shouldn’t be that different. Is there anything I miss to use in TF2?
Standalone code to reproduce the issue Colab link of full code: https://colab.research.google.com/drive/17QsTkV271LPvo6aJf1QOGuumdU8OXmVj?usp=sharing
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 18 (3 by maintainers)
@jvishnuvardhan Thanks for the update. So glad to see the issue is resolved in 2.7.0! @mdanatg Thanks for taking this issue previously, also many thanks to the engineers who fixed this!!! 😃
The explicit placement was done two years ago, so it probably applies to all versions of
kers.layers.Embedding. That said, it might go unnoticed in practice if for example the embedding table is too large to fit the GPU memory, or if the model is large enough to dominate the compute time.I had a chance to have a closer look at the
tf.functioninstance. I see the latest version of the colab does avoid excessive retracing, which is good. It’s still significantly slower than PyTorch, and we need to have a closer look why.I had a suspicion that it’s because the GPU placement is inefficient, so I re-ran the tests on CPU:
So, the TF tests are slower on GPU than on CPU, which is a strong indication that some part of the computation is placed on the CPU, causing massive slowdown due to the relatively slow CPU->GPU transfers.
Another observation - if I enable XLA with
tf.function(experimental_compile=True), it’s way faster than all three. But, that errors out on GPU, yielding a hint to the problem:Trying to access resource skip_gram_model_tf_18/w_emb/embeddings_34250 located in device /job:localhost/replica:0/task:0/device:CPU:0 from device /job:localhost/replica:0/task:0/device:GPU:0.This is a sure indication that the embedding is incorrectly placed on the CPU, slowing everything down.
@tomerk @fchollet @omalleyt12 for more thoughts on why the embedding might be misplaced