1313 * See the License for the specific language governing permissions and
1414 * limitations under the License.
1515 */
16+ #include " tmm.h"
17+
1618#include < iostream>
1719#include < string>
1820#include < vector>
@@ -42,21 +44,22 @@ DEFINE_uint32(M, 32, "M dimension in C(m, n) += A(m, kk) * B(n, kk)");
4244DEFINE_uint32 (K, 256 , " K dimension in C(m, n) += A(m, kk) * B(n, kk)" );
4345
4446class TransposedMatMul : public Benchmark {
47+ protected:
48+ uint32_t M, N, K;
49+
4550 public:
46- void runTransposedMatMul (
47- uint32_t N,
48- uint32_t M,
49- uint32_t K,
50- const tc::CudaMappingOptions& options,
51- bool use_flags = false );
51+ void Init (uint32_t m, uint32_t n, uint32_t k) {
52+ M = m;
53+ N = n;
54+ K = k;
55+ }
56+ void runTransposedMatMul (const tc::CudaMappingOptions& options);
57+ void runATenTransposedMatMul ();
58+ void runCaffe2TransposedMatMul ();
5259};
5360
5461void TransposedMatMul::runTransposedMatMul (
55- uint32_t N,
56- uint32_t M,
57- uint32_t K,
58- const tc::CudaMappingOptions& options,
59- bool use_flags) {
62+ const tc::CudaMappingOptions& options) {
6063 at::Tensor A = at::CUDA (at::kFloat ).rand ({M, K});
6164 at::Tensor B = at::CUDA (at::kFloat ).rand ({N, K});
6265
@@ -82,134 +85,30 @@ def tmm(float(M,K) A, float(N,K) B) -> (C) {
8285 std::string suffix = std::string (" _M_" ) + std::to_string (FLAGS_M) +
8386 std::string (" _N_" ) + std::to_string (FLAGS_N) + std::string (" _K_" ) +
8487 std::to_string (FLAGS_K);
85- if (use_flags && FLAGS_validate_proto) {
86- validateProto (
88+ std::vector<tc::CudaMappingOptions> bestOptions{options};
89+ if (FLAGS_autotune) {
90+ bestOptions = autotune (
8791 FLAGS_save_tuner_proto_prefix + std::string (" /tmm_cache" ) + suffix,
92+ FLAGS_save_tuner_proto_prefix + std::string (" /tmm_best" ) + suffix,
8893 tc,
8994 " tmm" ,
9095 inputs,
96+ options,
9197 check_fun);
92- } else {
93- Check (tc, " tmm" , options, inputs, check_fun);
94- if (use_flags) {
95- autotune (
96- FLAGS_save_tuner_proto_prefix + std::string (" /tmm_cache" ) + suffix,
97- FLAGS_save_tuner_proto_prefix + std::string (" /tmm_best" ) + suffix,
98- tc,
99- " tmm" ,
100- inputs,
101- options,
102- check_fun);
103- }
98+ CHECK_GE (bestOptions.size (), 1u );
10499 }
100+ Check (tc, " tmm" , bestOptions[0 ], inputs, check_fun);
105101}
106102
107- TEST_F (TransposedMatMul, TransposedMatMul) {
108- auto N = FLAGS_N;
109- auto M = FLAGS_M;
110- auto K = FLAGS_K;
111- auto options = tc::CudaMappingOptions::makeNaiveMappingOptions ()
112- .fixParametersBeforeScheduling (true )
113- .tile (32 , 32 , 32 )
114- .mapToThreads ({32 , 32 })
115- .mapToBlocks ({M / 32 , N / 32 })
116- .useSharedMemory (true )
117- .usePrivateMemory (true )
118- .unroll (256 );
119- runTransposedMatMul (N, M, K, options, true );
120- }
121-
122- TEST_F (TransposedMatMul, TransposedMatMul_P100_autotuned_M_128_N_1024_K_1024) {
123- uint32_t M = 128 ;
124- uint32_t N = 1024 ;
125- uint32_t K = 1024 ;
126- auto options =
127- tc::CudaMappingOptions::makeNaiveMappingOptions ()
128- .outerScheduleFusionStrategy (tc::FusionStrategy::Preserve3Coincident)
129- .outerScheduleAllowSkewing (false )
130- .outerSchedulePositiveOrthant (true )
131- .intraTileScheduleFusionStrategy (
132- tc::FusionStrategy::Preserve3Coincident)
133- .intraTileScheduleAllowSkewing (false )
134- .intraTileSchedulePositiveOrthant (true )
135- .tile (1 , 32 )
136- .mapToThreads (64 , 4 )
137- .mapToBlocks (256 , 32 )
138- .unroll (256 )
139- .tileImperfectlyNested (false )
140- .useSharedMemory (true )
141- .usePrivateMemory (false )
142- .unrollCopyShared (true )
143- .matchLibraryCalls (true );
144- runTransposedMatMul (N, M, K, options);
145- }
146-
147- TEST_F (TransposedMatMul, TransposedMatMul_P100_autotuned_M_128_N_256_K_32) {
148- uint32_t M = 128 ;
149- uint32_t N = 256 ;
150- uint32_t K = 32 ;
151- auto options =
152- tc::CudaMappingOptions::makeNaiveMappingOptions ()
153- .outerScheduleFusionStrategy (tc::FusionStrategy::Preserve3Coincident)
154- .outerScheduleAllowSkewing (false )
155- .outerSchedulePositiveOrthant (true )
156- .intraTileScheduleFusionStrategy (
157- tc::FusionStrategy::Preserve3Coincident)
158- .intraTileScheduleAllowSkewing (false )
159- .intraTileSchedulePositiveOrthant (true )
160- .tile (8 , 32 )
161- .mapToThreads (64 )
162- .mapToBlocks (64 , 32 , 64 )
163- .unroll (64 )
164- .tileImperfectlyNested (false )
165- .useSharedMemory (true )
166- .usePrivateMemory (true )
167- .unrollCopyShared (false )
168- .matchLibraryCalls (false );
169- runTransposedMatMul (N, M, K, options);
170- }
171-
172- TEST_F (TransposedMatMul, TransposedMatMul_P100_autotuned_M_128_N_16384_K_4096) {
173- uint32_t M = 128 ;
174- uint32_t N = 16384 ;
175- uint32_t K = 4096 ;
176- auto options =
177- tc::CudaMappingOptions::makeNaiveMappingOptions ()
178- .outerScheduleFusionStrategy (tc::FusionStrategy::Preserve3Coincident)
179- .outerScheduleAllowSkewing (false )
180- .outerSchedulePositiveOrthant (true )
181- .intraTileScheduleFusionStrategy (
182- tc::FusionStrategy::Preserve3Coincident)
183- .intraTileScheduleAllowSkewing (false )
184- .intraTileSchedulePositiveOrthant (true )
185- .tile (32 , 32 , 2 )
186- .mapToThreads (32 )
187- .mapToBlocks (4 , 128 )
188- .unroll (8 )
189- .tileImperfectlyNested (false )
190- .useSharedMemory (true )
191- .usePrivateMemory (true )
192- .unrollCopyShared (false )
193- .matchLibraryCalls (false );
194- runTransposedMatMul (N, M, K, options);
195- }
196-
197- TEST_F (TransposedMatMul, ATenTransposedMatMulReference) {
198- auto N = FLAGS_N;
199- auto M = FLAGS_M;
200- auto K = FLAGS_K;
103+ void TransposedMatMul::runATenTransposedMatMul () {
201104 at::Tensor A = at::CUDA (at::kFloat ).rand ({M, K});
202105 at::Tensor B = at::CUDA (at::kFloat ).rand ({N, K});
203106 Reference (
204107 [&]() { return at::mm (A, B.t ()); },
205108 [&](at::Tensor& res) { at::mm_out (res, A, B.t ()); });
206109}
207110
208- TEST_F (TransposedMatMul, C2TransposedMatMulReference) {
209- auto N = FLAGS_N;
210- auto M = FLAGS_M;
211- auto K = FLAGS_K;
212-
111+ void TransposedMatMul::runCaffe2TransposedMatMul () {
213112 auto ws_init_func = [&](Workspace& w) {
214113 auto AddInput = AddDeterministicallyRandomInput<caffe2::CUDABackend, float >;
215114 AddInput (w, {M, K}, " I" );
@@ -220,11 +119,118 @@ TEST_F(TransposedMatMul, C2TransposedMatMulReference) {
220119 float precision = 0.0 ;
221120 std::unique_ptr<OpTester> reference (new OpTester (op_def, precision));
222121 reference->InitializeReference (ws_init_func, {{" trans_b" , 1 }});
223-
224122 Reference (
225123 [&]() { return true ; }, [&](bool flag) { reference->RunReference (); });
226124}
227125
126+ // Generic
127+ TEST_F (TransposedMatMul, TransposedMatMul) {
128+ Init (FLAGS_M, FLAGS_N, FLAGS_K);
129+ runTransposedMatMul (tc::CudaMappingOptions::makeNaiveMappingOptions ());
130+ }
131+
132+ // P100 TC
133+ TEST_F (TransposedMatMul, TransposedMatMul_P100_autotuned_M_128_N_256_K_32) {
134+ Init (128 , 256 , 32 );
135+ runTransposedMatMul (
136+ tc::options_TransposedMatMul_P100_autotuned_M_128_N_256_K_32);
137+ }
138+
139+ TEST_F (TransposedMatMul, TransposedMatMul_P100_autotuned_M_128_N_1024_K_1024) {
140+ Init (128 , 1024 , 1024 );
141+ runTransposedMatMul (
142+ tc::options_TransposedMatMul_P100_autotuned_M_128_N_1024_K_1024);
143+ }
144+
145+ TEST_F (TransposedMatMul, TransposedMatMul_P100_autotuned_M_128_N_16384_K_4096) {
146+ Init (128 , 16384 , 4096 );
147+ runTransposedMatMul (
148+ tc::options_TransposedMatMul_P100_autotuned_M_128_N_16384_K_4096);
149+ }
150+
151+ // P100 ATen
152+ TEST_F (TransposedMatMul, TransposedMatMul_ATen_P100_M_128_N_256_K_32) {
153+ Init (128 , 256 , 32 );
154+ runATenTransposedMatMul ();
155+ }
156+
157+ TEST_F (TransposedMatMul, TransposedMatMul_ATen_P100_M_128_N_1024_K_1024) {
158+ Init (128 , 1024 , 1024 );
159+ runATenTransposedMatMul ();
160+ }
161+
162+ TEST_F (TransposedMatMul, TransposedMatMul_ATen_P100_M_128_N_16384_K_4096) {
163+ Init (128 , 16384 , 4096 );
164+ runATenTransposedMatMul ();
165+ }
166+
167+ // P100 Caffe2
168+ TEST_F (TransposedMatMul, TransposedMatMul_Caffe2_P100_M_128_N_256_K_32) {
169+ Init (128 , 256 , 32 );
170+ runCaffe2TransposedMatMul ();
171+ }
172+
173+ TEST_F (TransposedMatMul, TransposedMatMul_Caffe2_P100_M_128_N_1024_K_1024) {
174+ Init (128 , 1024 , 1024 );
175+ runCaffe2TransposedMatMul ();
176+ }
177+
178+ TEST_F (TransposedMatMul, TransposedMatMul_Caffe2_P100_M_128_N_16384_K_4096) {
179+ Init (128 , 16384 , 4096 );
180+ runCaffe2TransposedMatMul ();
181+ }
182+
183+ // V100 TC
184+ TEST_F (TransposedMatMul, TransposedMatMul_V100_autotuned_M_128_N_256_K_32) {
185+ Init (128 , 256 , 32 );
186+ runTransposedMatMul (
187+ tc::options_TransposedMatMul_V100_autotuned_M_128_N_256_K_32);
188+ }
189+
190+ TEST_F (TransposedMatMul, TransposedMatMul_V100_autotuned_M_128_N_1024_K_1024) {
191+ Init (128 , 1024 , 1024 );
192+ runTransposedMatMul (
193+ tc::options_TransposedMatMul_V100_autotuned_M_128_N_1024_K_1024);
194+ }
195+
196+ TEST_F (TransposedMatMul, TransposedMatMul_V100_autotuned_M_128_N_16384_K_4096) {
197+ Init (128 , 16384 , 4096 );
198+ runTransposedMatMul (
199+ tc::options_TransposedMatMul_V100_autotuned_M_128_N_16384_K_4096);
200+ }
201+
202+ // V100 ATen
203+ TEST_F (TransposedMatMul, TransposedMatMul_ATen_V100_M_128_N_256_K_32) {
204+ Init (128 , 256 , 32 );
205+ runATenTransposedMatMul ();
206+ }
207+
208+ TEST_F (TransposedMatMul, TransposedMatMul_ATen_V100_M_128_N_1024_K_1024) {
209+ Init (128 , 1024 , 1024 );
210+ runATenTransposedMatMul ();
211+ }
212+
213+ TEST_F (TransposedMatMul, TransposedMatMul_ATen_V100_M_128_N_16384_K_4096) {
214+ Init (128 , 16384 , 4096 );
215+ runATenTransposedMatMul ();
216+ }
217+
218+ // V100 Caffe2
219+ TEST_F (TransposedMatMul, TransposedMatMul_Caffe2_V100_M_128_N_256_K_32) {
220+ Init (128 , 256 , 32 );
221+ runCaffe2TransposedMatMul ();
222+ }
223+
224+ TEST_F (TransposedMatMul, TransposedMatMul_Caffe2_V100_M_128_N_1024_K_1024) {
225+ Init (128 , 1024 , 1024 );
226+ runCaffe2TransposedMatMul ();
227+ }
228+
229+ TEST_F (TransposedMatMul, TransposedMatMul_Caffe2_V100_M_128_N_16384_K_4096) {
230+ Init (128 , 16384 , 4096 );
231+ runCaffe2TransposedMatMul ();
232+ }
233+
228234int main (int argc, char ** argv) {
229235 ::testing::InitGoogleTest (&argc, argv);
230236 ::gflags::ParseCommandLineFlags (&argc, &argv, true );
0 commit comments