DeepSpeed: [BUG] `zero.Init` fails when some class is brought in dynamically

Once the issue of https://github.com/microsoft/DeepSpeed/issues/2811 is resolved we immediately get a new problem with dynamic class importation.

This time the situation can be first demonstrated with the real code of:

from transformers import AutoModel
with deepspeed.zero.Init(config_dict_or_path=ds_config):
     x = AutoModel.from_config(x)

though usually it’s much more complex - I will show a more complete example below:

Here is a simple way to repro this use case with a simple dynamic class creation:

import torch
import deepspeed

ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))

with deepspeed.zero.Init(config_dict_or_path=ds_config):
    # this is the same as `class MyModel(torch.nn.Module): pass` but it's done dynamically 
    x = type("MyModel", (torch.nn.Module,), {})
    model = x()

deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
$ deepspeed --num_gpus 1 test2.py

Traceback (most recent call last):
  File "test2.py", line 9, in <module>
    model = x()
  File "/mnt/nvme0/code/github/00optimize/DeepSpeed-optim-grad-accessors/deepspeed/runtime/zero/partition_parameters.py", line 414, in __exit__
    shutdown_init_context()
  File "/mnt/nvme0/code/github/00optimize/DeepSpeed-optim-grad-accessors/deepspeed/runtime/zero/partition_parameters.py", line 454, in shutdown_init_context
    _disable_class(subclass)
  File "/mnt/nvme0/code/github/00optimize/DeepSpeed-optim-grad-accessors/deepspeed/runtime/zero/partition_parameters.py", line 450, in _disable_class
    cls.__init__ = cls._old_init
AttributeError: type object 'MyModel' has no attribute '_old_init'

The problem here is that zero.Init context when entering overrides __init__ methods, but because MyModel class wasn’t loaded, torch.nn.Module doesn’t have as its descendants and when zero.Init traverses the torch.nn.Module descendants it can’t install the override.

but on exit it tries to restore the init, yet it was never registered!

there are 2 problems here - the failure above, but of course even if I fix it by checking:

if hasattr(submodule,  '_old_init': cls.__init__ = cls._old_init

the problem is that weights of this class will be not be sharded during zero.Init.

So the only solution the user has here, is that they can’t rely on AutoModel

They have to figure out what actual class they need to import it.

Just fixing the attribute issue in the traceback would actually make things worse as it’d swipe the problem under the carpet.

If the above condition type object 'MyModel' has no attribute '_old_init' occurs Deepspeed could assert and ask the user whether they were loading any modules dynamically inside the zero.Init context. And explain that all used torch.nn.Module subclasses must be imported into the space where zero.Init runs.

This could be very tricky for complex models that import dynamically inside their constructors. Again I’ll use the example of a VisionEncoderDecoderModel which is in essence like this:

VisionEncoderDecoderModel.from_pretrained(...)

which calls this constructor inside zero.Init

class VisionEncoderDecoderModel(...)::
    def __init__(self):
            encoder = AutoModel.from_config(config.encoder)
            decoder = AutoModelForCausalLM.from_config(config.decoder)

so here the user of VisionEncoderDecoderModel has no idea that a model internally calls on another model that will be imported dynamically at run-time after zero.Init context has been entered.

So this is going to break. and there are more and more models of this kind.

@tjruwase, @samyam

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 18 (18 by maintainers)

Most upvoted comments

I was only able to test a few cases, I will continue testing/studying your work tomorrow, @tohtana - thank you for working on it.

Hello @stas00,

I just submitted a PR (#2989) to fix the problems addressed in this issue (and #2812).

I checked the fix with some patterns. Nested zero.Init():

with deepspeed.zero.Init(config_dict_or_path=ds_config):
    with deepspeed.zero.Init(config_dict_or_path=ds_config):
        m1 = torch.nn.Linear(1,1)

deepspeed_engine, *_ = deepspeed.initialize(model=m1, config_params=ds_config)

Dynamically defined class:

with deepspeed.zero.Init(config_dict_or_path=ds_config):
    
    class MyModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = torch.nn.Linear(1,1)

    with deepspeed.zero.Init(config_dict_or_path=ds_config):
        model = MyModel()

deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)

The code in the issue regarding VisionEncoderDecoderModel also ran after a few small fixes (it didn’t converge in the settings).

As we discussed, the following pattern is not supported. We need another zero.Init() before the instantiation of the new class.

with deepspeed.zero.Init(config_dict_or_path=ds_config):
    class MyModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = torch.nn.Linear(1,1)
    model = MyModel()

deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)

In this pattern, zero.Init() finds that the new class wasn’t properly processed and throws an error.

RuntimeError: `<class '__main__.MyModel'>' was not properly set up for sharding by zero.Init(). A subclass of torch.nn.Module must be defined before zero.Init() where an instance of the class is created.

I appreciate your support for our project. Any feedback is welcome.

@stas00 Thank you for the clarification. I understand that the code defines new attributes. Sorry, I just didn’t properly understand the meaning of substitute and used the word to mean setting a value to somewhere.

Since it is difficult to estimate the cost for implementation of my approach using metaclass, we still want to implement your approach now. I think it is clear that your proposal can solve the issues of nested zero.Init and dynamically defined class.

As @tjruwase mentioned, I will also check if the case like vision transformer you mentioned works.