PySyft: Error in Federated Learning when Batch Normalization layer is used
Hi, I am trying to replicate the tutorial (Part 06 - Federated Learning on MNIST using a CNN.ipynb) with CIFAR10 dataset and with a network model I defined myself. However, it gives an error like this when it tries to execute nn.BatchNorm2d() function:
RuntimeError: running_mean should contain 3 elements not [32]
If I don’t use batch normalization layers in my model, it works fine. If I use batch normalization layer without any federated learning approach, it also works perfectly. I need to use batch normalization layer to get a good accuracy and also I need to work this example with ResNet (also having batch normalization layer), so I am not sure how to work around this.
I am using google colab netbook with the installation example you gave (and also tested with my own computer with Python3.6 and PyTorch 1.1.0)
The example code and the related error is below:
from __future__ import print_function
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.__version__
from torchvision import datasets, transforms
from torch.autograd import Variable
import syft as sy # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch) # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob") # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice") # <-- NEW: and alice
class Arguments():
def __init__(self):
self.batch_size = 64
self.test_batch_size = 1000
self.epochs = 10
self.lr = 0.01
self.momentum = 0.5
self.no_cuda = False
self.seed = 1
self.log_interval = 30
self.save_model = False
args = Arguments()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,), (0.5,))
]))
.federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,), (0.5,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
class Net(nn.Module):
def __init__(self, dropout=0.0):
super(Net, self).__init__()
self.dropout = dropout
self.conv_layer = nn.Sequential(
# Conv Layer block 1
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# Conv Layer block 2
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout2d(p=0.05), # 0.05
# Conv Layer block 3
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fc_layer = nn.Sequential(
nn.Dropout(p=0.1), # 0.1
nn.Linear(4096, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1), # 0.1
nn.Linear(512, 10)
)
def forward(self, x):
x = self.conv_layer(x)
x = x.view(x.size(0), -1)
x = self.fc_layer(x)
return x
def train(args, model, device, federated_train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
model.send(data.location) # <-- NEW: send the model to the right location
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
model.get() # <-- NEW: get the model back
if batch_idx % args.log_interval == 0:
loss = loss.get() # <-- NEW: get the loss back
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
100. * batch_idx / len(federated_train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment
for epoch in range(1, args.epochs + 1):
train(args, model, device, federated_train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(), "CIFAR10_cnn.pt")
The complete output (from google colab):
WARNING: Logging before flag parsing goes to stderr.
W0813 08:46:34.190028 140520788572032 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/usr/local/lib/python3.6/dist-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'
W0813 08:46:34.211576 140520788572032 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.
0it [00:00, ?it/s]
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz
170500096it [00:04, 40609204.86it/s]
---------------------------------------------------------------------------
PureTorchTensorFoundError Traceback (most recent call last)
/content/PySyft/syft/frameworks/torch/tensors/interpreters/native.py in handle_func_command(cls, command)
300 new_args, new_kwargs, new_type, args_type = syft.frameworks.torch.hook_args.unwrap_args_from_function(
--> 301 cmd, args, kwargs, return_args_type=True
302 )
16 frames
PureTorchTensorFoundError:
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
1695 return torch.batch_norm(
1696 input, weight, bias, running_mean, running_var,
-> 1697 training, momentum, eps, torch.backends.cudnn.enabled
1698 )
1699
RuntimeError: running_mean should contain 3 elements not 32
One possible workaround for this problem is to divide the net and use batch norm in client side. However, I do not want to rewrite big ResNet architectures and make the code more complex.
I hope you can help me with this issue.
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 17 (5 by maintainers)
+10086
@bussfromspace @karlhigley This issue seems to be solved. Can you close it ?
@ratmcu I think my syft version is the old one , and I udpate it to the new version which use pytorch1.4.0 . This issue is solve . Thanks for helping !