tensorflow: Recurrent Dropout is Wrong
I’ve reviewed one design in-depth, and two others superficially, but Keras/TF’s recurrent_dropout does not implement any of them; publication links below.
-
I see some potentially severe problems with TF’s implementation in light of the papers I’ve read, which explicitly advocate against the used scheme. This said - what is TensorFlow / Keras’s justification / rationale of its own implementation?
-
The implementation is inconsistent - see below; the docstring only mentions a performance difference, but there’s also a reproducibility and design difference;
==1uses different masks per gate, whereas==2uses a shared mask. The two are neither theoretically nor practically identical.
Second’s fixable via a docstring, but first involves significant changes to recurrent dropout logic for LSTM, GRU, and maybe other RNNs. This said: is TensorFlow / Keras open to changing its base implementations of recurrent dropout? If so, I can go ahead and clarify (1) in detail, and maybe even do the re-implementing myself in a PR, per paper 1.
Inconsistency: implementation==1 vs. implementation==2
if 0. < self.recurrent_dropout < 1.: # implementation==1
h_tm1_i = h_tm1 * rec_dp_mask[0]
h_tm1_f = h_tm1 * rec_dp_mask[1]
h_tm1_c = h_tm1 * rec_dp_mask[2]
h_tm1_o = h_tm1 * rec_dp_mask[3]
if 0. < self.recurrent_dropout < 1.: # implementation==2
h_tm1 *= rec_dp_mask[0]
Source codes: keras – tf.keras
Publications:
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 22 (18 by maintainers)
@qlzh727 It should now be clear that
recurrent_dropoutis indeed off in regards to scaling, and possibly in regards tostatefuldropout masking - but I’d like to make a stronger claim that paper 1’s implementation should replace the current one, or at least be added as an option to the base implementation. To this end, I present a case in favor of 1 in this post:THEORETICAL:
Improved gradient persistence along time; with a bounded activation like
sigmoidortanhin the hidden-to-hidden transformation, inverted dropout upscales pre-activations, which drives activations closer to saturation, diminishing the gradient. The effect is amplified upon recurrence - especially for long sequences, inhibiting learning long-term dependencies. (My observation)Improved inference-time performance; for
ttimesteps, assuming a favorable case of all gates being equal to 1 to preserve the hidden state, the hidden state vector at time t,h_t, under paper 2’s recurrent transformation, istanhactivations. Higherrecurrent_dropoutaccelerated divergence.EMPIRICAL:
Authors demonstrate equal or better performance to 2’s implementation in a number of applications, including word-level language modeling, character-level language modeling, named entity recognition, Twitter sentiment analysis, and a demonstrative synthetic task. Some distinguishing findings are:
(iv) applying dropout to hidden state updates rather than hidden states in some cases leads to a perplexity decrease by more than 30 points; (v) our approach is effective even when combined with the forward dropout - for LSTMs we are able to bring down perplexity on validation set from 130 to 91.6 – pg. 6, Word-level Language Modeling
Per-step mask sampling was shown to work better or as well as per-sequence sampling across all tasks, and even for 2’s implementation on word-level language embedding.
While all presented so far should be sufficient, I offer to take it a step further; I can fill the missing gap in the sequence-to-sequence domain. I’m working on a seizure EEG classification task, containing 16-channel recordings w/ 240,000 timesteps, totaling 65+ GB of data. My work has been extensive, and I have a massive database of hundreds of tested architectures, majority CNN + RNN. The stacked RNNs are typically fed ~1500 timesteps post dim-reduction in a stateful manner - which is far more than what’s been tested in any of the papers.
I can compare the current dropout scheme against 1’s, and report results; I’ll also be using my visualization package to inspect the gradients directly.