frugally-deep: Frugally LSTM Encoder-Decoder results different from Keras/Tensorflow LSTM Encoder-Decoder (missing support for initial_state)

Hi @Dobiasd

I have been working on the Encoder-Decoder model for Vehicle Path Forecasting since you added support for returned_states and show_tensor5 on LSTM-based models. The workflow of the project was described on this past issue. After some experiments, the LSTM-Based encoder and decoder models are not giving me any problem related to returned_states = True or show_tensor5, confirming frugally-deep fixes worked. However, I have been trying to replicate the results I obtained using the Keras/Tensorflow models without success.

The frugally-deep fdeep_encoder_model_NT is returning the exact same encoder_hidden_state and encoder_cell_state states compared to its Tf + Keras counterparts using the encoder_model.hdf5. However, the fdeep_decoder_model_NT is not giving me the same decoder_hidden_state and decoder_cell_state output states (compared to the results using Tf + Keras encoder_model.hdf5) 😦

Specifically, I develop the decoder inference model using TF + Keras (please refer yourself to past comments in this issue to see the corresponding code), and then converted it from .hdf5 to .json, ready to be ported into the C++ application (same as with the encoder model). Validating the encoder states: image However, both frugally-deep decoder_hidden_state and decoder_cell_state differ from corresponding Keras-based decoder_hidden_state and decoder_cell_state: image Resulting, as expected, in a wrong bounding box prediction: image which does not match with the corresponding Keras Results: image I do not really know about what is happening with fdeep_decoder_model_NT, so I have various options in mind:

  • I have trained another model using LSTM instead CuDNNLSTM layers in order to check if the problem is with CuDNNLSTM layer implementation. However, the problem is still present when using other LSTM-based cells like CuDNNLSTM and LSTM. The fdeep_encoder_model works well but fdeep_decoder_model is still making wrong predictions (both at states returned and next bbox prediction).
  • Now I am working in the main.cpp file. Maybe the problem is inside my internal manipulation of fdeep::tensor5 and fdeep::tensor5s when feeding the data into the ported models. However, both models are working well, except that the decoder’s model is making (inaccurate) predictions of future bounding boxes, but it did not crash in any step of the script execution.
  • I am puzzled about the following fact: At main.cpp the decoders predictions is made with the following command: auto decoder_outputs = decoder_model.predict({target_seq, encoder_states.at(0), encoder_states.at(1)});, where encoder_states.at(0) and encoder_states.at(1) represent h_enc and c_enc respectively. However, I tried by interchanging the encoder states at the input of the decoder prediction line like this: auto decoder_outputs = decoder_model.predict({target_seq, encoder_states.at(1), encoder_states.at(0)}); and obtaining the exact same predicted_next_box (even though I interchanged the input order of decoder_states at the prediction function).
  • Finally, apart from the wrong values of h_dec and c_dec returned by fdeep_decoder_model, I noticed both h_dec hidden states (from frugally AND Keras) are in the range [-1, 1], but that does not occur to c_dec hidden states. In Keras, c_dec have values from [-11, 11] but, in frugally, c_dec takes values from [-1, 1]. In addition, based on your suggestion about internal scaling causing this kind of issues, by inspecting the fdeep_encoder_model.json, there are some initializers parameters that are using Variance_Scaling parameter inside that maybe are the cause of errors at inference-time. I think maybe this at the root of the problem but I have no idea of how to get the correct h_enc and c_enc, both between the same ranges used in Keras and with the correct values as well.

Here is the main.cpp file I am running to test the results. Any comment or suggestion about the code would be welcomed!

#include <fdeep/fdeep.hpp>
#include <vector>
#include <fstream>
#include <iostream>

