Skip to content

Commit 9182ffb

Browse files
alrighty
1 parent c8c6b1b commit 9182ffb

File tree

5 files changed

+46
-78
lines changed

5 files changed

+46
-78
lines changed

crates/pgt_treesitter/src/queries/helper.rs

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1+
use std::sync::LazyLock;
2+
13
use tree_sitter::Node;
4+
use tree_sitter::StreamingIterator;
25

36
pub(crate) static OBJECT_REFERENCE_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
47
static QUERY_STR: &str = r#"
5-
(object_reference
6-
object_reference_1of1: (any_identifier) @tail
7-
)
8-
(object_reference
9-
object_reference_1of2: (any_identifier) @head
10-
object_reference_2of2: (any_identifier) @tail
11-
)
12-
(object_reference
13-
object_reference_1of1: (any_identifier) @head
14-
object_reference_1of2: (any_identifier) @middle
15-
object_reference_1of3: (any_identifier) @tail
16-
)
8+
(object_reference
9+
object_reference_1of1: (any_identifier) @tail
10+
)
11+
(object_reference
12+
object_reference_1of2: (any_identifier) @head
13+
object_reference_2of2: (any_identifier) @tail
14+
)
15+
(object_reference
16+
object_reference_1of3: (any_identifier) @head
17+
object_reference_2of3: (any_identifier) @middle
18+
object_reference_3of3: (any_identifier) @tail
1719
)
1820
"#;
1921
tree_sitter::Query::new(&pgt_treesitter_grammar::LANGUAGE.into(), QUERY_STR)
@@ -25,25 +27,22 @@ pub(crate) fn object_reference_query<'a>(
2527
stmt: &'a str,
2628
) -> Option<(Option<Node<'a>>, Option<Node<'a>>, Node<'a>)> {
2729
let mut cursor = tree_sitter::QueryCursor::new();
28-
let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes());
30+
let mut matches = cursor.matches(&OBJECT_REFERENCE_QUERY, node, stmt.as_bytes());
2931

30-
assert!(
31-
matches.len() <= 1,
32-
"Please pass a single `object_reference` node into the `object_reference_query`!"
33-
);
34-
35-
if matches[0].len() == 0 {
36-
None
37-
} else if matches[0].captures.len() == 1 {
38-
Some((None, None, m.captures[0].node))
39-
} else if matches[0].captures.len() == 2 {
40-
Some((None, m.captures[0].node, m.captures[1].node))
41-
} else if matches[0].captures.len() == 3 {
42-
Some((
43-
Some(m.captures[0].node),
44-
Some(m.captures[1].node),
45-
m.captures[2].node,
46-
))
32+
if let Some(next) = matches.next() {
33+
if next.captures.len() == 1 {
34+
Some((None, None, next.captures[0].node))
35+
} else if next.captures.len() == 2 {
36+
Some((None, Some(next.captures[0].node), next.captures[1].node))
37+
} else if next.captures.len() == 3 {
38+
Some((
39+
Some(next.captures[0].node),
40+
Some(next.captures[1].node),
41+
next.captures[2].node,
42+
))
43+
} else {
44+
None
45+
}
4746
} else {
4847
None
4948
}

crates/pgt_treesitter/src/queries/relations.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ impl<'a> Query<'a> for RelationMatch<'a> {
7777
matches.for_each(|m| {
7878
m.captures.iter().for_each(|capture| {
7979
if let Some((_, schema, table)) = object_reference_query(capture.node, stmt) {
80-
to_return.push(RelationMatch { schema, table })
80+
to_return.push(QueryResult::Relation(RelationMatch { schema, table }));
8181
}
8282
});
8383
});

crates/pgt_treesitter/src/queries/select_columns.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ impl<'a> Query<'a> for SelectColumnMatch<'a> {
7373
matches.for_each(|m| {
7474
m.captures.iter().for_each(|capture| {
7575
if let Some((schema, alias, column)) = object_reference_query(capture.node, stmt) {
76-
to_return.push(SelectColumnMatch {
76+
to_return.push(QueryResult::SelectClauseColumns(SelectColumnMatch {
7777
schema,
7878
alias,
7979
column,
80-
});
80+
}));
8181
}
8282
});
8383
});

crates/pgt_treesitter/src/queries/table_aliases.rs

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::LazyLock;
22

3-
use crate::queries::{Query, QueryResult};
3+
use crate::queries::{Query, QueryResult, helper::object_reference_query};
44
use tree_sitter::StreamingIterator;
55

66
use super::QueryTryFrom;
@@ -74,17 +74,6 @@ impl<'a> Query<'a> for TableAliasMatch<'a> {
7474
let mut to_return = vec![];
7575

7676
matches.for_each(|m| {
77-
if m.captures.len() == 1 {
78-
let obj_ref = m.captures[0].node;
79-
if let Some((_, schema, table)) = object_reference_query(obj_ref, stmt) {
80-
to_return.push(QueryResult::TableAliases(TableAliasMatch {
81-
schema,
82-
table,
83-
alias: None,
84-
}));
85-
}
86-
}
87-
8877
if m.captures.len() == 2 {
8978
let obj_ref = m.captures[0].node;
9079
let alias = m.captures[1].node;

crates/pgt_treesitter/src/queries/where_columns.rs

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::LazyLock;
22

3-
use crate::queries::{Query, QueryResult};
3+
use crate::queries::{Query, QueryResult, helper::object_reference_query};
44

55
use tree_sitter::StreamingIterator;
66

@@ -14,14 +14,14 @@ static WHERE_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
1414
.expect("Invalid TS Query")
1515
});
1616

17+
/**
18+
* The binary expressions can be nested inside a @where clause, e.g. (where user_id = 1 or (email = 2 and user_id = 3));
19+
* We'll need a separate query to find all nested binary expressions.
20+
*/
1721
static BINARY_EXPR_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
1822
static QUERY_STR: &str = r#"
1923
(binary_expression
20-
binary_expr_left: (object_reference
21-
object_reference_first: (any_identifier) @first
22-
object_reference_second: (any_identifier)? @second
23-
object_reference_third: (any_identifier)? @third
24-
)
24+
binary_expr_left: (object_reference) @ref
2525
)
2626
"#;
2727
tree_sitter::Query::new(&pgt_treesitter_grammar::LANGUAGE.into(), QUERY_STR)
@@ -92,34 +92,14 @@ impl<'a> Query<'a> for WhereColumnMatch<'a> {
9292
binary_expr_matches.for_each(|m| {
9393
if m.captures.len() == 1 {
9494
let capture = m.captures[0].node;
95-
to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch {
96-
schema: None,
97-
alias: None,
98-
column: capture,
99-
}));
100-
}
101-
102-
if m.captures.len() == 2 {
103-
let alias = m.captures[0].node;
104-
let column = m.captures[1].node;
105-
106-
to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch {
107-
schema: None,
108-
alias: Some(alias),
109-
column,
110-
}));
111-
}
112-
113-
if m.captures.len() == 3 {
114-
let schema = m.captures[0].node;
115-
let alias = m.captures[1].node;
116-
let column = m.captures[2].node;
11795

118-
to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch {
119-
schema: Some(schema),
120-
alias: Some(alias),
121-
column,
122-
}));
96+
if let Some((schema, alias, column)) = object_reference_query(capture, stmt) {
97+
to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch {
98+
schema,
99+
alias,
100+
column,
101+
}));
102+
}
123103
}
124104
})
125105
});

0 commit comments

Comments
 (0)