Commit 0edba55
[fx2trt] Handle shapes like [batch_size] and scalars for binary ops properly (#74)
Summary:
Pull Request resolved: https://github.com/pytorch/fx2trt/pull/74
## Root cause
We have code like:
```
x = ... # result shape is [batch_size, N]
y = mean(y, dim=1, keepdim=False) # result shape is [batch_size]
z = y + 0.5 # result shape is [batch_size]
```
For TRT with implicit batch dimension it should look like:
```
x = ... # result shape is [N]
y = mean(y, dim=1, keepdim=False) # result shape is []
z = y + 0.5 # result shape is []
```
However, because we convert scalar to `TRTTensor` and don't do dimensions squeeze for it, the resulting tensor `z` would have shape `[1]`, and this is gonna break the rest of the net.
## Solution
Convert the scalar value to `torch.Tensor`, because we have dimensions squeeze logic implemented for them.
## P.S.:
Also added support for `sqrt` tracing.
Reviewed By: yinghai, houseroad
Differential Revision: D36336816
fbshipit-source-id: 412e44e99f25ab3549df540a87bd005e6b3fe08a1 parent 5d80f41 commit 0edba55
File tree
2 files changed
+16
-0
lines changed- fx/converters
- tracer/acc_tracer
2 files changed
+16
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
452 | 452 | | |
453 | 453 | | |
454 | 454 | | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
455 | 470 | | |
456 | 471 | | |
457 | 472 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1482 | 1482 | | |
1483 | 1483 | | |
1484 | 1484 | | |
| 1485 | + | |
1485 | 1486 | | |
1486 | 1487 | | |
1487 | 1488 | | |
| |||
0 commit comments