keras-core: Possible bug: Torch backend backpropagation
I was attempting to replicate the demo_mnist_convnet example by replacing keras_core.layers with torch.nn.Module inside a keras_core.Model. Although training with model.fit doesn’t encounter any error, it seems that the model is not learning anything.
I further attempted to override the train_step function according to the guide custom_train_step_in_torch, but that too did not fix the issue. I also tried to explicitly call torch.nn.Module.zero_grad(self) as recommended in the same guide, without any effect.
About this issue
- Original URL
- State: closed
- Created a year ago
- Reactions: 4
- Comments: 20 (19 by maintainers)
Would you like to explore implementing a prototype of a
TorchModuleWrapper? If it’s simple, we can auto-wrap any module upon assignment to a Keras Layer with the torch backend.Hi @fchollet Here’s the
TorchModuleWarppermodified as per your feedback (with some generous help from @ariG23498)Some observations wrt the following:
It’s probably not possible to build the parameters for lazy torch modules except by doing an eager
_ = self.module(sample_input)since they don’t initialize the parameters unless there’s a forward pass. This is what the official docs mention…Would love to know your thoughts on this.
PS: Apologies for the delayed response.
This is not intended to be supported at this time. Only
Layerobjects are tracked by Keras, not torch Modules.@soumik12345 thanks – please open a PR and let’s move the discussion to the PR. This should live in
backend/torch/torch_module_wrapper.py. There is some light refactoring we will need.Thanks for the analysis! Is there a way to merge both classes into a single one?
Instead of doing an eager
_ = self.module(sample_input), we should leveragebackend.compute_output_specwhich will be more efficient.Sure, let’s add it. First, please investigate the case where not all module parameters are created upon instantiation (lazy module). When that works, we can start a PR.
@soumik12345 I have it working with a couple of simple changes:
self.built = Truesince torch module weights are already created (mind you, I guess this won’t be the case with lazy modules, I have no idea how those work)https://colab.research.google.com/drive/1gG93Fb03Ef-77suS6b7PqgWtMOopNyyF?usp=sharing
Update – actually there is a problem with variable tracking. It doesn’t yet work.
Udpate – fixed it by setting
variable._valuedirectly. It’s training now.Would be a very welcome feature, just using .fit, .evaluate() etc on for example all models in https://docs.monai.io/en/stable/networks.html
It’s probably going to be something like this (warning: entirely untested):
Yes! This is something I would love to explore.
This would be a valuable feature, making
keras_corecompatible with likely the entirety of the PyTorch ecosystem.It’s possible, but you’d need to wrap the module in a layer in a way that makes it track the underlying
parameters(). We haven’t implemented anything like this yet. But it’s likely simple.