jax: [jax2tf] NotImplementedError: Call to scatter add cannot be converted with enable_xla=False

Conversions for XlaScatter are currently unsupported when using enabled_xla=False. I’m wondering if support could be added?

Here’s the full error that I’m seeing:

NotImplementedError                       Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/impl_no_xla.py in op(*arg, **kwargs)
     52 
     53   def op(*arg, **kwargs):
---> 54     raise _xla_disabled_error(name)
     55 
     56   return op

NotImplementedError: Call to scatter add cannot be converted with enable_xla=False.

Here is some code that reproduces this error:

!pip install --upgrade flax
!pip install git+https://github.com/josephrocca/transformers.git@patch-2
import jax
from jax.experimental import jax2tf
from jax import numpy as jnp

import numpy as np
import tensorflow as tf

from transformers import FlaxCLIPModel

clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")

def score(pixel_values, input_ids, attention_mask):
    pixel_values = jax.image.resize(pixel_values, (3, 224, 224), "nearest")
    inputs = {"pixel_values":jnp.array([pixel_values]), "input_ids":input_ids, "attention_mask":attention_mask}
    outputs = clip(**inputs)
    return outputs.logits_per_image[0][0][0]

score_tf = jax2tf.convert(jax.grad(score), enable_xla=False)

my_model = tf.Module()
my_model.f = tf.function(score_tf, autograph=False, jit_compile=True, input_signature=[
  tf.TensorSpec([3, 40, 40], tf.float32),
  tf.TensorSpec([1, 30], tf.int32),
  tf.TensorSpec([1, 30], tf.int32),
])

model_name = 'pixel_text_score_grad'
tf.saved_model.save(my_model, model_name, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

Here’s a public colab with that code: https://colab.research.google.com/drive/1HjpRXsa8Ue9KWiKbVVWUX6DlXoIYx2r8?usp=sharing You can click “Runtime > Run all” to see the error.

Thanks!

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 35 (24 by maintainers)

Most upvoted comments

One advantage for now for the jax2tf + TFLite converter is that it is much more extensively used and tested than the tflite.experimental_from_jax. There is significant usage for jax2tf for serving, and the tflite converter is also used a lot. The corner cases when this path does not work are likely to be a subset of those when the experimental_from_jax does not work either. This is because the difficulties in both cases are to cover all the corner cases for complex operations such as gather, scatter, and convolution.

For the long term, it is quite possible that the direct HLO to TFLite path will evolve to be the simpler path.

Another update: I think I was actually wrong in saying that you should always convert JAX --> TFLite through from_concrete_function. I think the main use case for going through a SavedModel is to be able to support shape polymorphism. This doesn’t seem to work when using from_concrete_function because the signature looks quite odd, and I think it only supports a single shape. Moreover, TFLite actually recommends the SavedModel path themselves as well here!. Our current MNIST example does not show how to go from a JAX model to a TFLite model with shape polymorphism.

I have filed another issue saying we should improve this (#10821).

Hello all! Given that @oliverdutton’s PR #10653 is merged and we now have initial support for the scatter op, I am closing this issue.

Two closing remarks:

@josephrocca

🤦 I didn’t realise that the TFLite converter had a from_concrete_functions method, thank you! I guess that skips the SavedModel signature issue. One problem though: The resulting vit.tflite file from the notebook you linked is only 588kb, and according to Netron, only has only two notes - an input node and an output node. So something seems to be silently going wrong during the conversion? I did notice that there are a few (not-super-informative) warnings.

Hmm you are right. Indeed it seems that when we wrap the function in jax.grad it doesn’t include the parameters. I think the reason is that we still have a bug in the conversion of jax.grad, which results in the TFLite model outputting only zeros, which is then optimized by the TFLite model by just throwing away the params (since it doesn’t need them). I have filed a separate issue for this (#10819).

@oliverdutton

Route 2. tflite.experimental_from_jax uses the HLO itself, and doesn’t save any intermediate files. While both are perfectly valid, the second seems like less possible pain points.

Generally I would really recommend the jax2tf --> tflite path because it has most support and is actively being worked on, so you have the biggest chance that any issues you run into will actually be addressed.

Also, note you don’t have to store any intermediate files: jax2tf converts your JAX function to a TF function, which can directly be converted to a TFLite model using TFLIte’s from_concrete_function.

Also, why wouldn’t you recommend the path that uses jax2tf?

For converting to tflite Route 1. jax2tf involves mapping jaxpr of the program to tensorflow ops, saving a saved model then loading and passing it with tflite converter. When enable_xla=False is used (which for tflite it’s not required) this involves mapping to high level tensorflow ops. Route 2. tflite.experimental_from_jax uses the HLO itself, and doesn’t save any intermediate files. While both are perfectly valid, the second seems like less possible pain points.

Converting to ONNX however, means jax2tf enable_xla=False still has a specific use. (That’s the one I know of, there are likely many other cases I’m not aware of)

CLIP grads working correctly here notebook, though I’ve used a matching image size straight away. Please experiment with it and see if the resize steps are whats causing the issues. Are you happy with not having the grads on un-resized image?

Summary:

[jax2tf] NotImplementedError: Call to scatter add cannot be converted with enable_xla=False began as supporting XLA free scatters.

XLA free scatters now supported, so this ticket is theoretically closed. However the root issue was the ability to run tflite models with the grad of LiT and CLIP.

The way the scatters ended up in these models prompted me to solve an issue I’ve had before in #10682 so the scatter simply doesn’t appear anyhow. So two things have happened: one pull request which makes scatters xla-free, and another which avoids most scatters appearing to begin with.

So two tasks: 1 - grad of LiT in tflite 2 - grad of CLIP in tflite

LiT grad

For LiT this notebook does build the tflite model (which does end up huge, 893MB, of which ~130MB is the model with rest params) and runs, matching jax output with a reasonable tolerance.

Model file being 130Mb (without params) makes sense to me as some of the grads on the params are zeros. They’re materialised in the tflite file to read straight away as they have no dependency on input. So, in my view tflite is working exactly as intended and that file size is optimal given tflites design choice to materialise arrays.

Do you actually want the grads on the params, or do you want them on the image input?

CLIP grad

CLIP is ongoing, zero grads are observed in the tflite model while they are non zero in the jax version.

Thank you for your work on making jax2tf! I am very excited about getting JAX models working in the browser, and I’m dreaming of the day that I can just do everything in JAX and then port it to the browser with a conversion process that “just works”. In my wildest dreams I run JAX models https://github.com/google/jax/issues/1472.

We thank you for your interest and patience.

This may be useful. I recently had to for a set of simple scatters, however I found another route so closed the request. Please check first, but it did function correctly for a subset of cases.

https://github.com/google/jax/pull/9289

Thanks for the details @josephrocca! I’m still working on providing better support for the gather op (#9572), once I’ve finished that I will work on scatter.