tensorflow: TensorFlow 60-80% slower than PyTorch on training Wide ResNet

cc @tfboyd

From https://github.com/tensorflow/tensorflow/issues/7187#issuecomment-295502315

On an AWS p2.xlarge, using the tensorflow/tensorflow:1.0.1-devel-gpu Docker image as a base, I see ~270 ms per epoch while training a WRN-16-4 without dropout on CIFAR-10.

Using a PyTorch implementation from https://github.com/xternalz/WideResNet-pytorch, I see instead ~150 ms per epoch for the same.

My implementation of Wide ResNets uses NCHW and fused batch norm. It does use feed_dict for data loading, but I’ve observed with nvidia-smi that my GPU utilization stays near 100%.


To reproduce:

$ docker build -t dl-papers .
  • Run the Docker image using NVIDIA Docker:
$ nvidia-docker run --rm -it dl-papers /bin/bash
  • Run the TF WRN-16-4 training:
# python -m dl_papers.wide_resnet.train cifar10
  • Observe the logged batch timings, then kill the process.
  • In the same Docker container up the PyTorch Wide ResNet example:
# cd ..
# pip install http://download.pytorch.org/whl/cu80/torch-0.1.11.post5-cp27-none-linux_x86_64.whl
# pip install torchvision tensorboard_logger
# git clone https://github.com/xternalz/WideResNet-pytorch.git
# cd WideResNet-pytorch
  • Run PyTorch training:
# python train.py --dataset cifar10 --layers 16 --widen-factor 4 -p 1
  • Observe logged batch timings.

About this issue

  • Original URL
  • State: closed
  • Created 7 years ago
  • Comments: 39 (39 by maintainers)

Most upvoted comments

I’d like to highlight that I think the biggest gap here is in DX. We started migrating from Theano to TensorFlow the first week of this year. This is partly my frustration speaking, but the process looked something like this:

  1. Write naive model implementations, which end up using NHWC and unfused batch norm.
  2. See documentation that suggests NCHW is faster for convolutions (this predates the perf guide). Refactor models to support both NCHW and NHWC, because we still need to support CPU inference. Observe that the models run ~6x slower. Give up on this for a bit.
  3. See documentation that fused batch norm is faster. Switch back from tf.layers.batch_normalization to tf.contrib.layers.batch_norm. This shows a performance improvement.
  4. Try NCHW again, and now see a performance improvement. a. Realize earlier issue was due to https://github.com/tensorflow/tensorflow/issues/7551
  5. See the performance guide. Convert input pipeline to use queues instead of feed_dict. See negligible speedup.
  6. Compare to straightforward PyTorch impl, and observe that the PyTorch impl runs much faster.
  7. Learn through this issue that we need to enable an undocumented environment flag to access the fastest cuDNN mode for 3x3 convolutions.

As a developer, this is a really suboptimal experience.

My impression is that the modal TF example or published code is at something like our step (1) above, in that it uses NHWC, uses unfused batch norm, and doesn’t enable the non-fused Winograd convolution. Correspondingly, performance is quite far from optimal.

By contrast, though with a smaller sample size, the PyTorch examples I’ve seen generally seem to do “the right thing” performance-wise, and seem to run quickly out-of-the-box. (Also, the status of the built-in PyTorch layer API makes separate PyTorch examples far more consistent in terms of how the code reads.)

I’m very grateful for your help in tracking down these issues, but I really wish the out-of-the-box experience were better, and that it didn’t take so much work to get to this point.

I managed to get a set of traces from pytorch and tensorflow about what convolution algorithm and shapes they use: pytorch:

Input: 128 3 32 32 0 Weight: 16 3 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 1
Input: 128 16 32 32 0 Weight: 160 16 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 6
Input: 128 160 32 32 0 Weight: 160 160 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 6
Input: 128 16 32 32 0 Weight: 160 16 1 1 0 pad: 0 0 0 Stride: 1 1 0 -> Algo: 1
Input: 128 160 32 32 0 Weight: 320 160 3 3 0 pad: 1 1 0 Stride: 2 2 0 -> Algo: 1
Input: 128 320 16 16 0 Weight: 320 320 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 7
Input: 128 160 32 32 0 Weight: 320 160 1 1 0 pad: 0 0 0 Stride: 2 2 0 -> Algo: 1
Input: 128 320 16 16 0 Weight: 640 320 1 1 0 pad: 0 0 0 Stride: 2 2 0 -> Algo: 1
Input: 128 320 16 16 0 Weight: 640 320 3 3 0 pad: 1 1 0 Stride: 2 2 0 -> Algo: 1
Input: 128 640 8 8 0 Weight: 640 640 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 7

bd Input: 128 640 8 8 0 Weight: 640 640 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 5
bf Input: 128 640 8 8 0 Weight: 640 640 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 5

bd Input: 128 320 16 16 0 Weight: 640 320 3 3 0 pad: 1 1 0 Stride: 2 2 0 -> Algo: 0
bf Input: 128 320 16 16 0 Weight: 640 320 3 3 0 pad: 1 1 0 Stride: 2 2 0 -> Algo: 1

bd Input: 128 320 16 16 0 Weight: 640 320 1 1 0 pad: 0 0 0 Stride: 2 2 0 -> Algo: 0
bf Input: 128 320 16 16 0 Weight: 640 320 1 1 0 pad: 0 0 0 Stride: 2 2 0 -> Algo: 3

Input: 128 320 16 16 0 Weight: 320 320 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 5
Input: 128 320 16 16 0 Weight: 320 320 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 5

Input: 128 160 32 32 0 Weight: 320 160 1 1 0 pad: 0 0 0 Stride: 2 2 0 -> Algo: 0
Input: 128 160 32 32 0 Weight: 320 160 1 1 0 pad: 0 0 0 Stride: 2 2 0 -> Algo: 0

