hydra: [Bug] Cannot use plugin's conf dataclasses in structured config

๐Ÿ› Bug

Description

Hi,

First of all, many thanks for that wonderful tool youโ€™ve made. It is, for sure, extremely helpful to have it when production environment is necessary. I have met an error/bug when I wanted to use the OptunaSweeperConf using the Python API (instead of using the traditional YAML files).

Checklist

  • I checked on the latest version of Hydra
  • I created a minimal repro (See this for tips).

To reproduce

** Minimal Code/Config snippet to reproduce **

# Standard libraries
from dataclasses import dataclass

# Third-party libraries
import hydra
from hydra.conf import HydraConf
from hydra.core.config_store import ConfigStore
from hydra_plugins.hydra_optuna_sweeper.config import OptunaSweeperConf
from omegaconf import DictConfig


@dataclass
class ExampleConfig():
    hydra: HydraConf = HydraConf(sweeper=OptunaSweeperConf())


ConfigStore.instance().store(name='config', node=ExampleConfig)


@hydra.main(config_path=None, config_name='config')
def main(config: DictConfig) -> None:
    """Main function."""
    print(config)


if __name__ == '__main__':
    main()

And the stack trace with the returned error:

** Stack trace/error message **

In 'config': Validation error while composing config:
Merge error: OptunaSweeperConf is not a subclass of BasicSweeperConf. value: {'_target_': 'hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper', 'defaults': [{'sampler': 'tpe'}], 'sampler': '???', 'direction': <Direction.minimize: 1>, 'storage': None, 'study_name': None, 'n_trials': 20, 'n_jobs': 2, 'search_space': {}}
    full_key: 
    object_type=dict

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Expected Behavior

I would like to override the hydra config group, directly into the structure config. To do so, I import the HydraConf from hydra and install the hydra-optuna-sweeper plugin. The error seems to indicate that the OptunaSweeperConf must inherit from the BasicSweeperConf structured config (current implementation: https://github.com/facebookresearch/hydra/blob/v1.1.1/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/config.py#L133). The idea is to be able to override hydra config group directly from a python script. Does that make sense?

System information

  • Hydra Version : 1.1.1
  • Hydra Optuna Sweeper Version : 1.1.1
  • Python version : 3.8.11
  • Virtual environment type and version : MiniConda environment
  • Operating system : LinuxMint

Additional context

Add any other context about the problem here.

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Comments: 22 (7 by maintainers)

Most upvoted comments

@Jasha10 you are right! ๐Ÿคฃ i updated the previous comment.

Ok - I took another look at this and here is what is happening: the tl;dr is the plugins are being imported twice and as a result OmegaConf fails the validation here

  1. OmegaConf fails here - complaining OptunaSweeperConf is not a subclass of OptunaSweeperConf. In debug mode, comparing the dest_obj_type and src_obj_type, i can see that although they are the same class, they have different ids (meaning they are different types due to the double imports)
  2. This seems to be only happening to plugins, since they are being imported twice: 1) first time in the user application 2) second time here
  3. This is probably not an optuna only issue but affects all plugins.

I also found a so question that describes a similar issue here

Worth trying to figure out the root cause.

Hi @jieru-hu, Sorry for the explanation, I think I am not clear enough. I am perfectly aware of how to configure a sweeper using YAML files (this is well documented) or directly through the command line. My question was about configuring a Sweeper directly using the Python API (structured config)? But your answers suggest that it is not possible or not proposed by hydra.

To put it in some context, Iโ€™m using Hydra to run Deep Learning trainings using PyTorch-Lightning like that:

@dataclass
class ExampleConfig:

    datamodule = DataModuleConf(
        train_dataset=RandomDatasetConf(),
        val_dataset=RandomDatasetConf(),
        test_dataset=RandomDatasetConf(),
        num_workers=0,
        batch_size=8
    )

    module = ModuleConf(
        model=RandomModelConf(),
        loss=CrossEntropyLossConf(),
        optimizer=SGDConf(lr=0.1),
        scheduler=StepLRConf(step_size=1),
    )

    trainer = TrainerConf(
        callbacks=[LearningRateMonitorConf()],
        max_epochs=20
    )


ConfigStore.instance().store(name='config', node=ExampleConfig)


@hydra.main(config_path=None, config_name='config')
def main(config) -> None:
    """Main function."""
    # Instantiate datamodule
    datamodule = instantiate(config=config.datamodule)

    # Instantiate module
    module = instantiate(config=config.module)

    # Instantiate trainer
    trainer = instantiate(config=config.trainer)

    # Run training
    trainer.fit(module, datamodule=datamodule)


