keras-core: Possible bug: Torch backend backpropagation

I was attempting to replicate the demo_mnist_convnet example by replacing keras_core.layers with torch.nn.Module inside a keras_core.Model. Although training with model.fit doesn’t encounter any error, it seems that the model is not learning anything.

I further attempted to override the train_step function according to the guide custom_train_step_in_torch, but that too did not fix the issue. I also tried to explicitly call torch.nn.Module.zero_grad(self) as recommended in the same guide, without any effect.

Colab to reproduce the issue

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Reactions: 4
  • Comments: 20 (19 by maintainers)

Most upvoted comments

Would you like to explore implementing a prototype of a TorchModuleWrapper? If it’s simple, we can auto-wrap any module upon assignment to a Keras Layer with the torch backend.

Hi @fchollet Here’s the TorchModuleWarpper modified as per your feedback (with some generous help from @ariG23498)

class TorchModuleWarpper(keras.layers.Layer):
    def __init__(self, module, name=None):
        super().__init__(name=name)
        self.module = module.to("cuda")
        self.lazy = isinstance(self.module, nn.modules.lazy.LazyModuleMixin)
        if not self.lazy:
            self.track_module_parameters()

    def parameters(self, recurse=True):
        return self.module.parameters(recurse=not self.lazy)
    
    def track_module_parameters(self):
        for param in self.module.parameters():
            variable = keras.Variable(
                initializer=param, trainable=param.requires_grad
            )
            variable._value = param
            self._track_variable(variable)
        self.built = True
    
    def build(self, input_shape, *args, **kwargs):
        if not self.lazy:
            self._build_by_run_for_single_pos_arg(args)
            self._build_by_run_for_kwargs(kwargs)
        else:
            sample_input = torch.ones(*input_shape).to("cuda")
            _ = self.module(sample_input)
        self.track_module_parameters()

    def call(self, inputs, *args, **kwargs):
        if not self.built:
            self.build(inputs.shape[1:])
        return self.module.forward(inputs, *args, **kwargs)

Some observations wrt the following:

Instead of doing an eager _ = self.module(sample_input), we should leverage backend.compute_output_spec which will be more efficient.

It’s probably not possible to build the parameters for lazy torch modules except by doing an eager _ = self.module(sample_input) since they don’t initialize the parameters unless there’s a forward pass. This is what the official docs mention…

Modules that lazily initialize parameters, or “lazy modules”, derive the shapes of their parameters from the first input(s) to their forward method. Until that first forward they contain torch.nn.UninitializedParameters that should not be accessed or used, and afterward they contain regular torch.nn.Parameters.

Would love to know your thoughts on this.

PS: Apologies for the delayed response.

replacing keras_core.layers with torch.nn.Module inside a keras_core.Model

This is not intended to be supported at this time. Only Layer objects are tracked by Keras, not torch Modules.

@soumik12345 thanks – please open a PR and let’s move the discussion to the PR. This should live in backend/torch/torch_module_wrapper.py. There is some light refactoring we will need.

Thanks for the analysis! Is there a way to merge both classes into a single one?

Instead of doing an eager _ = self.module(sample_input), we should leverage backend.compute_output_spec which will be more efficient.

Also, should this be an officially supported feature on Keras Core? If not a feature, maybe this could be mentioned in a guide?

Sure, let’s add it. First, please investigate the case where not all module parameters are created upon instantiation (lazy module). When that works, we can start a PR.

@soumik12345 I have it working with a couple of simple changes:

  • The module wrapper can set self.built = True since torch module weights are already created (mind you, I guess this won’t be the case with lazy modules, I have no idea how those work)
  • There was a device placement issue, as is common with torch. I placed the module on GPU to make it work.

https://colab.research.google.com/drive/1gG93Fb03Ef-77suS6b7PqgWtMOopNyyF?usp=sharing

Update – actually there is a problem with variable tracking. It doesn’t yet work.

Udpate – fixed it by setting variable._value directly. It’s training now.

Would be a very welcome feature, just using .fit, .evaluate() etc on for example all models in https://docs.monai.io/en/stable/networks.html

It’s probably going to be something like this (warning: entirely untested):

class TorchModuleWarpper(Layer):
    def __init__(self, module, name=None):
        super().__init__(name=name)
        self.module = module

    def parameters(self, recurse=True):
        return self.module.parameters(recurse=recurse)

    def build(self, _):
        if not self.built:
            for param in self.module.parameters():
                variable = Variable(value=param, trainable=param.requires_grad)
                self._track_variable(variable)
        self.built = True

    def call(self, *args, **kwargs):
        return self.module.forward(*args, **kwargs)

Would you like to explore implementing a prototype of a TorchModuleWrapper? If it’s simple, we can auto-wrap any module upon assignment to a Keras Layer with the torch backend.

Yes! This is something I would love to explore.

It’s possible, but you’d need to wrap the module in a layer in a way that makes it track the underlying parameters(). We haven’t implemented anything like this yet. But it’s likely simple.

This would be a valuable feature, making keras_core compatible with likely the entirety of the PyTorch ecosystem.

It’s possible, but you’d need to wrap the module in a layer in a way that makes it track the underlying parameters(). We haven’t implemented anything like this yet. But it’s likely simple.