Skip to content

Commit 09ad324

Browse files
committed
add inference example with hf and gpt oss
1 parent 3b72bce commit 09ad324

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#!/usr/bin/env python3
2+
"""Play TextArena Wordle with a hosted LLM via Hugging Face Inference Providers.
3+
4+
This script mirrors the structure of the Kuhn Poker inference sample but targets
5+
the Wordle environment. We deploy the generic TextArena server (wrapped in
6+
OpenEnv) inside a local Docker container and query a single hosted model using
7+
the OpenAI-compatible API provided by Hugging Face's router.
8+
9+
Prerequisites
10+
-------------
11+
1. Build the TextArena Docker image::
12+
13+
docker build -f src/envs/textarena_env/server/Dockerfile -t textarena-env:latest .
14+
15+
2. Set your Hugging Face token::
16+
17+
export HF_TOKEN=your_token_here
18+
19+
3. Run this script::
20+
21+
python examples/wordle_inference.py
22+
23+
By default we ask the DeepSeek Terminus model to play ``Wordle-v0``. Adjust the
24+
``MODEL`` constant if you'd like to experiment with another provider-compatible
25+
model.
26+
"""
27+
28+
from __future__ import annotations
29+
30+
import os
31+
import re
32+
from typing import Iterable, List
33+
34+
from openai import OpenAI
35+
36+
from envs.textarena_env import TextArenaAction, TextArenaEnv
37+
from envs.textarena_env.models import TextArenaMessage
38+
39+
# ---------------------------------------------------------------------------
40+
# Configuration
41+
# ---------------------------------------------------------------------------
42+
43+
API_BASE_URL = "https://router.huggingface.co/v1"
44+
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
45+
46+
MODEL = "openai/gpt-oss-120b:novita"
47+
MAX_TURNS = 8
48+
VERBOSE = True
49+
50+
SYSTEM_PROMPT = (
51+
"You are an expert Wordle solver."
52+
" Always respond with a single guess inside square brackets, e.g. [crane]."
53+
" Use lowercase letters, exactly one five-letter word per reply."
54+
" Reason about prior feedback before choosing the next guess."
55+
" Words must be 5 letters long and real English words."
56+
" Do not not include any other text in your response."
57+
" Do not repeat the same guess twice."
58+
)
59+
60+
61+
# ---------------------------------------------------------------------------
62+
# Helpers
63+
# ---------------------------------------------------------------------------
64+
65+
def format_history(messages: Iterable[TextArenaMessage]) -> str:
66+
"""Convert TextArena message history into plain text for the model."""
67+
68+
lines: List[str] = []
69+
for message in messages:
70+
tag = message.category or "MESSAGE"
71+
lines.append(f"[{tag}] {message.content}")
72+
return "\n".join(lines)
73+
74+
75+
def extract_guess(text: str) -> str:
76+
"""Return the first Wordle-style guess enclosed in square brackets."""
77+
78+
match = re.search(r"\[[A-Za-z]{5}\]", text)
79+
if match:
80+
return match.group(0).lower()
81+
# Fallback: remove whitespace and ensure lowercase, then wrap
82+
cleaned = re.sub(r"[^a-zA-Z]", "", text).lower()
83+
if len(cleaned) >= 5:
84+
return f"[{cleaned[:5]}]"
85+
return "[dunno]"
86+
87+
88+
def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str:
89+
"""Combine the TextArena prompt and feedback history for the model."""
90+
91+
history = format_history(messages)
92+
return (
93+
f"Current prompt:\n{prompt_text}\n\n"
94+
f"Conversation so far:\n{history}\n\n"
95+
"Reply with your next guess enclosed in square brackets."
96+
)
97+
98+
99+
# ---------------------------------------------------------------------------
100+
# Gameplay
101+
# ---------------------------------------------------------------------------
102+
103+
def play_wordle(env: TextArenaEnv, client: OpenAI) -> None:
104+
result = env.reset()
105+
observation = result.observation
106+
107+
if VERBOSE:
108+
print("📜 Initial Prompt:\n" + observation.prompt)
109+
110+
for turn in range(1, MAX_TURNS + 1):
111+
if result.done:
112+
break
113+
114+
user_prompt = make_user_prompt(observation.prompt, observation.messages)
115+
116+
response = client.chat.completions.create(
117+
model=MODEL,
118+
messages=[
119+
{"role": "system", "content": SYSTEM_PROMPT},
120+
{"role": "user", "content": user_prompt},
121+
],
122+
max_tokens=2048,
123+
temperature=0.7,
124+
)
125+
126+
raw_output = response.choices[0].message.content.strip()
127+
guess = extract_guess(raw_output)
128+
129+
if VERBOSE:
130+
print(f"\n🎯 Turn {turn}: model replied with -> {raw_output}")
131+
print(f" Parsed guess: {guess}")
132+
133+
result = env.step(TextArenaAction(message=guess))
134+
observation = result.observation
135+
136+
if VERBOSE:
137+
print(" Feedback messages:")
138+
for message in observation.messages:
139+
print(f" [{message.category}] {message.content}")
140+
141+
print("\n✅ Game finished")
142+
print(f" Reward: {result.reward}")
143+
print(f" Done: {result.done}")
144+
145+
146+
# ---------------------------------------------------------------------------
147+
# Entrypoint
148+
# ---------------------------------------------------------------------------
149+
150+
def main() -> None:
151+
if not API_KEY:
152+
raise SystemExit("HF_TOKEN (or API_KEY) must be set to query the model.")
153+
154+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
155+
156+
env = TextArenaEnv.from_docker_image(
157+
"textarena-env:latest",
158+
env_vars={
159+
"TEXTARENA_ENV_ID": "Wordle-v0",
160+
"TEXTARENA_NUM_PLAYERS": "1",
161+
},
162+
ports={8000: 8000},
163+
)
164+
165+
try:
166+
play_wordle(env, client)
167+
finally:
168+
env.close()
169+
170+
171+
if __name__ == "__main__":
172+
main()
173+
174+

0 commit comments

Comments
 (0)