2222#include " tc/aten/aten_compiler.h"
2323#include " tc/core/cuda/cuda.h"
2424#include " tc/core/cuda/cuda_tc_executor.h"
25+ #include " tc/core/exceptions.h"
2526#include " tc/core/scope_guard.h"
2627#include " tc/lang/canonicalize.h"
2728#include " tc/lang/sema.h"
@@ -306,10 +307,10 @@ TEST_F(TcCudaMapperTest, TensorAddStrided) {
306307 M = 64 ;
307308 at::Tensor I0 = at::CUDA (at::kFloat ).rand ({N, M});
308309 at::Tensor I0_view =
309- I0.type ().tensor ().set_ (*I0.storage (), 0 , {N, M}, {1 , 16 });
310+ I0.type ().tensor ().set_ (*I0.storage (), 0 , {N, M}, {128 , 1 });
310311 at::Tensor I1 = at::CUDA (at::kFloat ).rand ({N, M});
311312 at::Tensor I1_view =
312- I1.type ().tensor ().set_ (*I1.storage (), 0 , {N, M}, {1 , 16 });
313+ I1.type ().tensor ().set_ (*I1.storage (), 0 , {N, M}, {128 , 1 });
313314 std::vector<at::Tensor> inputs = {I0_view, I1_view};
314315
315316 static constexpr auto TC = R"TC(
@@ -327,12 +328,41 @@ def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
327328 std::string expected =
328329 " const float32 (*I0_view)[64] = "
329330 " reinterpret_cast<const float32 (*)[64]>(pI0_view)" ;
330-
331331 ASSERT_NE (std::string::npos, res.second .find (expected))
332332 << " In resulting code:\n "
333333 << res.second << " \n found unexpected: " << expected;
334334}
335335
336+ // /////////////////////////////////////////////////////////////////////////////
337+ // TensorAddInvalidStrides
338+ // O(n, m) += I0_view(n, m) * I1_view(n, m)
339+ // /////////////////////////////////////////////////////////////////////////////
340+ TEST_F (TcCudaMapperTest, TensorAddInvalidStrides) {
341+ N = 64 ;
342+ M = 64 ;
343+ at::Tensor I0 = at::CUDA (at::kFloat ).rand ({N, M});
344+ at::Tensor I0_view =
345+ I0.type ().tensor ().set_ (*I0.storage (), 0 , {N, M}, {16 , 1 });
346+ at::Tensor I1 = at::CUDA (at::kFloat ).rand ({N, M});
347+ at::Tensor I1_view =
348+ I1.type ().tensor ().set_ (*I1.storage (), 0 , {N, M}, {16 , 1 });
349+ std::vector<at::Tensor> inputs = {I0_view, I1_view};
350+
351+ static constexpr auto TC = R"TC(
352+ def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
353+ O(n, m) += I0_view(n, m) + I1_view(n, m)
354+ }
355+ )TC" ;
356+
357+ auto checkFun = [](const std::vector<at::Tensor>& ins,
358+ std::vector<at::Tensor>& outs) { return true ; };
359+ auto options = tc::CudaMappingOptions::makeNaiveMappingOptions ();
360+ auto name = " tensoraddstrided" ;
361+
362+ EXPECT_THROW (
363+ Check (TC, name, options, inputs, checkFun), tc::InvalidStrideException);
364+ }
365+
336366// /////////////////////////////////////////////////////////////////////////////
337367// Lookup Table
338368// O(b, n) +=! LUT(I(b, n), r_r)
0 commit comments