kornia: [Bug] convert_points_from_homogeneous - NaN gradients in backward pass
I just experienced a NaN-gradient problem while doing a backward pass here: https://github.com/kornia/kornia/blob/4b0ae70f7f806c5eff5ab87f8b1f2d9ab4ff1e45/kornia/geometry/conversions.py#L99
torch.where
works absolutely fine, but if you have zero divisions you find yourself with NaN-gradients for sure đź’©
Here is a toy example:
eps = 1e-8
z_vec: torch.Tensor = torch.tensor([4., 6., 0., -3., 1e-9], requires_grad=True)
scale: torch.Tensor = torch.where(
torch.abs(z_vec) > eps,
torch.tensor(1.) / z_vec,
torch.ones_like(z_vec)
)
scale.backward(torch.ones_like(scale))
And these are z_vec gradients:
tensor([-0.0625, -0.0278, nan, -0.1111, -0.0000])
For now my little hack is:
...
# we check for points at infinity
z_vec: torch.Tensor = points[..., -1:]
if z_vec.requires_grad:
def z_vec_backward_hook(grad: torch.Tensor) -> torch.Tensor:
grad[grad != grad] = 0.
return grad
z_vec.register_hook(z_vec_backward_hook)
...
But not sure if it’s good enough.
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Reactions: 10
- Comments: 19 (12 by maintainers)
Yes, please open a separated issue. Will close this once we merge #369 Regarding this new issue, I would give a try at this code from tf graphics, https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py since at least it’s more clean
https://discuss.pytorch.org/t/gradients-of-torch-where/26835/2
my current workaround is:
thank you, @edgarriba
tell me if you need more information
I will try to give you a real example instead of this toy one.
Even in this toy example? In my case it happend during homography regression. I initialized my homography module with some “strong” transformation and after several iterations grads became NaN (inside
kornia.HomographyWarper
). As I told you there are no problems during forward pass (I tested all homographies manually) and shit happens only in backward pass, when one of the z_vec value becomesinf
after division.@poxyu thanks for reporting, I will investigate it