gradient-checkpointing: Does Not Work with Keras

@yaroslavvb Would you please add keras model.fit_generator to your test cases? I notice the keras test case is a simple MNIST model that does not use convolutional layers either. As an example for me, on tensorflow 1.5-gpu with keras 2.1.6 and python 3.5 x64-bit on a Windows 10 machine, I cannot get the following to work (i.e. memory used and time per epoch is the same with or without memory_saving_gradients code):

# -*- coding: utf-8 -*-

##########
#LIBRARIES
##########

#Future
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import pandas as pd

pd.set_option('chained_assignment',None) #Sets `SettingWithCopyWarning` to None. If
                                         # making a chained assignment, the outcome may
                                         # vary depnding on if the data is a view of
                                         # other data or a copy of other data.

import cv2

import os
import time
import argparse
import h5py
import gc

import multiprocessing as mp

import tensorflow as tf
from tensorflow.python.keras._impl.keras import backend as K

from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.batching import map_and_batch

import memory_saving_gradients

Dataset = tf.data.Dataset

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.python.keras.models import Sequential, Model, load_model, model_from_yaml
from tensorflow.python.keras.callbacks import LearningRateScheduler, ModelCheckpoint, EarlyStopping, History, TensorBoard
from tensorflow.python.keras import regularizers, optimizers
from tensorflow.python.keras.layers import Conv2D, Dense, Flatten, Dropout, Input, Lambda, Activation

##################
#GLOBAL VARIABLES
##################

img_shape_raw = (3, 160, 320)

batch_size = 32

num_epochs = 1

crop_top = 70
crop_btm = 25

img_format = 'channels_first'
K.set_image_data_format(img_format)

img_shape_input = (img_shape_raw[0],
                   img_shape_raw[1] - crop_top - crop_btm,
                   img_shape_raw[2]) #(3, 65, 320)

max_procs = mp.cpu_count() - 1 or 1 # 4 physical cores, 8 logical cores
max_q_size = batch_size

root = r'.'

fldr_img_raw = os.path.join( root, r'dat\raw' )
fldr_csv_raw = os.path.join( root, r'dat\raw' )

fldr_img_mod = os.path.join( root, r'dat\mod' )
fldr_csv_mod = os.path.join( root, r'dat\mod' )

train_csv = os.path.join(fldr_csv_mod, 'training_data.csv')
val_csv = os.path.join(fldr_csv_mod, 'validation_data.csv')
test_csv = os.path.join(fldr_csv_mod, 'test_data.csv')

pth_bins_fl = os.path.join( fldr_csv_mod, 'bins.txt' )

fldr_fig = os.path.join( root, r'fig' )

lr = [1e-4, ]
run = [1, ]

hparam_str = ['1e-4', ]

fldr_log = os.path.join( root, r'log', hparam_str[0], 'run_{:04d}'.format(run[0]))

fldr_arch = os.path.join( root, r'arch' )
fldr_wt = os.path.join( root, r'wt' )
fldr_ckpt = os.path.join( root, r'ckpt' )
fldr_mdl = os.path.join( root, r'mdl' )

fldr_summary = os.path.join( root, r'summary' )

fl_fmt_wt_ckpt = os.path.join( fldr_ckpt,
                               r'wt_ckpt-run_{run:04d}'.format(run=run[0]) + '_epoch_{epoch:04d}_val_mse_{val_mean_squared_error:.7f}.h5' )

################
#DATA GENERATOR
################

def get_data( keep_ptl = 75 ):
    '''This just returns the train, validation, and test dataframes
       keeping a certain percentile of the original data. I'm not
       including it here for space and since it doesn't seem pertinent.
    '''

