jax: expm does not support ndarray

Description

Hi, I tried to use jax.scipy.linalg.expm with 3 dimensional array but got error. Docs says it’s doable when last 2 dimensions have same length but failed. Can you please check?

https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html

import jax
import jax.numpy as jnp
A = jnp.arange(2*11*11).reshape((2,11,11))
jax.scipy.linalg.expm(A)

Traceback (most recent call last): File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 3378, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File “<ipython-input-5-d2ac0f4d7ead>”, line 1, in <module> jax.scipy.linalg.expm(A) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/_src/traceback_util.py”, line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/_src/api.py”, line 622, in cache_miss execute = dispatch.xla_call_impl_lazy(fun, *tracers, **params) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/_src/dispatch.py”, line 241, in _xla_call_impl_lazy return xla_callable(fun, device, backend, name, donated_invars, keep_unused, File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/linear_util.py”, line 303, in memoized_fun ans = call(fun, *args) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/_src/dispatch.py”, line 357, in _xla_callable_uncached computation = sharded_lowering(fun, device, backend, name, donated_invars, File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/_src/dispatch.py”, line 348, in sharded_lowering return pxla.lower_sharding_computation( File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/_src/profiler.py”, line 314, in wrapper return func(*args, **kwargs) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/interpreters/pxla.py”, line 2792, in lower_sharding_computation jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( File "/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/interpreters/partial_eval.py”, line 2065, in trace_to_jaxpr_final jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/interpreters/partial_eval.py”, line 1998, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(*in_tracers) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/linear_util.py”, line 167, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/_src/scipy/linalg.py”, line 436, in expm P, Q, n_squarings = _calc_P_Q(A) File "/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/src/api.py", line 626, in cache_miss top_trace.process_call(primitive, fun, tracers, params)) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/interpreters/partial_eval.py”, line 1739, in process_call jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/interpreters/partial_eval.py”, line 2027, in trace_to_subjaxpr_dynamic2 ans = fun.call_wrapped(*in_tracers) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/linear_util.py”, line 167, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File “/Users/yongha/miniforge3_m1/envs/rcwa/lib/python3.9/site-packages/jax/_src/scipy/linalg.py”, line 455, in _calc_P_Q raise ValueError(‘expected A to be a square matrix’) ValueError: expected A to be a square matrix

What jax/jaxlib version are you using?

jax 0.4.1, jaxlib 0.4.1

Which accelerator(s) are you using?

CPU

Additional system info

M1

NVIDIA GPU info

No response

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 19 (1 by maintainers)

Most upvoted comments

Just to mention that the results of JAX typically err on the accuracy for performance. This because the squaring number logic is rather crude. We discovered this during an investigation https://github.com/scipy/scipy/issues/17919#issuecomment-1428897893 In fact it is the most annoying part of the contemporary expm calculations.

Not that it is alarming but inaccuracies can happen due to the chosen reduced number of squarings calculation.

I found via test failures that batched expm was only added in recent scipy versions, which explains why it wasn’t part of the original JAX implementation 😁 (https://github.com/scipy/scipy/pull/15079)

Thanks for the report, that’s definitely a bug!

Until we get it fixed, you can just batch it with jax.vmap, as in jax.vmap(jax.scipy.linalg.expm)(A) (though also it doesn’t accept int32, which seems to be another bug, so you’ll need to explicitly cast like jax.vmap(jax.scipy.linalg.expm)(A.astype('float32'))).