openfold: training speed is about 2x slower than JAX trainable version (Uni-Fold)

device: 1 A100 with 40GB memory cuda: 11.3 Compared with https://github.com/dptech-corp/Uni-Fold, using model_2 setting, and the same data (only use one sample, and use DummyDataLoader in openfold).

And I follow this issue, https://github.com/aqlaboratory/openfold/issues/19, disabled clear_cache_between_blocks and deepspeed for cpu offload. The commit I used is https://github.com/aqlaboratory/openfold/commit/c4d9f57f9005f3e9e0325eff97b8232e328b4813

speed per example:

FP32 FP16
openfold 24.5 s 17 s
Uni-Fold 13.25 s 8.9 s

Is that expected? any tricks that I can get further speed-up?

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Comments: 44 (8 by maintainers)

Most upvoted comments

BTW @guolinke the recycling number bug is now fixed. The fix requires a little bit of extra data processing, and so it comes with a performance penalty of about half a second. I’m trying to think of ways to improve it.

@lhatsk would you mind moving this bfloat16 stuff into a new issue?

The mmcif cache isn’t required, but the template mmCIFs are. I’ll send those over now.

Sent.

Our A100 results were obtained using the following:

CUDA Driver 465.19.01 CUDA 11.3 Update 1 (11.3.1.005) cuBLAS 11.5.1.109 (part of CUDA 11.3 U1) CUDNN 8.2.1.32 NCCL 2.9.9 PyTorch 1.9.0a0+c3d40fd

and with cache clearing disabled (but using the real dataloader).