def generator_from_df( df, batch_size, shuffle = True ):
    
    def read( img_pth, angle ):
        
        im_fl = tf.read_file( img_pth )
        im = tf.image.decode_image(im_fl, channels=3)
        im = tf.transpose( im, [2, 0, 1] ) # Make image channels first

        return Dataset.from_tensors( (im, angle) )

    img_pths = tf.convert_to_tensor( df['Image_Path'].values )
    angs = tf.convert_to_tensor( df['Angle'].values )

    ds = Dataset.from_tensor_slices( (img_pths, angs) )

    ds = ds.apply( tf.contrib.data.parallel_interleave( read, cycle_length = batch_size, sloppy = True ) )

    if shuffle:
        ds = ds.apply( shuffle_and_repeat( buffer_size = 2*batch_size, count = num_epochs ) )
    else:
        ds = ds.repeat( num_epochs )

    ds = ds.apply( map_and_batch(
        lambda img_pth, ang: (img_pth,ang),
        batch_size,
        num_parallel_batches = max_procs ) )
    
    ds = ds.prefetch( max_procs )

    iterator = ds.make_one_shot_iterator()
    sess = K.get_session()

    next_element = iterator.get_next()

    while True:

        try:
          yield sess.run(next_element)
        except tf.errors.OutOfRangeError:
          break

###########
#GET MODEL
###########

