diff --git a/crates/postgresql-cst-parser/src/tree_sitter.rs b/crates/postgresql-cst-parser/src/tree_sitter.rs index 85adab5..94af1a3 100644 --- a/crates/postgresql-cst-parser/src/tree_sitter.rs +++ b/crates/postgresql-cst-parser/src/tree_sitter.rs @@ -98,6 +98,24 @@ impl std::fmt::Display for Range { } } +impl Range { + pub fn extended_by(&self, other: &Self) -> Self { + Range { + start_byte: self.start_byte.min(other.start_byte), + end_byte: self.end_byte.max(other.end_byte), + + start_position: Point { + row: self.start_position.row.min(other.start_position.row), + column: self.start_position.column.min(other.start_position.column), + }, + end_position: Point { + row: self.end_position.row.max(other.end_position.row), + column: self.end_position.column.max(other.end_position.column), + }, + } + } +} + impl<'a> Node<'a> { pub fn walk(&self) -> TreeCursor<'a> { TreeCursor { @@ -144,6 +162,20 @@ impl<'a> Node<'a> { } } + pub fn children(&self) -> Vec> { + if let Some(node) = self.node_or_token.as_node() { + node.children_with_tokens() + .map(|node| Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: node, + }) + .collect() + } else { + vec![] + } + } + pub fn next_sibling(&self) -> Option> { self.node_or_token .next_sibling_or_token() @@ -154,6 +186,16 @@ impl<'a> Node<'a> { }) } + pub fn prev_sibling(&self) -> Option> { + self.node_or_token + .prev_sibling_or_token() + .map(|sibling| Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: sibling, + }) + } + pub fn parent(&self) -> Option> { self.node_or_token.parent().map(|parent| Node { input: self.input, @@ -165,6 +207,59 @@ impl<'a> Node<'a> { pub fn is_comment(&self) -> bool { matches!(self.kind(), SyntaxKind::C_COMMENT | SyntaxKind::SQL_COMMENT) } + + /// Return the rightmost token in the subtree of this node + /// this is not tree-sitter's API + pub fn last_node(&self) -> Option> { + match &self.node_or_token { + NodeOrToken::Node(node) => node.last_token().map(|token| Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: NodeOrToken::Token(token), + }), + NodeOrToken::Token(token) => Some(Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: NodeOrToken::Token(token), + }), + } + } + + /// Returns an iterator over all descendant nodes (not including tokens) + /// this is not tree-sitter's API + pub fn descendants(&self) -> impl Iterator> { + struct Descendants<'a> { + input: &'a str, + range_map: Rc>, + iter: Box + 'a>, + } + + impl<'a> Iterator for Descendants<'a> { + type Item = Node<'a>; + + fn next(&mut self) -> Option { + self.iter.next().map(|node| Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: NodeOrToken::Node(node), + }) + } + } + + if let Some(node) = self.node_or_token.as_node() { + Descendants { + input: self.input, + range_map: Rc::clone(&self.range_map), + iter: Box::new(node.descendants()), + } + } else { + Descendants { + input: self.input, + range_map: Rc::clone(&self.range_map), + iter: Box::new(std::iter::empty()), + } + } + } } impl<'a> From> for TreeCursor<'a> { @@ -214,6 +309,15 @@ impl<'a> TreeCursor<'a> { } } + pub fn goto_prev_sibling(&mut self) -> bool { + if let Some(sibling) = self.node_or_token.prev_sibling_or_token() { + self.node_or_token = sibling; + true + } else { + false + } + } + pub fn is_comment(&self) -> bool { matches!( self.node_or_token.kind(), @@ -462,4 +566,36 @@ from assert_eq!(stmt_count, 2); } + + #[test] + fn test_last_node_returns_rightmost_node() { + let src = "SELECT u.*, (v).id, name;"; + let tree = parse(src).unwrap(); + let root = tree.root_node(); + + let target_list = root + .descendants() + .find(|node| node.kind() == SyntaxKind::target_list) + .expect("should find target_list"); + + // last node of the target_list is returned + let last_node = target_list.last_node().expect("should have last node"); + assert_eq!(last_node.text(), "name"); + + let target_els = target_list + .children() + .into_iter() + .filter(|node| node.kind() == SyntaxKind::target_el) + .collect::>(); + + let mut last_nodes = target_els + .iter() + .map(|node| node.last_node().expect("should have last node")); + + // last node of each target_el is returned + assert_eq!(last_nodes.next().unwrap().text(), "*"); + assert_eq!(last_nodes.next().unwrap().text(), "id"); + assert_eq!(last_nodes.next().unwrap().text(), "name"); + assert!(last_nodes.next().is_none()); + } }