@@ -54,3 +54,105 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
5454 }
5555 return %1 : tensor <?x?x?x?xf32 >
5656}
57+
58+ // -----
59+
60+ // CHECK-LABEL: func @insert_of_padding_expand_shape(
61+ // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
62+ // CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
63+ // CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
64+ // CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
65+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
66+ // CHECK: return %[[insert]]
67+ func.func @insert_of_padding_expand_shape (
68+ %t: tensor <?x?xf32 >, %d: tensor <?x?x?x?xf32 >, %x: index , %y: index )
69+ -> tensor <?x?x?x?xf32 > {
70+ %c0 = arith.constant 0 : index
71+ %c1 = arith.constant 1 : index
72+ %sz0 = tensor.dim %t , %c0 : tensor <?x?xf32 >
73+ %sz1 = tensor.dim %t , %c1 : tensor <?x?xf32 >
74+ %0 = tensor.expand_shape %t [[0 , 1 ], [2 , 3 ]] output_shape [1 , %sz0 , 1 , %sz1 ]
75+ : tensor <?x?xf32 > into tensor <1 x?x1 x?xf32 >
76+ %1 = tensor.insert_slice %0 into %d [%x , %y , 0 , 0 ][1 , %sz0 , 1 , %sz1 ][1 , 1 , 1 , 1 ]
77+ : tensor <1 x?x1 x?xf32 > into tensor <?x?x?x?xf32 >
78+ return %1 : tensor <?x?x?x?xf32 >
79+ }
80+
81+ // -----
82+
83+ // CHECK-LABEL: func @insert_of_non_padding_expand_shape(
84+ // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
85+ // CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
86+ // CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
87+ // CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
88+ // CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
89+ // CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
90+ // CHECK: %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
91+ // CHECK: return %[[insert]]
92+ func.func @insert_of_non_padding_expand_shape (
93+ %t: tensor <?x?xf32 >, %d: tensor <?x?x?x?xf32 >, %x: index , %y: index , %sz: index )
94+ -> tensor <?x?x?x?xf32 > {
95+ %c0 = arith.constant 0 : index
96+ %c1 = arith.constant 1 : index
97+ %sz0 = tensor.dim %t , %c0 : tensor <?x?xf32 >
98+ %sz1 = tensor.dim %t , %c1 : tensor <?x?xf32 >
99+ %0 = tensor.expand_shape %t [[0 , 1 ], [2 ]] output_shape [%sz , %sz0 , %sz1 ]
100+ : tensor <?x?xf32 > into tensor <?x?x?xf32 >
101+ %1 = tensor.insert_slice %0 into %d [%x , %y , 0 , 0 ][%sz , 1 , %sz0 , %sz1 ][1 , 1 , 1 , 1 ]
102+ : tensor <?x?x?xf32 > into tensor <?x?x?x?xf32 >
103+ return %1 : tensor <?x?x?x?xf32 >
104+ }
105+
106+ // -----
107+
108+ // CHECK-LABEL: func @parallel_insert_of_padding_expand_shape(
109+ // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
110+ // CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
111+ // CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
112+ // CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
113+ // CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
114+ func.func @parallel_insert_of_padding_expand_shape (
115+ %t: tensor <?x?xf32 >, %d: tensor <?x?x?x?xf32 >, %x: index , %y: index )
116+ -> tensor <?x?x?x?xf32 > {
117+ %c0 = arith.constant 0 : index
118+ %c1 = arith.constant 1 : index
119+ %sz0 = tensor.dim %t , %c0 : tensor <?x?xf32 >
120+ %sz1 = tensor.dim %t , %c1 : tensor <?x?xf32 >
121+ %0 = tensor.expand_shape %t [[0 , 1 ], [2 , 3 ]] output_shape [1 , %sz0 , 1 , %sz1 ]
122+ : tensor <?x?xf32 > into tensor <1 x?x1 x?xf32 >
123+ %1 = scf.forall (%i , %j ) in (%x , %y ) shared_outs (%o = %d ) -> (tensor <?x?x?x?xf32 >) {
124+ scf.forall.in_parallel {
125+ tensor.parallel_insert_slice %0 into %o [%i , %j , 0 , 0 ][1 , %sz0 , 1 , %sz1 ][1 , 1 , 1 , 1 ]
126+ : tensor <1 x?x1 x?xf32 > into tensor <?x?x?x?xf32 >
127+ }
128+ }
129+ return %1 : tensor <?x?x?x?xf32 >
130+ }
131+
132+ // -----
133+
134+ // CHECK-LABEL: func @parallel_insert_of_non_padding_expand_shape(
135+ // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
136+ // CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
137+ // CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
138+ // CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
139+ // CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
140+ // CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
141+ // CHECK: tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
142+ func.func @parallel_insert_of_non_padding_expand_shape (
143+ %t: tensor <?x?xf32 >, %d: tensor <?x?x?x?xf32 >, %x: index , %y: index , %sz: index )
144+ -> tensor <?x?x?x?xf32 > {
145+ %c0 = arith.constant 0 : index
146+ %c1 = arith.constant 1 : index
147+ %sz0 = tensor.dim %t , %c0 : tensor <?x?xf32 >
148+ %sz1 = tensor.dim %t , %c1 : tensor <?x?xf32 >
149+ %0 = tensor.expand_shape %t [[0 , 1 ], [2 ]] output_shape [%sz , %sz0 , %sz1 ]
150+ : tensor <?x?xf32 > into tensor <?x?x?xf32 >
151+ %1 = scf.forall (%i , %j ) in (%x , %y ) shared_outs (%o = %d ) -> (tensor <?x?x?x?xf32 >) {
152+ scf.forall.in_parallel {
153+ tensor.parallel_insert_slice %0 into %o [%i , %j , 0 , 0 ][%sz , 1 , %sz0 , %sz1 ][1 , 1 , 1 , 1 ]
154+ : tensor <?x?x?xf32 > into tensor <?x?x?x?xf32 >
155+ }
156+ }
157+ return %1 : tensor <?x?x?x?xf32 >
158+ }
0 commit comments