Input: 128 160 32 32 0 Weight: 320 160 3 3 0 pad: 1 1 0 Stride: 2 2 0 -> Algo: 0
Input: 128 160 32 32 0 Weight: 320 160 3 3 0 pad: 1 1 0 Stride: 2 2 0 -> Algo: 3

Input: 128 160 32 32 0 Weight: 160 160 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 4
Input: 128 160 32 32 0 Weight: 160 160 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 5

bd Input: 128 16 32 32 0 Weight: 160 16 1 1 0 pad: 0 0 0 Stride: 1 1 0 -> Algo: 1
bf Input: 128 16 32 32 0 Weight: 160 16 1 1 0 pad: 0 0 0 Stride: 1 1 0 -> Algo: 3

bd Input: 128 16 32 32 0 Weight: 160 16 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 4
bf Input: 128 16 32 32 0 Weight: 160 16 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 3

bf Input: 128 3 32 32 0 Weight: 16 3 3 3 0 pad: 1 1 0 Stride: 1 1 0 -> Algo: 0

TensorFlow:

Conv accepts: 128, 3, (32, 32), 16, (3, 3), (1, 1), (2, 2), 1, 0 -> (1, 0)
Conv accepts: 128, 16, (32, 32), 160, (3, 3), (1, 1), (2, 2), 1, 0 -> (6, 0)
Conv accepts: 128, 160, (32, 32), 160, (3, 3), (1, 1), (2, 2), 1, 0 -> (6, 0)
Conv accepts: 128, 16, (32, 32), 160, (1, 1), (1, 1), (0, 0), 1, 0 -> (1, 0)
Conv accepts: 128, 160, (33, 33), 320, (3, 3), (2, 2), (1, 1), 1, 0 -> (1, 0)
Conv accepts: 128, 320, (16, 16), 320, (3, 3), (1, 1), (2, 2), 1, 0 -> (7, 0)
Conv accepts: 128, 160, (32, 32), 320, (1, 1), (2, 2), (0, 0), 1, 0 -> (1, 0)
Conv accepts: 128, 320, (16, 16), 640, (1, 1), (2, 2), (0, 0), 1, 0 -> (1, 0)
Conv accepts: 128, 320, (17, 17), 640, (3, 3), (2, 2), (1, 1), 1, 0 -> (1, 0)
Conv accepts: 128, 640, (8, 8), 640, (3, 3), (1, 1), (2, 2), 1, 0 -> (7, 0)

ConvBwdData accepts: 128, 640, (8, 8), 640, (3, 3), (1, 1), (2, 2), 1, 0 -> (5, 0)
ConvBwdFilter accepts: 128, 640, (8, 8), 640, (3, 3), (1, 1), (2, 2), 1, 0 -> (5, 0)

ConvBwdData accepts: 128, 320, (17, 17), 640, (3, 3), (2, 2), (1, 1), 1, 0 -> (0, 0)
ConvBwdFilter accepts: 128, 320, (17, 17), 640, (3, 3), (2, 2), (1, 1), 1, 0 -> (1, 0)

ConvBwdData accepts: 128, 320, (16, 16), 640, (1, 1), (2, 2), (0, 0), 1, 0 -> (0, 0)
ConvBwdFilter accepts: 128, 320, (16, 16), 640, (1, 1), (2, 2), (0, 0), 1, 0 -> (3, 0)

ConvBwdData accepts: 128, 320, (16, 16), 320, (3, 3), (1, 1), (2, 2), 1, 0 -> (5, 0)
ConvBwdFilter accepts: 128, 320, (16, 16), 320, (3, 3), (1, 1), (2, 2), 1, 0 -> (5, 0)

ConvBwdData accepts: 128, 160, (32, 32), 320, (1, 1), (2, 2), (0, 0), 1, 0 -> (0, 0)
ConvBwdFilter accepts: 128, 160, (32, 32), 320, (1, 1), (2, 2), (0, 0), 1, 0 -> (0, 0)

ConvBwdData accepts: 128, 160, (33, 33), 320, (3, 3), (2, 2), (1, 1), 1, 0 -> (0, 0)
ConvBwdFilter accepts: 128, 160, (33, 33), 320, (3, 3), (2, 2), (1, 1), 1, 0 -> (3, 0)

ConvBwdData accepts: 128, 160, (32, 32), 160, (3, 3), (1, 1), (2, 2), 1, 0 -> (4, 0)
ConvBwdFilter accepts: 128, 160, (32, 32), 160, (3, 3), (1, 1), (2, 2), 1, 0 -> (5, 0)

ConvBwdData accepts: 128, 16, (32, 32), 160, (1, 1), (1, 1), (0, 0), 1, 0 -> (1, 0)
ConvBwdFilter accepts: 128, 16, (32, 32), 160, (1, 1), (1, 1), (0, 0), 1, 0 -> (3, 0)

ConvBwdData accepts: 128, 16, (32, 32), 160, (3, 3), (1, 1), (2, 2), 1, 0 -> (4, 0)
ConvBwdFilter accepts: 128, 16, (32, 32), 160, (3, 3), (1, 1), (2, 2), 1, 0 -> (3, 0)

ConvBwdFilter accepts: 128, 3, (32, 32), 16, (3, 3), (1, 1), (2, 2), 1, 0 -> (0, 0)
# batch size, input chan, (input shape), out chan, (kernel shape), (stride), (padding), _, _ -> algo

With TF_ENABLE_WINOGRAD_NONFUSED=1, tensorflow always chooses the same algorithm as pytorch. But the log shows that tensorflow sometimes calls cudnn with irregular input shapes when stride=2. Maybe it can cause some performance issue.