def get_model( lr ):

    keep_prob = 0.5
    rate = keep_prob
    
    l2 = regularizers.l2(0.001)

    with tf.name_scope('Input'):
        inputs = Input( shape=img_shape_input, name='input' )

        x = Lambda(lambda x: x / 255. - 0.5,
                   input_shape=img_shape_input, name = 'norm_-0.5_to_0.5')(inputs)

    with tf.name_scope('Hidden_Layers'):

        with K.name_scope('ConvLayer_01'):
        
            x = Conv2D(4, (5,5),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv01')(x)

        with tf.name_scope('ConvLayer_02'):
        
            x = Conv2D(12, (5,5),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv02')(x)

        with tf.name_scope('ConvLayer_03'):
        
            x = Conv2D(24, (5,5),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv03')(x)

        with tf.name_scope('ConvLayer_04'):
        
            x = Conv2D(24, (3,3),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv04')(x)

        with tf.name_scope('ConvLayer_05'):
        
            x = Conv2D(32, (3,3),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv05')(x)

        with tf.name_scope('Flatten'):
        
            x = Flatten(name='flatten')(x)

        with tf.name_scope('FullyConnectedLayer_01'):
                
            x = Dense(100,
                      kernel_regularizer=l2,
                      bias_regularizer=l2,
                      name='fc01')(x)

        with tf.name_scope('FullyConnectedLayer_02'):
        
            x = Dense(50,
                      kernel_regularizer=l2,
                      bias_regularizer=l2,
                      name='fc02')(x)

        with tf.name_scope('FullyConnectedLayer_03'):

            x = Dense(25,
                      kernel_regularizer=l2,
                      bias_regularizer=l2,
                      name='fc03')(x)

        with tf.name_scope('FullyConnectedLayer_04'):
        
            x = Dense(10,
                      kernel_regularizer=l2,
                      bias_regularizer=l2,
                      name='fc04')(x)

    with tf.name_scope('Output'):
    
        outputs = Dense(1,
                        name='output')(x)

    # Create Model
        
    model = Model( inputs = inputs, outputs = outputs )

    adam = optimizers.Adam( lr = lr, decay = 0.001 ) # Learning rate and decay set in LearningRateScheduler

    # Memory Saving Gradients

    layer_names = [ 'conv02', 'conv04', 'fc01', 'fc03' ]

    [tf.add_to_collection('checkpoints', model.get_layer(l).get_output_at(0))
     for l in layer_names]
    
    K.__dict__['gradients'] = memory_saving_gradients.gradients_collection

    # Compile Model

    model.compile(loss='mean_squared_error', optimizer=adam, metrics=['mse'])

    return model

class CumulativeHistory( History ):
    '''
    History does not allow resume history, but this does.
    '''
    def on_train_begin( self, logs=None ):
        if not hasattr(self, 'epoch'):
            super(CumulativeHistory, self).on_train_begin( logs )

def main(*args, **kargs):
    """ Behavioral Cloning Project
    """

    parser = argparse.ArgumentParser(description='Behavioral Cloning Project')

    parser.add_argument('-c', '--checkpoint', type=str, help='Checkpoint (`.h5` file)')
    parser.add_argument('-e', '--epoch', type=int, help='Initial epoch')
    
    args = parser.parse_args()

    model_type = 'new'
    train_model = None
    initial_epoch = 0

    if args.checkpoint is not None:

        train_model = load_model( args.checkpoint )

        initial_epoch = args.epoch

        model_type = 'loaded'

    # Set Configuration

    config = tf.ConfigProto( intra_op_parallelism_threads = max_procs,
                             inter_op_parallelism_threads = 0) # set automatically to number of logical cores

    config.gpu_options.allow_growth = True

    # Get Data

    df_train, df_val, df_test, bins = get_data( keep_ptl = 60 )
    
    ntrain, nval, ntest = df_train.shape[0], df_val.shape[0], df_test.shape[0]

    # Training

    train_graph = tf.Graph()

    train_generator = generator_from_df( df_train, batch_size )
    val_generator   = generator_from_df( df_val,   batch_size, shuffle=False )

    nbatches_train = ntrain // batch_size
    nbatches_val   = nval // batch_size
    
    history = CumulativeHistory()
    
    early_stop = EarlyStopping( monitor='val_mean_squared_error',
                                min_delta=1e-4,
                                patience=50,
                                verbose=0,
                                mode='min')
    
    model_ckpt = ModelCheckpoint( fl_fmt_wt_ckpt,
                                  monitor='val_mean_squared_error',
                                  verbose=0,
                                  save_best_only=True,
                                  save_weights_only=True,
                                  period=1)
    
    callbacks = [history, early_stop, model_ckpt]

    for i in range(len(lr)):

        train_sess = tf.Session( config = config, graph = train_graph )
        K.set_session( train_sess )

        if model_type == 'new':
            
            with train_graph.as_default():

                # Print model summary
                summary_fl_pth = os.path.join( fldr_summary, 'model_summary_run_{:04d}_'.format(run[0]) + r'.txt' )

                train_model = get_model( lr[i], is_training = True )

                with open(summary_fl_pth, 'w') as summary_file:
                    train_model.summary( print_fn=lambda x: summary_file.write(x + '\n') )

        with train_graph.as_default():
            
            with train_sess.as_default():

                if K.backend() == 'tensorflow':
                    
                    board = TensorBoard( log_dir = fldr_log,
                                         histogram_freq = 0,
                                         write_graph = True,
                                         write_images = True )
                    callbacks.append( board )

                writer = tf.summary.FileWriter( fldr_log, train_graph )

                ts = time.time()
                ts = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')

                arch_yaml = train_model.to_yaml()
                arch_fl_pth = os.path.join( fldr_arch, 'arch_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.yaml' )

                with open(arch_fl_pth, 'w') as arch_file:
                    arch_file.write( arch_yaml )
                
                train_model.save( os.path.join( fldr_mdl,
                                                'model_init_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.h5') )

                train_model.save_weights( os.path.join( fldr_wt,
                                                        'weights_init_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts  + '.h5' ) )

                train_model.fit_generator(
                    generator = train_generator,
                    steps_per_epoch = nbatches_train,
                    epochs = num_epochs,
                    max_queue_size = max_q_size,
                    validation_data = val_generator,
                    validation_steps = nbatches_val,
                    workers = 0,
                    callbacks = callbacks,
                    initial_epoch = initial_epoch)

                ts = time.time()
                ts = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')

                train_model.save( os.path.join( fldr_mdl,
                                                'model_final_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.h5') )

                train_model.save_weights( os.path.join( fldr_wt,
                                                        'weights_final_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts  + '.h5' ) )
                
        if K.backend() == 'tensorflow':
            K.clear_session()

        del train_model
        gc.collect()

if __name__ == '__main__':
    """ Entry point to the program
    """

    main()

About this issue

  • Original URL
  • State: open
  • Created 5 years ago
  • Reactions: 1
  • Comments: 66 (1 by maintainers)

Most upvoted comments

@yaroslavvb @TimSalimans @christopherhesse @cberner @davidBelanger Can someone PLEASE HELP ME?! Please. I’ll do whatever suggestions you have.

@yaroslavvb @TimSalimans @christopherhesse @cberner @davidBelanger Please, if this doesn’t work with keras, update the README.md to say it does not work with keras. I would REALLY like to get this to work, but I need your help. I’ve provided a relatively straightforward use case for keras above and am happy to provide training data or anything else that is necessary. I’m also more than happy to work through any suggestions you might have to get the code to work.