pyro: trace's log_prob of models with Gamma/Dirichlet distributions incompatible with JIT
Many of our JIT tests in HMC / NUTS. e.g. test_gamma_normal, test_dirichlet_categorical, etc. are not terminating on the pytorch-1.0 branch. The reason for this is that the step size is becoming extremely small in the adaptation phase, and we end up crawling to a halt. Note that the test runs fine without JIT. This probably points to some incompatibility in the adaptation code w.r.t. the assumptions that the JIT is making. This is likely a regression due to some of our internal changes / refactoring of the code, since these tests were working earlier for the most part. cc. @fehiepsi
About this issue
- Original URL
- State: closed
- Created 6 years ago
- Comments: 15 (15 by maintainers)
@neerajprad @eb8680 This is not a pyro problem. I can replicate the bug in PyTorch
I have raised the bug in pytorch slack.
You are right that JIT grad is just incorrect for gamma and dirichlet models. Other models look fine. This seems like a PyTorch bug.
The bug is reported at: https://github.com/pytorch/pytorch/issues/13669
Some observations:
log_probatz, then grad is right. If we returnlog_probatobs, things are still right. But if we return sum of them (or any linear combination of them), then wrong grad happens.log_prob_at_z + 0 * log_prob_at_obs, then grad will increased by 1000 times. This is the number of data points inobs..expand. But it seems the problem has been solved in pytorch master (no issue with jit offn2in my gist).@neerajprad I ping a bug I notice in slack. You might check it there. 😃
@fehiepsi - Great! You will need to checkout the
pytorch-1.0branch to get all the JIT related changes (see the corresponding PR #1431). You will also need to be on PyTorch master (or download the nightly build usingconda install torch_nightly -c pytorch, though it might have a few perf issues that are being worked on). You can run all jit tests usingmake test-jit. Since that takes a bit of time, for the purpose of debugging, I would suggest running the JIT tests directly fromtest_hmc.pyandtest_nuts.py.All that does, is JIT the potential energy computation when run the first time, and use the compiled version subsequently. You can toggle off the JIT warnings by setting
ignore_jit_warnings=True, which should be the case for some of these tests. Let me know if you face any issues.@neerajprad I would like to resolve this bug to learn a bit about jit (I have no idea how it works at the moment ^___^!). Could you let me know how to do these tests with jit?