pyro: HMM model using HMC/NUTS is slow

Since we started using einsum to evaluate the joint log density of the trace with discrete parameters, we can do discrete site enumeration across many more models in HMC/NUTS without going OOM.

For HMMs, however, NUTS / HMC is extremely slow to the point of being unusable beyond a few time steps, say 10. Refer to this test for profiling.

There are a few issues that I noticed:

  • ~NUTS crawls to a halt due to extremely small step size in the warmup phase itself. I suppose the issue might be with starting from a random initialization, but I am not sure. It might be worth initializing with the MAP estimate and then running NUTS to not hang with unworkably small step size. I think we have seen this issue with a few other models too. e.g. #1470. cc. @fehiepsi~ I don’t observe this issue on pytorch 1.0.
  • Converting the values returned from the integrator into a model trace (_get_trace) also takes more than 3s. While comparatively small, I believe this can be optimized if we assume our models to be static by assuming a different data structure inside HMC, so that we do not need to run the model each time.
  • The trace log density evaluation (and gradient computation) takes the bulk of the time, as expected. It is not immediately clear how this can be improved given that we need to call this many times per integrator step to generate even a single sample, and it does seem like NUTS will continue to be slow until we can make this much faster. Profiling viz below.
screen shot 2018-11-01 at 5 24 41 pm

About this issue

  • Original URL
  • State: closed
  • Created 6 years ago
  • Reactions: 2
  • Comments: 19 (18 by maintainers)

Most upvoted comments

The only thing I think to worry about in using ArviZ is that we are writing the library with no regard for Python 2. In particular, we use matplotlib 3.x, which is Python3 only, and the rest of the python data science infrastructure seems to be phasing python2 support out over the next year, so we did not want to start a new project with that technical debt. I understand this may hurt adoption in legacy stacks!

Beyond that, please tag one of us here or open an issue if we can help with the integration at all. We have found xarray and netcdf to be very natural ways of storing inference data.

I’m jumping in here because I’ve also seen serious performance issues with NUTS/HMC.

In my testing, performance starts out adequate, and then begins to drop precipitously after around 20 iterations. I have observed as this is occurring, the step size is increasing.

To me, the interesting thing is that the performance isn’t constant. It declines after time. That suggests to me that the issue is not limited to the time it takes to calculate the energy, which shouldn’t vary that much from iteration to iteration.

I have two suspicions: The first is that the HMC/NUTS implementation is trying too hard to increase the step size, and so that it ends up producing lots and lots of divergences. The second is that this has to do with memory fragmentation because of the very large number of tensors that are created as intermediate steps and then retained through gradient calculation.

It is surprised to me that distributions’ log_prob just take 40s

In the HMM example the distribution log_prob computation is merely a gather, i.e. memcopy; all actual computation is done by sumproduct when combining log_prob tensors from multiple sites, i.e matmul and einsum.

@fehiepsi - I updated the example in the prof-hmm branch. It should be in tests/perf/prof_hmc_hmm.py. I don’t think there are any immediate TODOs for this one, and this is more of an enhancement issue than a perf issue. Some things we can experiment with in the future would be JITing the grad step itself (once PyTorch supports it).

@fritzo - You can run the profiler using:

python -m cProfile -o hmm.prof tests/perf/prof_hmc_hmm.py

on the prof-hmm branch. Attaching the .prof file. I have turned off step size adaptation so as not to take too much time. Most of the time is actually just taken by einsum so I am not sure if there is much room for optimization here.

hmm.tar.gz

@fehiepsi - This is a known issue that we would like to address or at least conclude that we cannot run HMM type models with the current constraints. I just cc’d you as an FYI - you shouldn’t feel compelled to work on this! 😃

I tested with pytorch 0.4.0 and didn’t catch the slowness.

The HMM test is only sampling 10 traces, but if you run it for longer than 10 time steps, you will find the issue of step sizes getting very small and making extremely slow progress. This is without using JIT. My hunch was that this could be the case with a bad initialization, and with transforms warping the potential energy surface in a way that trajectories are extremely unstable, and we keep lowering the step size, making progress extremely slow. This is just a guess though and needs to be investigated further.

Stan does not use MAP. Previously, PyMC3 used MAP, but now they don’t by default.

Even if it is not available by default, I am interested in exploring if initializing with the MAP estimate does better on these kinds of models. If so, it will be useful to provide an optional kwarg initialize_with_map=False to the HMC/NUTS kernels.

implement statistics such as effective number of samples / Gelman-Rubin convergence diagnostic instead.

That will be really useful! You should also check out arviz (which implements traceplot and diagnostics like gelman rubin), and this PR https://github.com/arviz-devs/arviz/pull/309 by @ColCarroll, which extends support for Pyro.

About initialization, Stan does not use MAP. Previously, PyMC3 used MAP, but now they don’t by default. There might be experimental reasons for their decisions (which I can’t track back). In my opinion, allowing users set starting points is enough (from my experience, it is extremely useful for some models: when I got stuck with randomize initialization, I set initializations to the mean, then things go smoothly). These starting points can come from intuition or from mean of priors or from MAP.