jax: Gradient of QR fails
There seems to be a major bug in the gradient of the QR decomposition. The forward-mode derivative is accurate for Q but very far off for R. There also seems to be no check in the tests making sure the gradient is accurate.
This reproduces the error:
import jax
import jax.numpy as jnp
import numpy as np
x = np.random.randn(3, 3)
dx = np.random.randn(3, 3)
primals, tangents = jax.jvp(jnp.linalg.qr, (x,), (dx,))
q, r = primals
dq, dr = tangents
dt = 1e-6
dq_ = (np.linalg.qr(x + dt * dx)[0] - np.linalg.qr(x)[0]) / dt
dr_ = (np.linalg.qr(x + dt * dx)[1] - np.linalg.qr(x)[1]) / dt
assert jnp.allclose(x, q @ r, atol=1e-5, rtol=1e-5) # passes
assert jnp.allclose(dq, dq_, atol=1e-5, rtol=1e-5) # passes
assert jnp.allclose(dx, q @ dr_ + dq_ @ r, atol=1e-5, rtol=1e-5) # passes
assert jnp.allclose(dr, dr_, atol=1e-5, rtol=1e-5) # fails
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 23 (23 by maintainers)
Commits related to this issue
- Fix qr_jvp (fixes #2863) — committed to j-towns/jax by deleted user 4 years ago
- Fix definition of qr primitive to return only the upper triangular part of r. Issue #2863. — committed to hawkinsp/jax by hawkinsp 4 years ago
- Fix definition of qr primitive to return only the upper triangular part of r. (#2870) Issue #2863. — committed to google/jax by hawkinsp 4 years ago
- Improve JAX test PRNG APIs to fix correlations between test cases. In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because ... — committed to hawkinsp/jax by hawkinsp 4 years ago
- Improve JAX test PRNG APIs to fix correlations between test cases. In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because ... — committed to hawkinsp/jax by hawkinsp 4 years ago
- Improve JAX test PRNG APIs to fix correlations between test cases. In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because ... — committed to hawkinsp/jax by hawkinsp 4 years ago
- Improve JAX test PRNG APIs to fix correlations between test cases. (#2957) * Improve JAX test PRNG APIs to fix correlations between test cases. In #2863, we observed that we were missing gradient ... — committed to google/jax by hawkinsp 4 years ago
- Print doc (#3032) * Implement jax.ops.index_mul. (#2696) * Implement jax.ops.index_mul. * Add index_mul to documentation. * Fix RHS JVP rule for scatter_mul, fix test bug that meant it was n... — committed to google/jax by gnecula 4 years ago
Thanks. Yes, I saw the derivation, and checked it manually. The problem is perhaps deeper than the derivation itself and has more to do with something weird happening in the back end if the exact same function implemented as a jvp rule for a primitive gives different results than executing the bare function.
On Tue, Apr 28, 2020, 10:18 AM Jamie Townsend notifications@github.com wrote: