@@ -43,7 +43,7 @@ TEST(DivisorsAndPowers, Default) {
4343}
4444
4545std::vector<CudaMappingOptions> restoreCandidates (
46- const std::string& kernelName ,
46+ const std::string& tc ,
4747 std::vector<at::Tensor>& inputs,
4848 std::vector<at::Tensor>& outputs) {
4949 auto inputsPair = toConstDlpackTensors (inputs);
@@ -54,13 +54,23 @@ std::vector<CudaMappingOptions> restoreCandidates(
5454 });
5555
5656 return tc::autotune::restoreCandidates (
57- kernelName , inputsPair.first , outputsPair.first );
57+ tc , inputsPair.first , outputsPair.first );
5858}
5959
6060TEST (RestoreCandidates, NoCache) {
6161 std::vector<at::Tensor> inputs{at::CUDA (at::kFloat ).rand ({10 , 16 }),
6262 at::CUDA (at::kFloat ).rand ({16 , 20 })};
63- ASSERT_THROW (restoreCandidates (" bla" , inputs, inputs), std::runtime_error);
63+ static constexpr auto tc = R"(
64+ def tc2(float(M,N) A, float(N,K) B) -> (output) {
65+ output(m, k) +=! A(m, nn) * B(nn, k) + 1
66+ })" ;
67+ ASSERT_THROW (restoreCandidates (tc, inputs, inputs), std::runtime_error);
68+ }
69+
70+ TEST (RestoreCandidates, NotATCid) {
71+ std::vector<at::Tensor> inputs{at::CUDA (at::kFloat ).rand ({10 , 16 }),
72+ at::CUDA (at::kFloat ).rand ({16 , 20 })};
73+ ASSERT_THROW (restoreCandidates (" bla" , inputs, inputs), lang::ErrorReport);
6474}
6575
6676static constexpr auto tc_ = R"(
@@ -89,7 +99,7 @@ TEST(RestoreCandidates, NoRuntimeRecorded) {
8999 atCompl.run (" matmul" , inputs, outputs_, handle);
90100
91101 FLAGS_tuner_gen_restore_number = 1 ;
92- ASSERT_EQ (restoreCandidates (" matmul " , inputs, outputs_).size (), 0 );
102+ ASSERT_EQ (restoreCandidates (tc_ , inputs, outputs_).size (), 0 );
93103}
94104
95105TEST (RestoreCandidates, Hit) {
@@ -110,11 +120,11 @@ TEST(RestoreCandidates, Hit) {
110120 atCompl.run (" matmul" , inputs, outputs_, handle, true );
111121
112122 FLAGS_tuner_gen_restore_number = 2 ;
113- auto restored = restoreCandidates (" matmul " , inputs, outputs_);
123+ auto restored = restoreCandidates (tc_ , inputs, outputs_);
114124 ASSERT_EQ (restored.size (), 2 );
115125
116126 FLAGS_tuner_gen_restore_number = 1 ;
117- restored = restoreCandidates (" matmul " , inputs, outputs_);
127+ restored = restoreCandidates (tc_ , inputs, outputs_);
118128 ASSERT_EQ (restored.size (), 1 );
119129}
120130
0 commit comments