datasets: Iterating over DataLoader based on HF datasets is stuck forever

Describe the bug

I am using Amazon Sagemaker notebook (Amazon Linux 2) with python 3.10 based Conda environment. I have a dataset in parquet format locally. When I try to iterate over it, the loader is stuck forever. Note that the same code is working for python 3.6 based conda environment seamlessly. What should be my next steps here?

Steps to reproduce the bug

train_dataset = load_dataset(
    "parquet", data_files = {'train': tr_data_path + '*.parquet'}, 
    split = 'train', 
    collate_fn = streaming_data_collate_fn,
    streaming = True
).with_format('torch')

train_dataloader = DataLoader(train_dataset, batch_size = 2, num_workers = 0)

t = time.time()
iter_ = 0
for batch in train_dataloader:
    iter_ += 1
    
    if iter_ == 1000:
        break
        
print (time.time() - t) 

Expected behavior

The snippet should work normally and load the next batch of data.

Environment info

datasets: ‘2.14.0’ pyarrow: ‘12.0.0’ torch: ‘2.0.0’ Python: 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:08:06) [GCC 11.3.0]

!uname -r 5.10.178-162.673.amzn2.x86_64

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 15 (6 by maintainers)

Most upvoted comments

Update: It seems the latency part is more of a multiprocessing issue with torch and some host specific issue, and I had to scourge through relevant pytorch issues, when I stumbled across these threads:

  1. https://github.com/pytorch/pytorch/issues/102494
  2. https://github.com/pytorch/pytorch/issues/102269
  3. https://github.com/pytorch/pytorch/issues/99625

Out of the suggested solutions, the one that worked in my case was:

os.environ['KMP_AFFINITY'] = "disabled"

It is working for now, though I have no clue why, just I hope it does not get stuck when I do actual model training, will update by tomorrow.

collate_fn is applied after the torch formatting step, so I think the only option when working with an IterableDataset is to remove the with_format call and perform the conversion from Python values to PyTorch tensors in collate_fn. The standard Dataset supports with_format("numpy"), which should make this conversion faster.