@@ -1169,6 +1169,50 @@ func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !
11691169 return %0 : !torch.vtensor <[3 ,5 ],si64 >
11701170 }
11711171
1172+ // -----
1173+ // CHECK-LABEL: func.func @torch.aten.to.dtype$floatToBool(
1174+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> {
1175+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
1176+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 11
1177+ // CHECK: %[[VAL_3:.*]] = torch.constant.bool false
1178+ // CHECK: %[[VAL_4:.*]] = torch.constant.none
1179+ // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
1180+ // CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
1181+ // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor<f32>, !tosa.shape<2>) -> tensor<1x1xf32>
1182+ // CHECK: %[[VAL_8:.*]] = tosa.equal %[[VAL_1]], %[[VAL_7]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xi1>
1183+ // CHECK: %[[VAL_9:.*]] = tosa.logical_not %[[VAL_8]] : (tensor<3x5xi1>) -> tensor<3x5xi1>
1184+ // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1>
1185+ // CHECK: return %[[VAL_10]] : !torch.vtensor<[3,5],i1>
1186+ // CHECK: }
1187+ func.func @torch.aten.to.dtype$floatToBool (%arg0: !torch.vtensor <[3 ,5 ],f32 >) -> !torch.vtensor <[3 ,5 ],i1 > {
1188+ %int11 = torch.constant.int 11
1189+ %false = torch.constant.bool false
1190+ %none = torch.constant.none
1191+ %0 = torch.aten.to.dtype %arg0 , %int11 , %false , %false , %none : !torch.vtensor <[3 ,5 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[3 ,5 ],i1 >
1192+ return %0 : !torch.vtensor <[3 ,5 ],i1 >
1193+ }
1194+
1195+ // -----
1196+ // CHECK-LABEL: func.func @torch.aten.to.dtype$boolToFloat(
1197+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> {
1198+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],i1> -> tensor<3x4xi1>
1199+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 6
1200+ // CHECK: %[[VAL_3:.*]] = torch.constant.bool false
1201+ // CHECK: %[[VAL_4:.*]] = torch.constant.none
1202+ // CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi1>) -> tensor<3x4xi8>
1203+ // CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<3x4xi8>) -> tensor<3x4xf32>
1204+ // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
1205+ // CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32>
1206+ // CHECK: }
1207+ func.func @torch.aten.to.dtype$boolToFloat (%arg0: !torch.vtensor <[3 ,4 ],i1 >) -> !torch.vtensor <[3 ,4 ],f32 > {
1208+ %int6 = torch.constant.int 6
1209+ %false = torch.constant.bool false
1210+ %none = torch.constant.none
1211+ %0 = torch.aten.to.dtype %arg0 , %int6 , %false , %false , %none : !torch.vtensor <[3 ,4 ],i1 >, !torch.int , !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[3 ,4 ],f32 >
1212+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
1213+ }
1214+
1215+
11721216// -----
11731217// CHECK-LABEL: func.func @torch.aten.gather(
11741218// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,
0 commit comments