PySyft: Adam causes errors in training loop
[Related to #1909]
I tracked down the error to ‘optimizer = optim.Adam(model.parameters(), lr=1e-3)’ in my code, ‘optimizer = optim.SGD(model.parameters(), lr=1e-3)’ works.
My guess is that this is related to me using PyTorch 1.0 .
The code can be found in https://github.com/2fasc/Distributed_Malaria/blob/master/src/federated_training.py
Traceback (most recent call last):
File "/home/fasc/Documents/Distributed_Malaria/src/federated_training.py", line 143, in <module>
simple_federated_model()
File "/home/fasc/Documents/Distributed_Malaria/src/federated_training.py", line 66, in simple_federated_model
federated=True,
File "/home/fasc/Documents/Distributed_Malaria/src/auxiliaries.py", line 111, in train
optimizer.step()
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/torch/optim/adam.py", line 94, in step
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/frameworks/torch/hook.py", line 650, in overloaded_native_method
response = method(*new_args, **new_kwargs)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/frameworks/torch/hook.py", line 486, in overloaded_pointer_method
response = owner.send_command(location, command)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/workers/base.py", line 364, in send_command
_ = self.send_msg(MSGTYPE.CMD, message, location=recipient)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/workers/base.py", line 198, in send_msg
bin_response = self._send_msg(bin_message, location)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/workers/virtual.py", line 6, in _send_msg
return location._recv_msg(message)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/workers/virtual.py", line 9, in _recv_msg
return self.recv_msg(message)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/workers/base.py", line 229, in recv_msg
response = self._message_router[msg_type](contents)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/workers/base.py", line 316, in execute_command
getattr(_self, command_name)(*args, **kwargs)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/frameworks/torch/hook.py", line 636, in overloaded_native_method
raise route_method_exception(e, self, args, kwargs)
File "/home/fasc/miniconda3/envs/malaria/lib/python3.6/site-packages/syft/frameworks/torch/hook.py", line 630, in overloaded_native_method
response = method(*args, **kwargs)
TypeError: addcmul_() takes 2 positional arguments but 3 were given
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 21 (12 by maintainers)
@LaRiffle I’ve experienced the increasing loss issues with my Adam as well.
But, since it wasn’t raising errors, I tried to see where things went numerically wrong.
Somehow, in Adam’s
step()method, changing this lineto this line
and this line
to this line
made it work and the loss decreased with no problem.
So, it seems like
addcmul_()andaddcdiv_()functions have issues working with PySyft, but doing what those functions do step by step somehow fixes them.Hey @2fasc - seems like the solution here is to have different optimizers for different machines (one for bob, another for alice, etc.) Eventually we’ll write custom Federated optimizers but for now this is the solution 😃
Just a small report, I’m running into a similar error, but it appears a few lines earlier:
The following error appears:
So it’s to be a more general problem than only specific functions like
addcmul_()andaddcdiv_()Ok here is the trouble: Adam uses momentum (actually: second moments of the gradients), which means it stores the gradients in a list and for each batch produces a correction of the current gradient based on the old ones. When changing of batch owner (so in Part 8 of tutorial at the middle of the epoch), you have now gradients from alice which you want to correct with moments of old gradients owned by bob: this is not possible as the data needs to be at the same location and it raises an error, which is here a bit tricky to find. This is also why momentum is not supported so far on SGD.
The fix for this would probably imply to rewrite the optimizers. This is an important project and maybe could be correlated to the notions of aggregator needed for Federated or Secure Averaging.
Thank you for reporting the error!