Skip to content

Commit 96cb381

Browse files
committed
vx_spawn_threads implementation
1 parent 8c5a783 commit 96cb381

File tree

5 files changed

+208
-152
lines changed

5 files changed

+208
-152
lines changed

kernel/include/vx_spawn.h

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,54 @@
1414
#ifndef __VX_SPAWN_H__
1515
#define __VX_SPAWN_H__
1616

17-
#include <VX_types.h>
17+
#include <vx_intrinsics.h>
1818
#include <stdint.h>
1919

2020
#ifdef __cplusplus
2121
extern "C" {
2222
#endif
2323

24-
typedef void (*vx_spawn_tasks_cb)(int task_id, const void *arg);
25-
26-
typedef void (*vx_spawn_task_groups_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, const void *arg);
24+
typedef void (*vx_spawn_tasks_cb)(uint32_t task_id, const void *arg);
2725

2826
typedef void (*vx_serial_cb)(const void *arg);
2927

30-
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, const void * arg);
28+
void vx_spawn_tasks(uint32_t num_tasks, vx_spawn_tasks_cb callback, const void * arg);
3129

32-
void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, const void * arg);
30+
void vx_serial(vx_serial_cb callback, void * arg);
3331

34-
inline void* vx_local_malloc(int local_group_id, int size) {
35-
return (int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + local_group_id * size;
36-
}
32+
///////////////////////////////////////////////////////////////////////////////
33+
34+
typedef union {
35+
struct {
36+
uint32_t x;
37+
uint32_t y;
38+
uint32_t z;
39+
};
40+
uint32_t m[3];
41+
} dim3_t;
42+
43+
extern __thread dim3_t blockIdx;
44+
extern __thread dim3_t threadIdx;
45+
extern dim3_t gridDim;
46+
extern dim3_t blockDim;
47+
48+
extern __thread uint32_t __local_group_id;
49+
extern uint32_t __groups_per_core;
50+
extern uint32_t __warps_per_group;
51+
52+
typedef void (*vx_kernel_func_cb)(const void *arg);
53+
54+
#define __local_mem(size) \
55+
(void*)((int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + __local_group_id * size)
56+
57+
#define __syncthreads() \
58+
vx_barrier(__COUNTER__ * __groups_per_core + __local_group_id, __warps_per_group)
3759

38-
void vx_serial(vx_serial_cb callback, const void * arg);
60+
int vx_spawn_threads(uint32_t dimension,
61+
const uint32_t* grid_dim,
62+
const uint32_t* block_dim,
63+
vx_kernel_func_cb kernel_func,
64+
const void* arg);
3965

4066
#ifdef __cplusplus
4167
}

0 commit comments

Comments
 (0)