Skip to content

Commit cb75a49

Browse files
committed
validator for index non zero case- else error- Could not find any implementation for node [NON_ZERO]-[aten_ops.index.Tensor]-[index_tensor_bool_nonzero_1]
1 parent aa19802 commit cb75a49

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorrt import ITensor as TRTTensor
1010
from torch.fx.node import Argument, Node, Target
1111
from torch_tensorrt._features import needs_not_tensorrt_rtx
12-
from torch_tensorrt._utils import is_tensorrt_version_supported
12+
from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor
1313
from torch_tensorrt.dynamo._settings import CompilationSettings
1414
from torch_tensorrt.dynamo._SourceIR import SourceIR
1515
from torch_tensorrt.dynamo.conversion import impl
@@ -24,7 +24,8 @@
2424
get_positive_dim,
2525
is_only_operator_on_placeholder,
2626
)
27-
from torch_tensorrt.dynamo.utils import is_thor
27+
28+
from torch_tensorrt._utils import is_thor
2829

2930
_LOGGER: logging.Logger = logging.getLogger(__name__)
3031

@@ -425,9 +426,24 @@ def index_dtype_validator(
425426
return True
426427

427428

429+
def index_nonbool_validator(
430+
node: Node, settings: Optional[CompilationSettings] = None
431+
) -> bool:
432+
# for thor, we don't support boolean indices
433+
if is_thor():
434+
index = node.args[1]
435+
for ind in index:
436+
if ind is not None:
437+
val = ind.meta.get("val")
438+
if val is not None and val.dtype == torch.bool:
439+
return False
440+
return True
441+
442+
428443
@dynamo_tensorrt_converter(
429444
torch.ops.aten.index.Tensor,
430-
capability_validator=index_dtype_validator,
445+
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
446+
and index_nonbool_validator(node, settings),
431447
supports_dynamic_shapes=True,
432448
requires_output_allocator=True,
433449
)

0 commit comments

Comments
 (0)