flax: Inconsistent network behaviour when using different batch sizes for `model.apply` on CPU

Hey, thanks for the great work!

I’m using BatchNorm in my network, but have set the use_running_average parameter of BatchNorm layers to true, which means it will not compute any running mean/stds using the input data that is passing through the network and it will use the pre-computed parameters. Thus, the network’s behaviour doesn’t change among different batches (Ideally, I guess, but it should be true).

I’ve provided a simple reproducible Colab notebook that reproduces the example. The colab needs two files to run properly which are:

psd_data.pkl is the pickled version of a dict containing three things:

  • data: The train and test data used for training the model.
  • params: The trained parameters of the WideResNet module that we’re using, such that it will achieve 1.0 train accuracy and 0.89 test accuracy.
  • labels: The labels of the datapoints in data, to double check the accuracies.

The problem that I have is:

ys = []
for i in range(10):
  ys.append(apply_fn(params, X_train[i:i+1]))
ys = jnp.stack(ys).squeeze()
vs = apply_fn(params, X_train[:10])
np.allclose(ys, vs)
# Outputs False!

which shows that the network’s behaviour varies for different outputs. I expect this to output true, as I have fixed the parameters and the BatchNorm layers. Am I doing something wrong?

https://colab.research.google.com/drive/1a_SheAt9RH9tPRJ1DC60yccsbYaFssDx?usp=sharing

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 22 (6 by maintainers)

Most upvoted comments

Closing because dtype behavior is now consistent since dropping the default float32 dtype

Here is a slightly simpler example that I think reproduces your issue:

import numpy as np
from jax import lax, random, jit
from jax import nn
import jax

init_fn = nn.initializers.lecun_normal()
lhs = random.normal(random.PRNGKey(0), (128, 256))

def conv(x):
  for i in range(10):
    rhs = init_fn(random.PRNGKey(i), (256, 256))
    x = x @ rhs
  return jax.device_get(x)

def np_conv(x):
  x = jax.device_get(x)
  for i in range(10):
    rhs = jax.device_get(init_fn(random.PRNGKey(i), (256, 256)))
    x = x @ rhs
  return x

logits = conv(lhs)
logits_np = np_conv(lhs)

logits_loop = np.zeros_like(logits)
for i in range(128):
  logits_loop[i] = conv(lhs[i:i+1])

print(np.allclose(logits_loop, logits))  # outputs False

The previous example uses stdev=1 weights for the kernels which gives you infinities/NaNs if you stack a bunch of them. In this example you get a relative error of approximately 10^-6 stacking more will make the errors larger.

I added a check against numpy as well which gives RMS errors of roughly 10^-6 errors for all pairs in logits, logits_np, and logits_loop):

image

Argh you are right, it is also quite late here. I will go to bed now and investigate it in more detail tomorrow. I’ve re-opened this issue.

Oh yes, I’m sorry! Indeed, your code is correct.

@mohamad-amin I recommend copying your last message in the JAX issue.

Thanks for filing the issue in the JAX repo! Closing this one since there doesn’t seem to be anything Flax related (but feel free to re-open if you think I am wrong).

Thanks for your analysis!

This difference originates from XLA, and it is probably due to the fact that the batch implementation and single instance implementations of some of the JAX primitives used in the CNN Module are different. Below is a minimal example that only uses conv_general_dilated from JAX but still has the discrepancy you observed.

Can you please file an issue in the JAX repo containing the example below? Perhaps it is expected, but it would be good to know why this happens. Please let me know when you filed the issue, then I can close this one.

import numpy as np
from jax import lax, random

rng1, rng2 = random.split(random.PRNGKey(0))

lhs = random.normal(rng1, shape=((2, 1, 28, 28)))
rhs = random.normal(rng2, shape=((32, 1, 3, 3)))

logits = lax.conv_general_dilated(
    lhs=lhs,
    rhs=rhs,
    window_strides=(1, 1),
    padding="SAME")

logit = lax.conv_general_dilated(
    lhs=lhs[0:1],
    rhs=rhs,
    window_strides=(1, 1),
    padding="SAME")

print(np.allclose(logits, logit))  # outputs False