mlflow: [BUG] keras.log_model tensorflow_probabilistic raises an error

Willingness to contribute

Yes. I would be willing to contribute a fix for this bug with guidance from the MLflow community.

MLflow version

1.26.0

System information

  • MacOS Monterey with M1 Chip
  • 3.8.13

Describe the problem

I have a keras model that has a tfp.layers.DistributionLambda(tfd.Poisson)(input) (https://www.tensorflow.org/probability/api_docs/python/tfp/layers/DistributionLambda, https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Poisson for reference).

In order to store the model locally, I save it just by doing: model.save(destination_path). Then, to load it, I simply do:

import keras
import tensorflow as tf
keras.models.load_model(file_path, compile=False, custom_objects={'exp': tf.exp})

Instead, when I try to log it to mlflow by doing: mlflow.keras.log_model(keras_model, path) I get the following error:

Traceback (most recent call last):
  File "/Users/user/dev/project/src/utils/mlflow_utils.py", line 54, in log
    model_info = mlflow.keras.log_model(model, 'model', keras_module=keras, custom_objects={'exp': tf.exp})
  File "/Users/user/miniforge3/envs/py38/lib/python3.8/site-packages/mlflow/keras.py", line 416, in log_model
    return Model.log(
  File "/Users/user/miniforge3/envs/py38/lib/python3.8/site-packages/mlflow/models/model.py", line 294, in log
    flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
  File "/Users/user/miniforge3/envs/py38/lib/python3.8/site-packages/mlflow/keras.py", line 268, in save_model
    keras_model.save(model_path, **kwargs)
  File "/Users/user/miniforge3/envs/py38/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/user/miniforge3/envs/py38/lib/python3.8/contextlib.py", line 120, in __exit__
    next(self.gen)
  File "/Users/user/miniforge3/envs/py38/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 224, in __call__
    distribution, _ = super(DistributionLambda, self).__call__(
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
python-BaseException

I also tried by adding the custom_objects parameter (so by doing mlflow.keras.log_model(keras_model, path, custom_objects={'exp': tf.exp}), but I still get the same error.

Could you please help me?

Tracking information

No response

Code to reproduce issue


import mlflow
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.datasets import load_iris

tfd = tfp.distributions

X, y = load_iris(return_X_y=True)

inputs = tf.keras.layers.Input(shape=(X.shape[1],))
x = tf.keras.layers.Dense(30, "relu")(inputs)
output = tf.keras.layers.Dense(1, activation=tf.exp)(x)
p_y = tfp.layers.DistributionLambda(tfd.Poisson)(output)
model = tf.keras.Model(inputs=inputs, outputs=p_y)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)

def nll(y_true, y_pred):
    return -y_pred.log_prob(y_true)

model.compile(optimizer, loss=nll)
model.fit(x=X.astype(float), y=y.astype(float), epochs=5)


active_experiment = "active-custom-experiment"
mlflow.set_experiment(active_experiment)
with mlflow.start_run():
    mlflow.keras.log_model(model, "model")

Stack trace

---------------------------------------------------------------------------
OperatorNotAllowedInGraphError            Traceback (most recent call last)
Input In [25], in <cell line: 2>()
      1 mlflow.set_experiment(active_experiment)
      2 with mlflow.start_run():
----> 3     mlflow.keras.log_model(model, 'model')

File ~/miniforge3/envs/py38/lib/python3.8/site-packages/mlflow/keras.py:416, in log_model(keras_model, artifact_path, conda_env, code_paths, custom_objects, keras_module, registered_model_name, signature, input_example, await_registration_for, pip_requirements, extra_pip_requirements, **kwargs)
    404 if signature is not None:
    405     warnings.warn(
    406         "The pyfunc inference behavior of Keras models logged "
    407         "with signatures differs from the behavior of Keras "
   (...)
    414         "a Pandas DataFrame output in response to a Pandas DataFrame input."
    415     )
--> 416 return Model.log(
    417     artifact_path=artifact_path,
    418     flavor=mlflow.keras,
    419     keras_model=keras_model,
    420     conda_env=conda_env,
    421     code_paths=code_paths,
    422     custom_objects=custom_objects,
    423     keras_module=keras_module,
    424     registered_model_name=registered_model_name,
    425     signature=signature,
    426     input_example=input_example,
    427     await_registration_for=await_registration_for,
    428     pip_requirements=pip_requirements,
    429     extra_pip_requirements=extra_pip_requirements,
    430     **kwargs,
    431 )

File ~/miniforge3/envs/py38/lib/python3.8/site-packages/mlflow/models/model.py:294, in Model.log(cls, artifact_path, flavor, registered_model_name, await_registration_for, **kwargs)
    292 run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
    293 mlflow_model = cls(artifact_path=artifact_path, run_id=run_id)
--> 294 flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
    295 mlflow.tracking.fluent.log_artifacts(local_path, artifact_path)
    296 try:

File ~/miniforge3/envs/py38/lib/python3.8/site-packages/mlflow/keras.py:268, in save_model(keras_model, path, conda_env, code_paths, mlflow_model, custom_objects, keras_module, signature, input_example, pip_requirements, extra_pip_requirements, **kwargs)
    266         shutil.copyfile(src=f.name, dst=model_path)
    267 else:
--> 268     keras_model.save(model_path, **kwargs)
    270 # update flavor info to mlflow_model
    271 mlflow_model.add_flavor(
    272     FLAVOR_NAME,
    273     keras_module=keras_module.__name__,
   (...)
    277     code=code_dir_subpath,
    278 )

File ~/miniforge3/envs/py38/lib/python3.8/site-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     65 except Exception as e:  # pylint: disable=broad-except
     66   filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67   raise e.with_traceback(filtered_tb) from None
     68 finally:
     69   del filtered_tb

File ~/miniforge3/envs/py38/lib/python3.8/contextlib.py:120, in _GeneratorContextManager.__exit__(self, type, value, traceback)
    118 if type is None:
    119     try:
--> 120         next(self.gen)
    121     except StopIteration:
    122         return False

File ~/miniforge3/envs/py38/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py:224, in DistributionLambda.__call__(self, inputs, *args, **kwargs)
    222 def __call__(self, inputs, *args, **kwargs):
    223   self._enter_dunder_call = True
--> 224   distribution, _ = super(DistributionLambda, self).__call__(
    225       inputs, *args, **kwargs)
    226   self._enter_dunder_call = False
    227   return distribution

OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

Other info / logs

No response

What component(s) does this bug affect?

  • area/artifacts: Artifact stores and artifact logging
  • area/build: Build and test infrastructure for MLflow
  • area/docs: MLflow documentation pages
  • area/examples: Example code
  • area/model-registry: Model Registry service, APIs, and the fluent client calls for Model Registry
  • area/models: MLmodel format, model serialization/deserialization, flavors
  • area/pipelines: Pipelines, Pipeline APIs, Pipeline configs, Pipeline Templates
  • area/projects: MLproject format, project running backends
  • area/scoring: MLflow Model server, model deployment tools, Spark UDFs
  • area/server-infra: MLflow Tracking server backend
  • area/tracking: Tracking Service, tracking client APIs, autologging

What interface(s) does this bug affect?

  • area/uiux: Front-end, user experience, plotting, JavaScript, JavaScript dev server
  • area/docker: Docker use across MLflow’s components, such as MLflow Projects and MLflow Models
  • area/sqlalchemy: Use of SQLAlchemy in the Tracking Service or Model Registry
  • area/windows: Windows support

What language(s) does this bug affect?

  • language/r: R APIs and clients
  • language/java: Java APIs and clients
  • language/new: Proposals for new client languages

What integration(s) does this bug affect?

  • integrations/azure: Azure and Azure ML integrations
  • integrations/sagemaker: SageMaker integrations
  • integrations/databricks: Databricks integrations

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Comments: 15 (7 by maintainers)

Most upvoted comments

@BenWilson2 @dbczumar @harupy @WeichenXu123 Please assign a maintainer and start triaging this issue.