keras: Multiprocessing: Failed to get device properties
I would like to use a keras model in a multiprocessing setup. The model is used in a generator, which produces data to train another model.
As long as I don’t use multiprocessing, everything works fine. But with multiprocessing, I get the following error:
E tensorflow/core/grappler/clusters/utils.cc:81] Failed to get device properties, error code: 3
I searched how to use Keras in a multithreaded context and found this: https://github.com/keras-team/keras/issues/5640
Apparently, I need to call _make_predict_function
and get the tensorflow graph.
I added this to my code:
before any training:
q_approximator = create_model()
q_approximator_fixed = create_model()
q_approximator._make_predict_function()
q_approximator_fixed._make_predict_function()
# only this one will be trained
q_approximator.compile(RMSprop(LEARNING_RATE, rho=RHO, epsilon=EPSILON), loss=huber_loss)
graph = tf.get_default_graph()
#graph = K.get_session().graph # this way also doesn't work
inside the generator:
with graph.as_default():
q_values = q_approximator_fixed.predict([state.reshape(1, *INPUT_SHAPE),
np.ones((1, NUM_ACTIONS))])
and finally, the training setup:
q_approximator.fit_generator(interaction_generator(q_approximator_fixed,
replay_memory,
exploration,
interaction_counter,
interaction_lock),
epochs=10, steps_per_epoch=BATCH_SIZE * 1000,
use_multiprocessing=True,
workers=1)
With just 1 worker and no multiprocessing it works fine. Multiple workers and no multiprocessing, also fine. But a single worker and multiprocessing makes the program crash with the above error message.
How can I use a keras model in a multiprocessing context ?
About this issue
- Original URL
- State: closed
- Created 6 years ago
- Comments: 15
Closing as this is resolved
Try reinstalling your graphics driver
I have the same issue. Tried using
trian_on_batch
instead offit
and it still did not work. I’m not why the issue was closed.