pyro: Precision issue with FloatTensor results in slow mixing in HMC
This issue was already discovered by @fehiepsi, and discussed in https://github.com/uber/pyro/pull/1678, but I was recently bitten by it again. As such, this is to track this issue and discuss any solutions / fixes, if there is indeed something that can be fixed.
e.g. model that works very well with default tensor type torch.DoubleTensor (as set currently), but is slow with torch.FloatTensor. The reason for the slowness is that during the adaptation, we keep decreasing the step size to a very small value. This results in very slow mixing of the chains, as can be seen from the diagnostics.
import pymc3 as pm
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS
import theano.tensor as tt
import torch
torch.set_printoptions(precision=10)
torch.set_default_tensor_type('torch.DoubleTensor')
def pm_model(data):
with pm.Model() as model:
p = pm.Beta('p', 1., 1.)
p_print = tt.printing.Print('p')(p)
pm.Binomial('obs', data['n'], p_print, observed=data['x'])
return model
def pyro_model(data):
p = pyro.sample('p', dist.Beta(1., 1.))
pyro.sample('obs', dist.Binomial(data['n'], p), obs=data['x'])
def get_samples_pymc(data, num_samples=200, warmup_steps=200):
data = {k: v.numpy() for k, v in data.items()}
with pm_model(data):
trace = pm.sample(draws=num_samples, tune=warmup_steps, chains=2)
def get_samples_pyro(data, num_samples=200, warmup_steps=200):
nuts_kernel = NUTS(pyro_model,
adapt_step_size=True,
adapt_mass_matrix=True,
jit_compile=True,
full_mass=True)
mcmc_run = MCMC(nuts_kernel,
num_samples=num_samples,
warmup_steps=warmup_steps,
num_chains=2).run(data)
print(mcmc_run.marginal(sites=['p']).diagnostics())
data = {'n': torch.tensor(5000000.), 'x': torch.tensor(3849.)}
#get_samples_pymc(data)
get_samples_pyro(data)
It seems that the reason is that our precision when using float tensors is too low - the acceptance rate with FloatTensor is low, and we end up massively decreasing the step size which in turn leads to slow sampling and mixing. This could point to:
- Numerical stability issues while computing log prob or gradient for the Beta distribution in PyTorch.
- Numerical stability in the transform code that converts the unconstrained values into [0, 1).
It will be nice if there was some way for us to at least detect and throw a warning to users, to help them debug such cases.
~Additionally, I found that in some larger models, we were able to make much faster progress by decreasing the tree depth during adaptation, to say, 6, without affecting the inference results. We could consider doing so by default (i.e. have a different tree depth during adaptation which can be overridden by the user).~ [This is a separate issue that is worth discussing separately, and is not relevant to this model].
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 17 (17 by maintainers)
@fehiepsi - There is an issue with the
log_probterm actually, not the conversion into logits. I will send a fix shortly to torch distributions.Thanks for digging into this, @fehiepsi. I’ll play around with your script a bit to see what’s happening here.
The following script shows the precision issue I guess (or it is acceptable).
Output (step_size, log_prob, grad)
And when we move to near the typical set:
x = -7.2(inverse transform of 3849/5000000), we getNote that 3849/5000000 = 0.0007698, which is not a so-small probability even in single precision. Four things we can observe:
You are right that the transform code is likely not at play here. ~Any numerical issues in
Beta.sample()shouldn’t affect us since we don’t exercise the.sample()method, just the.log_prob()method when computing the potential energy.~ (EDIT: This could impact the initial trace setting, but it is not an issue with this example and we observe a low acceptance even after we reach the typical set as @fehiepsi noted above). My first guess would be in computing the gradient of the Beta distribution (maybe gradient oftorch.lgamma, since everything else seems straightforward) – small discrepancies there could add up leading to a diverging hamiltonian trajectory. Will investigate further and report back.That is precisely the issue and the reason for opening this issue. Just to emphasize the causal link - it is due to bad precision that we end up using a lower step size than is needed for this model. And it is due to a very low step size that we end up needing to explore the tree to a much greater depth resulting in slow sampling and mixing. I didn’t mean to suggest that there is an issue with our adaptation scheme at all. As I mentioned earlier, it is most likely an issue with numerical instability either in distribution’s log prob, gradient computation or the transform code.
EDIT: Modified the title to reflect the issue better.