array-api: Eager functions in API appear in conflict with lazy implementation
We are looking at adapting this API for a lazy array library built on top of ONNX (and Spox). It seems to be an excellent fit for most parts. However, certain functions in the specification appear to be in conflict with a lazy implementation. One instance is __bool__ which is defined as:
array.__bool__() → bool
The output is documented as “a Python bool object representing the single element of the array”. This is problematic for lazy implementations since the value is not available at this point. How should a standard compliant lazy implementation deal with this apparent conflict?
About this issue
- Original URL
- State: closed
- Created a year ago
- Reactions: 4
- Comments: 31 (18 by maintainers)
From https://pytorch.org/docs/master/func.ux_limitations.html#data-dependent-python-control-flow: JAX supports transforming over data-dependent control flow using special control flow operators (e.g.
jax.lax.cond,jax.lax.while_loop). We’re investigating adding equivalents of those to PyTorch.In fact there now is a
torch.cond, but it seems so new that that is not yet reflected in the docs.I believe the majority of tracing use cases will work, using Python control flow based on values is one of very few things that won’t work. And such code isn’t a good fit for tracing anyway. So I think:
__bool__& co around__bool__/__int__/__float__, but implementations that must be fully lazy will have to raise an exception herecondfor this, nor to have scikit-learn & co worry about this as a problem right now.Just to be completely clear, it’s physically impossible for
__bool__to return anything other thanbool. Returning anything other than True or False from__bool__results in an error:(the same is true for
__int__,__float__, and__complex__).And even if this limitation didn’t exist, that wouldn’t really help, because 90% of the time
bool(x)is called implicitly as the result of something likeif xorassert x.I think the main takeaways from this discussion are:
We should have a page discussing delayed evaluation, noting that it is supported and certain operations may not work with it, and that the function an actual compute is not standardized because it should typically be done by end-users on a per-library basis.
We should add a note to
__bool__,__int__,__float__, and__complex__that they may not be supported by such libraries.I agree with @rgommers’s summary as a pragmatic stance for the short term.
I think we need to wait a year or two w.r.t. how tensor libraries with JIT compiler support will evolve to start thinking how to standardize API for data-dependent control flow operators (and maybe even for standard a jit compiler decorator).
However, those compiler related aspects are an important evolution of the numerical / data science Python ecosystem and I think we should keep them in mind to later consider Array API spec extension (similar to what is done for the
xp.linalgsubmodule).If we didn’t make it so that the dask/jax behaviour becomes what the standard says should be done, wouldn’t you still end up in trouble for tracing? In scikit-learn we’d have to explicitly be triggering the computation (instead of implicitly via
__bool__or__float__) and then we’d be back to square one that there are no concrete values.I’m still not sure I fully understand how spox (I assume this is the library you are thinking about) does its thing but it feels like tracing is not a use-case for the array API? Like, it is a neat trick but you wouldn’t change the design to make it easier to do tracing if more mainstream uses got harder. The work PyTorch has done is pretty exciting (as Olivier already said), in particular I think Torch Dynamo is the bit that does the tracing (or is it torch inductor?). Maybe worth investigating how they do it.
Thanks for this great discussion! I think it might be useful to reiterate the following point: A lazy library (such as the one we are building on top of ONNX) may have an eager mode on top of it for debugging purposes, but those eager values must never influence the lazy computational graph that we are building. We are essentially trying to compile a sklearn-like pipeline ahead of time into an ONNX graph. We don’t have any (meaningful) values available when executing the Python code that produces our lazy arrays. We have no other choice but to throw an exception if an eager value is requested. It would, however, be a pitty if that fact would stop a lazy array implementation from being standard compliant. Hence this issue to clarify if it would be ok by the standard to raise in those cases.
On the topic of control flow: The ONNX standard does offer lazy control flow (If, Loop, Scan) operators. Rather than using Python’s syntactic sugar for
if-elsestatements andfor-loops, those operators are more akin to the built-inmapfunction. It would be necessary to offer similar control flow functions through the array API if a use case like the aboveiterative_solverwere to be supported lazily.FWIW it’s the same in cuNumeric too. When a Python scalar is needed, an expression is force-evaluated in a blocking manner.
Not sure what question you’re referring to, Aditya?
Regarding Arron’s question about when to call
compute(), I have a strong opinion that it should be the end user’s responsibility, not any array-consuming libraries’. Suppose we have 3 libraries A, B, C, where A provides a lazy array, B calls A’s APIs, C calls B’s APIs, and the end user creates A arrays and calls C APIs. For simplicity we could assume A being Dask, B SciPy, and C scikit-learn. If B or C callscompute()on behave of the user (of C), the graph created by the user would be disconnected/materialized at the library boundary, not across the 3 libraries. This also saves all of the array commuting libraries’ lives, by staying at lazy/eager agnostic.It’s impossible for
__bool__to be lazy.bool(x)will callx.__bool__()and convert the result into a boolean.I guess it’s up to how you want to design your library, but I guess you really only have two choices: either make bool() implicitly perform a compute on the lazy graph, or make it raise an exception.
Note that the standard does already discuss this in the context of data-dependent output shapes https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html, but perhaps we also need to identify other functions such as
__bool__that are problematic for such libraries.One question I have is, for a lazy evaluation library, how should it work for array consuming libraries (e.g. scipy) that are written against a generic array API expecting both eager and non-eager array backends? Is it the case that a scipy function should just take in a lazy array object and return a computation graph object, and fail if it tries to call a function that cannot be performed lazily? Or do these libraries need to know how to call
compute()(or whatever) at the end? Should the function to perform an actual compute be standardized?There’s also a question of how we can support libraries like this in the test suite. The test suite calls things like
__bool__all the time because it compares actual values in the tests. So it would need to support callingcompute()or whatever to get actual array values. But this is orthogonal to the discussions here for the actual standard (feel free to open an issue in the test suite repo if you want to discuss this further).torch.compilecan generate a symbolic graph from a Python program with a data-dependent control flow (presumably via static analysis).<del>
jax.jitdoes not complain with the aboveiterative_solverfunction either but I am not sure what it does under the hood in case of data-dependent control flows.</del>EDIT: decorating the
iterative_solverfunction withjax.jitmakes jax complain as follows on the first call tofloat():See also: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
and later in that same document:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#structured-control-flow-primitives
As noted previously, this function runs fine without the
jax.jitdecorator.I updated the gist with torch and jax if you are interested in reproducing the above.
I gave it a try and indeed dask is smart enough to avoid any recomputation while triggering computation as needed as part of the control flow:
https://gist.github.com/ogrisel/6a4304e1831051203a98118875ead2d4
I am not sure if we can expect all the other lazy Array API implementations to follow the same semantics without a more explicit API though.
I think it’s actively non-idiomatic to do so. You want to write code that does not care whether evaluation is triggered, but rather expresses only the logic and leaves execution semantics to the array library.
Isn’t that just a quality-of-implementation issue? A good lazy library should automatically cache/persist calculated values that it can know have to be reused again. In this particular case though, the problem may be that Dask won’t see the line
for iter_idx in range(maxiter):? If you’d replacerangewithda.arange, it will be able to tell. Although also, perhaps the example is misleading - ifdataandparamsdon’t change inside the for-loop, you can move it out. And if they do change, then there’s nothing to persist.The question is here remains if the standard allows for raising when
bool()is used to warn users about evaluation being triggered implicitly (and maybe__float__, etc.).That would still mean that sklearn would have to call some form of
compute()to avoid that error.Thanks for clarifying @asmeurer, that’s the thing I was missing in my answer.
The array-consuming library doesn’t have to, as pointed out by Leo and Aaron, the Python language already forces the evaluation. And this is then what Dask does:
So no need for any
.compute()or similar call within the consuming library.