keras: Multiprocessing: Failed to get device properties

I would like to use a keras model in a multiprocessing setup. The model is used in a generator, which produces data to train another model.

As long as I don’t use multiprocessing, everything works fine. But with multiprocessing, I get the following error:

E tensorflow/core/grappler/clusters/utils.cc:81] Failed to get device properties, error code: 3

I searched how to use Keras in a multithreaded context and found this: https://github.com/keras-team/keras/issues/5640

Apparently, I need to call _make_predict_function and get the tensorflow graph. I added this to my code:

before any training:

q_approximator = create_model()
q_approximator_fixed = create_model()

q_approximator._make_predict_function()
q_approximator_fixed._make_predict_function()

# only this one will be trained
q_approximator.compile(RMSprop(LEARNING_RATE, rho=RHO, epsilon=EPSILON), loss=huber_loss)

graph = tf.get_default_graph()
#graph = K.get_session().graph         # this way also doesn't work

inside the generator:

with graph.as_default():
               q_values = q_approximator_fixed.predict([state.reshape(1, *INPUT_SHAPE),
                                                        np.ones((1, NUM_ACTIONS))])

and finally, the training setup:

q_approximator.fit_generator(interaction_generator(q_approximator_fixed,
                                                           replay_memory,
                                                           exploration,
                                                           interaction_counter,
                                                           interaction_lock),
                                     epochs=10, steps_per_epoch=BATCH_SIZE * 1000,
                                     use_multiprocessing=True,
                                     workers=1)

With just 1 worker and no multiprocessing it works fine. Multiple workers and no multiprocessing, also fine. But a single worker and multiprocessing makes the program crash with the above error message.

How can I use a keras model in a multiprocessing context ?

About this issue

  • Original URL
  • State: closed
  • Created 6 years ago
  • Comments: 15

Most upvoted comments

Closing as this is resolved

Try reinstalling your graphics driver

I have the same issue. Tried using trian_on_batch instead of fit and it still did not work. I’m not why the issue was closed.