1+ mod policy_parser;
2+
13use std:: collections:: { HashMap , HashSet } ;
24
35use pgt_schema_cache:: SchemaCache ;
6+ use pgt_text_size:: TextRange ;
47use pgt_treesitter_queries:: {
58 TreeSitterQueriesExecutor ,
69 queries:: { self , QueryResult } ,
710} ;
811
9- use crate :: sanitization:: SanitizedCompletionParams ;
12+ use crate :: {
13+ NodeText ,
14+ context:: policy_parser:: { PolicyParser , PolicyStmtKind } ,
15+ sanitization:: SanitizedCompletionParams ,
16+ } ;
1017
1118#[ derive( Debug , PartialEq , Eq , Hash ) ]
1219pub enum WrappingClause < ' a > {
@@ -18,12 +25,8 @@ pub enum WrappingClause<'a> {
1825 } ,
1926 Update ,
2027 Delete ,
21- }
22-
23- #[ derive( PartialEq , Eq , Debug ) ]
24- pub ( crate ) enum NodeText < ' a > {
25- Replaced ,
26- Original ( & ' a str ) ,
28+ PolicyName ,
29+ ToRoleAssignment ,
2730}
2831
2932#[ derive( PartialEq , Eq , Hash , Debug ) ]
@@ -47,6 +50,45 @@ pub enum WrappingNode {
4750 Assignment ,
4851}
4952
53+ #[ derive( Debug ) ]
54+ pub ( crate ) enum NodeUnderCursor < ' a > {
55+ TsNode ( tree_sitter:: Node < ' a > ) ,
56+ CustomNode {
57+ text : NodeText ,
58+ range : TextRange ,
59+ kind : String ,
60+ } ,
61+ }
62+
63+ impl NodeUnderCursor < ' _ > {
64+ pub fn start_byte ( & self ) -> usize {
65+ match self {
66+ NodeUnderCursor :: TsNode ( node) => node. start_byte ( ) ,
67+ NodeUnderCursor :: CustomNode { range, .. } => range. start ( ) . into ( ) ,
68+ }
69+ }
70+
71+ pub fn end_byte ( & self ) -> usize {
72+ match self {
73+ NodeUnderCursor :: TsNode ( node) => node. end_byte ( ) ,
74+ NodeUnderCursor :: CustomNode { range, .. } => range. end ( ) . into ( ) ,
75+ }
76+ }
77+
78+ pub fn kind ( & self ) -> & str {
79+ match self {
80+ NodeUnderCursor :: TsNode ( node) => node. kind ( ) ,
81+ NodeUnderCursor :: CustomNode { kind, .. } => kind. as_str ( ) ,
82+ }
83+ }
84+ }
85+
86+ impl < ' a > From < tree_sitter:: Node < ' a > > for NodeUnderCursor < ' a > {
87+ fn from ( node : tree_sitter:: Node < ' a > ) -> Self {
88+ NodeUnderCursor :: TsNode ( node)
89+ }
90+ }
91+
5092impl TryFrom < & str > for WrappingNode {
5193 type Error = String ;
5294
@@ -77,7 +119,7 @@ impl TryFrom<String> for WrappingNode {
77119}
78120
79121pub ( crate ) struct CompletionContext < ' a > {
80- pub node_under_cursor : Option < tree_sitter :: Node < ' a > > ,
122+ pub node_under_cursor : Option < NodeUnderCursor < ' a > > ,
81123
82124 pub tree : & ' a tree_sitter:: Tree ,
83125 pub text : & ' a str ,
@@ -137,12 +179,49 @@ impl<'a> CompletionContext<'a> {
137179 is_in_error_node : false ,
138180 } ;
139181
140- ctx. gather_tree_context ( ) ;
141- ctx. gather_info_from_ts_queries ( ) ;
182+ // policy handling is important to Supabase, but they are a PostgreSQL specific extension,
183+ // so the tree_sitter_sql language does not support it.
184+ // We infer the context manually.
185+ if PolicyParser :: looks_like_policy_stmt ( & params. text ) {
186+ ctx. gather_policy_context ( ) ;
187+ } else {
188+ ctx. gather_tree_context ( ) ;
189+ ctx. gather_info_from_ts_queries ( ) ;
190+ }
142191
143192 ctx
144193 }
145194
195+ fn gather_policy_context ( & mut self ) {
196+ let policy_context = PolicyParser :: get_context ( self . text , self . position ) ;
197+
198+ self . node_under_cursor = Some ( NodeUnderCursor :: CustomNode {
199+ text : policy_context. node_text . into ( ) ,
200+ range : policy_context. node_range ,
201+ kind : policy_context. node_kind . clone ( ) ,
202+ } ) ;
203+
204+ if policy_context. node_kind == "policy_table" {
205+ self . schema_or_alias_name = policy_context. schema_name . clone ( ) ;
206+ }
207+
208+ if policy_context. table_name . is_some ( ) {
209+ let mut new = HashSet :: new ( ) ;
210+ new. insert ( policy_context. table_name . unwrap ( ) ) ;
211+ self . mentioned_relations
212+ . insert ( policy_context. schema_name , new) ;
213+ }
214+
215+ self . wrapping_clause_type = match policy_context. node_kind . as_str ( ) {
216+ "policy_name" if policy_context. statement_kind != PolicyStmtKind :: Create => {
217+ Some ( WrappingClause :: PolicyName )
218+ }
219+ "policy_role" => Some ( WrappingClause :: ToRoleAssignment ) ,
220+ "policy_table" => Some ( WrappingClause :: From ) ,
221+ _ => None ,
222+ } ;
223+ }
224+
146225 fn gather_info_from_ts_queries ( & mut self ) {
147226 let stmt_range = self . wrapping_statement_range . as_ref ( ) ;
148227 let sql = self . text ;
@@ -195,24 +274,30 @@ impl<'a> CompletionContext<'a> {
195274 }
196275 }
197276
198- pub fn get_ts_node_content ( & self , ts_node : tree_sitter:: Node < ' a > ) -> Option < NodeText < ' a > > {
277+ fn get_ts_node_content ( & self , ts_node : & tree_sitter:: Node < ' a > ) -> Option < NodeText > {
199278 let source = self . text ;
200279 ts_node. utf8_text ( source. as_bytes ( ) ) . ok ( ) . map ( |txt| {
201280 if SanitizedCompletionParams :: is_sanitized_token ( txt) {
202281 NodeText :: Replaced
203282 } else {
204- NodeText :: Original ( txt)
283+ NodeText :: Original ( txt. into ( ) )
205284 }
206285 } )
207286 }
208287
209288 pub fn get_node_under_cursor_content ( & self ) -> Option < String > {
210- self . node_under_cursor
211- . and_then ( |n| self . get_ts_node_content ( n) )
212- . and_then ( |txt| match txt {
289+ match self . node_under_cursor . as_ref ( ) ? {
290+ NodeUnderCursor :: TsNode ( node) => {
291+ self . get_ts_node_content ( node) . and_then ( |nt| match nt {
292+ NodeText :: Replaced => None ,
293+ NodeText :: Original ( c) => Some ( c. to_string ( ) ) ,
294+ } )
295+ }
296+ NodeUnderCursor :: CustomNode { text, .. } => match text {
213297 NodeText :: Replaced => None ,
214298 NodeText :: Original ( c) => Some ( c. to_string ( ) ) ,
215- } )
299+ } ,
300+ }
216301 }
217302
218303 fn gather_tree_context ( & mut self ) {
@@ -250,7 +335,7 @@ impl<'a> CompletionContext<'a> {
250335
251336 // prevent infinite recursion – this can happen if we only have a PROGRAM node
252337 if current_node_kind == parent_node_kind {
253- self . node_under_cursor = Some ( current_node) ;
338+ self . node_under_cursor = Some ( NodeUnderCursor :: from ( current_node) ) ;
254339 return ;
255340 }
256341
@@ -289,7 +374,7 @@ impl<'a> CompletionContext<'a> {
289374
290375 match current_node_kind {
291376 "object_reference" | "field" => {
292- let content = self . get_ts_node_content ( current_node) ;
377+ let content = self . get_ts_node_content ( & current_node) ;
293378 if let Some ( node_txt) = content {
294379 match node_txt {
295380 NodeText :: Original ( txt) => {
@@ -321,7 +406,7 @@ impl<'a> CompletionContext<'a> {
321406
322407 // We have arrived at the leaf node
323408 if current_node. child_count ( ) == 0 {
324- self . node_under_cursor = Some ( current_node) ;
409+ self . node_under_cursor = Some ( NodeUnderCursor :: from ( current_node) ) ;
325410 return ;
326411 }
327412
@@ -334,11 +419,11 @@ impl<'a> CompletionContext<'a> {
334419 node : tree_sitter:: Node < ' a > ,
335420 ) -> Option < WrappingClause < ' a > > {
336421 if node. kind ( ) . starts_with ( "keyword_" ) {
337- if let Some ( txt) = self . get_ts_node_content ( node) . and_then ( |txt| match txt {
422+ if let Some ( txt) = self . get_ts_node_content ( & node) . and_then ( |txt| match txt {
338423 NodeText :: Original ( txt) => Some ( txt) ,
339424 NodeText :: Replaced => None ,
340425 } ) {
341- match txt {
426+ match txt. as_str ( ) {
342427 "where" => return Some ( WrappingClause :: Where ) ,
343428 "update" => return Some ( WrappingClause :: Update ) ,
344429 "select" => return Some ( WrappingClause :: Select ) ,
@@ -388,11 +473,14 @@ impl<'a> CompletionContext<'a> {
388473#[ cfg( test) ]
389474mod tests {
390475 use crate :: {
391- context:: { CompletionContext , NodeText , WrappingClause } ,
476+ NodeText ,
477+ context:: { CompletionContext , WrappingClause } ,
392478 sanitization:: SanitizedCompletionParams ,
393479 test_helper:: { CURSOR_POS , get_text_and_position} ,
394480 } ;
395481
482+ use super :: NodeUnderCursor ;
483+
396484 fn get_tree ( input : & str ) -> tree_sitter:: Tree {
397485 let mut parser = tree_sitter:: Parser :: new ( ) ;
398486 parser
@@ -551,17 +639,22 @@ mod tests {
551639
552640 let ctx = CompletionContext :: new ( & params) ;
553641
554- let node = ctx. node_under_cursor . unwrap ( ) ;
642+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
555643
556- assert_eq ! (
557- ctx. get_ts_node_content( node) ,
558- Some ( NodeText :: Original ( "select" ) )
559- ) ;
644+ match node {
645+ NodeUnderCursor :: TsNode ( node) => {
646+ assert_eq ! (
647+ ctx. get_ts_node_content( node) ,
648+ Some ( NodeText :: Original ( "select" . into( ) ) )
649+ ) ;
560650
561- assert_eq ! (
562- ctx. wrapping_clause_type,
563- Some ( crate :: context:: WrappingClause :: Select )
564- ) ;
651+ assert_eq ! (
652+ ctx. wrapping_clause_type,
653+ Some ( crate :: context:: WrappingClause :: Select )
654+ ) ;
655+ }
656+ _ => unreachable ! ( ) ,
657+ }
565658 }
566659 }
567660
@@ -582,12 +675,17 @@ mod tests {
582675
583676 let ctx = CompletionContext :: new ( & params) ;
584677
585- let node = ctx. node_under_cursor . unwrap ( ) ;
678+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
586679
587- assert_eq ! (
588- ctx. get_ts_node_content( node) ,
589- Some ( NodeText :: Original ( "from" ) )
590- ) ;
680+ match node {
681+ NodeUnderCursor :: TsNode ( node) => {
682+ assert_eq ! (
683+ ctx. get_ts_node_content( node) ,
684+ Some ( NodeText :: Original ( "from" . into( ) ) )
685+ ) ;
686+ }
687+ _ => unreachable ! ( ) ,
688+ }
591689 }
592690
593691 #[ test]
@@ -607,10 +705,18 @@ mod tests {
607705
608706 let ctx = CompletionContext :: new ( & params) ;
609707
610- let node = ctx. node_under_cursor . unwrap ( ) ;
708+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
611709
612- assert_eq ! ( ctx. get_ts_node_content( node) , Some ( NodeText :: Original ( "" ) ) ) ;
613- assert_eq ! ( ctx. wrapping_clause_type, None ) ;
710+ match node {
711+ NodeUnderCursor :: TsNode ( node) => {
712+ assert_eq ! (
713+ ctx. get_ts_node_content( node) ,
714+ Some ( NodeText :: Original ( "" . into( ) ) )
715+ ) ;
716+ assert_eq ! ( ctx. wrapping_clause_type, None ) ;
717+ }
718+ _ => unreachable ! ( ) ,
719+ }
614720 }
615721
616722 #[ test]
@@ -632,12 +738,17 @@ mod tests {
632738
633739 let ctx = CompletionContext :: new ( & params) ;
634740
635- let node = ctx. node_under_cursor . unwrap ( ) ;
741+ let node = ctx. node_under_cursor . as_ref ( ) . unwrap ( ) ;
636742
637- assert_eq ! (
638- ctx. get_ts_node_content( node) ,
639- Some ( NodeText :: Original ( "fro" ) )
640- ) ;
641- assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
743+ match node {
744+ NodeUnderCursor :: TsNode ( node) => {
745+ assert_eq ! (
746+ ctx. get_ts_node_content( node) ,
747+ Some ( NodeText :: Original ( "fro" . into( ) ) )
748+ ) ;
749+ assert_eq ! ( ctx. wrapping_clause_type, Some ( WrappingClause :: Select ) ) ;
750+ }
751+ _ => unreachable ! ( ) ,
752+ }
642753 }
643754}
0 commit comments