torchgeo: Trainers: num_workers > 0 results in pickling error on macOS/Windows

i have a problem running torchgeo on eurosat MS tif images i wrote a simple code :

eurosat = EuroSAT(euro_root, split="train", download=False)
# sampler = RandomGeoSampler(eurosat, size=64, length=10000)
# dataloader = DataLoader(eurosat, batch_size=128, sampler=sampler, collate_fn=stack_samples)
dataloader = DataLoader(eurosat, batch_size=128, collate_fn=stack_samples)
num_classes = 10
channels = 13
num_workers = 4
batch_size = 4
backbone = "resnet50"
weights = "imagenet"
lr = 0.01
lr_schedule_patience = 5
epochs = 50
datamodule = EuroSATDataModule(
     root_dir=euro_root,
     batch_size=batch_size,
     num_workers=num_workers,
)
task = ClassificationTask(
    classification_model=backbone,
    weights=weights,
    num_classes=num_classes,
    in_channels=channels,
    loss="ce",
    learning_rate=lr,
    learning_rate_schedule_patience=lr_schedule_patience
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val_loss",
    save_top_k=1,
    save_last=True,
)
early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=10,
)

# Train
trainer = pl.Trainer(
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=epochs
)
trainer.fit(model=task, datamodule=datamodule)

but i get this error on fit line :

.........
_pickle.PicklingError: Can't pickle <class 'nn.BatchNorm2d.BatchNorm2d'>: import of module 'nn.BatchNorm2d' failed
..........
OSError: [Errno 22] Invalid argument: 'C:\\Users\\mohsc\\PycharmProjects\\pythonProject\\<input>'

am i missing a preparing step or something ? help would be appreciated, by the way tutorials are still lacking essential sample codes and documentations like this issue on MS images. thanks

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 33 (7 by maintainers)

Most upvoted comments

This is likely a multiprocessing issue. The reason @calebrob6 can’t reproduce this is because multiprocessing uses a different start method on macOS/Windows vs. Linux: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods

If I’m correct, I should be able to reproduce this on macOS. Let me give it a shot.

i will try on linux and report back

This is likely a multiprocessing issue. The reason @calebrob6 can’t reproduce this is because multiprocessing uses a different start method on macOS/Windows vs. Linux: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods

If I’m correct, I should be able to reproduce this on macOS. Let me give it a shot.

Bit obvious, but this erorr now also happens in the trainers tutorial for folks running MacOS (and probably also windows) .

This seems to work:

trainer = pl.Trainer(
    accelerator="mps",
    devices=1,
    callbacks=[checkpoint_callback, early_stopping_callback],
    logger=[csv_logger],
    default_root_dir=experiment_dir,
    min_epochs=1,
    max_epochs=10,
    fast_dev_run=in_tests
)

Although there are warnings:

/Users/calkoen/miniconda3/envs/torchgeo/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

i can confirm now , it runs ok on linux

The thing that’s odd to me is that multiprocessing (and therefore pickling) only happens within the data loader, but there is no batch norm in the dataset/data module. It’s almost like it’s trying to pickle the ResNet inside the dataset for some reason…