pyro: trace's log_prob of models with Gamma/Dirichlet distributions incompatible with JIT

Many of our JIT tests in HMC / NUTS. e.g. test_gamma_normal, test_dirichlet_categorical, etc. are not terminating on the pytorch-1.0 branch. The reason for this is that the step size is becoming extremely small in the adaptation phase, and we end up crawling to a halt. Note that the test runs fine without JIT. This probably points to some incompatibility in the adaptation code w.r.t. the assumptions that the JIT is making. This is likely a regression due to some of our internal changes / refactoring of the code, since these tests were working earlier for the most part. cc. @fehiepsi

About this issue

  • Original URL
  • State: closed
  • Created 6 years ago
  • Comments: 15 (15 by maintainers)

Most upvoted comments

@neerajprad @eb8680 This is not a pyro problem. I can replicate the bug in PyTorch

import torch
import torch.autograd as autograd
import torch.distributions as dist

def fn(z):
    a = dist.Gamma(1, 1)
    return a.log_prob(z).sum() + (z.log() - data).sum()

data = torch.zeros(1000)
z = torch.tensor(1., requires_grad=True)
fn_jit = torch.jit.trace(fn, (z,))
print(fn(z))  # -1
print(fn_jit(z))  # -1
print(autograd.grad(fn(z), (z,)))  # return 999
print(autograd.grad(fn_jit(z), (z,)))  # return 0

I have raised the bug in pytorch slack.

You are right that JIT grad is just incorrect for gamma and dirichlet models. Other models look fine. This seems like a PyTorch bug.

Some observations:

  • this just happens with some distributions such as Gamma, Dirichlet. Other distributions seem fine.
  • the problem happens at the return function: If we just return log_prob at z, then grad is right. If we return log_prob at obs, things are still right. But if we return sum of them (or any linear combination of them), then wrong grad happens.
  • If we return log_prob_at_z + 0 * log_prob_at_obs, then grad will increased by 1000 times. This is the number of data points in obs.
  • “grad is off by 1000x” is related to a well-known bug of the above distributions with .expand. But it seems the problem has been solved in pytorch master (no issue with jit of fn2 in my gist).

@neerajprad I ping a bug I notice in slack. You might check it there. 😃

@fehiepsi - Great! You will need to checkout the pytorch-1.0 branch to get all the JIT related changes (see the corresponding PR #1431). You will also need to be on PyTorch master (or download the nightly build using conda install torch_nightly -c pytorch, though it might have a few perf issues that are being worked on). You can run all jit tests using make test-jit. Since that takes a bit of time, for the purpose of debugging, I would suggest running the JIT tests directly from test_hmc.py and test_nuts.py.

All that does, is JIT the potential energy computation when run the first time, and use the compiled version subsequently. You can toggle off the JIT warnings by setting ignore_jit_warnings=True, which should be the case for some of these tests. Let me know if you face any issues.

@neerajprad I would like to resolve this bug to learn a bit about jit (I have no idea how it works at the moment ^___^!). Could you let me know how to do these tests with jit?