tensorflow: FusedBatchNorm does not support 3D Filters
The current implementation of tf.contrib.layers.batch_norm is quite slow with the default parameters. Recent builds of TF support fused=True which forces the use of the faster nn.fused_batch_norm. However this method does not support 3D filters for computing the mean and variance. The normal (slower) variant with fused=False does not have this problem.
FusedBatchNorm:
# Following error is raised when trying to batch normalize 3D filters
elif original_rank not in [2, 4]:
raise ValueError('Inputs %s has unsupported rank. \
Expected 2 or 4 but got %d' % (inputs.name, original_rank))
And the default implementation without the fused version uses these axis to compute the tf.nn.moments:
axis = list(range(inputs_rank - 1))
This also works for 3D filters of course. See code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/layers.py#L483
About this issue
- Original URL
- State: closed
- Created 8 years ago
- Reactions: 1
- Comments: 15 (3 by maintainers)
I don’t think the issue is fully resolved. And the lack of fused batch norm can be a critical performance issue for networks with 3D convolutions. I have done the modification in _fused_batch_norm as suggested by @zhangyaobit. Yet, I do see the to-do comment in normalization.py, which is more fundamental. I am willing to send a PR if that to-do action will not be implemented any time soon.