@@ -659,6 +659,48 @@ func.func @test_multinomial_dtype_double_samplenum_4(%arg0: !torch.vtensor<[3,5]
659659
660660// -----
661661
662+ // CHECK-LABEL: func.func @test_maxpool_1d_indices_default
663+ func.func @test_maxpool_1d_indices_default (%arg0: !torch.vtensor <[1 ,3 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,31 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 } {
664+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
665+ // CHECK: %[[VAL_0:.*]] = torch.constant.int 2
666+ // CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.int) -> !torch.list<int>
667+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 0
668+ // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
669+ // CHECK: %[[VAL_4:.*]] = torch.constant.int 1
670+ // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
671+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 1
672+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]] : (!torch.int) -> !torch.list<int>
673+ // CHECK: %[[VAL_8:.*]] = torch.constant.bool false
674+ // CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = torch.aten.max_pool1d_with_indices %[[ARG0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_7]], %[[VAL_8]] : !torch.vtensor<[1,3,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,31],f32>, !torch.vtensor<[93],ui64>
675+ // CHECK: return %[[VAL_9]] : !torch.vtensor<[1,3,31],f32>
676+ // CHECK: }
677+ %0:2 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.kernel_shape = [2 : si64 ]} : (!torch.vtensor <[1 ,3 ,32 ],f32 >) -> (!torch.vtensor <[1 ,3 ,31 ],f32 >, !torch.vtensor <[93 ], ui64 >)
678+ return %0#0 : !torch.vtensor <[1 ,3 ,31 ],f32 >
679+ }
680+
681+ // -----
682+
683+ // CHECK-LABEL: func.func @test_maxpool_1d_indices_ceil_pad_stride(
684+ func.func @test_maxpool_1d_indices_ceil_pad_stride (%arg0: !torch.vtensor <[1 ,3 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,16 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 } {
685+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,16],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
686+ // CHECK: %[[VAL_0:.*]] = torch.constant.int 5
687+ // CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[VAL_0]] : (!torch.int) -> !torch.list<int>
688+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 2
689+ // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
690+ // CHECK: %[[VAL_4:.*]] = torch.constant.int 2
691+ // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
692+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 1
693+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]] : (!torch.int) -> !torch.list<int>
694+ // CHECK: %[[VAL_8:.*]] = torch.constant.bool true
695+ // CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = torch.aten.max_pool1d_with_indices %[[ARG0]], %[[VAL_1]], %[[VAL_5]], %[[VAL_3]], %[[VAL_7]], %[[VAL_8]] : !torch.vtensor<[1,3,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,16],f32>, !torch.vtensor<[48],ui64>
696+ // CHECK: return %[[VAL_9]] : !torch.vtensor<[1,3,16],f32>
697+ // CHECK: }
698+ %0:2 = torch.operator " onnx.MaxPool" (%arg0 ) {torch.onnx.ceil_mode = 1 : si64 , torch.onnx.kernel_shape = [5 : si64 ], torch.onnx.pads = [2 : si64 , 2 : si64 ], torch.onnx.strides = [2 : si64 ]} : (!torch.vtensor <[1 ,3 ,32 ],f32 >) -> (!torch.vtensor <[1 ,3 ,16 ],f32 >, !torch.vtensor <[48 ], ui64 >)
699+ return %0#0 : !torch.vtensor <[1 ,3 ,16 ],f32 >
700+ }
701+
702+ // -----
703+
662704// CHECK-LABEL: func.func @test_maxpool_2d_default
663705func.func @test_maxpool_2d_default (%arg0: !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >) -> !torch.vtensor <[1 ,3 ,31 ,31 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 12 : si64 } {
664706 // CHECK: %[[I2:.*]] = torch.constant.int 2
0 commit comments