@@ -9,7 +9,8 @@ TEST(Converters, ATenEinsumConvertsMatMulCorrectly) {
99 graph(%x.1 : Tensor, %x.2 : Tensor):
1010 %0 : str = prim::Constant[value="ij,jk->ik"]()
1111 %3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
12- %4 : Tensor = aten::einsum(%0, %3)
12+ %none : NoneType = prim::Constant()
13+ %4 : Tensor = aten::einsum(%0, %3, %none)
1314 return (%4))IR" ;
1415
1516 auto g = std::make_shared<torch::jit::Graph>();
@@ -34,7 +35,8 @@ TEST(Converters, ATenEinsumConvertsElementwiseProdCorrectly) {
3435 graph(%x.1 : Tensor, %x.2 : Tensor):
3536 %0 : str = prim::Constant[value="abcd,abcd->abcd"]()
3637 %3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
37- %4 : Tensor = aten::einsum(%0, %3)
38+ %none : NoneType = prim::Constant()
39+ %4 : Tensor = aten::einsum(%0, %3, %none)
3840 return (%4))IR" ;
3941
4042 auto g = std::make_shared<torch::jit::Graph>();
@@ -59,7 +61,8 @@ TEST(Converters, ATenEinsumConvertsTransposeCorrectly) {
5961 graph(%x.1 : Tensor):
6062 %0 : str = prim::Constant[value="jk->kj"]()
6163 %3 : Tensor[] = prim::ListConstruct(%x.1)
62- %4 : Tensor = aten::einsum(%0, %3)
64+ %none : NoneType = prim::Constant()
65+ %4 : Tensor = aten::einsum(%0, %3, %none)
6366 return (%4))IR" ;
6467
6568 auto g = std::make_shared<torch::jit::Graph>();
@@ -83,7 +86,8 @@ TEST(Converters, ATenEinsumConvertsVectorsCorrectly) {
8386 graph(%x.1 : Tensor, %x.2 : Tensor):
8487 %0 : str = prim::Constant[value="a,b->ab"]()
8588 %3 : Tensor[] = prim::ListConstruct(%x.1, %x.2)
86- %4 : Tensor = aten::einsum(%0, %3)
89+ %none : NoneType = prim::Constant()
90+ %4 : Tensor = aten::einsum(%0, %3, %none)
8791 return (%4))IR" ;
8892
8993 auto g = std::make_shared<torch::jit::Graph>();
0 commit comments