datasets: Slow dataloading with big datasets issue persists

Hi,

I reported too slow data fetching when data is large(#2210) a couple of weeks ago, and @lhoestq referred me to the fix (#2122). However, the problem seems to persist. Here is the profiled results:

  1. Running with 60GB
Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  517.96         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
model_backward                     	|  0.26144        	|100            	|  26.144         	|  5.0475         	|
model_forward                      	|  0.11123        	|100            	|  11.123         	|  2.1474         	|
get_train_batch                    	|  0.097121       	|100            	|  9.7121         	|  1.8751         	|
  1. Running with 600GB, datasets==1.6.0
Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  4563.2         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
get_train_batch                    	|  5.1279         	|100            	|  512.79         	|  11.237         	|
model_backward                     	|  4.8394         	|100            	|  483.94         	|  10.605         	|
model_forward                      	|  0.12162        	|100            	|  12.162         	|  0.26652        	|

I see that get_train_batch lags when data is large. Could this be related to different issues? I would be happy to provide necessary information to investigate.

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 9
  • Comments: 70 (29 by maintainers)

Commits related to this issue

Most upvoted comments

This has been a very interesting discussion to read. Are there any updates on it? I take it that the best option we have now is to shard our data into multiple datasets and concatenate them as shown above by @hwijeen.

If this solution proves to help, we can add an arrow files sharding for all big datasets directly integrated in load_dataset.

I’m hoping that by using the huggingface Dataset, the data loader will just index into the pyarrow table and the dataset won’t be loaded in full in each process (but we have to pay the cost of the load_data in each process presumably so that the data loader can index into the table on that process)?

Yes your intuition is right 😃

Unfortunately no. Thanks for running the benchmark though, it shows that you machine does a lot of read operations. This is not expected: in other machines it does almost no read operations which enables a very fast loading.

I did some tests on google colab and have the same issue. The first time the dataset arrow file is memory mapped takes always a lot of time (time seems linear with respect to the dataset size). Reloading the dataset is then instantaneous since the arrow file has already been memory mapped.

I also tried using the Arrow IPC file format (see #1933) instead of the current streaming format that we use but it didn’t help.

Memory mapping is handled by the OS and depends on the disk you’re using, so I’m not sure we can do much about it. I’ll continue to investigate anyway, because I still don’t know why in some cases it would go through the entire file (high Blocks read as in your tests) and in other cases it would do almost no reading.

I managed to speed up the loading time (on the Lustre file system) by mmapping the arrow shards in parallel (python preload_mmap.py see the script below) and relying on the OS to cache the mmap.

Here are some results:

  1. Without caching (on a fresh node):
# Sequentially using datasets.load_from_disk
 Loading dataset_name: 1865.329 seconds
  1. After calling python preload_mmap.py once to cache (the first time, it takes around 90 seconds with 16 processes).
# return only the lenght of the table from each worker
# python preload_mmap.py p # use processes and return only len of table from workers
Loading dataset_name using num of (returning len) processes=16: 42.837 seconds
# python preload_mmap.py t  # use threads and return only len of table from workers
Loading dataset_name using num of (returning len) threads=16: 105.167 seconds

# return the whole table from each worker
# python preload_mmap.py p table # use processes and return tables from workers
Loading dataset_name using num of (returning table) processes=16: 367.917 seconds
# python preload_mmap.py t  table # use threads and return tables from workers
Loading dataset_name using num of (returning table) threads=16: 260.434 seconds

# Sequentially using datasets.load_from_disk (the dataset has only one split)
Loading dataset_name: 397.046 seconds

It seems that preloading the files in processes (without returning the table) speeds up subsequent load_from_disk calls. However, the communication time to return the tables for concatenation is high (I am not sure how they are pickled).

Threads are slower to mmap the table but faster to communicate. If this works on other file systems, it may be worth it to have the option to load the shards in parallel here.

# preload_mmap.py
import datasets
import os
from datasets.table import MemoryMappedTable, concat_tables
import glob
import logging
from time import perf_counter
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import sys
import concurrent
import functools

logger = datasets.logging.get_logger(__name__)


datasets.logging.set_verbosity_info()



class catchtime:
    # context to measure loading time: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time
    def __init__(self, debug_print="Time", logger=logger):
        self.debug_print = debug_print
        self.logger = logger

    def __enter__(self):
        self.start = perf_counter()
        return self

    def __exit__(self, type, value, traceback):
        self.time = perf_counter() - self.start
        readout = f"{self.debug_print}: {self.time:.3f} seconds"
        self.logger.info(readout)


def load_file(f, return_len=False):
    with catchtime(f"Loading {f}", logger=logger):
        ds = MemoryMappedTable.from_file(f)
    if return_len:  # process pool is slow to serialize
        return len(ds)
    return ds


def load_files(
    files, debug_name="dataset_name", num_proc=16, use_threads=False, return_len=False
):
    if use_threads:
        pool_cls = concurrent.futures.ThreadPoolExecutor
        pool_kwargs = {"max_workers": num_proc}
        debug_desc = "threads"
    else:
        pool_cls = Pool
        pool_kwargs = {"processes": num_proc}
        debug_desc = "processes"
    if return_len:
        debug_desc = "(returning table) " + debug_desc
    else:
        debug_desc = "(returning len) " + debug_desc
    with catchtime(
        f"Loading {debug_name} using num of {debug_desc}={num_proc}", logger=logger
    ):
        with pool_cls(**pool_kwargs) as pool:
            result = list(
                pool.map(functools.partial(load_file, return_len=return_len), files)
            )
    return result


def main(use_threads, return_len):
    datasets.logging.set_verbosity_info()
    logging.basicConfig(level=logging.DEBUG, format="%(message)s")
    logger.info("Starting")

    jc = "datset_name"
    local = "dataset_path"
    split = "train"
    files = glob.glob(os.path.join(local, jc, split, "*.arrow"))
    files = sorted(files)

    ds = load_files(files, jc, use_threads=use_threads, return_len=return_len)
    if not return_len:
        with catchtime(f"concat_ tables"):
            ds = concat_tables(ds)

    logger.info("done")


if __name__ == "__main__":
    use_threads = False
    return_len = True

    print(
        "Usage: \n Threads: python preload_mmap.py t \n Threads and concatenate datasets: python preload_mmap.py t c"
        "\n processes: python preload_mmap.py p \n processes and concatenate datasets: python preload_mmap.py p c "
    )
    if len(sys.argv) > 1 and sys.argv[1].startswith("t"):
        use_threads = True
    if len(sys.argv) > 2:
        return_len = False

    main(use_threads, return_len)

I’m very happying that I use this way to acclereate the time of the data load from 15 minutes to 45 seconds. And my dataset size is on the TB scale.

My file system is virtio-fs. I can’t know its real file system because my code is in the virtual machine and i can’t access the host machine. I guess it should be a distributed file system similar to Luster. I don’t know why it is so slow and why multi threading can acclerate it. But what i know is it really acclerates the load time and it is vital for me. Thanks for your code.

I managed to speed up the loading time (on the Lustre file system) by mmapping the arrow shards in parallel (python preload_mmap.py see the script below) and relying on the OS to cache the mmap.

Here are some results:

  1. Without caching (on a fresh node):
# Sequentially using datasets.load_from_disk
 Loading dataset_name: 1865.329 seconds
  1. After calling python preload_mmap.py once to cache (the first time, it takes around 90 seconds with 16 processes).
# return only the lenght of the table from each worker
# python preload_mmap.py p # use processes and return only len of table from workers
Loading dataset_name using num of (returning len) processes=16: 42.837 seconds
# python preload_mmap.py t  # use threads and return only len of table from workers
Loading dataset_name using num of (returning len) threads=16: 105.167 seconds

# return the whole table from each worker
# python preload_mmap.py p table # use processes and return tables from workers
Loading dataset_name using num of (returning table) processes=16: 367.917 seconds
# python preload_mmap.py t  table # use threads and return tables from workers
Loading dataset_name using num of (returning table) threads=16: 260.434 seconds

# Sequentially using datasets.load_from_disk (the dataset has only one split)
Loading dataset_name: 397.046 seconds

It seems that preloading the files in processes (without returning the table) speeds up subsequent load_from_disk calls. However, the communication time to return the tables for concatenation is high (I am not sure how they are pickled).

Threads are slower to mmap the table but faster to communicate. If this works on other file systems, it may be worth it to have the option to load the shards in parallel here.

# preload_mmap.py
import datasets
import os
from datasets.table import MemoryMappedTable, concat_tables
import glob
import logging
from time import perf_counter
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import sys
import concurrent
import functools

logger = datasets.logging.get_logger(__name__)


datasets.logging.set_verbosity_info()



class catchtime:
    # context to measure loading time: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time
    def __init__(self, debug_print="Time", logger=logger):
        self.debug_print = debug_print
        self.logger = logger

    def __enter__(self):
        self.start = perf_counter()
        return self

    def __exit__(self, type, value, traceback):
        self.time = perf_counter() - self.start
        readout = f"{self.debug_print}: {self.time:.3f} seconds"
        self.logger.info(readout)


def load_file(f, return_len=False):
    with catchtime(f"Loading {f}", logger=logger):
        ds = MemoryMappedTable.from_file(f)
    if return_len:  # process pool is slow to serialize
        return len(ds)
    return ds


def load_files(
    files, debug_name="dataset_name", num_proc=16, use_threads=False, return_len=False
):
    if use_threads:
        pool_cls = concurrent.futures.ThreadPoolExecutor
        pool_kwargs = {"max_workers": num_proc}
        debug_desc = "threads"
    else:
        pool_cls = Pool
        pool_kwargs = {"processes": num_proc}
        debug_desc = "processes"
    if return_len:
        debug_desc = "(returning table) " + debug_desc
    else:
        debug_desc = "(returning len) " + debug_desc
    with catchtime(
        f"Loading {debug_name} using num of {debug_desc}={num_proc}", logger=logger
    ):
        with pool_cls(**pool_kwargs) as pool:
            result = list(
                pool.map(functools.partial(load_file, return_len=return_len), files)
            )
    return result


def main(use_threads, return_len):
    datasets.logging.set_verbosity_info()
    logging.basicConfig(level=logging.DEBUG, format="%(message)s")
    logger.info("Starting")

    jc = "datset_name"
    local = "dataset_path"
    split = "train"
    files = glob.glob(os.path.join(local, jc, split, "*.arrow"))
    files = sorted(files)

    ds = load_files(files, jc, use_threads=use_threads, return_len=return_len)
    if not return_len:
        with catchtime(f"concat_ tables"):
            ds = concat_tables(ds)

    logger.info("done")


if __name__ == "__main__":
    use_threads = False
    return_len = True

    print(
        "Usage: \n Threads: python preload_mmap.py t \n Threads and concatenate datasets: python preload_mmap.py t c"
        "\n processes: python preload_mmap.py p \n processes and concatenate datasets: python preload_mmap.py p c "
    )
    if len(sys.argv) > 1 and sys.argv[1].startswith("t"):
        use_threads = True
    if len(sys.argv) > 2:
        return_len = False

    main(use_threads, return_len)

I’m facing the same issue when loading a 900GB dataset (stored via save_to_disk): load_from_disk(path_to_dir) takes 1.5 hours and htop consistently shows high IO rates > 120 M/s.

Nice to see this method validated on multiple setups !

Would be cool to integrate multithreading when memory mapping the Arrow files then I think this can be added here (for load_dataset):

https://github.com/huggingface/datasets/blob/796a47e388a5c5711a95bd649648608c18219ac5/src/datasets/arrow_reader.py#L199-L201

and here (for load_from_disk):

https://github.com/huggingface/datasets/blob/796a47e388a5c5711a95bd649648608c18219ac5/src/datasets/arrow_dataset.py#L1701-L1704

I can take some time next week to do it, but feel free to open a PR if you want to give it a try

Can you guys reproduce the problem? Testing just load_dataset is enough, I think. In my example, load_dataset read 25% of the cached data (9/40), which is also abnormal. I’m not sure what the data structure of metadata is, but I’m expecting it to be around several KB (if it is a length, then 8B per shard?).

Reproducing these issues is not easy on our side, given they depend on the setup.

For load_datastet it would be nice to be able to control the size of the batches written on disk, feel free to open an issue if it’s something you’d like to see, and we’ll discuss there how to do it.

I also faced the slow mmap when loading a large dataset (~2TB) using datasets.load_from_disk. The dataset was saved with around ~1100 shards of 2GB. I believe it seems to depend on the file system. I have access to two clusters:

That’s helpful information, thanks ! It seems like Lustre doesn’t read at full speed with the memory mapping in datasets

Is there any recommendations for stripe_size or stripe-count of luster or the shards size to improve the loading speed?

I would try increasing the stripe size in case the memory mapping does too much unecessary readahead with the default value

takes ~1h 30min, although I already cached it.

An alternative is to load the dataset as iterable but this is not implemented yet, see #5481

Is there any way I can instantiate the (also already cached) tokenized dataset directly without having to wait until raw_datasets is instantiated first?

If you want to skip that step, next time I’d recommend you to save the dataset somewhere after tokenization (e.g. using .save_to_disk()) and reload it from there instead of relying on the cache.

Though you could look for the cached arrow files in your cache and reload the data from there if you’re adventurous. You can use Dataset.from_file to reload a file, and then concatenate_datasets to concatenate all the chunks.

Cool !

but I’m curious why was the dataset so slow why is this happening bcs I do some preprocessing with the Dataset.map() function and it works quite fast (for 34GB of text) it takes around 1,5 hour to process the data (tokenization, chunk merging. etc.) with 6 workers, but then once I’m using this preprocessed dataset the iteration is significantly slower like 10times e.g., 2.2it/s vs 5s/batch.

When you process an unshuffled dataset with map, you iterate over contiguous chunks of data, which is very fast. You get the best speed when you have an iterable dataset as well, when it’s based on shards of contiguous data.

This is fast because internally Arrow simply iterates over the record batches.

On the other hand, if you use a map-style dataset in PyTorch, then PyTorch samples uniformly from the files on your disk. This is slower for your disk, and also requires an extra step to get the location of the examples from an index.

Please also make sure to use the latest version of pyarrow to benefit from the best speed, or at least pyarrow 8.0.0 😃

By any chance, do we have a better understanding of what’s happening?

I am encoutering a similar problem: I have an arrow file produced by HF datasets (shard+save_to_disk) and I am trying to load this dataset/arrow file with datasets.load_from_disk(the_dataset_folder). I noticed that the first time I load it, it would be significantly slower than the subsequent times. Two days later, I will retry loading it, and it will be slow again…

After diving a little bit, the gap happens in the _memory_mapped_arrow_table_from_file function, and in particular in the call to RecordBatchStreamReader.read_all:https://github.com/huggingface/datasets/blob/158917e24128afbbe0f03ce36ea8cd9f850ea853/src/datasets/table.py#L51 read_all is slow the first time (probably for some operations that are only happening once, and are cached for a few hours?), but not the subsequent times.

>>> def _memory_mapped_arrow_table_from_file(filename):
...     memory_mapped_stream = pa.memory_map(filename)
...     opened_stream = pa.ipc.open_stream(memory_mapped_stream)
...     start_time = time.time()
...     _ = opened_stream.read_all()
...     print(f"{time.time()-start_time}")
...
>>> filename_slow = "train/00248-00249/cache-3d25861de64b93b5.arrow"
>>> _memory_mapped_arrow_table_from_file(filename_slow) # First time
0.24040865898132324
>>> _memory_mapped_arrow_table_from_file(filename_slow) # subsequent times
0.0006551742553710938
>>> _memory_mapped_arrow_table_from_file(filename_slow)
0.0006804466247558594
>>> _memory_mapped_arrow_table_from_file(filename_slow)
0.0009818077087402344

My setup:

  • datasets version: 2.3.3.dev0
  • Platform: Linux-4.18.0-305.57.1.el8_4.x86_64-x86_64-with-glibc2.17
  • Python version: 3.8.13
  • PyArrow version: 9.0.0
  • Pandas version: 1.4.2

I realize this might be an Apache Arrow question so I ask them, but wanted to leave a message here too.

I wasn’t able to reproduce this on a toy dataset of around 300GB:

import datasets as ds

s = ds.load_dataset("squad", split="train")
s4000 = ds.concatenate_datasets([s] * 4000)
print(ds.utils.size_str(s4000.data.nbytes))  # '295.48 GiB'

s4000.save_to_disk("tmp/squad_4000")
import psutil
import time
from datasets import load_from_disk

disk = "disk0"  # You may have to change your disk here
iocnt1 = psutil.disk_io_counters(perdisk=True)[disk]
time1 = time.time()

s4000_reloaded = load_from_disk("tmp/squad_4000")

time2 = time.time()
iocnt2 = psutil.disk_io_counters(perdisk=True)[disk]

print(f"Blocks read {iocnt2.read_count - iocnt1.read_count}")  # Blocks read 18
print(f"Elapsed time: {time2 - time1:.02f}s")  # Elapsed time: 14.60s

Could you run this on your side and tell me if how much time it takes ? Please run this when your machine is idle so that other processes don’t interfere.

I got these results on my macbook pro on datasets 1.6.2