Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions crates/postgresql-cst-parser/src/tree_sitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ impl std::fmt::Display for Range {
}
}

impl Range {
pub fn extended_by(&self, other: &Self) -> Self {
Range {
start_byte: self.start_byte.min(other.start_byte),
end_byte: self.end_byte.max(other.end_byte),

start_position: Point {
row: self.start_position.row.min(other.start_position.row),
column: self.start_position.column.min(other.start_position.column),
},
end_position: Point {
row: self.end_position.row.max(other.end_position.row),
column: self.end_position.column.max(other.end_position.column),
},
}
}
}

impl<'a> Node<'a> {
pub fn walk(&self) -> TreeCursor<'a> {
TreeCursor {
Expand Down Expand Up @@ -144,6 +162,20 @@ impl<'a> Node<'a> {
}
}

pub fn children(&self) -> Vec<Node<'a>> {
if let Some(node) = self.node_or_token.as_node() {
node.children_with_tokens()
.map(|node| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: node,
})
.collect()
} else {
vec![]
}
}

pub fn next_sibling(&self) -> Option<Node<'a>> {
self.node_or_token
.next_sibling_or_token()
Expand All @@ -154,6 +186,16 @@ impl<'a> Node<'a> {
})
}

pub fn prev_sibling(&self) -> Option<Node<'a>> {
self.node_or_token
.prev_sibling_or_token()
.map(|sibling| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: sibling,
})
}

pub fn parent(&self) -> Option<Node<'a>> {
self.node_or_token.parent().map(|parent| Node {
input: self.input,
Expand All @@ -165,6 +207,59 @@ impl<'a> Node<'a> {
pub fn is_comment(&self) -> bool {
matches!(self.kind(), SyntaxKind::C_COMMENT | SyntaxKind::SQL_COMMENT)
}

/// Return the rightmost token in the subtree of this node
/// this is not tree-sitter's API
pub fn last_node(&self) -> Option<Node<'a>> {
match &self.node_or_token {
NodeOrToken::Node(node) => node.last_token().map(|token| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(token),
}),
NodeOrToken::Token(token) => Some(Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(token),
}),
}
}

/// Returns an iterator over all descendant nodes (not including tokens)
/// this is not tree-sitter's API
pub fn descendants(&self) -> impl Iterator<Item = Node<'a>> {
struct Descendants<'a> {
input: &'a str,
range_map: Rc<HashMap<TextRange, Range>>,
iter: Box<dyn Iterator<Item = &'a ResolvedNode> + 'a>,
}

impl<'a> Iterator for Descendants<'a> {
type Item = Node<'a>;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|node| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Node(node),
})
}
}

if let Some(node) = self.node_or_token.as_node() {
Descendants {
input: self.input,
range_map: Rc::clone(&self.range_map),
iter: Box::new(node.descendants()),
}
} else {
Descendants {
input: self.input,
range_map: Rc::clone(&self.range_map),
iter: Box::new(std::iter::empty()),
}
}
}
}

impl<'a> From<Node<'a>> for TreeCursor<'a> {
Expand Down Expand Up @@ -214,6 +309,15 @@ impl<'a> TreeCursor<'a> {
}
}

pub fn goto_prev_sibling(&mut self) -> bool {
if let Some(sibling) = self.node_or_token.prev_sibling_or_token() {
self.node_or_token = sibling;
true
} else {
false
}
}

pub fn is_comment(&self) -> bool {
matches!(
self.node_or_token.kind(),
Expand Down Expand Up @@ -462,4 +566,36 @@ from

assert_eq!(stmt_count, 2);
}

#[test]
fn test_last_node_returns_rightmost_node() {
let src = "SELECT u.*, (v).id, name;";
let tree = parse(src).unwrap();
let root = tree.root_node();

let target_list = root
.descendants()
.find(|node| node.kind() == SyntaxKind::target_list)
.expect("should find target_list");

// last node of the target_list is returned
let last_node = target_list.last_node().expect("should have last node");
assert_eq!(last_node.text(), "name");

let target_els = target_list
.children()
.into_iter()
.filter(|node| node.kind() == SyntaxKind::target_el)
.collect::<Vec<_>>();

let mut last_nodes = target_els
.iter()
.map(|node| node.last_node().expect("should have last node"));

// last node of each target_el is returned
assert_eq!(last_nodes.next().unwrap().text(), "*");
assert_eq!(last_nodes.next().unwrap().text(), "id");
assert_eq!(last_nodes.next().unwrap().text(), "name");
assert!(last_nodes.next().is_none());
}
}
Loading