keras-core: Possible bug: Torch backend backpropagation
I was attempting to replicate the demo_mnist_convnet
example by replacing keras_core.layers
with torch.nn.Module
inside a keras_core.Model
. Although training with model.fit
doesn’t encounter any error, it seems that the model is not learning anything.
I further attempted to override the train_step
function according to the guide custom_train_step_in_torch, but that too did not fix the issue. I also tried to explicitly call torch.nn.Module.zero_grad(self)
as recommended in the same guide, without any effect.
About this issue
- Original URL
- State: closed
- Created a year ago
- Reactions: 4
- Comments: 20 (19 by maintainers)
Would you like to explore implementing a prototype of a
TorchModuleWrapper
? If it’s simple, we can auto-wrap any module upon assignment to a Keras Layer with the torch backend.Hi @fchollet Here’s the
TorchModuleWarpper
modified as per your feedback (with some generous help from @ariG23498)Some observations wrt the following:
It’s probably not possible to build the parameters for lazy torch modules except by doing an eager
_ = self.module(sample_input)
since they don’t initialize the parameters unless there’s a forward pass. This is what the official docs mention…Would love to know your thoughts on this.
PS: Apologies for the delayed response.
This is not intended to be supported at this time. Only
Layer
objects are tracked by Keras, not torch Modules.@soumik12345 thanks – please open a PR and let’s move the discussion to the PR. This should live in
backend/torch/torch_module_wrapper.py
. There is some light refactoring we will need.Thanks for the analysis! Is there a way to merge both classes into a single one?
Instead of doing an eager
_ = self.module(sample_input)
, we should leveragebackend.compute_output_spec
which will be more efficient.Sure, let’s add it. First, please investigate the case where not all module parameters are created upon instantiation (lazy module). When that works, we can start a PR.
@soumik12345 I have it working with a couple of simple changes:
self.built = True
since torch module weights are already created (mind you, I guess this won’t be the case with lazy modules, I have no idea how those work)https://colab.research.google.com/drive/1gG93Fb03Ef-77suS6b7PqgWtMOopNyyF?usp=sharing
Update – actually there is a problem with variable tracking. It doesn’t yet work.
Udpate – fixed it by setting
variable._value
directly. It’s training now.Would be a very welcome feature, just using .fit, .evaluate() etc on for example all models in https://docs.monai.io/en/stable/networks.html
It’s probably going to be something like this (warning: entirely untested):
Yes! This is something I would love to explore.
This would be a valuable feature, making
keras_core
compatible with likely the entirety of the PyTorch ecosystem.It’s possible, but you’d need to wrap the module in a layer in a way that makes it track the underlying
parameters()
. We haven’t implemented anything like this yet. But it’s likely simple.