algorithmic-efficiency: imagenet resnet pytorch slow steps/second
The imagenet resent pytorch workload seems to be running at least twice as slow as the jax workload
Description
The graph below shows the jax and pytorch run. They both ran for 31h but the pytorch run only completed about 100K steps.

Provide more details on the bug itself On kasimbeg-4 ran:
torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \
--framework=pytorch \
--data_dir=~/local_data/imagenet/pytorch \
--imagenet_v2_data_dir=~ \
--experiment_dir=~ \
--experiment_name=target_setting \
--workload=imagenet_resnet \
--submission_path=reference_algorithms/target_setting_algorithms/pytorch_nesterov.py \
--tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json > results_imagenet_resnet_pytorch.txt 2>&1
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 21 (21 by maintainers)
With the recent change @mikerabbat suggested, I was able to bridge this gap (on a 0.5h run):
PyTorch:
Jax:
It seems like removing the sync batch norm had the largest effect. I can submit the PR or wait for the official PR.
I have not used it before, but I can implement it after fixing the remaining ViT issues (which I think is almost almost done).
Thanks for looking into this! Let me know if I can help with anything regarding investigating the speed difference.