ignite: This is breaking my IterableDataset use case: is it a bug or intended?

Looking at the following line of code: https://github.com/pytorch/ignite/blob/523798c13a9465f4c95565cb54ae3c5aa0786ac6/ignite/engine/engine.py#L742 It breaks out of the loop without resetting the dataloader (means without calling self.set_data(self.state.dataloader)), it’s breaking all my use cases where my own implementation of dataloader only raises StopIteration once. Is this a bug or intended?

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 15

Most upvoted comments

I would like to but there are some proprietary code I cannot share. I’ll look into it at my end. Also, let me know if you find something! Thanks a lot for helping.

@vfdev-5 I haven’t had time to work on this. Feel free to close the issue for now.

@snie2012 here is a working code on DDP with “gloo” backend on 4 procs, 1 node:

Code
import os

import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import DataLoader, IterableDataset

from ignite.engine import Engine


class MyIterableDataset(IterableDataset):
    def __init__(self, offset):
        super(MyIterableDataset).__init__()
        self.start = offset
        self.end = offset + 7  # unknown for user

    def __iter__(self):
        return iter(range(self.start, self.end))

if __name__ == "__main__":
    import ignite
    print(ignite.__version__)

    device = "cpu"    
    local_rank = os.environ["LOCAL_RANK"]
    
    dist.init_process_group("gloo", init_method="env://")
    
    rank = dist.get_rank()
    ws = dist.get_world_size()
    
    time.sleep(rank * 0.01)
    print("Dist Info: ", rank, ws)
    dataset = MyIterableDataset(offset=rank * ws * 10)    
    data_loader = DataLoader(dataset, num_workers=0, batch_size=4)
    
    model = nn.Linear(10, 10).to(device)
    model = nn.parallel.DistributedDataParallel(model)
    
    opt = optim.SGD(model.parameters(), lr=0.01)

    def foo(e, b):
                
        opt.zero_grad()
        x = torch.rand(10).to(device)
        y_pred = model(x)
        loss = y_pred.sum()
        loss.backward()
        opt.step()
        
        # for printing purposes
        time.sleep(rank * 0.01)
        print("{}:: {}-{}: {}".format(rank, e.state.epoch, e.state.iteration, b))

    engine = Engine(foo)
    engine.run(data_loader, epoch_length=None, max_epochs=5)

    dist.destroy_process_group()    
Output:
$ python -m torch.distributed.launch --nproc_per_node=4 --use_env issue-1094-iterable-dataset-DDP.py 

*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
0.4.0.dev20200603
0.4.0.dev20200603
0.4.0.dev20200603
0.4.0.dev20200603
Dist Info:  0 4
Dist Info:  1 4
Dist Info:  2 4
Dist Info:  3 4
0:: 1-1: tensor([0, 1, 2, 3])
1:: 1-1: tensor([40, 41, 42, 43])
2:: 1-1: tensor([80, 81, 82, 83])
3:: 1-1: tensor([120, 121, 122, 123])
0:: 1-2: tensor([4, 5, 6])
1:: 1-2: tensor([44, 45, 46])
2:: 1-2: tensor([84, 85, 86])
3:: 1-2: tensor([124, 125, 126])
0:: 2-3: tensor([0, 1, 2, 3])
1:: 2-3: tensor([40, 41, 42, 43])
2:: 2-3: tensor([80, 81, 82, 83])
3:: 2-3: tensor([120, 121, 122, 123])
0:: 2-4: tensor([4, 5, 6])
1:: 2-4: tensor([44, 45, 46])
2:: 2-4: tensor([84, 85, 86])
3:: 2-4: tensor([124, 125, 126])
0:: 3-5: tensor([0, 1, 2, 3])
1:: 3-5: tensor([40, 41, 42, 43])
2:: 3-5: tensor([80, 81, 82, 83])
3:: 3-5: tensor([120, 121, 122, 123])
0:: 3-6: tensor([4, 5, 6])
1:: 3-6: tensor([44, 45, 46])
2:: 3-6: tensor([84, 85, 86])
3:: 3-6: tensor([124, 125, 126])
0:: 4-7: tensor([0, 1, 2, 3])
1:: 4-7: tensor([40, 41, 42, 43])
2:: 4-7: tensor([80, 81, 82, 83])
3:: 4-7: tensor([120, 121, 122, 123])
0:: 4-8: tensor([4, 5, 6])
1:: 4-8: tensor([44, 45, 46])
2:: 4-8: tensor([84, 85, 86])
3:: 4-8: tensor([124, 125, 126])
0:: 5-9: tensor([0, 1, 2, 3])
1:: 5-9: tensor([40, 41, 42, 43])
2:: 5-9: tensor([80, 81, 82, 83])
3:: 5-9: tensor([120, 121, 122, 123])
0:: 5-10: tensor([4, 5, 6])
1:: 5-10: tensor([44, 45, 46])
2:: 5-10: tensor([84, 85, 86])
3:: 5-10: tensor([124, 125, 126])