tensorflow: FusedBatchNorm does not support 3D Filters

The current implementation of tf.contrib.layers.batch_norm is quite slow with the default parameters. Recent builds of TF support fused=True which forces the use of the faster nn.fused_batch_norm. However this method does not support 3D filters for computing the mean and variance. The normal (slower) variant with fused=False does not have this problem.

FusedBatchNorm:

# Following error is raised when trying to batch normalize 3D filters
elif original_rank not in [2, 4]:
      raise ValueError('Inputs %s has unsupported rank. \
          Expected 2 or 4 but got %d' % (inputs.name, original_rank))

And the default implementation without the fused version uses these axis to compute the tf.nn.moments:

axis = list(range(inputs_rank - 1))

This also works for 3D filters of course. See code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/layers.py#L483

About this issue

  • Original URL
  • State: closed
  • Created 8 years ago
  • Reactions: 1
  • Comments: 15 (3 by maintainers)

Most upvoted comments

I don’t think the issue is fully resolved. And the lack of fused batch norm can be a critical performance issue for networks with 3D convolutions. I have done the modification in _fused_batch_norm as suggested by @zhangyaobit. Yet, I do see the to-do comment in normalization.py, which is more fundamental. I am willing to send a PR if that to-do action will not be implemented any time soon.