jax: Gradient of QR fails

There seems to be a major bug in the gradient of the QR decomposition. The forward-mode derivative is accurate for Q but very far off for R. There also seems to be no check in the tests making sure the gradient is accurate.

This reproduces the error:

import jax
import jax.numpy as jnp
import numpy as np

x = np.random.randn(3, 3)
dx = np.random.randn(3, 3)

primals, tangents = jax.jvp(jnp.linalg.qr, (x,), (dx,))
q, r = primals
dq, dr = tangents

dt = 1e-6
dq_ = (np.linalg.qr(x + dt * dx)[0] - np.linalg.qr(x)[0]) / dt
dr_ = (np.linalg.qr(x + dt * dx)[1] - np.linalg.qr(x)[1]) / dt

assert jnp.allclose(x, q @ r, atol=1e-5, rtol=1e-5)  # passes
assert jnp.allclose(dq, dq_, atol=1e-5, rtol=1e-5)  # passes
assert jnp.allclose(dx, q @ dr_ + dq_ @ r, atol=1e-5, rtol=1e-5)  # passes
assert jnp.allclose(dr, dr_, atol=1e-5, rtol=1e-5)  # fails

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 23 (23 by maintainers)

Commits related to this issue

Most upvoted comments

Thanks. Yes, I saw the derivation, and checked it manually. The problem is perhaps deeper than the derivation itself and has more to do with something weird happening in the back end if the exact same function implemented as a jvp rule for a primitive gives different results than executing the bare function.

On Tue, Apr 28, 2020, 10:18 AM Jamie Townsend notifications@github.com wrote:

Interesting, I’ll try to take a look at this today. FYI the derivation I wrote is here https://j-towns.github.io/papers/qr-derivative.pdf.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/2863#issuecomment-620485555, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDACBLLSFY2EGRJ5JEXFDRO2NPFANCNFSM4MSDXDMA .