algorithmic-efficiency: imagenet resnet pytorch slow steps/second

The imagenet resent pytorch workload seems to be running at least twice as slow as the jax workload

Description

The graph below shows the jax and pytorch run. They both ran for 31h but the pytorch run only completed about 100K steps. image

Provide more details on the bug itself On kasimbeg-4 ran:

torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \
    --framework=pytorch \
    --data_dir=~/local_data/imagenet/pytorch \
    --imagenet_v2_data_dir=~ \
    --experiment_dir=~ \
    --experiment_name=target_setting \
    --workload=imagenet_resnet \
    --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nesterov.py \
    --tuning_search_space=reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json > results_imagenet_resnet_pytorch.txt 2>&1

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 21 (21 by maintainers)

Most upvoted comments

With the recent change @mikerabbat suggested, I was able to bridge this gap (on a 0.5h run):

PyTorch:

----------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                       |  Mean Duration (s)    |  Std Duration (s)     |  Num Calls            |  Total Time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                        |  -----                |  -----                |  9636                 |  1971.5               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------------------------
|  Train                        |  1969.7               |  0.0                  |  1                    |  1969.7               |  99.905               |
|  Update parameters            |  0.33362              |  0.060603             |  4814                 |  1606.0               |  81.461               |
|  Evaluation                   |  37.14                |  7.4504               |  4                    |  148.56               |  7.5352               |
|  Data selection               |  0.0022307            |  0.04729              |  4814                 |  10.738               |  0.54468              |
|  Initializing model           |  4.3086               |  0.0                  |  1                    |  4.3086               |  0.21854              |
|  Initializing dataset         |  4.1837               |  0.0                  |  1                    |  4.1837               |  0.2122               |
|  Initializing optimizer       |  0.27054              |  0.0                  |  1                    |  0.27054              |  0.013722             |
----------------------------------------------------------------------------------------------------------------------------------------------------------

Jax:

----------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                       |  Mean Duration (s)    |  Std Duration (s)     |  Num Calls            |  Total Time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                        |  -----                |  -----                |  10652                |  1954.4               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------------------------
|  Train                        |  1949.4               |  0.0                  |  1                    |  1949.4               |  99.742               |
|  Update parameters            |  0.3337               |  0.51689              |  5322                 |  1776.0               |  90.867               |
|  Data selection               |  0.0044586            |  0.066115             |  5322                 |  23.729               |  1.2141               |
|  Evaluation                   |  21.125               |  9.5886               |  4                    |  84.499               |  4.3234               |
|  Initializing model           |  10.519               |  0.0                  |  1                    |  10.519               |  0.53823              |
|  Initializing dataset         |  1.0959               |  0.0                  |  1                    |  1.0959               |  0.056073             |
|  Initializing optimizer       |  0.89024              |  0.0                  |  1                    |  0.89024              |  0.04555              |
----------------------------------------------------------------------------------------------------------------------------------------------------------

It seems like removing the sync batch norm had the largest effect. I can submit the PR or wait for the official PR.

I have not used it before, but I can implement it after fixing the remaining ViT issues (which I think is almost almost done).

Thanks for looking into this! Let me know if I can help with anything regarding investigating the speed difference.