Skip to content

Commit 8efa6db

Browse files
committed
Add new kernel
1 parent 2606451 commit 8efa6db

File tree

4 files changed

+475
-0
lines changed

4 files changed

+475
-0
lines changed

tests/opencl/kernel4/Makefile

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
ROOT_DIR := $(realpath ../../..)
2+
include $(ROOT_DIR)/config.mk
3+
4+
PROJECT := kernel4
5+
6+
SRC_DIR := $(VORTEX_HOME)/tests/opencl/$(PROJECT)
7+
8+
SRCS := $(SRC_DIR)/main.cc
9+
10+
kernel.cl: $(SRC_DIR)/kernel.cl
11+
cp $< $@
12+
13+
common.h: $(SRC_DIR)/common.h
14+
cp $< $@
15+
16+
KERNEL_SRCS := kernel.cl common.h
17+
18+
OPTS ?=
19+
20+
include ../common.mk

tests/opencl/kernel4/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#ifndef COMMON_H
2+
#define COMMON_H
3+
4+
#define TS 16
5+
#define WIDTH 8
6+
7+
#endif // COMMON_H

tests/opencl/kernel4/kernel.cl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#define TS 32
2+
#define WIDTH 4
3+
4+
#if WIDTH == 1
5+
typedef float floatX;
6+
#elif WIDTH == 2
7+
typedef float2 floatX;
8+
#elif WIDTH == 4
9+
typedef float4 floatX;
10+
#elif WIDTH == 8
11+
typedef float8 floatX;
12+
#endif
13+
14+
__kernel void myGEMM4(const int M, const int N, const int K,
15+
const __global floatX* A,
16+
const __global floatX* B,
17+
__global floatX* C) {
18+
19+
// Thread identifiers
20+
const int row = get_local_id(0); // Local row ID (max: TS/WIDTH)
21+
const int col = get_local_id(1); // Local col ID (max: TS)
22+
const int globalRow = (TS/WIDTH)*get_group_id(0) + row; // Row ID of C (0..M/WIDTH)
23+
const int globalCol = TS*get_group_id(1) + col; // Col ID of C (0..N)
24+
25+
// Local memory to fit a tile of TS*TS elements of A and B
26+
__local floatX Asub[TS][TS/WIDTH];
27+
__local floatX Bsub[TS][TS/WIDTH];
28+
29+
// Initialise the accumulation registers
30+
#if WIDTH == 1
31+
floatX acc = 0.0f;
32+
#elif WIDTH == 2
33+
floatX acc = { 0.0f, 0.0f };
34+
#elif WIDTH == 4
35+
floatX acc = { 0.0f, 0.0f, 0.0f, 0.0f };
36+
#elif WIDTH == 8
37+
floatX acc = { 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f };
38+
#endif
39+
40+
// Loop over all tiles
41+
const int numTiles = K/TS;
42+
for (int tile=0; tile<numTiles; tile++) {
43+
44+
// Load one tile of A and B into local memory
45+
const int tiledRow = (TS/WIDTH)*tile + row;
46+
const int tiledCol = TS*tile + col;
47+
Asub[col][row] = A[tiledCol*(M/WIDTH) + globalRow];
48+
Bsub[col][row] = B[globalCol*(K/WIDTH) + tiledRow];
49+
50+
// Synchronise to make sure the tile is loaded
51+
barrier(CLK_LOCAL_MEM_FENCE);
52+
53+
// Perform the computation for a single tile
54+
floatX vecA, vecB;
55+
float valB;
56+
for (int k=0; k<TS/WIDTH; k++) {
57+
vecB = Bsub[col][k];
58+
for (int w=0; w<WIDTH; w++) {
59+
vecA = Asub[WIDTH*k + w][row];
60+
#if WIDTH == 1
61+
valB = vecB;
62+
acc += vecA * valB;
63+
#elif WIDTH == 2
64+
switch (w) {
65+
case 0: valB = vecB.x; break;
66+
case 1: valB = vecB.y; break;
67+
}
68+
acc.x += vecA.x * valB;
69+
acc.y += vecA.y * valB;
70+
#elif WIDTH == 4
71+
switch (w) {
72+
case 0: valB = vecB.x; break;
73+
case 1: valB = vecB.y; break;
74+
case 2: valB = vecB.z; break;
75+
case 3: valB = vecB.w; break;
76+
}
77+
acc.x += vecA.x * valB;
78+
acc.y += vecA.y * valB;
79+
acc.z += vecA.z * valB;
80+
acc.w += vecA.w * valB;
81+
#elif WIDTH == 8
82+
switch (w) {
83+
case 0: valB = vecB.s0; break;
84+
case 1: valB = vecB.s1; break;
85+
case 2: valB = vecB.s2; break;
86+
case 3: valB = vecB.s3; break;
87+
case 4: valB = vecB.s4; break;
88+
case 5: valB = vecB.s5; break;
89+
case 6: valB = vecB.s6; break;
90+
case 7: valB = vecB.s7; break;
91+
}
92+
acc.s0 += vecA.s0 * valB;
93+
acc.s1 += vecA.s1 * valB;
94+
acc.s2 += vecA.s2 * valB;
95+
acc.s3 += vecA.s3 * valB;
96+
acc.s4 += vecA.s4 * valB;
97+
acc.s5 += vecA.s5 * valB;
98+
acc.s6 += vecA.s6 * valB;
99+
acc.s7 += vecA.s7 * valB;
100+
#endif
101+
}
102+
}
103+
104+
// Synchronise before loading the next tile
105+
barrier(CLK_LOCAL_MEM_FENCE);
106+
}
107+
108+
// Store the final results in C
109+
C[globalCol*(M/WIDTH) + globalRow] = acc;
110+
}
111+

0 commit comments

Comments
 (0)