torchgeo: Trainers: num_workers > 0 results in pickling error on macOS/Windows
i have a problem running torchgeo on eurosat MS tif images i wrote a simple code :
eurosat = EuroSAT(euro_root, split="train", download=False)
# sampler = RandomGeoSampler(eurosat, size=64, length=10000)
# dataloader = DataLoader(eurosat, batch_size=128, sampler=sampler, collate_fn=stack_samples)
dataloader = DataLoader(eurosat, batch_size=128, collate_fn=stack_samples)
num_classes = 10
channels = 13
num_workers = 4
batch_size = 4
backbone = "resnet50"
weights = "imagenet"
lr = 0.01
lr_schedule_patience = 5
epochs = 50
datamodule = EuroSATDataModule(
root_dir=euro_root,
batch_size=batch_size,
num_workers=num_workers,
)
task = ClassificationTask(
classification_model=backbone,
weights=weights,
num_classes=num_classes,
in_channels=channels,
loss="ce",
learning_rate=lr,
learning_rate_schedule_patience=lr_schedule_patience
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="val_loss",
save_top_k=1,
save_last=True,
)
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=10,
)
# Train
trainer = pl.Trainer(
callbacks=[checkpoint_callback, early_stopping_callback],
max_epochs=epochs
)
trainer.fit(model=task, datamodule=datamodule)
but i get this error on fit line :
.........
_pickle.PicklingError: Can't pickle <class 'nn.BatchNorm2d.BatchNorm2d'>: import of module 'nn.BatchNorm2d' failed
..........
OSError: [Errno 22] Invalid argument: 'C:\\Users\\mohsc\\PycharmProjects\\pythonProject\\<input>'
am i missing a preparing step or something ? help would be appreciated, by the way tutorials are still lacking essential sample codes and documentations like this issue on MS images. thanks
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 33 (7 by maintainers)
i will try on linux and report back
This is likely a multiprocessing issue. The reason @calebrob6 can’t reproduce this is because multiprocessing uses a different start method on macOS/Windows vs. Linux: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
If I’m correct, I should be able to reproduce this on macOS. Let me give it a shot.
Bit obvious, but this erorr now also happens in the trainers tutorial for folks running MacOS (and probably also windows) .
This seems to work:
Although there are warnings:
i can confirm now , it runs ok on linux
The thing that’s odd to me is that multiprocessing (and therefore pickling) only happens within the data loader, but there is no batch norm in the dataset/data module. It’s almost like it’s trying to pickle the ResNet inside the dataset for some reason…