keras: Attention module not working with TimeDistributed layer
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 21.04
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): TensorFlow 2.5.0
- Python version: Python 3.6.9
- Bazel version (if compiling from source):
- GPU model and memory: GeForce GTX 1080 and 64 gb memory Describe the problem.
I am trying to apply the keras TimeDistributed layer to the Attention layer in Keras but it gives and error.
Describe the current behavior.
When I try to run the code I just get an error as follows:
NotImplementedError: Please run in eager mode or implement the compute_output_shape method on your layer (Attention).
Describe the expected behavior. I expect the code snippet example below to run without an error. The TimeDistributed layer should apply the attention layer to each timestep in the tensor provided. Contributing.
- Do you want to contribute a PR? (yes/no): No. I am willing to try and fix this if someone has any guidance on how.
Standalone code to reproduce the issue.
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LSTM, Conv1D, TimeDistributed, Flatten, Concatenate, Attention
# Variable-length int sequences.
query_input = tf.keras.Input(shape=(None, 1,100), dtype='float32')
value_input = tf.keras.Input(shape=(None, 4,100), dtype='float32')
print("query_seq_encoding: ", query_input)
print("value_seq_encoding: ", value_input)
# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = TimeDistributed(tf.keras.layers.Attention())(
[query_input, value_input])
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Reactions: 1
- Comments: 15 (6 by maintainers)
It seems like
Attentiondoes not implementcompute_output_shapeand uses the method provided by theLayerbase class (which does not compute shape but raises an error). But it’s hard to implementcompute_output_shapefor Attention layers because its output depends on the parameters provided of thecallmethod. Ifcall(inputs, return_attention_scores=False), it returns a single tensor of shape[batch_size, Tq, dim], and if the flag is setcall(inputs, return_attention_scores=True), it returns additional tensor with shape[batch_size, Tq, Tv]. Hence, it’s impossible to compute output shape once (when building a model). Because of this, I guesscompute_output_shapeis not implemented yet.I’m not sure why this design was chosen. This is the only layer I know that has a parameter of the
callmethod that changes the output’s shape.For other layers, output shapes can vary, but they depend on attributes of the class that are defined during instantiation. So, it’s easy to compute output shape once. For example, output shape of RNN may vary, but it depends on the parameters of the constructor like
LSTM(…, return_sequences=True).So, I see two possible solutions for this issue:
compute_output_shapethat supports parameters likereturn_attention_scoreswithout significant (and probably unwanted) changes to thekerasreturn_attention_scoresfromcallmethod to the constructor ofAttentionlayer and implementcompute_output_shape. I don’t know if it breaks something or what’s the reason for such design, though. So guidance or approval of this approach is needed.@mishc9
you can try this colab if you’re interested in
https://colab.research.google.com/drive/1X6zuidDqZqf4xM5YBEygPlyy6KVJwYC1?usp=sharing
@tempdeltavalue I will try to reproduce this error and fix that. For Attention layer it was enough to add
compute_output_shapemethod. I believe it would work forTFOpLambdaas well.Edit
TFOpLambdalackscompute_output_shape, so it falls back to the implementation ofLayer.compute_output_shapethat doesn’t work without eager mode. But I don’t know yet how to implementTFOpLambda.compute_output_shapefor general purpose. As a workaround, you could make a patch for this class for your specific needs.Which version of tensorflow are you using in this example?