@@ -3498,3 +3498,101 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1
34983498 torch.bind_symbolic_shape %3 , [%0 ], affine_map <()[s0 ] -> (s0 )> : !torch.vtensor <[?],f32 >
34993499 return %3 : !torch.vtensor <[?],f32 >
35003500}
3501+
3502+ // -----
3503+
3504+ // CHECK-LABEL: func.func @ttorch.aten.ones$float_fold() -> !torch.vtensor<[2,3,4],f32> {
3505+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32>
3506+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32>
3507+ // CHECK: }
3508+ func.func @ttorch.aten.ones$float_fold () -> !torch.vtensor <[2 ,3 ,4 ],f32 > {
3509+ %int2 = torch.constant.int 2
3510+ %int3 = torch.constant.int 3
3511+ %int4 = torch.constant.int 4
3512+ %none = torch.constant.none
3513+ %0 = torch.prim.ListConstruct %int2 , %int3 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3514+ %1 = torch.aten.ones %0 , %none , %none , %none , %none : !torch.list <int >, !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,3 ,4 ],f32 >
3515+ return %1 : !torch.vtensor <[2 ,3 ,4 ],f32 >
3516+ }
3517+
3518+ // -----
3519+
3520+ // CHECK-LABEL: func.func @ttorch.aten.ones$int_fold() -> !torch.vtensor<[2,3,4],si64> {
3521+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64>
3522+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64>
3523+ // CHECK: }
3524+ func.func @ttorch.aten.ones$int_fold () -> !torch.vtensor <[2 ,3 ,4 ],si64 > {
3525+ %int2 = torch.constant.int 2
3526+ %int3 = torch.constant.int 3
3527+ %int4 = torch.constant.int 4
3528+ %none = torch.constant.none
3529+ %0 = torch.prim.ListConstruct %int2 , %int3 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3530+ %1 = torch.aten.ones %0 , %none , %none , %none , %none : !torch.list <int >, !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,3 ,4 ],si64 >
3531+ return %1 : !torch.vtensor <[2 ,3 ,4 ],si64 >
3532+ }
3533+
3534+ // -----
3535+
3536+ // CHECK-LABEL: func.func @test_aten_zeros$float_fold() -> !torch.vtensor<[2,3,4],f32> {
3537+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32>
3538+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32>
3539+ // CHECK: }
3540+ func.func @test_aten_zeros$float_fold () -> !torch.vtensor <[2 ,3 ,4 ],f32 > {
3541+ %int2 = torch.constant.int 2
3542+ %int3 = torch.constant.int 3
3543+ %int4 = torch.constant.int 4
3544+ %none = torch.constant.none
3545+ %0 = torch.prim.ListConstruct %int2 , %int3 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3546+ %1 = torch.aten.zeros %0 , %none , %none , %none , %none : !torch.list <int >, !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,3 ,4 ],f32 >
3547+ return %1 : !torch.vtensor <[2 ,3 ,4 ],f32 >
3548+ }
3549+
3550+ // -----
3551+
3552+ // CHECK-LABEL: func.func @test_aten_zeros$int_fold() -> !torch.vtensor<[2,3,4],si64> {
3553+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64>
3554+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64>
3555+ // CHECK: }
3556+ func.func @test_aten_zeros$int_fold () -> !torch.vtensor <[2 ,3 ,4 ],si64 > {
3557+ %int2 = torch.constant.int 2
3558+ %int3 = torch.constant.int 3
3559+ %int4 = torch.constant.int 4
3560+ %none = torch.constant.none
3561+ %0 = torch.prim.ListConstruct %int2 , %int3 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3562+ %1 = torch.aten.zeros %0 , %none , %none , %none , %none : !torch.list <int >, !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,3 ,4 ],si64 >
3563+ return %1 : !torch.vtensor <[2 ,3 ,4 ],si64 >
3564+ }
3565+
3566+ // -----
3567+
3568+ // CHECK-LABEL: func.func @torch.aten.full$float_fold() -> !torch.vtensor<[2,1,4],f32> {
3569+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0xFF800000> : tensor<2x1x4xf32>) : !torch.vtensor<[2,1,4],f32>
3570+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],f32>
3571+ // CHECK: }
3572+ func.func @torch.aten.full$float_fold () -> !torch.vtensor <[2 ,1 ,4 ],f32 > {
3573+ %float -Inf = torch.constant.float 0xFFF0000000000000
3574+ %int2 = torch.constant.int 2
3575+ %int1 = torch.constant.int 1
3576+ %int4 = torch.constant.int 4
3577+ %none = torch.constant.none
3578+ %0 = torch.prim.ListConstruct %int2 , %int1 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3579+ %1 = torch.aten.full %0 , %float -Inf , %none , %none , %none , %none : !torch.list <int >, !torch.float , !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,1 ,4 ],f32 >
3580+ return %1 : !torch.vtensor <[2 ,1 ,4 ],f32 >
3581+ }
3582+
3583+ // -----
3584+
3585+ // CHECK-LABEL: func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> {
3586+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x1x4xsi64>) : !torch.vtensor<[2,1,4],si64>
3587+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],si64>
3588+ // CHECK: }
3589+ func.func @torch.aten.full$int_fold () -> !torch.vtensor <[2 ,1 ,4 ],si64 > {
3590+ %int -Inf = torch.constant.int 0
3591+ %int2 = torch.constant.int 2
3592+ %int1 = torch.constant.int 1
3593+ %int4 = torch.constant.int 4
3594+ %none = torch.constant.none
3595+ %0 = torch.prim.ListConstruct %int2 , %int1 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3596+ %1 = torch.aten.full %0 , %int -Inf , %none , %none , %none , %none : !torch.list <int >, !torch.int , !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,1 ,4 ],si64 >
3597+ return %1 : !torch.vtensor <[2 ,1 ,4 ],si64 >
3598+ }
0 commit comments