tensorflow: tf.Keras BatchNormalization layer causing impossible loss/accuracy in GAN training
System information
- Have I written custom code: yes
- OS Platform and Distribution: Windows 10, version 1809, os build 17763.1098
- Mobile device if the issue happens on mobile device: N/A
- TensorFlow installed from: binary
- TensorFlow version: 2.0.0
- Python version: Python 3.7.6
- Bazel version: N/A
- GCC/Compiler version: N/A
- CUDA/cuDNN version: 10.1, V10.1.105
- GPU model and memory: GeForce RTX 2080 Ti, 11GB
I’ve been getting unusual losses and accuracies when training GANs with batch normalization layers in the discriminator using tf.keras. GANs have an optimal objective function value of log(4), which occurs when the discriminator is completely unable to discern real samples from fakes and hence predicts 0.5 for all samples. When I include BatchNormalization layers in my discriminator, both the generator and the discriminator achieve near perfect scores (high accuracy, low loss), which is impossible in an adversarial setting.
Without BatchNorm: This figure shows the losses (y) per epoch (x) when BN is not used. Note that occasional values below the theoretical minimum are due to the training being an iterative process. This figure shows the accuracies when BN is not used, which settle at about 50% each. Both of these figures show reasonable values.
With BatchNorm: This figure shows the losses (y) per epoch (x) when BN is used. See how the GAN objective, which shouldn’t fall below log(4), approaches 0. This figure shows the accuracies when BN is used, with both approaching 100%. GANs are adversarial; the generator and discriminator can’t both have 100% accuracy.
I suspect this has something to do with setting model.trainable = False
.
Standalone code to reproduce and visualize can be found here; set BATCHNORM_MOMENTUM
to 0.9
to enable batchnorm or None
to disable it
Update: It seems that BatchNorm in the generator also causes this problem, but has been harder to reproduce.
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 21 (4 by maintainers)
There were some changes in the way trainable behaves between 1.x and 2.x; please take a look at the BatchNorm docs here under “About setting layer.trainable = False on a BatchNormalization layer”, as this may be causing the discrepancy.
@gadagashwini Yeah that’s it reproduced. The objective shouldn’t ever go much below the theoretical minimum.
Hi @ConorLazarou . verified your claim. Yes, there is some issue. But it works well on tf 1.15 (current default version on colab).