tensorflow: Unable to use keras.application within TF Estimator

I am having problems using keras.applications within an estimator. After much trying out myself, I came up with a minimal version to indicate the problem. I presume that it is caused by keras only having one session at a time while estimators set up several sessions for training and evaluation.

This is my original post from stackoverflow

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes, see below
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 1.12.0
  • Python version: 3.6.7

Describe the current behavior The estimator apparently does not learn the data, which can be observed by a constant loss (in evaluation) and a stagnant loss during training.

Describe the expected behavior The estimator should learn the MNIST dataset at least as good from the model features as from the raw data.

Code to reproduce the issue

import tensorflow as tf
import numpy as np

from keras.datasets import mnist


# switch to example 1/2/3
EXAMPLE_CASE = 1

# flag for initial weights loading of keras model
_W_INIT = True


def dense_net(features, labels, mode, params):
    # --- code to load a keras application ---

    # commenting in this line leads to a bump in the loss everytime the
    # evaluation is run, this indicating that keras does not handle well the
    # two sessions of the estimator API
    # tf.keras.backend.set_learning_phase(mode == tf.estimator.ModeKeys.TRAIN)
    global _W_INIT

    model = tf.keras.applications.MobileNet(
        input_tensor=features,
        input_shape=(128, 128, 3),
        include_top=False,
        pooling='avg',
        weights='imagenet' if _W_INIT else None)

    # only initialize weights once
    if _W_INIT:
        _W_INIT = False

    # switch cases
    if EXAMPLE_CASE == 1:
        # model.output is the same as model.layers[-1].output
        img = model.layers[-1].output
    elif EXAMPLE_CASE == 2:
        img = model(features)
    elif EXAMPLE_CASE == 3:
        # do not use keras features
        img = tf.keras.layers.Flatten()(features)
    else:
        raise NotImplementedError

    # --- regular code from here on ---
    for units in params['dense_layers']:
        img = tf.keras.layers.Dense(units=units, activation='relu')(img)

    logits = tf.keras.layers.Dense(units=10,
                                   activation='relu')(img)

    # compute predictions
    probs = tf.nn.softmax(logits)
    predicted_classes = tf.argmax(probs, 1)

    # compute loss
    loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)

    acc = tf.metrics.accuracy(labels, predicted_classes)
    metrics = {'accuracy': acc}
    tf.summary.scalar('accuarcy', acc[1])

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode, loss=loss, eval_metric_ops=metrics)

    # create training operation
    assert mode == tf.estimator.ModeKeys.TRAIN

    optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)


def prepare_dataset(in_tuple, n):
    feats = in_tuple[0][:n, :, :]
    labels = in_tuple[1][:n]
    feats = feats.astype(np.float32)
    feats /= 255
    labels = labels.astype(np.int32)
    return (feats, labels)


def _parse_func(features, labels):
    feats = tf.expand_dims(features, -1)
    feats = tf.image.grayscale_to_rgb(feats)
    feats = tf.image.resize_images(feats, (128, 128))
    return (feats, labels)


def load_mnist(n_train=10000, n_test=3000):
    train, test = mnist.load_data()
    train = prepare_dataset(train, n_train)
    test = prepare_dataset(test, n_test)
    return train, test


def train_input_fn(imgs, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((imgs, labels))
    dataset = dataset.map(_parse_func)
    dataset = dataset.shuffle(500)
    dataset = dataset.repeat().batch(batch_size)
    return dataset


def eval_input_fn(imgs, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((imgs, labels))
    dataset = dataset.map(_parse_func)
    dataset = dataset.batch(batch_size)
    return dataset


def main(m_dir=None):
    # fetch data
    (x_train, y_train), (x_test, y_test) = load_mnist()

    train_spec = tf.estimator.TrainSpec(
        input_fn=lambda: train_input_fn(
            x_train, y_train, 30),
        max_steps=150)

    eval_spec = tf.estimator.EvalSpec(
        input_fn=lambda: eval_input_fn(
            x_test, y_test, 30),
        steps=100,
        start_delay_secs=0,
        throttle_secs=0)

    run_cfg = tf.estimator.RunConfig(
        model_dir=m_dir,
        tf_random_seed=2,
        save_summary_steps=2,
        save_checkpoints_steps=10,
        keep_checkpoint_max=1)

    # build network
    classifier = tf.estimator.Estimator(
        model_fn=dense_net,
        params={
            'dense_layers': [256]},
        config=run_cfg)

    # fit the model
    tf.estimator.train_and_evaluate(
        classifier,
        train_spec,
        eval_spec)


if __name__ == "__main__":
    main()

Other info / logs

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 22 (5 by maintainers)

Most upvoted comments

I think this should be a bug, how can it be that you cannot train an estimator using a Keras application? This is a super common use-case. We’ve had to resort to saving the Keras application as a saved model and loading it as a predictor in the input_fn, to deploy we’ve had to merge the two saved models together in a separate script.

All this doesn’t make sense. If TFHub enabled you to have variable image input size all would be fine, but it doesn’t, that leaves keras as the only source for easy-to-use general purpose pretrained models, but if you cannot fine tune them using the estimator then the high-level API of Tensorflow is broken.

This is not Build/Installation or Bug/Performance issue. Please post this kind of support questions at Stackoverflow. There is a big community to support and learn from your questions. GitHub is mainly for addressing bugs in installation and performance. I have seen your clear problem description on Stackoverflow and I think community will respond to your question soon. Thanks!