@@ -9,11 +9,13 @@ use pgt_treesitter_queries::{
99use crate :: sanitization:: SanitizedCompletionParams ;
1010
1111#[ derive( Debug , PartialEq , Eq ) ]
12- pub enum ClauseType {
12+ pub enum WrappingClause < ' a > {
1313 Select ,
1414 Where ,
1515 From ,
16- Join ,
16+ Join {
17+ on_node : Option < tree_sitter:: Node < ' a > > ,
18+ } ,
1719 Update ,
1820 Delete ,
1921}
@@ -24,38 +26,6 @@ pub(crate) enum NodeText<'a> {
2426 Original ( & ' a str ) ,
2527}
2628
27- impl TryFrom < & str > for ClauseType {
28- type Error = String ;
29-
30- fn try_from ( value : & str ) -> Result < Self , Self :: Error > {
31- match value {
32- "select" => Ok ( Self :: Select ) ,
33- "where" => Ok ( Self :: Where ) ,
34- "from" => Ok ( Self :: From ) ,
35- "update" => Ok ( Self :: Update ) ,
36- "delete" => Ok ( Self :: Delete ) ,
37- "join" => Ok ( Self :: Join ) ,
38- _ => {
39- let message = format ! ( "Unimplemented ClauseType: {}" , value) ;
40-
41- // Err on tests, so we notice that we're lacking an implementation immediately.
42- if cfg ! ( test) {
43- panic ! ( "{}" , message) ;
44- }
45-
46- Err ( message)
47- }
48- }
49- }
50- }
51-
52- impl TryFrom < String > for ClauseType {
53- type Error = String ;
54- fn try_from ( value : String ) -> Result < Self , Self :: Error > {
55- Self :: try_from ( value. as_str ( ) )
56- }
57- }
58-
5929/// We can map a few nodes, such as the "update" node, to actual SQL clauses.
6030/// That gives us a lot of insight for completions.
6131/// Other nodes, such as the "relation" node, gives us less but still
@@ -127,7 +97,7 @@ pub(crate) struct CompletionContext<'a> {
12797 /// on u.id = i.user_id;
12898 /// ```
12999 pub schema_or_alias_name : Option < String > ,
130- pub wrapping_clause_type : Option < ClauseType > ,
100+ pub wrapping_clause_type : Option < WrappingClause < ' a > > ,
131101
132102 pub wrapping_node_kind : Option < WrappingNode > ,
133103
@@ -266,7 +236,9 @@ impl<'a> CompletionContext<'a> {
266236
267237 match parent_node_kind {
268238 "statement" | "subquery" => {
269- self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ;
239+ self . wrapping_clause_type =
240+ self . get_wrapping_clause_from_current_node ( current_node, & mut cursor) ;
241+
270242 self . wrapping_statement_range = Some ( parent_node. range ( ) ) ;
271243 }
272244 "invocation" => self . is_invocation = true ,
@@ -277,39 +249,21 @@ impl<'a> CompletionContext<'a> {
277249 if self . is_in_error_node {
278250 let mut next_sibling = current_node. next_named_sibling ( ) ;
279251 while let Some ( n) = next_sibling {
280- if n. kind ( ) . starts_with ( "keyword_" ) {
281- if let Some ( txt) = self . get_ts_node_content ( n) . and_then ( |txt| match txt {
282- NodeText :: Original ( txt) => Some ( txt) ,
283- NodeText :: Replaced => None ,
284- } ) {
285- match txt {
286- "where" | "update" | "select" | "delete" | "from" | "join" => {
287- self . wrapping_clause_type = txt. try_into ( ) . ok ( ) ;
288- break ;
289- }
290- _ => { }
291- }
292- } ;
252+ if let Some ( clause_type) = self . get_wrapping_clause_from_keyword_node ( n) {
253+ self . wrapping_clause_type = Some ( clause_type) ;
254+ break ;
255+ } else {
256+ next_sibling = n. next_named_sibling ( ) ;
293257 }
294- next_sibling = n. next_named_sibling ( ) ;
295258 }
296259 let mut prev_sibling = current_node. prev_named_sibling ( ) ;
297260 while let Some ( n) = prev_sibling {
298- if n. kind ( ) . starts_with ( "keyword_" ) {
299- if let Some ( txt) = self . get_ts_node_content ( n) . and_then ( |txt| match txt {
300- NodeText :: Original ( txt) => Some ( txt) ,
301- NodeText :: Replaced => None ,
302- } ) {
303- match txt {
304- "where" | "update" | "select" | "delete" | "from" | "join" => {
305- self . wrapping_clause_type = txt. try_into ( ) . ok ( ) ;
306- break ;
307- }
308- _ => { }
309- }
310- } ;
261+ if let Some ( clause_type) = self . get_wrapping_clause_from_keyword_node ( n) {
262+ self . wrapping_clause_type = Some ( clause_type) ;
263+ break ;
264+ } else {
265+ prev_sibling = n. prev_named_sibling ( ) ;
311266 }
312- prev_sibling = n. prev_named_sibling ( ) ;
313267 }
314268 }
315269
@@ -330,7 +284,8 @@ impl<'a> CompletionContext<'a> {
330284 }
331285
332286 "where" | "update" | "select" | "delete" | "from" | "join" => {
333- self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ;
287+ self . wrapping_clause_type =
288+ self . get_wrapping_clause_from_current_node ( current_node, & mut cursor) ;
334289 }
335290
336291 "relation" | "binary_expression" | "assignment" => {
@@ -353,12 +308,67 @@ impl<'a> CompletionContext<'a> {
353308 cursor. goto_first_child_for_byte ( self . position ) ;
354309 self . gather_context_from_node ( cursor, current_node) ;
355310 }
311+
312+ fn get_wrapping_clause_from_keyword_node (
313+ & self ,
314+ node : tree_sitter:: Node < ' a > ,
315+ ) -> Option < WrappingClause < ' a > > {
316+ if node. kind ( ) . starts_with ( "keyword_" ) {
317+ if let Some ( txt) = self . get_ts_node_content ( node) . and_then ( |txt| match txt {
318+ NodeText :: Original ( txt) => Some ( txt) ,
319+ NodeText :: Replaced => None ,
320+ } ) {
321+ match txt {
322+ "where" => return Some ( WrappingClause :: Where ) ,
323+ "update" => return Some ( WrappingClause :: Update ) ,
324+ "select" => return Some ( WrappingClause :: Select ) ,
325+ "delete" => return Some ( WrappingClause :: Delete ) ,
326+ "from" => return Some ( WrappingClause :: From ) ,
327+ "join" => {
328+ // TODO: not sure if we can infer it here.
329+ return Some ( WrappingClause :: Join { on_node : None } ) ;
330+ }
331+ _ => { }
332+ }
333+ } ;
334+ }
335+
336+ None
337+ }
338+
339+ fn get_wrapping_clause_from_current_node (
340+ & self ,
341+ node : tree_sitter:: Node < ' a > ,
342+ cursor : & mut tree_sitter:: TreeCursor < ' a > ,
343+ ) -> Option < WrappingClause < ' a > > {
344+ match node. kind ( ) {
345+ "where" => Some ( WrappingClause :: Where ) ,
346+ "update" => Some ( WrappingClause :: Update ) ,
347+ "select" => Some ( WrappingClause :: Select ) ,
348+ "delete" => Some ( WrappingClause :: Delete ) ,
349+ "from" => Some ( WrappingClause :: From ) ,
350+ "join" => {
351+ // sadly, we need to manually iterate over the children –
352+ // `node.child_by_field_id(..)` does not work as expected
353+ let mut on_node = None ;
354+ for child in node. children ( cursor) {
355+ // 28 is the id for "keyword_on"
356+ if child. kind_id ( ) == 28 {
357+ on_node = Some ( child) ;
358+ }
359+ }
360+ cursor. goto_parent ( ) ;
361+ Some ( WrappingClause :: Join { on_node } )
362+ }
363+ _ => None ,
364+ }
365+ }
356366}
357367
358368#[ cfg( test) ]
359369mod tests {
360370 use crate :: {
361- context:: { ClauseType , CompletionContext , NodeText } ,
371+ context:: { CompletionContext , NodeText , WrappingClause } ,
362372 sanitization:: SanitizedCompletionParams ,
363373 test_helper:: { CURSOR_POS , get_text_and_position} ,
364374 } ;
@@ -375,29 +385,41 @@ mod tests {
375385 #[ test]
376386 fn identifies_clauses ( ) {
377387 let test_cases = vec ! [
378- ( format!( "Select {}* from users;" , CURSOR_POS ) , "select" ) ,
379- ( format!( "Select * from u{};" , CURSOR_POS ) , "from" ) ,
388+ (
389+ format!( "Select {}* from users;" , CURSOR_POS ) ,
390+ WrappingClause :: Select ,
391+ ) ,
392+ (
393+ format!( "Select * from u{};" , CURSOR_POS ) ,
394+ WrappingClause :: From ,
395+ ) ,
380396 (
381397 format!( "Select {}* from users where n = 1;" , CURSOR_POS ) ,
382- "select" ,
398+ WrappingClause :: Select ,
383399 ) ,
384400 (
385401 format!( "Select * from users where {}n = 1;" , CURSOR_POS ) ,
386- "where" ,
402+ WrappingClause :: Where ,
387403 ) ,
388404 (
389405 format!( "update users set u{} = 1 where n = 2;" , CURSOR_POS ) ,
390- "update" ,
406+ WrappingClause :: Update ,
391407 ) ,
392408 (
393409 format!( "update users set u = 1 where n{} = 2;" , CURSOR_POS ) ,
394- "where" ,
410+ WrappingClause :: Where ,
411+ ) ,
412+ (
413+ format!( "delete{} from users;" , CURSOR_POS ) ,
414+ WrappingClause :: Delete ,
415+ ) ,
416+ (
417+ format!( "delete from {}users;" , CURSOR_POS ) ,
418+ WrappingClause :: From ,
395419 ) ,
396- ( format!( "delete{} from users;" , CURSOR_POS ) , "delete" ) ,
397- ( format!( "delete from {}users;" , CURSOR_POS ) , "from" ) ,
398420 (
399421 format!( "select name, age, location from public.u{}sers" , CURSOR_POS ) ,
400- "from" ,
422+ WrappingClause :: From ,
401423 ) ,
402424 ] ;
403425
@@ -415,7 +437,7 @@ mod tests {
415437
416438 let ctx = CompletionContext :: new ( & params) ;
417439
418- assert_eq ! ( ctx. wrapping_clause_type, expected_clause . try_into ( ) . ok ( ) ) ;
440+ assert_eq ! ( ctx. wrapping_clause_type, Some ( expected_clause ) ) ;
419441 }
420442 }
421443
@@ -518,7 +540,7 @@ mod tests {
518540
519541 assert_eq ! (
520542 ctx. wrapping_clause_type,
521- Some ( crate :: context:: ClauseType :: Select )
543+ Some ( crate :: context:: WrappingClause :: Select )
522544 ) ;
523545 }
524546 }
@@ -596,6 +618,6 @@ mod tests {
596618 ctx. get_ts_node_content( node) ,
597619 Some ( NodeText :: Original ( "fro" ) )
598620 ) ;
599- assert_eq ! ( ctx. wrapping_clause_type, Some ( ClauseType :: Select ) ) ;
621+ assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
600622 }
601623}
0 commit comments