@@ -3400,3 +3400,30 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1
34003400 torch.bind_symbolic_shape %3 , [%0 ], affine_map <()[s0 ] -> (s0 )> : !torch.vtensor <[?],f32 >
34013401 return %3 : !torch.vtensor <[?],f32 >
34023402}
3403+
3404+ // -----
3405+ // CHECK-LABEL: func.func @torch.aten.avg_pool2d.single_int_tuple(
3406+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,20,20],f32>) -> !torch.vtensor<[2,4,9,9],f32> {
3407+ // CHECK: %[[NONE:.*]] = torch.constant.none
3408+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
3409+ // CHECK: %[[C_6:.*]] = torch.constant.int 6
3410+ // CHECK: %[[C_1:.*]] = torch.constant.int 1
3411+ // CHECK: %[[C_2:.*]] = torch.constant.int 2
3412+ // CHECK: %[[KERNEL:.*]] = torch.prim.ListConstruct %[[C_6]], %[[C_6]] : (!torch.int, !torch.int) -> !torch.list<int>
3413+ // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C_1]], %[[C_1]] : (!torch.int, !torch.int) -> !torch.list<int>
3414+ // CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[C_2]], %[[C_2]] : (!torch.int, !torch.int) -> !torch.list<int>
3415+ // CHECK: %[[POOL:.*]] = torch.aten.avg_pool2d %[[ARG0]], %[[KERNEL]], %[[PAD]], %[[STRIDE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[2,4,20,20],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,4,9,9],f32>
3416+ // CHECK: return %[[POOL]]
3417+ func.func @torch.aten.avg_pool2d.single_int_tuple (%arg0: !torch.vtensor <[2 ,4 ,20 ,20 ],f32 >) -> !torch.vtensor <[2 ,4 ,9 ,9 ],f32 > {
3418+ %int6 = torch.constant.int 6
3419+ %0 = torch.prim.ListConstruct %int6 : (!torch.int ) -> !torch.list <int >
3420+ %int2 = torch.constant.int 2
3421+ %1 = torch.prim.ListConstruct %int2 : (!torch.int ) -> !torch.list <int >
3422+ %int1 = torch.constant.int 1
3423+ %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
3424+ %false = torch.constant.bool false
3425+ %false_0 = torch.constant.bool false
3426+ %none = torch.constant.none
3427+ %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %false_0 , %none : !torch.vtensor <[2 ,4 ,20 ,20 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[2 ,4 ,9 ,9 ],f32 >
3428+ return %3 : !torch.vtensor <[2 ,4 ,9 ,9 ],f32 >
3429+ }
0 commit comments