|
1 | | -from typing import List, Optional, Sequence, cast |
| 1 | +from typing import List, Optional, Sequence |
2 | 2 |
|
3 | 3 | from torch.fx.node import Target |
4 | 4 | from torch_tensorrt.dynamo._SourceIR import SourceIR |
5 | 5 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
6 | 6 | from torch_tensorrt.dynamo.conversion.converter_utils import ( |
7 | | - get_positive_dim, |
8 | 7 | get_trt_tensor, |
| 8 | + set_layer_name, |
9 | 9 | ) |
10 | | -from torch_tensorrt.fx.converters.converter_utils import set_layer_name |
11 | | -from torch_tensorrt.fx.types import Shape, TRTTensor |
| 10 | +from torch_tensorrt.dynamo.types import TRTTensor |
12 | 11 |
|
13 | 12 |
|
14 | 13 | def unsqueeze( |
15 | 14 | ctx: ConversionContext, |
16 | 15 | target: Target, |
17 | 16 | source_ir: Optional[SourceIR], |
18 | 17 | name: str, |
19 | | - input_t: TRTTensor, |
20 | | - dim: Shape, |
| 18 | + input: TRTTensor, |
| 19 | + dim: int, |
21 | 20 | ) -> TRTTensor: |
22 | | - input_val = get_trt_tensor(ctx, input_t, f"{name}_input_t") |
23 | | - if not isinstance(input_val, TRTTensor): |
24 | | - raise RuntimeError( |
25 | | - f"unsqueeze received input {input_val} that is not part " |
26 | | - "of the TensorRT region!" |
27 | | - ) |
28 | | - |
29 | | - dim = cast(int, dim) |
30 | | - |
31 | | - input_shape_size = len(input_val.shape) |
32 | | - dim = get_positive_dim(dim, input_shape_size + 1) |
33 | | - |
34 | | - intermediate_dim = 0 |
35 | | - dynamic_shape_cnt = 0 |
36 | | - # if unsqueeze the last dimensions, we can directly append to the shape |
37 | | - if dim == input_shape_size: |
38 | | - intermediate_dim = dim |
39 | | - else: |
40 | | - # since maximum of one dimension is permitted to be specified as -1 |
41 | | - # find the intermediate_dim which has only 1 dynamic_shape_cnt |
42 | | - # and then we can add a transpose after reshape if it is not the final shape we want |
43 | | - for i, s in reversed(list(enumerate(input_val.shape))): |
44 | | - if i >= dim: |
45 | | - if s == -1: |
46 | | - dynamic_shape_cnt += 1 |
47 | | - if dynamic_shape_cnt > 1: |
48 | | - intermediate_dim = i + 1 |
49 | | - break |
50 | | - if i == dim: |
51 | | - intermediate_dim = i |
52 | | - break |
53 | | - # calculate the new_shape for the shuffle layer's reshape_dims |
54 | | - new_shape = list( |
55 | | - tuple(input_val.shape)[:intermediate_dim] |
56 | | - + (1,) |
57 | | - + tuple(input_val.shape)[intermediate_dim:] |
58 | | - ) |
59 | | - for i, s in enumerate(new_shape): |
60 | | - if i < intermediate_dim and s == -1: |
61 | | - new_shape[i] = 0 |
62 | | - layer = ctx.net.add_shuffle(input_val) |
63 | | - layer.reshape_dims = tuple(new_shape) |
64 | | - # if the intermediate_dim is not the final dim we want to unsqueeze, add a second_transpose after reshape |
65 | | - if intermediate_dim != dim: |
66 | | - # calculate the second_transpose for the shuffle layer |
67 | | - permutation = [*range(0, len(new_shape))] |
68 | | - # for example: if the reshape_dims is (3, 3, 5, 1, 5) and the final shape we want is (3, 1, 3, 5, 5) |
69 | | - # here intermediate_dim=3, dim=1, we need to move intermediate_dim before [dim: intermediate_dim) |
70 | | - new_permutation = ( |
71 | | - tuple(permutation[:dim]) |
72 | | - + (intermediate_dim,) |
73 | | - + tuple(permutation[dim:intermediate_dim]) |
74 | | - + tuple(permutation[intermediate_dim + 1 :]) |
75 | | - ) |
76 | | - layer.second_transpose = new_permutation |
| 21 | + axes = get_trt_tensor(ctx, dim, f"{name}_axes") |
| 22 | + layer = ctx.net.add_unsqueeze(input, axes) |
77 | 23 | set_layer_name(layer, target, name, source_ir) |
78 | 24 | return layer.get_output(0) |
79 | 25 |
|
|
0 commit comments