@@ -182,8 +182,8 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
182182 R"RES( int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
183183 int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
184184 float32 (*C)[M] = reinterpret_cast<float32 (*)[M]>(pC);
185- float32 (*A)[M] = reinterpret_cast<float32 (*)[M]>(pA);
186- float32 (*B)[M] = reinterpret_cast<float32 (*)[M]>(pB);
185+ const float32 (*A)[M] = reinterpret_cast<const float32 (*)[M]>(pA);
186+ const float32 (*B)[M] = reinterpret_cast<const float32 (*)[M]>(pB);
187187 for (int c1 = 16 * b1; c1 < M; c1 += 4096) {
188188 if (M >= t1 + c1 + 1) {
189189 C[(t0 + 16 * b0)][(t1 + c1)] = (A[(t0 + 16 * b0)][(t1 + c1)] + B[(t0 + 16 * b0)][(t1 + c1)]);
@@ -219,10 +219,10 @@ def fun(float(N, N, N, N) A, float(N, N) B, float(N, N) C, float(N, N) D)
219219 float32 (*O1)[N] = reinterpret_cast<float32 (*)[N]>(pO1);
220220 float32 (*O2)[N] = reinterpret_cast<float32 (*)[N]>(pO2);
221221 float32 (*O3)[N] = reinterpret_cast<float32 (*)[N]>(pO3);
222- float32 (*A)[N][N][N] = reinterpret_cast<float32 (*)[N][N][N]>(pA);
223- float32 (*B)[N] = reinterpret_cast<float32 (*)[N]>(pB);
224- float32 (*C)[N] = reinterpret_cast<float32 (*)[N]>(pC);
225- float32 (*D)[N] = reinterpret_cast<float32 (*)[N]>(pD);
222+ const float32 (*A)[N][N][N] = reinterpret_cast<const float32 (*)[N][N][N]>(pA);
223+ const float32 (*B)[N] = reinterpret_cast<const float32 (*)[N]>(pB);
224+ const float32 (*C)[N] = reinterpret_cast<const float32 (*)[N]>(pC);
225+ const float32 (*D)[N] = reinterpret_cast<const float32 (*)[N]>(pD);
226226 for (int c0 = 0; c0 < N; c0 += 1) {
227227 for (int c1 = 0; c1 < N; c1 += 1) {
228228 O1[c0][c1] = 0.000000f;
@@ -261,11 +261,11 @@ def fun(float(N, N) A) -> (O)
261261 auto res = std::get<0 >(mscop->codegen (specializedName));
262262
263263 string expected (
264- R"RES( __global__ void kernel_anon(int32 N, float32* pO, float32* pA) {
264+ R"RES( __global__ void kernel_anon(int32 N, float32* pO, const float32* pA) {
265265 int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
266266 int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
267267 float32 (*O)[N] = reinterpret_cast<float32 (*)[N]>(pO);
268- float32 (*A)[N] = reinterpret_cast<float32 (*)[N]>(pA);
268+ const float32 (*A)[N] = reinterpret_cast<const float32 (*)[N]>(pA);
269269 for (int c0 = 0; c0 < N; c0 += 1) {
270270 for (int c1 = 0; c1 < N; c1 += 1) {
271271 O[c0][c1] = (((A[c0][c1] + float32(c0)) + float32(c1)) + float32(N));
@@ -290,13 +290,13 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O)
290290 auto res = std::get<0 >(mscop->codegen (specializedName));
291291
292292 string expected =
293- R"RES( __global__ void kernel_anon(int32 N, float32* pO, float32* pA, float32* pB, float32* pC) {
293+ R"RES( __global__ void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) {
294294 int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
295295 int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
296296 float32 (*O)[512] = reinterpret_cast<float32 (*)[512]>(pO);
297- float32 (*A)[512] = reinterpret_cast<float32 (*)[512]>(pA);
298- float32 (*B)[512] = reinterpret_cast<float32 (*)[512]>(pB);
299- float32 (*C) = reinterpret_cast<float32 (*)>(pC);
297+ const float32 (*A)[512] = reinterpret_cast<const float32 (*)[512]>(pA);
298+ const float32 (*B)[512] = reinterpret_cast<const float32 (*)[512]>(pB);
299+ const float32 (*C) = reinterpret_cast<const float32 (*)>(pC);
300300 for (int c0 = 0; c0 <= 511; c0 += 1) {
301301 for (int c1 = 0; c1 <= 511; c1 += 1) {
302302 O[c0][c1] = (nextafter(C[c0], exp(A[c0][c1])) + log(B[c1][c0]));
@@ -312,8 +312,8 @@ constexpr auto kExpectedMatmul_64_64_64 =
312312 R"CUDA( int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
313313 int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
314314 float32 (*O)[64] = reinterpret_cast<float32 (*)[64]>(pO);
315- float32 (*A)[64] = reinterpret_cast<float32 (*)[64]>(pA);
316- float32 (*B)[64] = reinterpret_cast<float32 (*)[64]>(pB);
315+ const float32 (*A)[64] = reinterpret_cast<const float32 (*)[64]>(pA);
316+ const float32 (*B)[64] = reinterpret_cast<const float32 (*)[64]>(pB);
317317 for (int c0 = 0; c0 <= 63; c0 += 16) {
318318 for (int c1 = 0; c1 <= 63; c1 += 16) {
319319 for (int c2 = t1; c2 <= 15; c2 += 8) {
0 commit comments