Skip to content

Commit db31b64

Browse files
authored
Merge pull request #133 from kibitzing/update_tutorial
Update OpenEnv tutorial
2 parents dd10e47 + 7c262f1 commit db31b64

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

examples/OpenEnv_Tutorial.ipynb

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -691,9 +691,14 @@
691691
"\n",
692692
"```\n",
693693
"⬜ ⬜ 🔴 ⬜ ⬜\n",
694+
"⬜ ⬜ ⬜ ⬜ ⬜\n",
694695
"⬜ ⬜ ⬜ ⬜ ⬜ Ball\n",
696+
"⬜ ⬜ ⬜ ⬜ ⬜\n",
695697
"⬜ ⬜ ⬜ ⬜ ⬜ falls\n",
698+
"⬜ ⬜ ⬜ ⬜ ⬜\n",
696699
"⬜ ⬜ ⬜ ⬜ ⬜ down\n",
700+
"⬜ ⬜ ⬜ ⬜ ⬜\n",
701+
"⬜ ⬜ ⬜ ⬜ ⬜\n",
697702
"⬜ ⬜ 🏓 ⬜ ⬜\n",
698703
" Paddle\n",
699704
"```\n",
@@ -702,7 +707,7 @@
702707
"<td width=\"60%\">\n",
703708
"\n",
704709
"**Rules:**\n",
705-
"- 5×5 grid\n",
710+
"- 10×5 grid\n",
706711
"- Ball falls from random column\n",
707712
"- Move paddle left/right to catch it\n",
708713
"\n",
@@ -817,8 +822,8 @@
817822
" \"OPENSPIEL_GAME\": \"catch\",\n",
818823
" \"OPENSPIEL_AGENT_PLAYER\": \"0\",\n",
819824
" \"OPENSPIEL_OPPONENT_POLICY\": \"random\"},\n",
820-
" stdout=subprocess.PIPE,\n",
821-
" stderr=subprocess.PIPE,\n",
825+
" stdout=subprocess.DEVNULL,\n",
826+
" stderr=subprocess.DEVNULL,\n",
822827
" text=True,\n",
823828
" cwd=work_dir\n",
824829
")\n",
@@ -895,6 +900,7 @@
895900
"\n",
896901
"print(\"📥 Received OpenSpielObservation:\")\n",
897902
"print(f\" • info_state: {result.observation.info_state[:10]}... (first 10 values)\")\n",
903+
"print(f\" • number of info_state: {len(result.observation.info_state)}\")\n",
898904
"print(f\" • legal_actions: {result.observation.legal_actions}\")\n",
899905
"print(f\" • game_phase: {result.observation.game_phase}\")\n",
900906
"print(f\" • done: {result.done}\")\n",
@@ -1006,23 +1012,25 @@
10061012
"\n",
10071013
" def select_action(self, obs: OpenSpielObservation) -> int:\n",
10081014
" # Parse OpenSpiel observation\n",
1009-
" # For Catch: info_state is a flattened 5x5 grid\n",
1015+
" # For Catch: info_state is a flattened 10x5 grid\n",
10101016
" # Ball position and paddle position encoded in the vector\n",
10111017
" info_state = obs.info_state\n",
10121018
"\n",
10131019
" # Find ball and paddle positions from info_state\n",
1014-
" # Catch uses a 5x5 grid, so 25 values\n",
1020+
" # Catch uses a 10x5 grid, so 50 values\n",
10151021
" grid_size = 5\n",
10161022
"\n",
1017-
" # Find positions (ball = 1.0, paddle = 0.5 in the flattened grid)\n",
1023+
" # Find positions (ball = 1.0 in the flattened grid, paddle = 1.0 in the last row of the flattened grid)\n",
10181024
" ball_col = None\n",
10191025
" paddle_col = None\n",
10201026
"\n",
10211027
" for idx, val in enumerate(info_state):\n",
10221028
" if abs(val - 1.0) < 0.01: # Ball\n",
10231029
" ball_col = idx % grid_size\n",
1024-
" elif abs(val - 0.5) < 0.01: # Paddle\n",
1025-
" paddle_col = idx % grid_size\n",
1030+
" break\n",
1031+
"\n",
1032+
" last_row = info_state[-grid_size:]\n",
1033+
" paddle_col = last_row.index(1.0) # Paddle\n",
10261034
"\n",
10271035
" if ball_col is not None and paddle_col is not None:\n",
10281036
" if paddle_col < ball_col:\n",

0 commit comments

Comments
 (0)