pyro: Several iterations into NUTS, sampling from a Categorical distribution unexpectedly returns multiple values
Issue Description
While attempting to perform inference on a model (code located below) using NUTS/MCMC, after several (~100) iterations, I get the error: ValueError: only one element tensors can be converted to Python scalars. When adding a print statement to relevant variable, which is a draw from a Categorical distribution, it became clear that the variable was now a two-element tensor, hence the issue. However, I don’t see why a draw from a Categorical distribution should be returning a two element tensor.
Environment
I am running this in Google Colab, using Python 3.6, pyro-ppl-0.5.1, torch 1.3.0.
Code Snippet
import torch
import pyro.distributions as dist
import pyro
from pyro.infer import NUTS, MCMC
from torch.distributions import constraints
from pyro import poutine
data1 = dist.MultivariateNormal(-5 * torch.ones(2), torch.eye(2)).sample([100])
data2 = dist.MultivariateNormal(5 * torch.ones(2), torch.eye(2)).sample([100])
data = torch.cat((data1, data2))
pyro.enable_validation(True)
N = len(data)
T = 2
def mix_weights(beta):
weights = torch.zeros(beta.shape[0] + 1)
for t in range(beta.shape[0]):
weights[t] = beta[t] * torch.prod(1. - beta[:t], dim=0)
weights[beta.shape[0]] = 1. - torch.sum(weights)
return weights
def model(data):
alpha = 1.2
with pyro.plate("beta_plate", T-1):
beta = pyro.sample("beta", dist.Beta(1, alpha))
with pyro.plate("mu_plate", T):
mu = pyro.sample("mu", dist.MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))
for i in range(N):
z = pyro.sample("z_{}".format(i), dist.Categorical(mix_weights(beta)))
print (z.numpy())
pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mu[z.item()], torch.eye(2)), obs=data[i])
nuts_kernel = NUTS(model, adapt_step_size=True)
mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=300)
mcmc.run(data)
samples = mcmc.get_samples()
print(samples)
Output:
Warmup: 0%| | 0/800 [00:00, ?it/s]0
0
0
0
0
0
1
1
0
1
0
1
0
0
1
0
1
0
0
0
0
0
0
0
0
0
0
1
0
0
1
0
1
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
0
0
0
0
1
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
0
0
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
1
0
0
1
0
0
0
0
0
1
0
0
0
1
0
1
1
0
1
0
0
0
1
0
0
0
0
0
1
0
0
0
0
0
0
1
1
0
0
0
0
0
0
0
1
1
0
1
0
0
0
0
0
0
0
0
0
1
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
[[0]
[1]]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
142 try:
--> 143 ret = self.fn(*args, **kwargs)
144 except (ValueError, RuntimeError):
16 frames
ValueError: only one element tensors can be converted to Python scalars
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-19-63983afdb4c3> in model(data)
22 z = pyro.sample("z_{}".format(i), dist.Categorical(mix_weights(beta)))
23 print (z.numpy())
---> 24 pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mu[z.item()], torch.eye(2)), obs=data[i])
25
26 nuts_kernel = NUTS(model, adapt_step_size=True)
ValueError: only one element tensors can be converted to Python scalars
Trace Shapes:
Param Sites:
Sample Sites:
beta_plate dist |
value 1 |
beta dist 1 |
value 1 |
mu_plate dist |
value 2 |
mu dist 2 | 2
value 2 | 2
z_0 dist |
value 2 1 |
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 20 (8 by maintainers)
It does have continuous variables as well. I’ll create a small example program to illustrate the problem and open a separate issue. Thanks!
Thanks @fehiepsi, this worked perfectly.