DeepSpeed: [BUG] `zero.Init` fails when some class is brought in dynamically
Once the issue of https://github.com/microsoft/DeepSpeed/issues/2811 is resolved we immediately get a new problem with dynamic class importation.
This time the situation can be first demonstrated with the real code of:
from transformers import AutoModel
with deepspeed.zero.Init(config_dict_or_path=ds_config):
x = AutoModel.from_config(x)
though usually it’s much more complex - I will show a more complete example below:
Here is a simple way to repro this use case with a simple dynamic class creation:
import torch
import deepspeed
ds_config = dict(train_batch_size=1, zero_optimization=dict(stage=3))
with deepspeed.zero.Init(config_dict_or_path=ds_config):
# this is the same as `class MyModel(torch.nn.Module): pass` but it's done dynamically
x = type("MyModel", (torch.nn.Module,), {})
model = x()
deepspeed_engine, *_ = deepspeed.initialize(model=model, config_params=ds_config)
$ deepspeed --num_gpus 1 test2.py
Traceback (most recent call last):
File "test2.py", line 9, in <module>
model = x()
File "/mnt/nvme0/code/github/00optimize/DeepSpeed-optim-grad-accessors/deepspeed/runtime/zero/partition_parameters.py", line 414, in __exit__
shutdown_init_context()
File "/mnt/nvme0/code/github/00optimize/DeepSpeed-optim-grad-accessors/deepspeed/runtime/zero/partition_parameters.py", line 454, in shutdown_init_context
_disable_class(subclass)
File "/mnt/nvme0/code/github/00optimize/DeepSpeed-optim-grad-accessors/deepspeed/runtime/zero/partition_parameters.py", line 450, in _disable_class
cls.__init__ = cls._old_init
AttributeError: type object 'MyModel' has no attribute '_old_init'
The problem here is that zero.Init
context when entering overrides __init__
methods, but because MyModel
class wasn’t loaded, torch.nn.Module
doesn’t have as its descendants and when zero.Init
traverses the torch.nn.Module
descendants it can’t install the override.
but on exit
it tries to restore the init, yet it was never registered!
there are 2 problems here - the failure above, but of course even if I fix it by checking:
if hasattr(submodule, '_old_init': cls.__init__ = cls._old_init
the problem is that weights of this class will be not be sharded during zero.Init
.
So the only solution the user has here, is that they can’t rely on AutoModel
They have to figure out what actual class they need to import it.
Just fixing the attribute issue in the traceback would actually make things worse as it’d swipe the problem under the carpet.
If the above condition type object 'MyModel' has no attribute '_old_init'
occurs Deepspeed could assert and ask the user whether they were loading any modules dynamically inside the zero.Init
context. And explain that all used torch.nn.Module
subclasses must be imported into the space where zero.Init
runs.
This could be very tricky for complex models that import dynamically inside their constructors. Again I’ll use the example of a VisionEncoderDecoderModel
which is in essence like this:
VisionEncoderDecoderModel.from_pretrained(...)
which calls this constructor inside zero.Init
class VisionEncoderDecoderModel(...)::
def __init__(self):
encoder = AutoModel.from_config(config.encoder)
decoder = AutoModelForCausalLM.from_config(config.decoder)
so here the user of VisionEncoderDecoderModel
has no idea that a model internally calls on another model that will be imported dynamically at run-time after zero.Init
context has been entered.
So this is going to break. and there are more and more models of this kind.
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 18 (18 by maintainers)
I was only able to test a few cases, I will continue testing/studying your work tomorrow, @tohtana - thank you for working on it.
Hello @stas00,
I just submitted a PR (#2989) to fix the problems addressed in this issue (and #2812).
I checked the fix with some patterns. Nested
zero.Init()
:Dynamically defined class:
The code in the issue regarding
VisionEncoderDecoderModel
also ran after a few small fixes (it didn’t converge in the settings).As we discussed, the following pattern is not supported. We need another zero.Init() before the instantiation of the new class.
In this pattern,
zero.Init()
finds that the new class wasn’t properly processed and throws an error.I appreciate your support for our project. Any feedback is welcome.
@stas00 Thank you for the clarification. I understand that the code defines new attributes. Sorry, I just didn’t properly understand the meaning of
substitute
and used the word to mean setting a value to somewhere.Since it is difficult to estimate the cost for implementation of my approach using metaclass, we still want to implement your approach now. I think it is clear that your proposal can solve the issues of nested
zero.Init
and dynamically defined class.As @tjruwase mentioned, I will also check if the case like vision transformer you mentioned works.