Skip to content

Commit 6caa781

Browse files
committed
fix #6
1 parent 30ee071 commit 6caa781

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

jupyterlab_leetcode/utils/notebook_generator.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import ast
12
import copy
23
import json
34
import os
4-
import re
55
import sys
66
import typing
77

@@ -22,15 +22,6 @@ def __init__(self):
2222
with open(template_path, "rt") as f:
2323
self.__template = json.load(f)
2424

25-
self.typing_regex = re.compile(
26-
"|".join(
27-
# '|' is matched by order
28-
sorted(
29-
filter(lambda t: t[0].isupper(), dir(typing)), key=len, reverse=True
30-
)
31-
)
32-
)
33-
3425
def __populate_metadata(self, q):
3526
self.template["metadata"]["language_info"]["version"] = "{}.{}.{}".format(
3627
*sys.version_info[:3]
@@ -111,9 +102,8 @@ def __populate_test(self, q):
111102
test_cell["metadata"]["exampleTestcaseList"] = q["exampleTestcaseList"]
112103

113104
def __extract_type(self, code) -> list[str]:
114-
_, args = self.__parse_code(code)
115-
# FIXME: args: `root1: Optional[TreeNode], root2: Optional[TreeNode]` will extract type `Optional, T`
116-
return self.typing_regex.findall(args)
105+
_, args_types = self.__parse_code(code)
106+
return list(args_types.intersection((t for t in dir(typing) if t[0].isupper())))
117107

118108
def __populate_code(self, q):
119109
code_cell = first(
@@ -126,11 +116,11 @@ def __populate_code(self, q):
126116
if not code_snippet:
127117
return
128118

129-
snippet = code_snippet["code"]
119+
snippet = code_snippet["code"] + "pass"
130120
pre_solution_index = snippet.find("class Solution:")
131121
pre_solution = snippet[:pre_solution_index]
132122
snippet = snippet[pre_solution_index:]
133-
code_cell["source"] = [snippet + "pass"]
123+
code_cell["source"] = snippet
134124
code_cell["metadata"]["isSolutionCode"] = True
135125

136126
types = self.__extract_type(snippet)
@@ -161,11 +151,27 @@ def __populate_code(self, q):
161151

162152
return snippet
163153

164-
def __parse_code(self, code) -> tuple[str, str]:
165-
match = re.search(r"class Solution:\s+def (.*?)\(self,(.*)", code)
166-
if not match:
167-
return ("", "")
168-
return (match[1], match[2])
154+
"""
155+
return (function_name, argument_types)
156+
"""
157+
158+
def __parse_code(self, code) -> tuple[str, typing.Set[str]]:
159+
m = ast.parse(code)
160+
func_name = ""
161+
args_types = set()
162+
163+
for node in ast.walk(m):
164+
if isinstance(node, ast.FunctionDef):
165+
func_name = node.name
166+
for arg in node.args.args:
167+
if arg.annotation:
168+
if isinstance(arg.annotation, ast.Subscript):
169+
args_types.add(arg.annotation.value.id)
170+
args_types.add(arg.annotation.slice.id)
171+
elif isinstance(arg.annotation, ast.Name):
172+
args_types.add(arg.annotation.id)
173+
174+
return func_name, args_types
169175

170176
def __populate_run(self, q, snippet):
171177
run_cell_with_idx = first(

0 commit comments

Comments
 (0)