int main()
{
	// Loading the previously trained models
	const auto encoder_model = fdeep::load_model("fdeep_encoder_model_NT.json");
	std::cout << "Encoder Model Loaded!" << std::endl;
	const auto decoder_model = fdeep::load_model("fdeep_decoder_model_NT.json");
	std::cout << "Decoder Model Loaded!" << std::endl;
	// Batch_size = 1, num_timesteps = 10 and num_features = 4
	fdeep::shape5 in_traj_shape(1,1,1,10,4);
	// Loading a sample sequence trajectory into tensor5 data structure
	const std::vector<float> src_traj  = {1728, 715, 191, 221,
					1717, 710, 202, 215,
					1706, 704, 206, 198,
					1695, 700, 217, 196,
					1687, 696, 228, 183,
					1680, 689, 240, 181,
					1668, 668, 240, 198,
					1661, 668, 243, 194,
					1650, 664, 251, 189,
					1635, 660, 266, 181};
	// Input trajectory from vector to tensor5 data structure
	const fdeep::shared_float_vec shared_traj(fplus::make_shared_ref<fdeep::float_vec>(src_traj));
	const fdeep::tensor5 encoder_inputs(in_traj_shape, shared_traj);
	std::cout << "Trajectory #0!" << fdeep::show_tensor5(encoder_inputs) << std::endl;
	// Using loaded encoder model to predict encoder output states
	// Then encoder_states can be feed as input tensors into decoder_model
	const auto encoder_states = encoder_model.predict({encoder_inputs});
	// Printing for debbuging purposes
	std::cout << "h_enc: "<< fdeep::show_tensor5(encoder_states.at(0)) << std::endl;
	std::cout << "c_enc: "<< fdeep::show_tensor5(encoder_states.at(1)) << std::endl;
	// Creating a SOS input sequence token to signal decoder model to start making predictions
	fdeep::shape5 bbox_shape(1,1,1,1,4);
	// Loading a sample sequence trajectory into tensor5 data structure
	const std::vector<float> SOS_token  = {9999.0, 9999.0, 9999.0, 9999.0};
	const fdeep::shared_float_vec shared_SOS_token(fplus::make_shared_ref<fdeep::float_vec>(SOS_token));
	fdeep::tensor5 target_seq(bbox_shape, shared_SOS_token);
	// In Python we have: Prediction, h, c = decoder_model.predict([target_seq] + state)
	auto decoder_outputs = decoder_model.predict({target_seq, encoder_states.at(1), encoder_states.at(0)});
	// Printing for debugging purposes
	std::cout << "h_dec: "<< fdeep::show_tensor5(decoder_outputs.at(1)) << std::endl;
	std::cout << "c_dec: "<< fdeep::show_tensor5(decoder_outputs.at(2)) << std::endl;
	std::cout << "Predicted next bounding box!" << fdeep::show_tensor5(decoder_outputs.at(0)) << std::endl;
}

The fdeep_encoder_model_NT.json model imported into the C++ application is avaliable to download and inspect from this past comment. The fdeep_decoder_model_NT.json can be downloaded from the following link: Decoder model: https://drive.google.com/open?id=1hwrjcnNfWaqQI0o8TmJKtfsAwj6zd9aq I would really appreciate any help with this issue. I am puzzled because the encoder model is working perfectly but the decoder model does not, specifically, the results between the Keras vs Frugally decoder models differ, giving me wrong output predictions that cannot be used at all.

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 25 (19 by maintainers)

Most upvoted comments

For now, I cannot commit to this. If I find the time I will let you know.

Hi @Dobiasd

Thank you very much! Firstly, thanks for pinpointing the issue with the initial_state in LSTM layers as the root course of the issue. Secondly, I will definitely try the new 0.8.5 frugally-deep version with the fixes for the Vehicle Path Prediction project and see how it performs. Finally, it is a pleasure to contribute to the improvement of Frugally-Deep. IMHO, this kind of tools and developments are really appreciated for the entire Machine Learning Community.

