1- from typing import Iterable , List , NamedTuple , Sequence , Tuple
1+ from typing import Iterable , List , NamedTuple , Optional , Sequence , Tuple
22
33import torch
44
55from .types import Shape , ShapeRange
66from .utils import get_dynamic_dims
77
88
9+ def generate_input_specs (
10+ inputs , lower_setting , additional_inputs = None , fixed_shape = False
11+ ):
12+ # AIT lower setting doesn't have explicit_batch_dimension field and
13+ # we just return None.
14+ if not hasattr (lower_setting , "explicit_batch_dimension" ):
15+ return None
16+
17+ if not lower_setting .explicit_batch_dimension or fixed_shape :
18+ return InputTensorSpec .from_tensors (inputs )
19+
20+ # If we don't have additional inputs, we assume the first dimension
21+ # is the dynamic batch dimension. Otherwise, we use the additional
22+ # inputs to determine the batch dimension.
23+ if additional_inputs is None :
24+ return InputTensorSpec .from_tensors_with_dynamic_batch_size (
25+ inputs ,
26+ (
27+ 0 ,
28+ lower_setting .max_batch_size ,
29+ lower_setting .max_batch_size ,
30+ ),
31+ lower_setting .opt_profile_replica ,
32+ )
33+ else :
34+ batch_dims = []
35+
36+ for i , j in zip (inputs , additional_inputs ):
37+ found_batch_dim = False
38+
39+ for idx , values in enumerate (zip (i .shape , j .shape )):
40+ if values [0 ] != values [1 ]:
41+ assert (
42+ found_batch_dim is False
43+ ), f"We've already found a batch dim, { i .shape } , { j .shape } ."
44+ batch_dims .append (idx )
45+ found_batch_dim = True
46+
47+ if not found_batch_dim :
48+ raise RuntimeError (
49+ f"Failed to find batch dimension because shapes are the same, { i .shape } "
50+ )
51+
52+ return InputTensorSpec .from_tensors_with_dynamic_batch_size (
53+ inputs ,
54+ (
55+ 0 ,
56+ lower_setting .max_batch_size ,
57+ lower_setting .max_batch_size ,
58+ ),
59+ lower_setting .opt_profile_replica ,
60+ batch_dims ,
61+ )
62+
63+
964class InputTensorSpec (NamedTuple ):
1065 """
1166 This class contains the information of a input tensor.
@@ -70,6 +125,7 @@ def from_tensors_with_dynamic_batch_size(
70125 tensors : Sequence [torch .Tensor ],
71126 batch_size_range : Tuple [int , int , int ],
72127 opt_profile_replica : int = 1 ,
128+ batch_dims : Optional [List [int ]] = None ,
73129 ) -> List ["InputTensorSpec" ]:
74130 """
75131 Produce a list of InputTenosrSpec named tuples which would contain
@@ -83,20 +139,30 @@ def from_tensors_with_dynamic_batch_size(
83139 the smallest batch size allowed. The second integer indiceates
84140 the batch size that we'll optimize for. The third integer indicates
85141 the largest batch size allowed.
142+ opt_profile_replica (int): If dynamic shape is enabled, each execution
143+ context requires a different optimization profile. This arg determines
144+ how many optimization profile replicas we want to produce.
145+ batch_dims (Optional[List[int]]): The batch dim might not be the leading dim
146+ and allow user to specify the batch dims using this arg. Default we treat
147+ dim 0 as the batch dim.
86148
87149 Returns:
88150 A list of InputTensorSpec named tuples with dynamic ranges.
89151 """
152+ if batch_dims is None :
153+ batch_dims = [0 ] * len (tensors )
154+
90155 input_specs = []
91- batch_size = tensors [0 ].size (0 )
156+ batch_size = tensors [0 ].size (batch_dims [ 0 ] )
92157
93158 for i , tensor in enumerate (tensors ):
159+ batch_dim = batch_dims [i ]
94160 assert batch_size == tensor .size (
95- 0
161+ batch_dim
96162 ), f"The { i } th tensor (shape: { tensor .shape } ) doesn't have the correct batch size: { batch_size } ."
97163 shape = list (tensor .shape )
98- shape [0 ] = - 1
99- shape_ranges : List [ShapeRange ] = [tuple (tuple ([ bs ] + shape [1 :]) for bs in batch_size_range )] * opt_profile_replica # type: ignore[list-item]
164+ shape [batch_dim ] = - 1
165+ shape_ranges : List [ShapeRange ] = [tuple (tuple (shape [ 0 : batch_dim ] + [ bs ] + shape [batch_dim + 1 :]) for bs in batch_size_range )] * opt_profile_replica # type: ignore[list-item]
100166 input_specs .append (
101167 cls (tuple (shape ), tensor .dtype , tensor .device , shape_ranges )
102168 )
0 commit comments