@@ -753,6 +753,39 @@ def perforatedConvolution(float(N, C, H, W) input, float(M, C, KH, KW) weights,
753753 Prepare (tc);
754754}
755755
756+ TEST_F (PolyhedralMapperTest, ModulusConstantRHS) {
757+ string tc = R"TC(
758+ def fun(float(N) a) -> (b) { b(i) = a(i % 2) where i in 0:N }
759+ )TC" ;
760+ // This triggers tc2halide conversion and should not throw.
761+ auto scop = Prepare (tc);
762+ for (auto r : scop->reads .range ().get_set_list ()) {
763+ // skip irrelevant reads, if any
764+ if (r.get_tuple_name () != std::string (" a" )) {
765+ continue ;
766+ }
767+ // std::cout << "Got stride: " << r.get_stride() << std::endl;
768+ // EXPECT_EQ(r.get_stride(), 2);
769+ }
770+ }
771+
772+ TEST_F (PolyhedralMapperTest, ModulusVariableRHS) {
773+ string tc = R"TC(
774+ def local_sparse_convolution(float(N, C, H, W) I, float(O, KC, KH, KW) W1) -> (O1) {
775+ O1(n, o, h, w) +=! I(n, kc % c, h + kh, w + kw) * W1(o, kc, kh, kw) where c in 1:C
776+ }
777+ )TC" ;
778+ // This triggers tc2halide conversion and should not throw.
779+ auto scop = Prepare (tc);
780+ for (auto r : scop->reads .range ().get_set_list ()) {
781+ // skip irrelevant reads, if any
782+ if (r.get_tuple_name () != std::string (" I" )) {
783+ continue ;
784+ }
785+ EXPECT_TRUE (r.plain_is_universe ());
786+ }
787+ }
788+
756789int main (int argc, char ** argv) {
757790 ::testing::InitGoogleTest (&argc, argv);
758791 ::gflags::ParseCommandLineFlags (&argc, &argv, true );
0 commit comments