DALI: Iterator doesn't give data out after the first epoch when using GenericIterator and ExternalSource.
Hey, I’m facing an issue with my pipeline that currently takes values from the ExternalSource operator. My issue is very similar to that seen in issue #1134 here. I see that the error has been fixed in #1136 where if the data is exhausted, it is fixed. I have a DALIGenericIterator that encloses the pipeline, with auto_reset set to True and a last_batch_policy. I have tried adding a StopIteration call within iter_setup as mentioned in those issues. Apologies that some parts look like pseudo code, I’ve tried to condense the main parts.
class MultiTaskPipeline(Pipeline):
def __init__(self, file_list, std_train_config, batch_size, num_threads, device_id, external_data):
super(MultiTaskPipeline, self).__init__(batch_size, num_threads, device_id)
self.input = ops.FileReader(file_list= file_list, random_shuffle = False)
self.crops = ops.ExternalSource()
self.identifier = ops.ExternalSource()
self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
self.res = ops.Resize(resize_x= self.image_size , resize_y=self.image_size, interp_type=types.INTERP_TRIANGULAR, device="gpu")
self.external_data = external_data
self.iterator = iter(external_data)
####### other init functions of DALI operators
def define_graph(self):
jpegs, labels = self.input(name='reader')
self.crop_dims = self.crops()
self.identifier_batch = self.identifier()
anchor = fn.reshape(fn.slice(self.crop_dims, 0, 2, axes=[1]), shape=[-1])
shape = fn.reshape(fn.slice(self.crop_dims, 2, 2, axes = [1]), shape= [-1])
anchor = self.cast(anchor)
shape = self.cast(shape)
jpegs = ops.ImageDecoder(device = 'mixed', output_type = types.RGB)(jpegs)
jpegs = self.res(jpegs)
#other transformations applied
jpegs = self.swapaxes(jpegs)
labels = labels.gpu()
labels = self.to_int64(labels)
return (jpegs, labels, self.crop_dims, self.identifier_batch)
def iter_setup(self):
try:
crops, identifier = list(next(self.iterator))[0]
self.feed_input(self.crop_dims, crops)
self.feed_input(self.identifier_batch, identifier)
except StopIteration: #It does not even get to this block of the code even when theres no data
self.iterator = iter(self.external_data)
crops, identifer = next(self.iterator)
self.feed_input(self.crop_dims, crops)
self.feed_input(self.identifier_batch, identifier)
raise StopIteration
train_multi_iter = SampleIteratorClass(data, train_batch_size, std_train_config) #just my generatort
val_multi_iter = SampleIteratorClass(val)
train_pipe = MultiTaskPipeline(train_file_list, int(train_batch_size), 4, gpu_id, external_data=train_multi_iter) #the dali pipeline
val_pipe =MultiTaskPipeline(val_file_list, int(train_batch_size), 4, gpu_id, external_data=val_multi_iter)
train_loader = CustomGenericIterator(std_train_config, trainset.dataset_len, pipelines=multi_pipe,
output_map=["images", "labels", "crops", "identifier"], reader_name="reader", auto_reset = True, last_batch_policy = LastBatchPolicy.PARTIAL, dynamic_shape = True)
val_loader = CustomGenericIterator(std_train_config, valset.dataset_len , pipelines=multi_pipe,
output_map=["images", "labels", "crops", "identifier"], reader_name="reader", auto_reset = True, last_batch_policy = LastBatchPolicy.PARTIAL, dynamic_shape = True)
It just seems to suspend after the first epoch. Kindly help me out here. Thanks in advance!
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 17 (9 by maintainers)
@JanuszL That did the trick! Thanks a lot! Couldnt figure out how to properly propegate and throw the exception.