gradient-checkpointing: Does Not Work with Keras
@yaroslavvb Would you please add keras model.fit_generator to your test cases? I notice the keras test case is a simple MNIST model that does not use convolutional layers either. As an example for me, on tensorflow 1.5-gpu with keras 2.1.6 and python 3.5 x64-bit on a Windows 10 machine, I cannot get the following to work (i.e. memory used and time per epoch is the same with or without memory_saving_gradients code):
# -*- coding: utf-8 -*-
##########
#LIBRARIES
##########
#Future
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import pandas as pd
pd.set_option('chained_assignment',None) #Sets `SettingWithCopyWarning` to None. If
# making a chained assignment, the outcome may
# vary depnding on if the data is a view of
# other data or a copy of other data.
import cv2
import os
import time
import argparse
import h5py
import gc
import multiprocessing as mp
import tensorflow as tf
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.batching import map_and_batch
import memory_saving_gradients
Dataset = tf.data.Dataset
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.python.keras.models import Sequential, Model, load_model, model_from_yaml
from tensorflow.python.keras.callbacks import LearningRateScheduler, ModelCheckpoint, EarlyStopping, History, TensorBoard
from tensorflow.python.keras import regularizers, optimizers
from tensorflow.python.keras.layers import Conv2D, Dense, Flatten, Dropout, Input, Lambda, Activation
##################
#GLOBAL VARIABLES
##################
img_shape_raw = (3, 160, 320)
batch_size = 32
num_epochs = 1
crop_top = 70
crop_btm = 25
img_format = 'channels_first'
K.set_image_data_format(img_format)
img_shape_input = (img_shape_raw[0],
img_shape_raw[1] - crop_top - crop_btm,
img_shape_raw[2]) #(3, 65, 320)
max_procs = mp.cpu_count() - 1 or 1 # 4 physical cores, 8 logical cores
max_q_size = batch_size
root = r'.'
fldr_img_raw = os.path.join( root, r'dat\raw' )
fldr_csv_raw = os.path.join( root, r'dat\raw' )
fldr_img_mod = os.path.join( root, r'dat\mod' )
fldr_csv_mod = os.path.join( root, r'dat\mod' )
train_csv = os.path.join(fldr_csv_mod, 'training_data.csv')
val_csv = os.path.join(fldr_csv_mod, 'validation_data.csv')
test_csv = os.path.join(fldr_csv_mod, 'test_data.csv')
pth_bins_fl = os.path.join( fldr_csv_mod, 'bins.txt' )
fldr_fig = os.path.join( root, r'fig' )
lr = [1e-4, ]
run = [1, ]
hparam_str = ['1e-4', ]
fldr_log = os.path.join( root, r'log', hparam_str[0], 'run_{:04d}'.format(run[0]))
fldr_arch = os.path.join( root, r'arch' )
fldr_wt = os.path.join( root, r'wt' )
fldr_ckpt = os.path.join( root, r'ckpt' )
fldr_mdl = os.path.join( root, r'mdl' )
fldr_summary = os.path.join( root, r'summary' )
fl_fmt_wt_ckpt = os.path.join( fldr_ckpt,
r'wt_ckpt-run_{run:04d}'.format(run=run[0]) + '_epoch_{epoch:04d}_val_mse_{val_mean_squared_error:.7f}.h5' )
################
#DATA GENERATOR
################
def get_data( keep_ptl = 75 ):
'''This just returns the train, validation, and test dataframes
keeping a certain percentile of the original data. I'm not
including it here for space and since it doesn't seem pertinent.
'''
def generator_from_df( df, batch_size, shuffle = True ):
def read( img_pth, angle ):
im_fl = tf.read_file( img_pth )
im = tf.image.decode_image(im_fl, channels=3)
im = tf.transpose( im, [2, 0, 1] ) # Make image channels first
return Dataset.from_tensors( (im, angle) )
img_pths = tf.convert_to_tensor( df['Image_Path'].values )
angs = tf.convert_to_tensor( df['Angle'].values )
ds = Dataset.from_tensor_slices( (img_pths, angs) )
ds = ds.apply( tf.contrib.data.parallel_interleave( read, cycle_length = batch_size, sloppy = True ) )
if shuffle:
ds = ds.apply( shuffle_and_repeat( buffer_size = 2*batch_size, count = num_epochs ) )
else:
ds = ds.repeat( num_epochs )
ds = ds.apply( map_and_batch(
lambda img_pth, ang: (img_pth,ang),
batch_size,
num_parallel_batches = max_procs ) )
ds = ds.prefetch( max_procs )
iterator = ds.make_one_shot_iterator()
sess = K.get_session()
next_element = iterator.get_next()
while True:
try:
yield sess.run(next_element)
except tf.errors.OutOfRangeError:
break
###########
#GET MODEL
###########
def get_model( lr ):
keep_prob = 0.5
rate = keep_prob
l2 = regularizers.l2(0.001)
with tf.name_scope('Input'):
inputs = Input( shape=img_shape_input, name='input' )
x = Lambda(lambda x: x / 255. - 0.5,
input_shape=img_shape_input, name = 'norm_-0.5_to_0.5')(inputs)
with tf.name_scope('Hidden_Layers'):
with K.name_scope('ConvLayer_01'):
x = Conv2D(4, (5,5),
kernel_regularizer=l2,
bias_regularizer=l2,
padding='same',
name='conv01')(x)
with tf.name_scope('ConvLayer_02'):
x = Conv2D(12, (5,5),
kernel_regularizer=l2,
bias_regularizer=l2,
padding='same',
name='conv02')(x)
with tf.name_scope('ConvLayer_03'):
x = Conv2D(24, (5,5),
kernel_regularizer=l2,
bias_regularizer=l2,
padding='same',
name='conv03')(x)
with tf.name_scope('ConvLayer_04'):
x = Conv2D(24, (3,3),
kernel_regularizer=l2,
bias_regularizer=l2,
padding='same',
name='conv04')(x)
with tf.name_scope('ConvLayer_05'):
x = Conv2D(32, (3,3),
kernel_regularizer=l2,
bias_regularizer=l2,
padding='same',
name='conv05')(x)
with tf.name_scope('Flatten'):
x = Flatten(name='flatten')(x)
with tf.name_scope('FullyConnectedLayer_01'):
x = Dense(100,
kernel_regularizer=l2,
bias_regularizer=l2,
name='fc01')(x)
with tf.name_scope('FullyConnectedLayer_02'):
x = Dense(50,
kernel_regularizer=l2,
bias_regularizer=l2,
name='fc02')(x)
with tf.name_scope('FullyConnectedLayer_03'):
x = Dense(25,
kernel_regularizer=l2,
bias_regularizer=l2,
name='fc03')(x)
with tf.name_scope('FullyConnectedLayer_04'):
x = Dense(10,
kernel_regularizer=l2,
bias_regularizer=l2,
name='fc04')(x)
with tf.name_scope('Output'):
outputs = Dense(1,
name='output')(x)
# Create Model
model = Model( inputs = inputs, outputs = outputs )
adam = optimizers.Adam( lr = lr, decay = 0.001 ) # Learning rate and decay set in LearningRateScheduler
# Memory Saving Gradients
layer_names = [ 'conv02', 'conv04', 'fc01', 'fc03' ]
[tf.add_to_collection('checkpoints', model.get_layer(l).get_output_at(0))
for l in layer_names]
K.__dict__['gradients'] = memory_saving_gradients.gradients_collection
# Compile Model
model.compile(loss='mean_squared_error', optimizer=adam, metrics=['mse'])
return model
class CumulativeHistory( History ):
'''
History does not allow resume history, but this does.
'''
def on_train_begin( self, logs=None ):
if not hasattr(self, 'epoch'):
super(CumulativeHistory, self).on_train_begin( logs )
def main(*args, **kargs):
""" Behavioral Cloning Project
"""
parser = argparse.ArgumentParser(description='Behavioral Cloning Project')
parser.add_argument('-c', '--checkpoint', type=str, help='Checkpoint (`.h5` file)')
parser.add_argument('-e', '--epoch', type=int, help='Initial epoch')
args = parser.parse_args()
model_type = 'new'
train_model = None
initial_epoch = 0
if args.checkpoint is not None:
train_model = load_model( args.checkpoint )
initial_epoch = args.epoch
model_type = 'loaded'
# Set Configuration
config = tf.ConfigProto( intra_op_parallelism_threads = max_procs,
inter_op_parallelism_threads = 0) # set automatically to number of logical cores
config.gpu_options.allow_growth = True
# Get Data
df_train, df_val, df_test, bins = get_data( keep_ptl = 60 )
ntrain, nval, ntest = df_train.shape[0], df_val.shape[0], df_test.shape[0]
# Training
train_graph = tf.Graph()
train_generator = generator_from_df( df_train, batch_size )
val_generator = generator_from_df( df_val, batch_size, shuffle=False )
nbatches_train = ntrain // batch_size
nbatches_val = nval // batch_size
history = CumulativeHistory()
early_stop = EarlyStopping( monitor='val_mean_squared_error',
min_delta=1e-4,
patience=50,
verbose=0,
mode='min')
model_ckpt = ModelCheckpoint( fl_fmt_wt_ckpt,
monitor='val_mean_squared_error',
verbose=0,
save_best_only=True,
save_weights_only=True,
period=1)
callbacks = [history, early_stop, model_ckpt]
for i in range(len(lr)):
train_sess = tf.Session( config = config, graph = train_graph )
K.set_session( train_sess )
if model_type == 'new':
with train_graph.as_default():
# Print model summary
summary_fl_pth = os.path.join( fldr_summary, 'model_summary_run_{:04d}_'.format(run[0]) + r'.txt' )
train_model = get_model( lr[i], is_training = True )
with open(summary_fl_pth, 'w') as summary_file:
train_model.summary( print_fn=lambda x: summary_file.write(x + '\n') )
with train_graph.as_default():
with train_sess.as_default():
if K.backend() == 'tensorflow':
board = TensorBoard( log_dir = fldr_log,
histogram_freq = 0,
write_graph = True,
write_images = True )
callbacks.append( board )
writer = tf.summary.FileWriter( fldr_log, train_graph )
ts = time.time()
ts = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')
arch_yaml = train_model.to_yaml()
arch_fl_pth = os.path.join( fldr_arch, 'arch_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.yaml' )
with open(arch_fl_pth, 'w') as arch_file:
arch_file.write( arch_yaml )
train_model.save( os.path.join( fldr_mdl,
'model_init_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.h5') )
train_model.save_weights( os.path.join( fldr_wt,
'weights_init_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.h5' ) )
train_model.fit_generator(
generator = train_generator,
steps_per_epoch = nbatches_train,
epochs = num_epochs,
max_queue_size = max_q_size,
validation_data = val_generator,
validation_steps = nbatches_val,
workers = 0,
callbacks = callbacks,
initial_epoch = initial_epoch)
ts = time.time()
ts = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')
train_model.save( os.path.join( fldr_mdl,
'model_final_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.h5') )
train_model.save_weights( os.path.join( fldr_wt,
'weights_final_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.h5' ) )
if K.backend() == 'tensorflow':
K.clear_session()
del train_model
gc.collect()
if __name__ == '__main__':
""" Entry point to the program
"""
main()
About this issue
- Original URL
- State: open
- Created 5 years ago
- Reactions: 1
- Comments: 66 (1 by maintainers)
@yaroslavvb @TimSalimans @christopherhesse @cberner @davidBelanger Can someone PLEASE HELP ME?! Please. I’ll do whatever suggestions you have.
@yaroslavvb @TimSalimans @christopherhesse @cberner @davidBelanger Please, if this doesn’t work with keras, update the
README.mdto say it does not work with keras. I would REALLY like to get this to work, but I need your help. I’ve provided a relatively straightforward use case for keras above and am happy to provide training data or anything else that is necessary. I’m also more than happy to work through any suggestions you might have to get the code to work.