pyro: Several iterations into NUTS, sampling from a Categorical distribution unexpectedly returns multiple values

Issue Description

While attempting to perform inference on a model (code located below) using NUTS/MCMC, after several (~100) iterations, I get the error: ValueError: only one element tensors can be converted to Python scalars. When adding a print statement to relevant variable, which is a draw from a Categorical distribution, it became clear that the variable was now a two-element tensor, hence the issue. However, I don’t see why a draw from a Categorical distribution should be returning a two element tensor.

Environment

I am running this in Google Colab, using Python 3.6, pyro-ppl-0.5.1, torch 1.3.0.

Code Snippet

import torch
import pyro.distributions as dist
import pyro
from pyro.infer import NUTS, MCMC
from torch.distributions import constraints
from pyro import poutine

data1 = dist.MultivariateNormal(-5 * torch.ones(2), torch.eye(2)).sample([100])
data2 = dist.MultivariateNormal(5 * torch.ones(2), torch.eye(2)).sample([100])
data = torch.cat((data1, data2))

pyro.enable_validation(True)

N = len(data)
T = 2

def mix_weights(beta): 
    weights = torch.zeros(beta.shape[0] + 1)
    for t in range(beta.shape[0]):
      weights[t] = beta[t] * torch.prod(1. - beta[:t], dim=0)
    weights[beta.shape[0]] = 1. - torch.sum(weights)
    return weights

def model(data):
    alpha = 1.2
    with pyro.plate("beta_plate", T-1):
      beta = pyro.sample("beta", dist.Beta(1, alpha))

    with pyro.plate("mu_plate", T):
      mu = pyro.sample("mu", dist.MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))

    for i in range(N):
      z = pyro.sample("z_{}".format(i), dist.Categorical(mix_weights(beta)))
      print (z.numpy())
      pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mu[z.item()], torch.eye(2)), obs=data[i])

nuts_kernel = NUTS(model, adapt_step_size=True)
mcmc = MCMC(nuts_kernel, num_samples=500, warmup_steps=300)
mcmc.run(data)

samples = mcmc.get_samples()
print(samples)

Output:

Warmup:   0%|          | 0/800 [00:00, ?it/s]0
0
0
0
0
0
1
1
0
1
0
1
0
0
1
0
1
0
0
0
0
0
0
0
0
0
0
1
0
0
1
0
1
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
0
0
0
0
1
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
0
0
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
1
0
0
1
0
0
0
0
0
1
0
0
0
1
0
1
1
0
1
0
0
0
1
0
0
0
0
0
1
0
0
0
0
0
0
1
1
0
0
0
0
0
0
0
1
1
0
1
0
0
0
0
0
0
0
0
0
1
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
[[0]
 [1]]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    142             try:
--> 143                 ret = self.fn(*args, **kwargs)
    144             except (ValueError, RuntimeError):

16 frames
ValueError: only one element tensors can be converted to Python scalars

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-19-63983afdb4c3> in model(data)
     22       z = pyro.sample("z_{}".format(i), dist.Categorical(mix_weights(beta)))
     23       print (z.numpy())
---> 24       pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mu[z.item()], torch.eye(2)), obs=data[i])
     25 
     26 nuts_kernel = NUTS(model, adapt_step_size=True)

ValueError: only one element tensors can be converted to Python scalars
  Trace Shapes:        
   Param Sites:        
  Sample Sites:        
beta_plate dist     |  
          value   1 |  
      beta dist   1 |  
          value   1 |  
  mu_plate dist     |  
          value   2 |  
        mu dist   2 | 2
          value   2 | 2
       z_0 dist     |  
          value 2 1 |  

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 20 (8 by maintainers)

Most upvoted comments

It does have continuous variables as well. I’ll create a small example program to illustrate the problem and open a separate issue. Thanks!

Thanks @fehiepsi, this worked perfectly.