if __name__ == '__main__':
    main()

I really like to use the Python API because it offers the benefits of using an IDE (like VSCode) to easily manipulate structured configurations with auto-completion. And I would like to be happy to override HydraConf directly in that file (having a single file to centralize all the config used for training). I managed to override the run dir like so, with adding the hydra config group:

    hydra: HydraConf = HydraConf(run=RunDir('./test'))

But if I want to do the same thing with a sweeper (like Optuna), it fails (see the error at the beginning of the thread). I understand that Hydra relies essentially on manipulating and composing YAML files but I also see an interest of manipulating the structured config (the schemas) directly through a Python script. Maybe I am wrong ๐Ÿ˜ƒ .

PS: there was an attempt, in https://github.com/pytorch/hydra-torch/blob/main/examples/mnist_00.py to use directly structured configs through a script. And I think we can go a step further with the example above. What do you think?

Hi @Jasha10,

I have made some tests using your PR with the following example (simply implementing the sphere example but using the python API:

# Standard libraries
from dataclasses import dataclass, field
from typing import Any, List

# Third-party libraries
import hydra
from hydra.conf import HydraConf
from hydra.core.config_store import ConfigStore
from hydra_plugins.hydra_optuna_sweeper.config import OptunaSweeperConf, TPESamplerConfig
from omegaconf import DictConfig



@dataclass
class ExampleConfig:
    defaults: List[Any] = field(
        default_factory=lambda: [
            {"override hydra/sweeper": "optuna"},
        ]
    )

    hydra: HydraConf = HydraConf(
        sweeper=OptunaSweeperConf(
            n_trials=20,
            n_jobs=1,
            storage=None,
            study_name='sphere',
            direction='minimize',
            sampler=TPESamplerConfig(seed=123),
            search_space={
                'x': {'type': 'float', 'low': -5.5, 'high': 5.5, 'step': 0.5},
                'y': {'type': 'categorical', 'choices': [-5, 0, 5]},
            }
        )
    )

    x: float = 1.0
    y: float = 1.0
    z: float = 1.0


ConfigStore.instance().store(name="config", node=ExampleConfig)


@hydra.main(config_path=None, config_name="config")
def main(config: DictConfig) -> None:
    """Main function."""
    x: float = config.x
    y: float = config.y
    return x**2 + y**2


if __name__ == "__main__":
    main()

Running with following command python sphere.py --multirun, I get this output:

[I 2022-02-08 10:04:15,274] A new study created in memory with name: sphere
[2022-02-08 10:04:15,274][HYDRA] Study name: sphere
[2022-02-08 10:04:15,274][HYDRA] Storage: None
[2022-02-08 10:04:15,274][HYDRA] Sampler: TPESampler
[2022-02-08 10:04:15,274][HYDRA] Directions: ['minimize']
[2022-02-08 10:04:15,277][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:15,277][HYDRA]        #0 : x=2.5 y=5
[2022-02-08 10:04:15,446][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:15,446][HYDRA]        #1 : x=2.5 y=0
[2022-02-08 10:04:15,559][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:15,559][HYDRA]        #2 : x=0.0 y=5
[2022-02-08 10:04:15,671][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:15,671][HYDRA]        #3 : x=-0.5 y=5
[2022-02-08 10:04:15,782][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:15,782][HYDRA]        #4 : x=-3.5 y=5
[2022-02-08 10:04:15,894][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:15,894][HYDRA]        #5 : x=1.5 y=-5
[2022-02-08 10:04:16,008][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:16,008][HYDRA]        #6 : x=2.5 y=0
[2022-02-08 10:04:16,118][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:16,118][HYDRA]        #7 : x=-2.5 y=-5
[2022-02-08 10:04:16,230][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:16,230][HYDRA]        #8 : x=-1.0 y=-5
[2022-02-08 10:04:16,342][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:16,342][HYDRA]        #9 : x=-1.0 y=0
[2022-02-08 10:04:16,455][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:16,455][HYDRA]        #10 : x=5.5 y=0
[2022-02-08 10:04:16,568][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:16,568][HYDRA]        #11 : x=5.0 y=0
[2022-02-08 10:04:16,689][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:16,689][HYDRA]        #12 : x=-5.0 y=0
[2022-02-08 10:04:16,802][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:16,802][HYDRA]        #13 : x=4.0 y=0
[2022-02-08 10:04:16,918][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:16,918][HYDRA]        #14 : x=1.5 y=0
[2022-02-08 10:04:17,094][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:17,094][HYDRA]        #15 : x=-2.0 y=0
[2022-02-08 10:04:17,208][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:17,208][HYDRA]        #16 : x=1.5 y=0
[2022-02-08 10:04:17,322][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:17,322][HYDRA]        #17 : x=1.0 y=0
[2022-02-08 10:04:17,437][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:17,437][HYDRA]        #18 : x=-5.5 y=0
[2022-02-08 10:04:17,551][HYDRA] Launching 1 jobs locally
[2022-02-08 10:04:17,551][HYDRA]        #19 : x=-3.5 y=-5
[2022-02-08 10:04:17,662][HYDRA] Best parameters: {'x': -1.0, 'y': 0}
[2022-02-08 10:04:17,662][HYDRA] Best value: 1.0

Which seems to be fine for me. ๐Ÿ˜ƒ

Nice detective work!!

Here is a minimal reproducer for the issue based on the comment above:

# min_repo1.py
import importlib
import pkgutil

from hydra_plugins.hydra_optuna_sweeper.config import OptunaSweeperConf

mdl = importlib.import_module("hydra_plugins")
for importer, modname, ispkg in pkgutil.walk_packages(
    path=mdl.__path__, prefix=mdl.__name__ + ".", onerror=lambda x: None
):
    m = importer.find_module(modname)
    loaded_mod = m.load_module(modname)
    if loaded_mod.__name__ == "hydra_plugins.hydra_optuna_sweeper.config":
        assert loaded_mod.OptunaSweeperConf is OptunaSweeperConf, "comparison failed"
        print("assertion passed!")
$ python min_repo1.py
Traceback (most recent call last):
  File "min_repo1.py", line 14, in <module>
    assert loaded_mod.OptunaSweeperConf is OptunaSweeperConf, "comparison failed"
AssertionError: comparison failed

I have an idea for a diff to solve the issue.

$ diff min_repo1.py min_repo2.py
1c1
< # min_repo1.py
---
> # min_repo2.py
11,12c11
<     m = importer.find_module(modname)
<     loaded_mod = m.load_module(modname)
---
>     loaded_mod = importlib.import_module(modname)
$ python min_repo2.py
assertion passed!

By doing the import with importlib.import_module instead of pkgutil, we get the same module as is loaded by the import statement at the top of the file.

Hello ๐Ÿ˜ƒ Do you have any update on that topic? Or need some help to further investigate on that?

sorry @mayeroa for getting back to you late, thanks for further explaining your use case.

like @omry suspected earlier, it does seem to be related to how the two classes are imported, in particular, the application fails at merging the configs here, although the two Conf classes have the same type, the issubclass return False.

I tried to move the ConfigStore call to the __init__.py like the following

a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/__init__.py
+++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/__init__.py
@@ -1,3 +1,14 @@
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
 
 __version__ = "1.2.0dev1"
+
+from hydra_plugins.hydra_optuna_sweeper.config import OptunaSweeperConf
+
+from hydra.core.config_store import ConfigStore
+
+ConfigStore.instance().store(
+    group="hydra/sweeper",
+    name="optuna",
+    node=OptunaSweeperConf,
+    provider="optuna_sweeper",
+)

which seems to resolve the issue, and the config override worked running @Jasha10โ€™s example above

$ python /Users/jieru/workspace/hydra-fork/hydra/examples/tutorials/structured_configs/1_minimal/my_app.py --cfg hydra -p hydra.sweeper.n_trials
100

Iโ€™m not sure if this is the best way to fix this though, @omry thoughts?

This is probably related to how plugins are loaded discovered and loaded. There are two different OptunaSweeperConf classes loaded.

hi @mayeroa pls try using the defaults list instead. something like the following should work

@dataclass
class ExampleConfig():
    defaults: List[Any] = field(
        default_factory=lambda: [
            {"override hydra/sweeper": "optuna"},
        ]
    )

Run the application, you should see optunaโ€™s config

$python my_app.py --cfg hydra -p hydra.sweeper
# @package hydra.sweeper
sampler:
  _target_: optuna.samplers.TPESampler
  seed: null
  consider_prior: true
  prior_weight: 1.0
  consider_magic_clip: true
  consider_endpoints: false
  n_startup_trials: 10
  n_ei_candidates: 24
  multivariate: false
  warn_independent_sampling: true
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
direction: minimize
storage: null
study_name: null
n_trials: 20
n_jobs: 2
search_space: {}

Also the cfg object you get via hydra.main does not contain Hydraโ€™s config, to access the configs, try something like the following

@hydra.main(config_path=None, config_name='config')
def main(config: DictConfig) -> None:
    """Main function."""
    print(HydraConfig.get().sweeper)