2020
2121from itertools import groupby
2222from typing import List , Set , Dict
23-
23+ from tree_sitter import Language , Node , Parser , Query , Tree
2424import tree_sitter_java as tsjava
2525from tree_sitter import Language , Node , Parser , Query
2626
@@ -51,10 +51,49 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool:
5151 bool
5252 True if the method is in the class, False otherwise.
5353 """
54- methods_in_class = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @name)" , class_body )
54+ methods_in_class = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @name)" ,
55+ class_body )
5556
5657 return method_name not in {method .node .text .decode () for method in methods_in_class }
5758
59+ def is_parsable (self , code : str ) -> bool :
60+ """
61+ Check if the code is parsable
62+ Args:
63+ code: source code
64+
65+ Returns:
66+ True if the code is parsable, False otherwise
67+ """
68+
69+ def syntax_error (node ):
70+ if node .type == "ERROR" :
71+ return True
72+ try :
73+ for child in node .children :
74+ if syntax_error (child ):
75+ return True
76+ except RecursionError as err :
77+ return True
78+
79+ return False
80+
81+ tree = self .parser .parse (bytes (code , "utf-8" ))
82+ if tree is not None :
83+ return not syntax_error (tree .root_node )
84+ return False
85+
86+ def get_raw_ast (self , code : str ) -> Tree :
87+ """
88+ Get the raw AST
89+ Args:
90+ code: source code
91+
92+ Returns:
93+ Tree: the raw AST
94+ """
95+ return self .parser .parse (bytes (code , "utf-8" ))
96+
5897 def get_all_imports (self , source_code : str ) -> Set [str ]:
5998 """Get a list of all the imports in a class.
6099
@@ -64,7 +103,8 @@ def get_all_imports(self, source_code: str) -> Set[str]:
64103 Returns:
65104 Set[str]: A set of all the imports in the class.
66105 """
67- import_declerations : Captures = self .frame_query_and_capture_output (query = "(import_declaration (scoped_identifier) @name)" , code_to_process = source_code )
106+ import_declerations : Captures = self .frame_query_and_capture_output (
107+ query = "(import_declaration (scoped_identifier) @name)" , code_to_process = source_code )
68108 return {capture .node .text .decode () for capture in import_declerations }
69109
70110 def get_pacakge_name (self , source_code : str ) -> str :
@@ -76,7 +116,8 @@ def get_pacakge_name(self, source_code: str) -> str:
76116 Returns:
77117 str: The package name.
78118 """
79- package_name : Captures = self .frame_query_and_capture_output (query = "((package_declaration) @name)" , code_to_process = source_code )
119+ package_name : Captures = self .frame_query_and_capture_output (query = "((package_declaration) @name)" ,
120+ code_to_process = source_code )
80121 if package_name :
81122 return package_name [0 ].node .text .decode ().replace ("package " , "" ).replace (";" , "" )
82123 return None
@@ -102,7 +143,8 @@ def get_superclass(self, source_code: str) -> str:
102143 Returns:
103144 Set[str]: A set of all the superclasses in the class.
104145 """
105- superclass : Captures = self .frame_query_and_capture_output (query = "(class_declaration (superclass (type_identifier) @superclass))" , code_to_process = source_code )
146+ superclass : Captures = self .frame_query_and_capture_output (
147+ query = "(class_declaration (superclass (type_identifier) @superclass))" , code_to_process = source_code )
106148
107149 if len (superclass ) == 0 :
108150 return ""
@@ -119,7 +161,9 @@ def get_all_interfaces(self, source_code: str) -> Set[str]:
119161 Set[str]: A set of all the interfaces implemented by the class.
120162 """
121163
122- interfaces = self .frame_query_and_capture_output ("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))" , code_to_process = source_code )
164+ interfaces = self .frame_query_and_capture_output (
165+ "(class_declaration (super_interfaces (type_list (type_identifier) @interface)))" ,
166+ code_to_process = source_code )
123167 return {interface .node .text .decode () for interface in interfaces }
124168
125169 def frame_query_and_capture_output (self , query : str , code_to_process : str ) -> Captures :
@@ -138,7 +182,8 @@ def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Ca
138182
139183 def get_method_name_from_declaration (self , method_name_string : str ) -> str :
140184 """Get the method name from the method signature."""
141- captures : Captures = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @method_name)" , method_name_string )
185+ captures : Captures = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @method_name)" ,
186+ method_name_string )
142187
143188 return captures [0 ].node .text .decode ()
144189
@@ -147,7 +192,8 @@ def get_method_name_from_invocation(self, method_invocation: str) -> str:
147192 Using the tree-sitter query, extract the method name from the method invocation.
148193 """
149194
150- captures : Captures = self .frame_query_and_capture_output ("(method_invocation object: (identifier) @class_name name: (identifier) @method_name)" , method_invocation )
195+ captures : Captures = self .frame_query_and_capture_output (
196+ "(method_invocation object: (identifier) @class_name name: (identifier) @method_name)" , method_invocation )
151197 return captures [0 ].node .text .decode ()
152198
153199 def safe_ascend (self , node : Node , ascend_count : int ) -> Node :
@@ -352,7 +398,8 @@ def get_method_return_type(self, source_code: str) -> str:
352398 The return type of the method.
353399 """
354400
355- type_references : Captures = self .frame_query_and_capture_output ("(method_declaration type: ((type_identifier) @type_id))" , source_code )
401+ type_references : Captures = self .frame_query_and_capture_output (
402+ "(method_declaration type: ((type_identifier) @type_id))" , source_code )
356403
357404 return type_references [0 ].node .text .decode ()
358405
@@ -379,9 +426,9 @@ def collect_leaf_token_values(node):
379426 if len (node .children ) == 0 :
380427 if filter_by_node_type is not None :
381428 if node .type in filter_by_node_type :
382- lexical_tokens .append (code [node .start_byte : node .end_byte ])
429+ lexical_tokens .append (code [node .start_byte : node .end_byte ])
383430 else :
384- lexical_tokens .append (code [node .start_byte : node .end_byte ])
431+ lexical_tokens .append (code [node .start_byte : node .end_byte ])
385432 else :
386433 for child in node .children :
387434 collect_leaf_token_values (child )
@@ -415,9 +462,11 @@ def remove_all_comments(self, source_code: str) -> str:
415462 pruned_source_code = self .make_pruned_code_prettier (source_code )
416463
417464 # Remove all comment lines: the comment lines start with / (for // and /*) or * (for multiline comments).
418- comment_blocks : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" , code_to_process = source_code )
465+ comment_blocks : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" ,
466+ code_to_process = source_code )
419467
420- comment_lines : Captures = self .frame_query_and_capture_output (query = "((line_comment) @comment_line)" , code_to_process = source_code )
468+ comment_lines : Captures = self .frame_query_and_capture_output (query = "((line_comment) @comment_line)" ,
469+ code_to_process = source_code )
421470
422471 for capture in comment_blocks :
423472 pruned_source_code = pruned_source_code .replace (capture .node .text .decode (), "" )
@@ -441,7 +490,8 @@ def make_pruned_code_prettier(self, pruned_code: str) -> str:
441490 The prettified pruned code.
442491 """
443492 # First remove remaining block comments
444- block_comments : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" , code_to_process = pruned_code )
493+ block_comments : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" ,
494+ code_to_process = pruned_code )
445495
446496 for capture in block_comments :
447497 pruned_code = pruned_code .replace (capture .node .text .decode (), "" )
0 commit comments