I will continue the journey with Frugally-Deep library and hope to share interesting results (when project works as expected 😃

Keep the great work @Dobiasd!

I reduced the model further:

from keras.layers import Input, LSTM
from keras.models import Model

decoder_inputs = Input(shape=(None, 4))

decoder_state_input_h = Input(shape=(12,))
decoder_state_input_c = Input(shape=(12,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

decoder_outputs, state_h, state_c = LSTM(units=12,
                                         return_sequences=True,
                                         return_state=True)(decoder_inputs,
                                                            initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]

decoder_model = Model(inputs=[decoder_inputs] + decoder_states_inputs, outputs=decoder_states)
decoder_model.save("decoder_model.h5", include_optimizer=False)

And the tests when loading in C++ still fail.

The problem is initial_state. Frugally-deep does not support it, but also not complain about it when converting the model.

I’ll try to add support for it, and get back to you with the results.

OK, with making the following adjustment the script trains and saves a decoder_model.h5.

HDF5_in_dir = "."
HDF5_in_filename = "10i_10o_10s_8-1-E-W-LBV.hdf5"
model_saving_dir = "."
base_saving_dir = "."

Then, after running python3 convert_model.py decoder_model.h5 decoder_model.json, in C++ fdeep::load_model("decoder_model.json"); fails.

So now I have something reproducible, I can work with.

Thanks a lot for your patience. 🙂

I’ll get back to you with the results.

Hi @Dobiasd

Sorry for the late response! We have been trying a couple of other options to integrate Python models for vehicle trajectory prediction and C++ optimized YOLO-V3 (Darknet) for object detection. For the time being, I am really interested in the frugally-deep implementation, and therefore, these are the files you requested for debugging purposes. I will try to be as concise as possible in the explanation 😃

These are the two models, encoder_model.h5 and decoder_model.h5, in .h5 format used for the tests. You can download them using the following links:

These are the Keras/Python scripts used for making inference both in Python and C++:

  • Minimal inference test example (Python):
import numpy as np
from keras.models import load_model
np.set_printoptions(suppress=True)

encoder_model = load_model("encoder_model.h5")
print("Encoder model loaded!")
decoder_model = load_model("decoder_model.h5")
print("Decoder model loaded!")
source_trajectory  = np.array([[1728, 715, 191, 221],
			[1717, 710, 202, 215],
			[1706, 704, 206, 198],
			[1695, 700, 217, 196],	
			[1687, 696, 228, 183],
			[1680, 689, 240, 181],
			[1668, 668, 240, 198],
			[1661, 668, 243, 194],
			[1650, 664, 251, 189],
			[1635, 660, 266, 181]])
output_trajectory_length = source_trajectory.shape[0]
num_output_features = source_trajectory.shape[1]
print("Source trajectory : \n {}".format(source_trajectory))
src_traj_batched = np.expand_dims(source_trajectory, axis = 0)
state = encoder_model.predict(src_traj_batched)
print("h_enc: \n {}".format(np.round(state[0], 2)))
print("c_enc: \n {}".format(np.round(state[1], 2)))
sos_token = 9999
target_trajectory = np.array([sos_token, sos_token, sos_token, sos_token]).astype("uint16").reshape(1,1, num_output_features)
prediction, h, c = decoder_model.predict([target_trajectory] + state)
print("h_dec = \n {}".format(np.round(h, 2)))
print("c_dec = \n {}".format(np.round(c, 2)))
print("Predicted next bounding box = {}".format(np.array(prediction).astype("uint16")))
  • Equivalent Minimal Inference test example (Frugally-C++):
#include <fdeep/fdeep.hpp>
#include <vector>
#include <fstream>
#include <iostream>

int main()
{
   const auto encoder_model = fdeep::load_model("fdeep_encoder_model.json");
   std::cout << "Encoder Model Loaded!" << std::endl;
   const auto decoder_model = fdeep::load_model("fdeep_decoder_model.json");
   std::cout << "Decoder Model Loaded!" << std::endl;
   const int num_input_features = 4;
   const int input_trajectory_length = 10;
   fdeep::shape5 input_trajectory_shape(1, 1, 1, input_trajectory_length, num_input_features);
   const std::vector<float> source_trajectory  = {1728, 715, 191, 221,
				1717, 710, 202, 215,
				1706, 704, 206, 198,
				1695, 700, 217, 196,
				1687, 696, 228, 183,
				1680, 689, 240, 181,
				1668, 668, 240, 198,
				1661, 668, 243, 194,
				1650, 664, 251, 189,
				1635, 660, 266, 181};
   const fdeep::shared_float_vec shared_trajectory(fplus::make_shared_ref<fdeep::float_vec>(source_trajectory));
   const fdeep::tensor5 encoder_inputs(input_trajectory_shape, shared_trajectory);
   std::cout << "Trajectory #0!" << fdeep::show_tensor5(encoder_inputs) << std::endl;
   const auto encoder_states = encoder_model.predict({encoder_inputs});
   std::cout << "h_enc: "<< fdeep::show_tensor5(encoder_states.at(0)) << std::endl;
   std::cout << "c_enc: "<< fdeep::show_tensor5(encoder_states.at(1)) << std::endl;
   fdeep::shape5 bbox_shape(1, 1, 1, 1, num_input_features);
   const std::vector<float> SOS_token  = {9999, 9999, 9999, 9999};
   const fdeep::shared_float_vec shared_SOS_token(fplus::make_shared_ref<fdeep::float_vec>(SOS_token));
   ffdeep::tensor5 target_sequence(bbox_shape, shared_SOS_token);
   auto decoder_outputs = decoder_model.predict({target_sequence, encoder_states.at(0), encoder_states.at(1)});
   std::cout << "h_dec: "<< fdeep::show_tensor5(decoder_outputs.at(1)) << std::endl;
   std::cout << "c_dec: "<< fdeep::show_tensor5(decoder_outputs.at(2)) << std::endl;
   std::cout << "Predicted next bounding box!" << fdeep::show_tensor5(decoder_outputs.at(0)) << std::endl;
}

One thing I noticed when analysing the test results was the values and their scales, for both encoder_model.h5 and fdeep_encoder_model.json. Specifically, the resulting encoder_states are exactly the same, and obviously having its internal values ranging at the same scale. This case, h_enc values are ranging from [-1, 1] and c_enc values from [-10, 10] for both Python and Frugally-Deep.

However, the same thing does not happen when computing the decoder_outputs = [prediction, h_dec, c_dec]. For this case, h_dec values are different between Keras/Python and Frugally/C++, but their scale for both implementations are the same, [-1,1] this case.

The worst case is with the c_dec values! This case, both c_dec values and their scale differ between Keras/Python and Frugally/C++ implementations. In Keras, c_enc values’ scale is [-11, 11] (I really do not why…), but in frugally-deep, the c_enc values’ scale is ranging between [-1, 1].

I created a small collaboratory notebook where I can execute the Keras/Python minimal example script. Please feel free to use it if it is needed 😃

Maybe, something inside the frugally-deep decoder’s prediction routine is transforming the decoder’s hidden state values, and/or their scale, and therefore, the fdeep_decoder_model is giving the wrong bounding box predictions and wrong next decoder states.

I hope this information and files are enough for you to debug and digging deeper into what is the issue. Please feel free to ask for more information (or files you may need for debugging purposes). I will be waiting for your insights. Thank you very much @Dobiasd.