tensorflow: Adaptive optmizers do not work well with multi-head networks when some labels are missing

Use case: we have a multi-task network with many outputs. Each example in the dataset has only subset of the labels; non-existing loss components are masked out. A single optimizer is used, because the presence of particular labels in the example stream is statically unknown.

Problem: masked loss components give zero gradient estimates for the corresponding variables. That breaks adaptive optimizers (Momentum, Adam, …) because steps are taken and slots are corrupted with the incoming zero gradient estimates. The desired behaviour is: do nothing for variables (and corresponding slots) that were effectively unused in the forward pass (like it is done for the embeddings using the IndexedSlices trick). Related problem: tf.global_norm, tf.clip_by_global_norm are also affected by these stray zero gradient estimates.

Minimal example/demonstration:

In [14]: feature = tf.constant([1.0])
In [15]: output = feature * 42.0
In [16]: loss = tf.squared_difference(output, [0.0])
In [17]: mask = tf.constant([False])
In [18]: masked_loss = tf.boolean_mask(loss, mask)
In [19]: masked_loss_grad = tf.gradients(masked_loss, feature)
In [20]: masked_loss_grad
Out[20]: [<tf.Tensor 'gradients/mul_3_grad/Reshape:0' shape=(1,) dtype=float32>]
In [21]: masked_loss_grad[0].eval(session=session)
Out[21]: array([ 0.], dtype=float32)

This is technically correct – the gradient is indeed zero – but if feature was a Variable, we don’t really want to apply this gradient in an adaptive optimizer, because the training example didn’t have useful information about the loss function.

Feature request: better way to deal with that (probably a special value for the gradients that would signify “nonexisting” gradients; e.g. tensorflow uses None, when this information is available statically).

For now here is a simple and ugly workaround:

def _is_all_zeros(grad):
    all_zeros = tf.equal(tf.count_nonzero(grad), 0)
    return all_zeros

class HackedMomentumOptimizer(tf.train.MomentumOptimizer):
    def __init__(self, *args, **kwargs):
        super(HackedMomentumOptimizer, self).__init__(*args, **kwargs)

    def _apply_dense(self, grad, var):
        all_zeros = _is_all_zeros(grad)
        return tf.cond(all_zeros,
                       lambda: tf.no_op(),
                       lambda: super(HackedMomentumOptimizer, self).
		       _apply_dense(grad, var))

    def _resource_apply_dense(self, grad, var):
        all_zeros = _is_all_zeros(grad)
        return tf.cond(all_zeros,
                       lambda: tf.no_op(),
                       lambda: super(HackedMomentumOptimizer, self).
                       _resource_apply_dense(grad, var))

    def _apply_sparse(self, grad, var):
        all_zeros = _is_all_zeros(grad)
        return tf.cond(all_zeros,
                       lambda: tf.no_op(),
                       lambda: super(HackedMomentumOptimizer, self).
                       _apply_sparse(grad, var))

    def _resource_apply_sparse(self, grad, var, indices):
        all_zeros = _is_all_zeros(grad)
        return tf.cond(all_zeros,
                       lambda: tf.no_op(),
                       lambda: super(HackedMomentumOptimizer, self).
                       _resource_apply_sparse(grad, var, indices))


def _clip_gradients(gradients_variables, clip_norm=20.):
    gradients, variables = six.moves.zip(*gradients_variables)

    def _replace_nonexisting_grad(grad):
        all_zeros = _is_all_zeros(grad)
        return tf.cond(all_zeros,
                       lambda: tf.zeros([], dtype=tf.float32),
                       lambda: grad)
    gradients_filtered = [_replace_nonexisting_grad(g) for g in gradients]
    fixed_global_norm = tf.global_norm(gradients_filtered)
    gradients, global_norm = tf.clip_by_global_norm(gradients, clip_norm,
	                                            use_norm=fixed_global_norm)
    tf.summary.scalar('global_norm', global_norm)
    return six.moves.zip(gradients, variables)

About this issue

  • Original URL
  • State: closed
  • Created 7 years ago
  • Reactions: 1
  • Comments: 19 (12 by maintainers)

Most upvoted comments

@Threynaud I haven’t tried it myself yet, but in my understanding it should work. To try you can simply replace the superclass (i.e. inherit from tf.train.AdamOptimizer instead of tf.train.MomentumOptimizer).