Skip to content

Commit aa19802

Browse files
committed
commit to add validator in nonzero for thor runs (currently fallback to torch) and make ATenAnyDimNegIndexConvertsCorrectly input deterministic on ORIN
1 parent dae1ead commit aa19802

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
get_positive_dim,
2525
is_only_operator_on_placeholder,
2626
)
27+
from torch_tensorrt.dynamo.utils import is_thor
2728

2829
_LOGGER: logging.Logger = logging.getLogger(__name__)
2930

@@ -3601,10 +3602,18 @@ def aten_ops_full(
36013602
)
36023603

36033604

3605+
def nonzero_validator(
3606+
node: Node, settings: Optional[CompilationSettings] = None
3607+
) -> bool:
3608+
return not is_thor()
3609+
3610+
36043611
# currently nonzero is not supported for tensorrt_rtx
36053612
# TODO: lan to add the nonzero support once tensorrt_rtx team has added the support
3613+
# TODO: apbose to remove the capability validator once thor bug resolve in NGC
36063614
@dynamo_tensorrt_converter(
36073615
torch.ops.aten.nonzero.default,
3616+
capability_validator=nonzero_validator,
36083617
supports_dynamic_shapes=True,
36093618
requires_output_allocator=True,
36103619
)

tests/core/conversion/converters/test_reduce.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,14 @@ TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) {
345345
%3 : bool = prim::Constant[value=1]()
346346
%5 : Tensor = aten::any(%0, %1, %3)
347347
return (%5))IR";
348-
auto in = at::randint(-2, 2, {2, 32}, at::kCUDA);
348+
std::vector<int> data(64, 0);
349+
for (int i = 0; i < 64; ++i) {
350+
if (i % 7 == 0)
351+
data[i] = 1; // some positives
352+
if (i % 13 == 0)
353+
data[i] = -1; // some negatives
354+
}
355+
auto in = at::tensor(data, at::TensorOptions().dtype(at::kInt).device(at::kCUDA)).reshape({2, 32}); // shape [2, 32]
349356
test_body(graph, in);
350357
}
351358

0 commit comments

Comments
 (0)