distributed: Possible memory leak when using LocalCluster

What happened: Memory usage of code using da.from_array and compute in a for loop grows over time when using a LocalCluster.

What you expected to happen: Memory usage should be approximately stable (subject to the GC).

Minimal Complete Verifiable Example:

import numpy as np
import dask.array as da
from dask.distributed import Client, LocalCluster


def f(x):
    return np.zeros(x.shape, dtype=x.dtype)


def wrapper(x):
    if not isinstance(x, da.Array):
        x = da.from_array(x, chunks=(1, -1, -1))

    return da.blockwise(f, ('nx', 'ny', 'nz'),
                        x, ('nx', 'ny', 'nz')).compute()


if __name__=='__main__':

    cluster = LocalCluster(
        processes=True,
        n_workers=4,
        threads_per_worker=1
    )
    client = Client(cluster)

    nx, ny, nz = 4, 512, 512
    
    for i in range(500):
        x = np.random.randn(nx, ny, nz)
        wrapper(x)

Anything else we need to know?:

The output of mprof run --multiprocess --include-children reproducer.py visualised by mprof plot using a LocalCluster: localcluster

The output of mprof run --multiprocess --include-children reproducer.py visualised by mprof plot using the threads scheduler: threads

The output of mprof run --multiprocess --include-children reproducer.py visualised by mprof plot using a LocalCluster but instantiating x as da.random.standard_normal((nx, ny, nz), chunks=(1, -1, -1)): nofromarray

Note that child 0 should be ignored in the LocalCluster plots as it is an artifact of the profiling procedure. Memory usage is in fact still climbing in the last plot, but it is much slower than in the first case (using da.from_array). My current best guess is that either the Client or the Cluster is somehow maintaining a reference to an object which should otherwise have been GCed.

Environment:

  • Dask version: 2022.2.1
  • Python version: 3.8.10
  • Operating System: Ubuntu 20.04
  • Install method (conda, pip, source): pip

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Comments: 33 (11 by maintainers)

Commits related to this issue

Most upvoted comments

There are currently two PRs (one will be merged) re-working some of the update_who_has flaws, see

These PRs should address ever increasing who_has/has_what sets. There might be still an issue about pending_data_per_worker / data_needed_per_worker but since this issue was filed, this data structured change significantly and it is possible that this is no longer a problem.

I suggest to rerun the reproducers in here once either of the above PRs are merged

I am not sure @zklaus. The “simple” short term solution is to go back to using the 2022.01.1 versions of dask and distributed. I do not think that I could implement a proper fix on my own right now - the surface area is too large and I am too unfamiliar with the innermost workings of the scheduler. It does seem that this issue may be affecting others so hopefully a proper fix is coming.

I think https://github.com/dask/distributed/pull/5653 is the culprit. The following example demonstrates that worker.pending_data_per_worker grows with every compute call.

import numpy as np
import dask.array as da
import dask
from dask.distributed import Client, LocalCluster
from pympler import asizeof


def f(x):
    return np.empty_like(x)


def wrapper(x, n_elem_per_chunk):
    if not isinstance(x, da.Array):
        x = da.from_array(x, chunks=(n_elem_per_chunk))

    return da.blockwise(f, ('x'), x, ('x'))


def check_growth(dask_worker):

    for uth in dask_worker.pending_data_per_worker.values():
        uth_size = asizeof.asizeof(uth) / 1024 ** 2
        print(f"Size of UniqueTaskHeap: {uth_size:.2f} MiB.")


if __name__ == '__main__':

    cluster = LocalCluster(
        processes=True,
        n_workers=1,
        threads_per_worker=1
    )
    client = Client(cluster)

    n_elem_per_chunk = 13107200  # 100MiB of float64.
    n_chunk = 1

    for i in range(10):
        x = np.random.randn(n_elem_per_chunk*n_chunk)

        result = wrapper(x, n_elem_per_chunk)

        dask.compute(result)

        client.run(check_growth)