keras: How to parallelize fit_generator? (PicklingError)

I tried several ways but cannot get parallelization of sample / data generation to work successfully. Below is a gist. Am I doing something wrong, or is there a bug?

https://gist.github.com/stmax82/283ef735c8e2601ef841de8b37243ee1

I suppose that my fourth try would be the correct one - but when I set pickle_safe=True, I get the error:

PicklingError: Can't pickle <function generator_queue.<locals>.data_generator_task at 0x000000001B042EA0>: attribute lookup data_generator_task on keras.engine.training failed

About this issue

  • Original URL
  • State: closed
  • Created 8 years ago
  • Reactions: 1
  • Comments: 15 (5 by maintainers)

Most upvoted comments

TL;DR: Do NOT set pickle_safe=True. You’re bound for trouble.

Extensive explanation: I’ve been investigating the way workers are used in the _generator function sets (see issues #5071, #6745 ). So far, my conclusion is that the way pickle_safe=True is implemented is, at best, flawed beyond recovery and should be avoided completely. Here’s what I’ve gathered:

  1. Generators are not picklable, and it seems they won’t be any time soon.
  2. If you’re a Windows user: you’re actually luckier because the code won’t run at all. Since Windows doesn’t have fork(), the multiprocessing.Process is made -simplifying heavily- by creating a whole new application process, pickling the data the new process needs and sending it over a pipe, together with other data required to simulate a fork() (See this article for a more detailed and precise explanation of why that is necessary)
  3. If you’re a Linux user, your problem is sneakier. The code will run just fine, because thanks to fork()'s magic, there isn’t the need to pickle and unpickle the generator. However, An identical independent clone of the original generator is created and used independently in each child process! Which means that your data is being enqueued one time per each worker whenever they call next() and obtain a new batch.

Take this example:

import multiprocessing
import time

if __name__ == '__main__':
    def my_generator():
        i = 0
        while True:
            i += 1
            yield i
    gen = my_generator()
    queue = multiprocessing.Queue()
    def target_f():
        import os
        import sys
        import time
        sys.stdout = open(str(os.getpid()) + ".out", "w")
        sys.stderr = open(str(os.getpid()) + ".err", "w")
        time.sleep(1)
        queue.put(next(gen))
    p1 = multiprocessing.Process(target=target_f)
    p2 = multiprocessing.Process(target=target_f)
    p1.start()
    time.sleep(0.3)
    p2.start()
    p1.join()
    p2.join()
    print(queue.qsize())
    print(queue.get())
    print(queue.get())

The output of this code under Linux is:

2
1
1

while under Windows it breaks with the error:

TypeError: can't pickle generator objects

Considering all, I believe the pickle_safe argument to be misleading, wrong and potentially harmful and should be IMO removed altogether. Until then, stick to pickle_safe=False in your code to avoid headaches.