|
14 | 14 | #ifndef __VX_SPAWN_H__ |
15 | 15 | #define __VX_SPAWN_H__ |
16 | 16 |
|
17 | | -#include <VX_types.h> |
| 17 | +#include <vx_intrinsics.h> |
18 | 18 | #include <stdint.h> |
19 | 19 |
|
20 | 20 | #ifdef __cplusplus |
21 | 21 | extern "C" { |
22 | 22 | #endif |
23 | 23 |
|
24 | | -typedef void (*vx_spawn_tasks_cb)(int task_id, const void *arg); |
25 | | - |
26 | | -typedef void (*vx_spawn_task_groups_cb)(int local_task_id, int group_id, int local_group_id, int warps_per_group, const void *arg); |
| 24 | +typedef void (*vx_spawn_tasks_cb)(uint32_t task_id, const void *arg); |
27 | 25 |
|
28 | 26 | typedef void (*vx_serial_cb)(const void *arg); |
29 | 27 |
|
30 | | -void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, const void * arg); |
| 28 | +void vx_spawn_tasks(uint32_t num_tasks, vx_spawn_tasks_cb callback, const void * arg); |
31 | 29 |
|
32 | | -void vx_spawn_task_groups(int num_groups, int group_size, vx_spawn_task_groups_cb callback, const void * arg); |
| 30 | +void vx_serial(vx_serial_cb callback, void * arg); |
33 | 31 |
|
34 | | -inline void* vx_local_malloc(int local_group_id, int size) { |
35 | | - return (int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + local_group_id * size; |
36 | | -} |
| 32 | +/////////////////////////////////////////////////////////////////////////////// |
| 33 | + |
| 34 | +typedef union { |
| 35 | + struct { |
| 36 | + uint32_t x; |
| 37 | + uint32_t y; |
| 38 | + uint32_t z; |
| 39 | + }; |
| 40 | + uint32_t m[3]; |
| 41 | +} dim3_t; |
| 42 | + |
| 43 | +extern __thread dim3_t blockIdx; |
| 44 | +extern __thread dim3_t threadIdx; |
| 45 | +extern dim3_t gridDim; |
| 46 | +extern dim3_t blockDim; |
| 47 | + |
| 48 | +extern __thread uint32_t __local_group_id; |
| 49 | +extern uint32_t __groups_per_core; |
| 50 | +extern uint32_t __warps_per_group; |
| 51 | + |
| 52 | +typedef void (*vx_kernel_func_cb)(const void *arg); |
| 53 | + |
| 54 | +#define __local_mem(size) \ |
| 55 | + (void*)((int8_t*)csr_read(VX_CSR_LOCAL_MEM_BASE) + __local_group_id * size) |
| 56 | + |
| 57 | +#define __syncthreads() \ |
| 58 | + vx_barrier(__COUNTER__ * __groups_per_core + __local_group_id, __warps_per_group) |
37 | 59 |
|
38 | | -void vx_serial(vx_serial_cb callback, const void * arg); |
| 60 | +int vx_spawn_threads(uint32_t dimension, |
| 61 | + const uint32_t* grid_dim, |
| 62 | + const uint32_t* block_dim, |
| 63 | + vx_kernel_func_cb kernel_func, |
| 64 | + const void* arg); |
39 | 65 |
|
40 | 66 | #ifdef __cplusplus |
41 | 67 | } |
|
0 commit comments