pyro: Memory leak while using Beta distribution

Issue Description

I’m training a simple feedforward neural network which outputs parameters of a beta distribution from which I take samples, and use it with BCELoss. The problem is that the memory (RAM) usage increases with number of training iterations and it starts thrashing soon after, giving me no choice but to force-shut-down the computer (code below). The GPU usage is stable.

Environment

  • OS and python version : Ubuntu 14.04, Python 2.7
  • PyTorch version: torch==0.4.1,
  • Pyro version: 0.2.1

Code Snippet

Here is the network architecture:

from pyro.distributions.torch import Beta
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable

class WNetBeta(nn.Module):
	def __init__(self, in_ch, out_ch, C0=32):
		super(WNetBeta, self).__init__()
		self.conv0 = conv(in_ch, C0, (1, 3, 3), padding=(0, 1, 1))

		self.resblock1 = ResBlock(C0, C0)
		self.resblock2 = ResBlock(C0, C0)
		self.tconv1 = tconv(C0, C0)
		
		self.avgsample = nn.AvgPool3d((1, 2, 2), stride=(1, 2, 2))

		self.resblock3 = ResBlock(C0, C0)
		self.resblock4 = ResBlock(C0, C0)

		self.tconv2 = tconv(C0, C0)

		self.pconv1 = pconv(C0, out_ch, (1, 3, 3), padding=(0, 1, 1))

	def forward(self, x):
		data = self.conv0(x)
		data = self.resblock1(data)
		data = self.resblock2(data)
		data = self.tconv1(data)
		data = self.avgsample(data)
		
		data = self.resblock3(data)
		data = self.resblock4(data)
		data = self.tconv2(data)
		out = self.pconv1(data)

		## get alpha, beta
		logalpha = out[:, 0]
		logbeta  = out[:, 1]

		alpha = torch.exp(logalpha)
		beta  = torch.exp(logbeta)

		# Get samples now
		T = 10
		alphar = alpha.unsqueeze(1).repeat(1, T, 1, 1, 1, 1)
		betar  =  beta.unsqueeze(1).repeat(1, T, 1, 1, 1, 1)
		m = Beta(alphar, betar)
		p = m.rsample().mean(1)

		return p, logalpha, logbeta

And here is my training loop:

	model = WNetBeta(4, 2).cuda()
	# define optimizer
	optim = Adam(model.parameters(), lr=2e-5, weight_decay=1e-7)
	#optim = SGD(model.parameters(), lr=1e-5, momentum=0.3)
        if os.path.exists(optim_name):
		optim.load_state_dict(torch.load(optim_name))

	# dataloader
	train_ds = Loader('train')
	train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)
	val_ds = Loader('val')
	val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0)

	# loss function
	loss_fn = nn.BCELoss() 

	# training loop
	for i in range(EPOCHS):
		model.train()
		for j, (patch, whole_label) in enumerate(train_loader):
			# get patch
			optim.zero_grad()
			patch = Variable(patch).cuda()
			whole_label = Variable(whole_label).cuda()

			p, log_alpha, log_beta = model(patch)

			loss_val = loss_fn(p, whole_label)
			loss_val.backward()
                        # Step
			optim.step()

Thank you in advance.

Edit: I’m not putting up the definitions of the different modules (ResBlock, pconv, etc.) because they just normal convolution layers with different activations and stuff, so I think it isn’t relevant and would just clutter up things.

This problem doesn’t occur in case of a “baseline” which simply outputs a logit that is trained with BCELoss, so the case that the memory leak is due to a faulty data loader is unlikely.

About this issue

  • Original URL
  • State: closed
  • Created 6 years ago
  • Comments: 17 (7 by maintainers)

Most upvoted comments

@rohitrango Just to clarify what might be going on: Pyro’s Beta actually doesn’t implement GPU reparametrization, instead import pyro.distributions triggers a patch to PyTorch’s underlying gradients to do a GPU->CPU->GPU conversion for that computation. I’m curious whether you could get the same effect by using the PyTorch Beta but including import pyro.distributions.torch_patch at the top of your script.

@rohitrango Can you create a smaller example to replicate the problem? Here are my suggestions:

  • Try to train in just a few iterations (to avoid shutdown)
  • Try to not use repeat before sampling, I don’t understand why you need repeat hear (instead of feeding sample_shape into rsample), but I have seen many problems with repeat/expand + log_prob in pytorch 0.4.1.
  • Try to call .contiguous() at various places.
  • Try to use PyTorch 1.0rc and pyro dev version.

@rohitrango can you try reproducing your issue with PyTorch 0.4.0?