1717"""
1818JavaSitter module
1919"""
20-
2120from itertools import groupby
2221from typing import List , Set , Dict
2322from tree_sitter import Language , Node , Parser , Query , Tree
2625
2726from cldk .models .treesitter import Captures
2827
28+ import logging
29+
30+ logger = logging .getLogger (__name__ )
31+
2932
3033class JavaSitter :
3134 """
@@ -51,8 +54,7 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool:
5154 bool
5255 True if the method is in the class, False otherwise.
5356 """
54- methods_in_class = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @name)" ,
55- class_body )
57+ methods_in_class = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @name)" , class_body )
5658
5759 return method_name not in {method .node .text .decode () for method in methods_in_class }
5860
@@ -103,8 +105,7 @@ def get_all_imports(self, source_code: str) -> Set[str]:
103105 Returns:
104106 Set[str]: A set of all the imports in the class.
105107 """
106- import_declerations : Captures = self .frame_query_and_capture_output (
107- query = "(import_declaration (scoped_identifier) @name)" , code_to_process = source_code )
108+ import_declerations : Captures = self .frame_query_and_capture_output (query = "(import_declaration (scoped_identifier) @name)" , code_to_process = source_code )
108109 return {capture .node .text .decode () for capture in import_declerations }
109110
110111 def get_pacakge_name (self , source_code : str ) -> str :
@@ -116,8 +117,7 @@ def get_pacakge_name(self, source_code: str) -> str:
116117 Returns:
117118 str: The package name.
118119 """
119- package_name : Captures = self .frame_query_and_capture_output (query = "((package_declaration) @name)" ,
120- code_to_process = source_code )
120+ package_name : Captures = self .frame_query_and_capture_output (query = "((package_declaration) @name)" , code_to_process = source_code )
121121 if package_name :
122122 return package_name [0 ].node .text .decode ().replace ("package " , "" ).replace (";" , "" )
123123 return None
@@ -143,8 +143,7 @@ def get_superclass(self, source_code: str) -> str:
143143 Returns:
144144 Set[str]: A set of all the superclasses in the class.
145145 """
146- superclass : Captures = self .frame_query_and_capture_output (
147- query = "(class_declaration (superclass (type_identifier) @superclass))" , code_to_process = source_code )
146+ superclass : Captures = self .frame_query_and_capture_output (query = "(class_declaration (superclass (type_identifier) @superclass))" , code_to_process = source_code )
148147
149148 if len (superclass ) == 0 :
150149 return ""
@@ -161,9 +160,7 @@ def get_all_interfaces(self, source_code: str) -> Set[str]:
161160 Set[str]: A set of all the interfaces implemented by the class.
162161 """
163162
164- interfaces = self .frame_query_and_capture_output (
165- "(class_declaration (super_interfaces (type_list (type_identifier) @interface)))" ,
166- code_to_process = source_code )
163+ interfaces = self .frame_query_and_capture_output ("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))" , code_to_process = source_code )
167164 return {interface .node .text .decode () for interface in interfaces }
168165
169166 def frame_query_and_capture_output (self , query : str , code_to_process : str ) -> Captures :
@@ -182,8 +179,7 @@ def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Ca
182179
183180 def get_method_name_from_declaration (self , method_name_string : str ) -> str :
184181 """Get the method name from the method signature."""
185- captures : Captures = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @method_name)" ,
186- method_name_string )
182+ captures : Captures = self .frame_query_and_capture_output ("(method_declaration name: (identifier) @method_name)" , method_name_string )
187183
188184 return captures [0 ].node .text .decode ()
189185
@@ -192,8 +188,12 @@ def get_method_name_from_invocation(self, method_invocation: str) -> str:
192188 Using the tree-sitter query, extract the method name from the method invocation.
193189 """
194190
195- captures : Captures = self .frame_query_and_capture_output (
196- "(method_invocation object: (identifier) @class_name name: (identifier) @method_name)" , method_invocation )
191+ captures : Captures = self .frame_query_and_capture_output ("(method_invocation name: (identifier) @method_name)" , method_invocation )
192+ return captures [0 ].node .text .decode ()
193+
194+ def get_identifier_from_arbitrary_statement (self , statement : str ) -> str :
195+ """Get the identifier from an arbitrary statement."""
196+ captures : Captures = self .frame_query_and_capture_output ("(identifier) @identifier" , statement )
197197 return captures [0 ].node .text .decode ()
198198
199199 def safe_ascend (self , node : Node , ascend_count : int ) -> Node :
@@ -260,7 +260,7 @@ def get_call_targets(self, method_body: str, declared_methods: dict) -> Set[str]
260260 )
261261 return call_targets
262262
263- def get_calling_lines (self , source_method_code : str , target_method_name : str ) -> List [int ]:
263+ def get_calling_lines (self , source_method_code : str , target_method_name : str , is_target_method_a_constructor : bool ) -> List [int ]:
264264 """
265265 Returns a list of line numbers in source method where target method is called.
266266
@@ -272,26 +272,34 @@ def get_calling_lines(self, source_method_code: str, target_method_name: str) ->
272272 target_method_code : str
273273 target method code
274274
275+ is_target_method_a_constructor : bool
276+ True if target method is a constructor, False otherwise.
277+
275278 Returns:
276279 --------
277280 List[int]
278281 List of line numbers within in source method code block.
279282 """
280- query = "(method_invocation name: (identifier) @method_name)"
283+ if not source_method_code :
284+ return []
285+ query = "(object_creation_expression (type_identifier) @object_name) (object_creation_expression type: (scoped_type_identifier (type_identifier) @type_name)) (method_invocation name: (identifier) @method_name)"
286+
281287 # if target_method_name is a method signature, get the method name
282288 # if it is not a signature, we will just keep the passed method name
289+
290+ target_method_name = target_method_name .split ("(" )[0 ] # remove the arguments from the constructor name
283291 try :
284- target_method_name = self .get_method_name_from_declaration ( target_method_name )
285- except Exception :
286- pass
287-
288- captures : Captures = self . frame_query_and_capture_output ( query , source_method_code )
289- # Find the line numbers where target method calls happen in source method
290- target_call_lines = []
291- for c in captures :
292- method_name = c . node . text . decode ( )
293- if method_name == target_method_name :
294- target_call_lines . append ( c . node . start_point [ 0 ])
292+ captures : Captures = self .frame_query_and_capture_output ( query , source_method_code )
293+ # Find the line numbers where target method calls happen in source method
294+ target_call_lines = []
295+ for c in captures :
296+ method_name = c . node . text . decode ( )
297+ if method_name == target_method_name :
298+ target_call_lines . append ( c . node . start_point [ 0 ])
299+ except :
300+ logger . warning ( f"Unable to get calling lines for { target_method_name } in { source_method_code } ." )
301+ return []
302+
295303 return target_call_lines
296304
297305 def get_test_methods (self , source_class_code : str ) -> Dict [str , str ]:
@@ -398,8 +406,7 @@ def get_method_return_type(self, source_code: str) -> str:
398406 The return type of the method.
399407 """
400408
401- type_references : Captures = self .frame_query_and_capture_output (
402- "(method_declaration type: ((type_identifier) @type_id))" , source_code )
409+ type_references : Captures = self .frame_query_and_capture_output ("(method_declaration type: ((type_identifier) @type_id))" , source_code )
403410
404411 return type_references [0 ].node .text .decode ()
405412
@@ -426,9 +433,9 @@ def collect_leaf_token_values(node):
426433 if len (node .children ) == 0 :
427434 if filter_by_node_type is not None :
428435 if node .type in filter_by_node_type :
429- lexical_tokens .append (code [node .start_byte : node .end_byte ])
436+ lexical_tokens .append (code [node .start_byte : node .end_byte ])
430437 else :
431- lexical_tokens .append (code [node .start_byte : node .end_byte ])
438+ lexical_tokens .append (code [node .start_byte : node .end_byte ])
432439 else :
433440 for child in node .children :
434441 collect_leaf_token_values (child )
@@ -462,11 +469,9 @@ def remove_all_comments(self, source_code: str) -> str:
462469 pruned_source_code = self .make_pruned_code_prettier (source_code )
463470
464471 # Remove all comment lines: the comment lines start with / (for // and /*) or * (for multiline comments).
465- comment_blocks : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" ,
466- code_to_process = source_code )
472+ comment_blocks : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" , code_to_process = source_code )
467473
468- comment_lines : Captures = self .frame_query_and_capture_output (query = "((line_comment) @comment_line)" ,
469- code_to_process = source_code )
474+ comment_lines : Captures = self .frame_query_and_capture_output (query = "((line_comment) @comment_line)" , code_to_process = source_code )
470475
471476 for capture in comment_blocks :
472477 pruned_source_code = pruned_source_code .replace (capture .node .text .decode (), "" )
@@ -490,8 +495,7 @@ def make_pruned_code_prettier(self, pruned_code: str) -> str:
490495 The prettified pruned code.
491496 """
492497 # First remove remaining block comments
493- block_comments : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" ,
494- code_to_process = pruned_code )
498+ block_comments : Captures = self .frame_query_and_capture_output (query = "((block_comment) @comment_block)" , code_to_process = pruned_code )
495499
496500 for capture in block_comments :
497501 pruned_code = pruned_code .replace (capture .node .text .decode (), "" )
0 commit comments