TensorRT: 🐛 [Bug] Non-zero bias term for torch.nn.Linear causes engine building to fail using FX frontend (release/1.3)
Bug Description
One fx test is failing on the release/1.3 branch
To Reproduce
Steps to reproduce the behavior:
- pytest test_fuse_permute_linear_trt.py
Expected behavior
Here is the debug process:
- I simplified the test to the version below
- The test will fail if self.linear2.bias.data.fill_() with non-zero. If we filled it with zero, then test is good.
- I further checked that self.linear1.bias.data.fill_(1) does not affect the test
- The bias value is printed out in converter function and it is under expectation with value=1 as shown below: ===bias= Parameter containing: tensor([1., 1.], requires_grad=True)
- Here is the output comparison. TRT output is 5 while ref is 6. bias is missed in the calculation. ==ref= tensor([[[6., 6.], [6., 6.]]]) ==trt= tensor([[[5., 5.], [5., 5.]]], device=‘cuda:0’)
Conclusion: this test works well in previous TRT versions. The root cause seems that the linear 2 bias is neglected.
def test_multi_fuse_permute_linear(self):
"""
Fusion when permute output is shared by multiple linears
"""
class TestModule(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear1 = torch.nn.Linear(in_features, out_features)
self.linear2 = torch.nn.Linear(in_features, out_features)
self.linear1.weight.data.fill_(1)
self.linear1.bias.data.fill_(1)
self.linear2.weight.data.fill_(1)
self.linear2.bias.data.fill_(1) <=========test fails if bias is non-zero. If bias is set to 0, test succeeds.
def forward(self, x):
y = x.permute(0, 2, 1)
return self.linear1(y) + self.linear2(y)
inputs = [torch.ones(1, 2, 2)]
self.run_test(
TestModule(2, 2),
inputs,
{trt_transposed_linear},
apply_passes=[fuse_permute_linear],
test_implicit_batch_dim=False,
)
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 8.5
- PyTorch Version (e.g. 1.0): 1.13
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 23 (7 by maintainers)
Hi folks,
I am re-opening this issue as we are seeing some numerical failures on 8.6 EA about test_fuse_permute_linear_trt.py