netket: `QGTOnTheFly` is weirdly slow on a GPU

@chrisrothUT has spotted a while ago that QGTOnTheFly is weirdly slow. To probe this, I’ve run two benchmarks for QGTOnTheFly, using

both on my laptop (4 CPU cores) and a GPU cluster (1 K80 GPU). The first iteration takes a while in both cases because the solver has to be jit-compiled. On the CPU, the second run (with all inputs changed, as it would be in the VMC driver) is extremely fast (below 1 ms). On the GPU, however, it takes a time comparable to the first run, which suggests that either some of the jit-compilation happens every time or it ends up being very inefficient.

We don’t know how to probe this further. As a quick fix, we could set backend='cpu' in the jit decorators in QGTOnTheFly, but it would be better to figure out why this happens.

Output of the scripts:

Dense, laptop:

First time took 9.841582 seconds
Second time took 0.000266 seconds

GCNN, laptop:

First time took 12.325909 seconds
Second time took 0.000301 seconds

Dense, GPU:

First time took 18.779504 seconds
Second time took 6.217923 seconds

GCNN, GPU:

First time took 131.027075 seconds
Second time took 120.383964 seconds

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 18 (7 by maintainers)

Most upvoted comments

And are you timing correctly according to jax lazy evaluation logic?