tensorflow: Keras scikit-learn wrapper not compatible with keras functional model

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • TensorFlow version (use command below): TF version: 2.0, 1.15
  • Python version: 3.6 , 3.7

Describe the current behavior when using a keras functional api model via the keras scikit-learn wrapper a crash occurs. see: https://github.com/tensorflow/tensorflow/blob/13f2db1e7071ae109d2f51c7202867a154f587d2/tensorflow/python/keras/wrappers/scikit_learn.py#L241 Describe the expected behavior model.predict() should work on all keras model types besides sequential

Code to reproduce the issue

import numpy as np
import tensorflow as tf
import tensorflow.keras
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier

def build_model():
  input = tf.keras.layers.Input(shape=(2,))
  pred = tf.keras.layers.Dense(2, activation='softmax')(input)
  model = tf.keras.models.Model(inputs=input, outputs=pred)
  model.compile(loss='categorical_crossentropy', metrics=['accuracy'])
  return model

X = np.array([[1,2],[3,1]])
Y = np.array([[1,0], [0,1]])
model = build_model()
model.fit(X, Y)
print(model.predict(X))  # this works

model_wrapped = KerasClassifier(build_model)
model_wrapped.fit(X, Y)
model_wrapped.predict(X)  # this crashes

Output: Train on 2 samples 2/2 [==============================] - 0s 62ms/sample - loss: 1.1024 - acc: 0.5000 [[0.62487346 0.37512657] [0.8205698 0.17943017]] Train on 2 samples 2/2 [==============================] - 0s 64ms/sample - loss: 0.2733 - acc: 1.0000

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-14-48bacae97b80> in <module>()
     19 model_wrapped = KerasClassifier(build_model)
     20 model_wrapped.fit(X, Y)
---> 21 model_wrapped.predict(X)  # this crashes
     22 
     23 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/wrappers/scikit_learn.py in predict(self, x, **kwargs)
    239     """
    240     kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
--> 241     classes = self.model.predict_classes(x, **kwargs)
    242     return self.classes_[classes]
    243 

AttributeError: 'Model' object has no attribute 'predict_classes'

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 19 (7 by maintainers)

Most upvoted comments


ModuleNotFoundError Traceback (most recent call last) Cell In[14], line 3 1 import keras 2 import tensorflow ----> 3 from tensorflow.keras.wrappers.scikit_learn import KerasClassifier

ModuleNotFoundError: No module named ‘tensorflow.keras.wrappers’

Hello, I have successfully installed tensorflow but still python showing this issue: from tensorflow.keras.wrappers.scikit_learn import KerasClassifier ModuleNotFoundError: No module named ‘tensorflow.keras.wrappers’ could any one help in this matter? thanks!

@karimmohraz, Can you please confirm if we can close this issue with respect to Adrian’s comment? Thanks!

Oops, it seems I missed part of above discussions. Great work, thank you!

This Python package seems to support also the functional API: https://pypi.org/project/scikeras/

Yep it does! It also supports subclassed models and quite a few other things.

This Python package seems to support also the functional API: https://pypi.org/project/scikeras/

Hi everyone,

From discussion with the Keras team, it looks like #37201 may not be merged and instead the functionality may be made into a separate package that will be easier to maintain and have more flexibility. The wrappers would then be deprecated from Keras/TF. This is not final, but I created a package to test out the idea.

Links: PyPi: https://pypi.org/project/scikeras Source: https://github.com/adriangb/scikeras

Please take a look and let me know if this satisfies the use cases in this PR. This should allow for input splitting/output joining similar to the gist that @karimmohraz references above, except it is done in a more programmatic way.

Any input is welcome.