Skip to content

Commit 627046f

Browse files
committed
Calculate load average of AI algorithm
Introduce the calculation of load average for mcts algorithm, taking the active nodes in the current mcts exeuction context as calculation units. The function "mcts_calc_load()" takes reference from the method of counting cpu load average in linux kernel, the concept is EWMA. Utilize EWMA to calculate the load average of MCTS algorithm so we can know how much recources the algorithm context is consuming.
1 parent c2cf2f6 commit 627046f

File tree

4 files changed

+48
-9
lines changed

4 files changed

+48
-9
lines changed

kmldrv-user.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,16 @@ static void listen_keyboard_handler(void)
6565
case 16:
6666
read(attr_fd, buf, 6);
6767
buf[0] = (buf[0] - '0') ? '0' : '1';
68-
read_attr = !read_attr;
68+
read_attr ^= 1;
6969
write(attr_fd, buf, 6);
70+
printf("Stopping to display the chess board...\n");
7071
break;
7172
case 17:
7273
read(attr_fd, buf, 6);
7374
buf[4] = '1';
75+
read_attr = false;
7476
write(attr_fd, buf, 6);
77+
printf("Stopping the kernel space tic-tac-toe game...\n");
7578
break;
7679
}
7780
}

mcts.c

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
#include <linux/sched/loadavg.h>
12
#include <linux/slab.h>
23
#include <linux/string.h>
34

45
#include "game.h"
56
#include "mcts.h"
67
#include "util.h"
78
// #include "wyhash.h"
8-
#include "xoroshiro.h"
99

1010
struct node {
1111
int move;
@@ -16,7 +16,7 @@ struct node {
1616
struct node *children[N_GRIDS];
1717
};
1818

19-
static struct state_array xoro_obj;
19+
static struct mcts_info mcts_obj;
2020

2121
static struct node *new_node(int move, char player, struct node *parent)
2222
{
@@ -125,7 +125,7 @@ static fixed_point_t simulate(char *table, char player)
125125
char current_player = player;
126126
char temp_table[N_GRIDS];
127127
memcpy(temp_table, table, N_GRIDS);
128-
xoro_jump(&xoro_obj);
128+
xoro_jump(&(mcts_obj.xoro_obj));
129129
while (1) {
130130
int *moves = available_moves(temp_table);
131131
if (moves[0] == -1) {
@@ -136,7 +136,7 @@ static fixed_point_t simulate(char *table, char player)
136136
while (n_moves < N_GRIDS && moves[n_moves] != -1)
137137
++n_moves;
138138
// int move = moves[wyhash64() % n_moves];
139-
int move = moves[xoro_next(&xoro_obj) % n_moves];
139+
int move = moves[xoro_next(&(mcts_obj.xoro_obj)) % n_moves];
140140
kfree(moves);
141141
temp_table[move] = current_player;
142142
char win;
@@ -157,7 +157,7 @@ static void backpropagate(struct node *node, fixed_point_t score)
157157
}
158158
}
159159

160-
static void expand(struct node *node, char *table)
160+
static int expand(struct node *node, char *table)
161161
{
162162
int *moves = available_moves(table);
163163
int n_moves = 0;
@@ -167,12 +167,14 @@ static void expand(struct node *node, char *table)
167167
node->children[i] = new_node(moves[i], node->player ^ 'O' ^ 'X', node);
168168
}
169169
kfree(moves);
170+
return n_moves;
170171
}
171172

172173
int mcts(char *table, char player)
173174
{
174175
char win;
175176
struct node *root = new_node(-1, player, NULL);
177+
mcts_obj.nr_active_nodes = 1;
176178
for (int i = 0; i < ITERATIONS; i++) {
177179
struct node *node = root;
178180
char temp_table[N_GRIDS];
@@ -190,7 +192,7 @@ int mcts(char *table, char player)
190192
break;
191193
}
192194
if (node->children[0] == NULL)
193-
expand(node, temp_table);
195+
mcts_obj.nr_active_nodes += expand(node, temp_table);
194196
node = select_move(node);
195197
if (!node)
196198
return -1;
@@ -210,7 +212,13 @@ int mcts(char *table, char player)
210212
return best_move;
211213
}
212214

215+
unsigned long count_active_nodes(void)
216+
{
217+
return mcts_obj.nr_active_nodes;
218+
}
219+
213220
void mcts_init(void)
214221
{
215-
xoro_init(&xoro_obj);
222+
xoro_init(&(mcts_obj.xoro_obj));
223+
mcts_obj.nr_active_nodes = 0;
216224
}

mcts.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
#pragma once
22

3+
#include "xoroshiro.h"
4+
35
#define ITERATIONS 100000
46
#define EXPLORATION_FACTOR fixed_sqrt(1U << (FIXED_SCALE_BITS + 1))
57

8+
struct mcts_info {
9+
struct state_array xoro_obj;
10+
int nr_active_nodes;
11+
};
12+
13+
unsigned long count_active_nodes(void);
614
int mcts(char *table, char player);
715
void mcts_init(void);

simrupt.c

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <linux/interrupt.h>
88
#include <linux/kfifo.h>
99
#include <linux/module.h>
10+
#include <linux/sched/loadavg.h>
1011
#include <linux/slab.h>
1112
#include <linux/sysfs.h>
1213
#include <linux/version.h>
@@ -117,7 +118,7 @@ static DEFINE_MUTEX(consumer_lock);
117118
*/
118119
static struct circ_buf fast_buf;
119120

120-
121+
static unsigned long mcts_avennode[3];
121122
static char table[N_GRIDS];
122123

123124
/* Draw the board into draw_buffer */
@@ -191,6 +192,23 @@ static void drawboard_work_func(struct work_struct *w)
191192
wake_up_interruptible(&rx_wait);
192193
}
193194

