|
| 1 | +use std::sync::LazyLock; |
| 2 | + |
| 3 | +use crate::{Query, QueryResult}; |
| 4 | + |
| 5 | +use super::QueryTryFrom; |
| 6 | + |
| 7 | +static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| { |
| 8 | + static QUERY_STR: &str = r#" |
| 9 | +[ |
| 10 | + (field |
| 11 | + (identifier)) @reference |
| 12 | + (field |
| 13 | + (object_reference) |
| 14 | + "." (identifier)) @reference |
| 15 | + (parameter) @parameter |
| 16 | +] |
| 17 | +"#; |
| 18 | + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") |
| 19 | +}); |
| 20 | + |
| 21 | +#[derive(Debug)] |
| 22 | +pub struct ParameterMatch<'a> { |
| 23 | + pub(crate) node: tree_sitter::Node<'a>, |
| 24 | +} |
| 25 | + |
| 26 | +impl ParameterMatch<'_> { |
| 27 | + pub fn get_path(&self, sql: &str) -> String { |
| 28 | + self.node |
| 29 | + .utf8_text(sql.as_bytes()) |
| 30 | + .expect("Failed to get path from ParameterMatch") |
| 31 | + .to_string() |
| 32 | + } |
| 33 | + |
| 34 | + pub fn get_range(&self) -> tree_sitter::Range { |
| 35 | + self.node.range() |
| 36 | + } |
| 37 | + |
| 38 | + pub fn get_byte_range(&self) -> std::ops::Range<usize> { |
| 39 | + let range = self.node.range(); |
| 40 | + range.start_byte..range.end_byte |
| 41 | + } |
| 42 | +} |
| 43 | + |
| 44 | +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a ParameterMatch<'a> { |
| 45 | + type Error = String; |
| 46 | + |
| 47 | + fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> { |
| 48 | + match q { |
| 49 | + QueryResult::Parameter(r) => Ok(r), |
| 50 | + |
| 51 | + #[allow(unreachable_patterns)] |
| 52 | + _ => Err("Invalid QueryResult type".into()), |
| 53 | + } |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +impl<'a> QueryTryFrom<'a> for ParameterMatch<'a> { |
| 58 | + type Ref = &'a ParameterMatch<'a>; |
| 59 | +} |
| 60 | + |
| 61 | +impl<'a> Query<'a> for ParameterMatch<'a> { |
| 62 | + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> { |
| 63 | + let mut cursor = tree_sitter::QueryCursor::new(); |
| 64 | + |
| 65 | + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); |
| 66 | + |
| 67 | + matches |
| 68 | + .filter_map(|m| { |
| 69 | + let captures = m.captures; |
| 70 | + |
| 71 | + // We expect exactly one capture for a parameter |
| 72 | + if captures.len() != 1 { |
| 73 | + return None; |
| 74 | + } |
| 75 | + |
| 76 | + Some(QueryResult::Parameter(ParameterMatch { |
| 77 | + node: captures[0].node, |
| 78 | + })) |
| 79 | + }) |
| 80 | + .collect() |
| 81 | + } |
| 82 | +} |
0 commit comments