Skip to content

Commit 76fe36b

Browse files
feat: able to run test
1 parent acb34ee commit 76fe36b

File tree

2 files changed

+131
-27
lines changed

2 files changed

+131
-27
lines changed

src/commands/test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,17 @@
66
auth_manager = Auth()
77
solution_manager = SolutionManager(auth_manager.get_session())
88

9+
map_lang = {
10+
"py": "python3",
11+
"java": "java",
12+
"js": "javascript",
13+
"c": "c",
14+
"cpp": "cpp"
15+
}
16+
917
def test(
1018
problem: str = typer.Argument(..., help="Problem slug (e.g., 'two-sum')"),
1119
file: Path = typer.Argument(..., help="Path to solution file"),
12-
lang: str = typer.Option("python3", help="Programming language")
1320
):
1421
"""Test a solution with LeetCode's test cases"""
1522
if not auth_manager.is_authenticated:
@@ -23,6 +30,11 @@ def test(
2330
with open(file, 'r') as f:
2431
code = f.read()
2532

33+
lang = map_lang.get(file.suffix[1:])
34+
if not lang:
35+
typer.echo(typer.style(f"❌ Unsupported file extension: {file.suffix}", fg=typer.colors.RED))
36+
raise typer.Exit(1)
37+
2638
typer.echo(typer.style("🧪 Testing solution with LeetCode test cases...", fg=typer.colors.YELLOW))
2739
result = solution_manager.test_solution(problem, code, lang)
2840

src/server/solution_manager.py

Lines changed: 118 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,38 @@
11
from typing import Dict, Any
22
import time
3+
import json
4+
import typer
35

46
class SolutionManager:
57
def __init__(self, session):
68
self.session = session
79
self.BASE_URL = "https://leetcode.com"
10+
self._clean_session_cookies()
11+
12+
def _clean_session_cookies(self):
13+
"""Clean up duplicate cookies"""
14+
seen_cookies = {}
15+
if not hasattr(self.session, 'cookies'):
16+
return
17+
18+
# Get all cookies
19+
all_cookies = list(self.session.cookies)
20+
21+
# Clear all cookies
22+
self.session.cookies.clear()
23+
24+
# Add back only the most recent cookie for each name
25+
for cookie in reversed(all_cookies):
26+
if cookie.name not in seen_cookies:
27+
seen_cookies[cookie.name] = True
28+
self.session.cookies.set_cookie(cookie)
29+
30+
def _get_csrf_token(self):
31+
"""Get CSRF token from cookies"""
32+
for cookie in self.session.cookies:
33+
if cookie.name == 'csrftoken':
34+
return cookie.value
35+
return "iIf1V3zVGiU5F0neqw63pFm0YlFtk9i531xUBoQe0hZ06pmDGPZ0uJW6vyhr8GEH"
836

937
def get_question_data(self, question_identifier: str) -> Dict[str, Any]:
1038
"""Get question details using GraphQL
@@ -55,6 +83,7 @@ def get_question_data(self, question_identifier: str) -> Dict[str, Any]:
5583

5684
def submit_solution(self, title_slug: str, code: str, lang: str = "python3") -> Dict[str, Any]:
5785
"""Submit a solution to LeetCode"""
86+
5887
# First get the question ID
5988
question_data = self.get_question_data(title_slug)
6089
question_id = question_data['data']['question']['questionId']
@@ -77,49 +106,112 @@ def submit_solution(self, title_slug: str, code: str, lang: str = "python3") ->
77106

78107
return self.get_submission_result(submission_id)
79108

80-
def test_solution(self, title_slug: str, code: str, lang: str = "python3") -> Dict[str, Any]:
109+
def test_solution(self, title_slug: str, code: str, lang: str = "python3", full: bool = False) -> Dict[str, Any]:
81110
"""Test a solution with LeetCode test cases"""
82111
try:
83-
# First get the question data
84-
question_data = self.get_question_data(title_slug)
85-
question = question_data['data']['question']
86-
question_id = question['questionId']
87-
test_cases = question['exampleTestcases']
112+
self._clean_session_cookies()
113+
114+
# Get question data first
115+
problem = self.get_question_data(title_slug)
116+
question_data = problem.get('data', {}).get('question')
117+
if not question_data:
118+
return {"success": False, "error": "Question data not found"}
119+
120+
question_id = question_data['questionId']
121+
test_cases = question_data['exampleTestcaseList']
122+
123+
endpoint = 'submit' if full else 'interpret_solution'
124+
sid_key = 'submission_id' if full else 'interpret_id'
88125

89-
test_url = f"{self.BASE_URL}/problems/{title_slug}/interpret_solution/"
126+
url = f"{self.BASE_URL}/problems/{title_slug}/{endpoint}/"
127+
128+
csrf_token = self.session.cookies.get('csrftoken', '')
129+
typer.echo(f"Using CSRF token: {csrf_token}")
90130

91131
headers = {
92-
'Accept': 'application/json',
93-
'Content-Type': 'application/json',
94-
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
95-
'Referer': f'{self.BASE_URL}/problems/{title_slug}/'
132+
'referer': f"{self.BASE_URL}/problems/{title_slug}/",
133+
'content-type': 'application/json',
134+
'x-csrftoken': csrf_token,
135+
'x-requested-with': 'XMLHttpRequest',
136+
'origin': self.BASE_URL
96137
}
97138

98139
data = {
99-
"lang": lang,
100-
"question_id": question_id,
101-
"typed_code": code,
102-
"data_input": test_cases,
103-
"judge_type": "large"
140+
'lang': lang,
141+
'question_id': str(question_id),
142+
'typed_code': code,
143+
'data_input': "\n".join(test_cases) if isinstance(test_cases, list) else test_cases,
144+
'test_mode': False,
145+
'judge_type': 'small'
104146
}
105147

106-
response = self.session.post(test_url, json=data, headers=headers, timeout=30)
148+
typer.echo(f"Sending request to {url}")
149+
response = self.session.post(url, json=data, headers=headers)
107150

108151
if response.status_code != 200:
109-
return {
110-
"success": False,
111-
"error": f"Test request failed. Make sure your code is valid {lang} code."
112-
}
113-
114-
interpret_id = response.json().get('interpret_id')
115-
if not interpret_id:
116-
return {"success": False, "error": "No interpret ID received"}
152+
typer.echo(f"Error response: {response.text}", err=True)
153+
return {"success": False, "error": f"Request failed with status {response.status_code}"}
117154

118-
return self.get_test_result(interpret_id)
155+
try:
156+
result_data = response.json()
157+
submission_id = result_data.get(sid_key)
158+
if submission_id:
159+
typer.echo(f"Got submission ID: {submission_id}")
160+
return self.get_result(submission_id)
161+
else:
162+
typer.echo("No submission ID in response", err=True)
163+
return {"success": False, "error": "No submission ID received"}
164+
except ValueError as e:
165+
typer.echo(f"Failed to parse response: {response.text}", err=True)
166+
return {"success": False, "error": f"Failed to parse response: {str(e)}"}
119167

120168
except Exception as e:
169+
typer.echo(f"Test error: {str(e)}", err=True)
121170
return {"success": False, "error": f"Test error: {str(e)}"}
122171

172+
def _format_output(self, output) -> str:
173+
"""Format output that could be string or list"""
174+
if isinstance(output, list):
175+
return '\n'.join(str(item) for item in output)
176+
if isinstance(output, str):
177+
return output.strip('[]"')
178+
return str(output)
179+
180+
def get_result(self, submission_id: str, timeout: int = 30) -> Dict[str, Any]:
181+
"""Poll for results with timeout"""
182+
url = f"{self.BASE_URL}/submissions/detail/{submission_id}/check/"
183+
typer.echo(f"Polling for results at {url}")
184+
185+
for i in range(timeout):
186+
try:
187+
time.sleep(1)
188+
typer.echo(f"Attempt {i+1}/{timeout}...")
189+
190+
response = self.session.get(url)
191+
if response.status_code != 200:
192+
continue
193+
194+
result = response.json()
195+
typer.echo(f"Got response: {json.dumps(result, indent=2)}")
196+
197+
if result.get('state') == 'SUCCESS':
198+
return {
199+
"success": True,
200+
"status": result.get('status_msg', 'Unknown'),
201+
"input": result.get('input', 'N/A'),
202+
"output": self._format_output(result.get('code_answer', [])),
203+
"expected": self._format_output(result.get('expected_code_answer', [])),
204+
"runtime": result.get('status_runtime', 'N/A'),
205+
"memory": result.get('status_memory', 'N/A'),
206+
"total_correct": result.get('total_correct', 0),
207+
"total_testcases": result.get('total_testcases', 0)
208+
}
209+
except Exception as e:
210+
typer.echo(f"Error checking result: {str(e)}", err=True)
211+
continue
212+
213+
return {"success": False, "error": "Timeout waiting for results"}
214+
123215
def get_submission_result(self, submission_id: str) -> Dict[str, Any]:
124216
"""Poll for submission results"""
125217
check_url = f"{self.BASE_URL}/submissions/detail/{submission_id}/check/"

0 commit comments

Comments
 (0)