1+ import ast
12import copy
23import json
34import os
4- import re
55import sys
66import typing
77
@@ -22,15 +22,6 @@ def __init__(self):
2222 with open (template_path , "rt" ) as f :
2323 self .__template = json .load (f )
2424
25- self .typing_regex = re .compile (
26- "|" .join (
27- # '|' is matched by order
28- sorted (
29- filter (lambda t : t [0 ].isupper (), dir (typing )), key = len , reverse = True
30- )
31- )
32- )
33-
3425 def __populate_metadata (self , q ):
3526 self .template ["metadata" ]["language_info" ]["version" ] = "{}.{}.{}" .format (
3627 * sys .version_info [:3 ]
@@ -111,9 +102,8 @@ def __populate_test(self, q):
111102 test_cell ["metadata" ]["exampleTestcaseList" ] = q ["exampleTestcaseList" ]
112103
113104 def __extract_type (self , code ) -> list [str ]:
114- _ , args = self .__parse_code (code )
115- # FIXME: args: `root1: Optional[TreeNode], root2: Optional[TreeNode]` will extract type `Optional, T`
116- return self .typing_regex .findall (args )
105+ _ , args_types = self .__parse_code (code )
106+ return list (args_types .intersection ((t for t in dir (typing ) if t [0 ].isupper ())))
117107
118108 def __populate_code (self , q ):
119109 code_cell = first (
@@ -126,11 +116,11 @@ def __populate_code(self, q):
126116 if not code_snippet :
127117 return
128118
129- snippet = code_snippet ["code" ]
119+ snippet = code_snippet ["code" ] + "pass"
130120 pre_solution_index = snippet .find ("class Solution:" )
131121 pre_solution = snippet [:pre_solution_index ]
132122 snippet = snippet [pre_solution_index :]
133- code_cell ["source" ] = [ snippet + "pass" ]
123+ code_cell ["source" ] = snippet
134124 code_cell ["metadata" ]["isSolutionCode" ] = True
135125
136126 types = self .__extract_type (snippet )
@@ -161,11 +151,27 @@ def __populate_code(self, q):
161151
162152 return snippet
163153
164- def __parse_code (self , code ) -> tuple [str , str ]:
165- match = re .search (r"class Solution:\s+def (.*?)\(self,(.*)" , code )
166- if not match :
167- return ("" , "" )
168- return (match [1 ], match [2 ])
154+ """
155+ return (function_name, argument_types)
156+ """
157+
158+ def __parse_code (self , code ) -> tuple [str , typing .Set [str ]]:
159+ m = ast .parse (code )
160+ func_name = ""
161+ args_types = set ()
162+
163+ for node in ast .walk (m ):
164+ if isinstance (node , ast .FunctionDef ):
165+ func_name = node .name
166+ for arg in node .args .args :
167+ if arg .annotation :
168+ if isinstance (arg .annotation , ast .Subscript ):
169+ args_types .add (arg .annotation .value .id )
170+ args_types .add (arg .annotation .slice .id )
171+ elif isinstance (arg .annotation , ast .Name ):
172+ args_types .add (arg .annotation .id )
173+
174+ return func_name , args_types
169175
170176 def __populate_run (self , q , snippet ):
171177 run_cell_with_idx = first (
0 commit comments