|
9 | 9 | from tensorrt import ITensor as TRTTensor |
10 | 10 | from torch.fx.node import Argument, Node, Target |
11 | 11 | from torch_tensorrt._features import needs_not_tensorrt_rtx |
12 | | -from torch_tensorrt._utils import is_tensorrt_version_supported |
| 12 | +from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor |
13 | 13 | from torch_tensorrt.dynamo._settings import CompilationSettings |
14 | 14 | from torch_tensorrt.dynamo._SourceIR import SourceIR |
15 | 15 | from torch_tensorrt.dynamo.conversion import impl |
|
24 | 24 | get_positive_dim, |
25 | 25 | is_only_operator_on_placeholder, |
26 | 26 | ) |
27 | | -from torch_tensorrt.dynamo.utils import is_thor |
| 27 | + |
| 28 | +from torch_tensorrt._utils import is_thor |
28 | 29 |
|
29 | 30 | _LOGGER: logging.Logger = logging.getLogger(__name__) |
30 | 31 |
|
@@ -425,9 +426,24 @@ def index_dtype_validator( |
425 | 426 | return True |
426 | 427 |
|
427 | 428 |
|
| 429 | +def index_nonbool_validator( |
| 430 | + node: Node, settings: Optional[CompilationSettings] = None |
| 431 | +) -> bool: |
| 432 | + # for thor, we don't support boolean indices |
| 433 | + if is_thor(): |
| 434 | + index = node.args[1] |
| 435 | + for ind in index: |
| 436 | + if ind is not None: |
| 437 | + val = ind.meta.get("val") |
| 438 | + if val is not None and val.dtype == torch.bool: |
| 439 | + return False |
| 440 | + return True |
| 441 | + |
| 442 | + |
428 | 443 | @dynamo_tensorrt_converter( |
429 | 444 | torch.ops.aten.index.Tensor, |
430 | | - capability_validator=index_dtype_validator, |
| 445 | + capability_validator=lambda node, settings: index_dtype_validator(node, settings) |
| 446 | + and index_nonbool_validator(node, settings), |
431 | 447 | supports_dynamic_shapes=True, |
432 | 448 | requires_output_allocator=True, |
433 | 449 | ) |
|
0 commit comments