@@ -7504,6 +7504,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
75047504" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
75057505" return %1 : !torch.list<int>\n"
75067506" }\n"
7507+ " func.func @\"__torch_mlir_shape_fn.aten.any.dims\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool) -> !torch.list<int> {\n"
7508+ " %none = torch.constant.none\n"
7509+ " %0 = torch.derefine %none : !torch.none to !torch.any\n"
7510+ " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
7511+ " return %1 : !torch.list<int>\n"
7512+ " }\n"
75077513" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
75087514" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
75097515" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
@@ -15420,6 +15426,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1542015426" }\n"
1542115427" return %2 : !torch.int\n"
1542215428" }\n"
15429+ " func.func @\"__torch_mlir_dtype_fn.aten.any.dims\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool) -> !torch.int {\n"
15430+ " %int11 = torch.constant.int 11\n"
15431+ " %int0 = torch.constant.int 0\n"
15432+ " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15433+ " %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
15434+ " %2 = torch.prim.If %1 -> (!torch.int) {\n"
15435+ " torch.prim.If.yield %0#1 : !torch.int\n"
15436+ " } else {\n"
15437+ " torch.prim.If.yield %int11 : !torch.int\n"
15438+ " }\n"
15439+ " return %2 : !torch.int\n"
15440+ " }\n"
1542315441" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
1542415442" %int11 = torch.constant.int 11\n"
1542515443" %int0 = torch.constant.int 0\n"
0 commit comments