195+
static void mcts_calc_load(struct work_struct *w)
196+
{
197+
unsigned long active_nodes;
198+
199+
active_nodes = count_active_nodes() * FIXED_1;
200+
mcts_avennode[0] = calc_load(mcts_avennode[0], EXP_1, active_nodes);
201+
mcts_avennode[1] = calc_load(mcts_avennode[1], EXP_5, active_nodes);
202+
mcts_avennode[2] = calc_load(mcts_avennode[2], EXP_15, active_nodes);
203+
204+
int a = mcts_avennode[0] + (FIXED_1 / 200);
205+
int b = mcts_avennode[1] + (FIXED_1 / 200);
206+
int c = mcts_avennode[2] + (FIXED_1 / 200);
207+
208+
pr_info("kmldrv: [MCTS LoadAvg] %d.%02d %d.%02d %d.%02d\n", LOAD_INT(a),
209+
LOAD_FRAC(a), LOAD_INT(b), LOAD_FRAC(b), LOAD_INT(c), LOAD_FRAC(c));
210+
}
211+
194212
static char turn;
195213
static int finish;
196214

@@ -271,6 +289,7 @@ static struct workqueue_struct *kmldrv_workqueue;
271289
static DECLARE_WORK(drawboard_work, drawboard_work_func);
272290
static DECLARE_WORK(ai_one_work, ai_one_work_func);
273291
static DECLARE_WORK(ai_two_work, ai_two_work_func);
292+
static DECLARE_WORK(mcts_calc_load_work, mcts_calc_load);
274293

275294
/* Tasklet handler.
276295
*
@@ -301,6 +320,7 @@ static void game_tasklet_func(unsigned long __data)
301320
smp_wmb();
302321
queue_work(kmldrv_workqueue, &ai_two_work);
303322
}
323+
queue_work(kmldrv_workqueue, &mcts_calc_load_work);
304324
queue_work(kmldrv_workqueue, &drawboard_work);
305325
tv_end = ktime_get();
306326

0 commit comments

Comments
 (0)