diff --git a/README.md b/README.md index 17c9c75..10aee35 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ a graph representation of your source code, the graph name should be the same as the name of the folder you've requested to analyze, for the example above a graph named: "GraphRAG-SDK". -At the moment only the Python and C languages are supported, we do intend to support additional languages. +Currently supported languages: Python, Java, and Kotlin. We intend to support additional languages in the future. At this point you can explore and query your source code using various tools Here are several options: diff --git a/api/analyzers/kotlin/__init__.py b/api/analyzers/kotlin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/analyzers/kotlin/analyzer.py b/api/analyzers/kotlin/analyzer.py new file mode 100644 index 0000000..cd27576 --- /dev/null +++ b/api/analyzers/kotlin/analyzer.py @@ -0,0 +1,153 @@ +from pathlib import Path +from ...entities import * +from typing import Optional +from ..analyzer import AbstractAnalyzer + +from multilspy import SyncLanguageServer + +import tree_sitter_kotlin as tskotlin +from tree_sitter import Language, Node + +import logging +logger = logging.getLogger('code_graph') + +class KotlinAnalyzer(AbstractAnalyzer): + def __init__(self) -> None: + super().__init__(Language(tskotlin.language())) + + def add_dependencies(self, path: Path, files: list[Path]): + # For now, we skip dependency resolution for Kotlin + # In the future, this could parse build.gradle or pom.xml for Kotlin projects + pass + + def get_entity_label(self, node: Node) -> str: + if node.type == 'class_declaration': + # Check if it's an interface by looking for interface keyword + for child in node.children: + if child.type == 'interface': + return "Interface" + return "Class" + elif node.type == 'object_declaration': + return "Object" + elif node.type == 'function_declaration': + # Check if this is a method (inside a class) or a top-level function + parent = node.parent + if parent and parent.type == 'class_body': + return "Method" + return "Function" + raise ValueError(f"Unknown entity type: {node.type}") + + def get_entity_name(self, node: Node) -> str: + if node.type in ['class_declaration', 'object_declaration']: + # Find the type_identifier child + for child in node.children: + if child.type == 'type_identifier': + return child.text.decode('utf-8') + elif node.type == 'function_declaration': + # Find the simple_identifier child + for child in node.children: + if child.type == 'simple_identifier': + return child.text.decode('utf-8') + raise ValueError(f"Cannot extract name from entity type: {node.type}") + + def get_entity_docstring(self, node: Node) -> Optional[str]: + if node.type in ['class_declaration', 'object_declaration', 'function_declaration']: + # Check for KDoc comment (/** ... */) before the node + if node.prev_sibling and node.prev_sibling.type == "multiline_comment": + comment_text = node.prev_sibling.text.decode('utf-8') + # Only return if it's a KDoc comment (starts with /**) + if comment_text.startswith('/**'): + return comment_text + return None + raise ValueError(f"Unknown entity type: {node.type}") + + def get_entity_types(self) -> list[str]: + return ['class_declaration', 'object_declaration', 'function_declaration'] + + def add_symbols(self, entity: Entity) -> None: + if entity.node.type == 'class_declaration': + # Find superclass (extends) + superclass_query = self.language.query("(delegation_specifier (user_type (type_identifier) @superclass))") + superclass_captures = superclass_query.captures(entity.node) + if 'superclass' in superclass_captures: + for superclass in superclass_captures['superclass']: + entity.add_symbol("base_class", superclass) + + # Find interfaces (implements) + # In Kotlin, both inheritance and interface implementation use the same syntax + # We'll treat all as interfaces for now since Kotlin can only extend one class + interface_query = self.language.query("(delegation_specifier (user_type (type_identifier) @interface))") + interface_captures = interface_query.captures(entity.node) + if 'interface' in interface_captures: + for interface in interface_captures['interface']: + entity.add_symbol("implement_interface", interface) + + elif entity.node.type == 'object_declaration': + # Objects can also have delegation specifiers + interface_query = self.language.query("(delegation_specifier (user_type (type_identifier) @interface))") + interface_captures = interface_query.captures(entity.node) + if 'interface' in interface_captures: + for interface in interface_captures['interface']: + entity.add_symbol("implement_interface", interface) + + elif entity.node.type == 'function_declaration': + # Find function calls + query = self.language.query("(call_expression) @reference.call") + captures = query.captures(entity.node) + if 'reference.call' in captures: + for caller in captures['reference.call']: + entity.add_symbol("call", caller) + + # Find parameters with types + param_query = self.language.query("(parameter type: (user_type (type_identifier) @parameter))") + param_captures = param_query.captures(entity.node) + if 'parameter' in param_captures: + for parameter in param_captures['parameter']: + entity.add_symbol("parameters", parameter) + + # Find return type + return_type_query = self.language.query("(function_declaration type: (user_type (type_identifier) @return_type))") + return_type_captures = return_type_query.captures(entity.node) + if 'return_type' in return_type_captures: + for return_type in return_type_captures['return_type']: + entity.add_symbol("return_type", return_type) + + def is_dependency(self, file_path: str) -> bool: + # Check if file is in a dependency directory (e.g., build, .gradle cache) + return "build/" in file_path or ".gradle/" in file_path or "/cache/" in file_path + + def resolve_path(self, file_path: str, path: Path) -> str: + # For Kotlin, just return the file path as-is for now + return file_path + + def resolve_type(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: + res = [] + for file, resolved_node in self.resolve(files, lsp, file_path, path, node): + type_dec = self.find_parent(resolved_node, ['class_declaration', 'object_declaration']) + if type_dec in file.entities: + res.append(file.entities[type_dec]) + return res + + def resolve_method(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, node: Node) -> list[Entity]: + res = [] + # For call expressions, we need to extract the function name + if node.type == 'call_expression': + # Find the identifier being called + for child in node.children: + if child.type in ['simple_identifier', 'navigation_expression']: + for file, resolved_node in self.resolve(files, lsp, file_path, path, child): + method_dec = self.find_parent(resolved_node, ['function_declaration', 'class_declaration', 'object_declaration']) + if method_dec and method_dec.type in ['class_declaration', 'object_declaration']: + continue + if method_dec in file.entities: + res.append(file.entities[method_dec]) + break + return res + + def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, key: str, symbol: Node) -> Entity: + if key in ["implement_interface", "base_class", "parameters", "return_type"]: + return self.resolve_type(files, lsp, file_path, path, symbol) + elif key in ["call"]: + return self.resolve_method(files, lsp, file_path, path, symbol) + else: + raise ValueError(f"Unknown key {key}") diff --git a/api/analyzers/source_analyzer.py b/api/analyzers/source_analyzer.py index 12502ab..caf0a53 100644 --- a/api/analyzers/source_analyzer.py +++ b/api/analyzers/source_analyzer.py @@ -9,6 +9,7 @@ from .analyzer import AbstractAnalyzer # from .c.analyzer import CAnalyzer from .java.analyzer import JavaAnalyzer +from .kotlin.analyzer import KotlinAnalyzer from .python.analyzer import PythonAnalyzer from multilspy import SyncLanguageServer @@ -24,7 +25,9 @@ # '.c': CAnalyzer(), # '.h': CAnalyzer(), '.py': PythonAnalyzer(), - '.java': JavaAnalyzer()} + '.java': JavaAnalyzer(), + '.kt': KotlinAnalyzer(), + '.kts': KotlinAnalyzer()} class NullLanguageServer: def start_server(self): @@ -136,7 +139,14 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None: lsps[".py"] = SyncLanguageServer.create(config, logger, str(path)) else: lsps[".py"] = NullLanguageServer() - with lsps[".java"].start_server(), lsps[".py"].start_server(): + if any(path.rglob('*.kt')) or any(path.rglob('*.kts')): + # For now, use NullLanguageServer for Kotlin as we need to set up kotlin-language-server + lsps[".kt"] = NullLanguageServer() + lsps[".kts"] = NullLanguageServer() + else: + lsps[".kt"] = NullLanguageServer() + lsps[".kts"] = NullLanguageServer() + with lsps[".java"].start_server(), lsps[".py"].start_server(), lsps[".kt"].start_server(), lsps[".kts"].start_server(): files_len = len(self.files) for i, file_path in enumerate(files): file = self.files[file_path] @@ -163,7 +173,7 @@ def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None: self.second_pass(graph, files, path) def analyze_sources(self, path: Path, ignore: list[str], graph: Graph) -> None: - files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + list(path.rglob("*.kt")) + list(path.rglob("*.kts")) # First pass analysis of the source code self.first_pass(path, files, ignore, graph) diff --git a/pyproject.toml b/pyproject.toml index b817a1f..f145598 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ falkordb = "^1.0.10" tree-sitter-c = "^0.23.4" tree-sitter-python = "^0.23.6" tree-sitter-java = "^0.23.5" +tree-sitter-kotlin = "^1.1.0" flask = "^3.1.0" python-dotenv = "^1.0.1" multilspy = {git = "https://github.com/AviAvni/multilspy.git", rev = "python-init-params"} diff --git a/tests/source_files/kotlin/src.kt b/tests/source_files/kotlin/src.kt new file mode 100644 index 0000000..5f79849 --- /dev/null +++ b/tests/source_files/kotlin/src.kt @@ -0,0 +1,30 @@ +// Test Kotlin file for analyzer +fun log(msg: String) { + println(msg) +} + +interface Task { + fun execute() +} + +class WorkerTask(val name: String, var duration: Int) : Task { + override fun execute() { + log("Executing task: $name") + } + + fun abort(delay: Float): WorkerTask { + log("Aborting task") + return this + } +} + +object TaskManager { + fun createTask(name: String): WorkerTask { + return WorkerTask(name, 0) + } +} + +fun main() { + val task = TaskManager.createTask("Test") + task.execute() +} diff --git a/tests/test_kotlin_analyzer.py b/tests/test_kotlin_analyzer.py new file mode 100644 index 0000000..86b6125 --- /dev/null +++ b/tests/test_kotlin_analyzer.py @@ -0,0 +1,24 @@ +import unittest +from pathlib import Path + +from api import SourceAnalyzer, Graph + +class Test_Kotlin_Analyzer(unittest.TestCase): + def test_analyzer(self): + analyzer = SourceAnalyzer() + + # Get the path to the test Kotlin source files + path = str(Path(__file__).parent / 'source_files' / 'kotlin') + + g = Graph("kotlin") + analyzer.analyze_local_folder(path, g) + + # Test that files were parsed + self.assertGreater(len(g.entities), 0) + + print(f"Entities found: {len(g.entities)}") + for entity_id, entity_data in g.entities.items(): + print(f" - {entity_data.get('label')}: {entity_data.get('name')}") + +if __name__ == '__main__': + unittest.main()