1- use pgt_console:: {
2- fmt:: { Formatter , HTML } ,
3- markup,
4- } ;
5- use pgt_diagnostics:: PrintDiagnostic ;
6- use pgt_typecheck:: { TypecheckParams , check_sql} ;
1+ use pgt_console:: fmt:: { Formatter , HTML } ;
2+ use pgt_diagnostics:: Diagnostic ;
3+ use pgt_typecheck:: { IdentifierType , TypecheckParams , TypedIdentifier , check_sql} ;
74use sqlx:: { Executor , PgPool } ;
5+ use std:: fmt:: Write ;
86
9- async fn test ( name : & str , query : & str , setup : Option < & str > , test_db : & PgPool ) {
10- if let Some ( setup ) = setup {
11- test_db
12- . execute ( setup )
13- . await
14- . expect ( "Failed to setup test database" ) ;
15- }
7+ struct TestSetup < ' a > {
8+ name : & ' a str ,
9+ query : & ' a str ,
10+ setup : Option < & ' a str > ,
11+ test_db : & ' a PgPool ,
12+ typed_identifiers : Vec < TypedIdentifier > ,
13+ }
1614
17- let mut parser = tree_sitter:: Parser :: new ( ) ;
18- parser
19- . set_language ( & pgt_treesitter_grammar:: LANGUAGE . into ( ) )
20- . expect ( "Error loading sql language" ) ;
21-
22- let schema_cache = pgt_schema_cache:: SchemaCache :: load ( test_db)
23- . await
24- . expect ( "Failed to load Schema Cache" ) ;
25-
26- let root = pgt_query:: parse ( query)
27- . unwrap ( )
28- . into_root ( )
29- . expect ( "Failed to parse query" ) ;
30- let tree = parser. parse ( query, None ) . unwrap ( ) ;
31-
32- let conn = & test_db;
33- let result = check_sql ( TypecheckParams {
34- conn,
35- sql : query,
36- ast : & root,
37- tree : & tree,
38- schema_cache : & schema_cache,
39- search_path_patterns : vec ! [ ] ,
40- identifiers : vec ! [ ] ,
41- } )
42- . await ;
15+ impl TestSetup < ' _ > {
16+ async fn test ( self ) {
17+ if let Some ( setup) = self . setup {
18+ self . test_db
19+ . execute ( setup)
20+ . await
21+ . expect ( "Failed to setup test selfbase" ) ;
22+ }
4323
44- let mut content = vec ! [ ] ;
45- let mut writer = HTML :: new ( & mut content) ;
24+ let mut parser = tree_sitter:: Parser :: new ( ) ;
25+ parser
26+ . set_language ( & pgt_treesitter_grammar:: LANGUAGE . into ( ) )
27+ . expect ( "Error loading sql language" ) ;
4628
47- Formatter :: new ( & mut writer)
48- . write_markup ( markup ! {
49- { PrintDiagnostic :: simple( & result. unwrap( ) . unwrap( ) ) }
29+ let schema_cache = pgt_schema_cache:: SchemaCache :: load ( self . test_db )
30+ . await
31+ . expect ( "Failed to load Schema Cache" ) ;
32+
33+ let root = pgt_query:: parse ( self . query )
34+ . unwrap ( )
35+ . into_root ( )
36+ . expect ( "Failed to parse query" ) ;
37+ let tree = parser. parse ( self . query , None ) . unwrap ( ) ;
38+
39+ let result = check_sql ( TypecheckParams {
40+ conn : self . test_db ,
41+ sql : self . query ,
42+ ast : & root,
43+ tree : & tree,
44+ schema_cache : & schema_cache,
45+ identifiers : self . typed_identifiers ,
46+ search_path_patterns : vec ! [ ] ,
5047 } )
51- . unwrap ( ) ;
48+ . await ;
49+
50+ assert ! (
51+ result. is_ok( ) ,
52+ "Got Typechecking error: {}" ,
53+ result. unwrap_err( )
54+ ) ;
5255
53- let content = String :: from_utf8 ( content ) . unwrap ( ) ;
56+ let maybe_diagnostic = result . unwrap ( ) ;
5457
55- insta:: with_settings!( {
56- prepend_module_to_snapshot => false ,
57- } , {
58- insta:: assert_snapshot!( name, content) ;
59- } ) ;
58+ let content = match maybe_diagnostic {
59+ Some ( d) => {
60+ let mut result = String :: new ( ) ;
61+
62+ if let Some ( span) = d. location ( ) . span {
63+ for ( idx, c) in self . query . char_indices ( ) {
64+ if pgt_text_size:: TextSize :: new ( idx. try_into ( ) . unwrap ( ) ) == span. start ( ) {
65+ result. push_str ( "~~~" ) ;
66+ }
67+ if pgt_text_size:: TextSize :: new ( idx. try_into ( ) . unwrap ( ) ) == span. end ( ) {
68+ result. push_str ( "~~~" ) ;
69+ }
70+ result. push ( c) ;
71+ }
72+ } else {
73+ result. push_str ( "~~~" ) ;
74+ result. push_str ( self . query ) ;
75+ result. push_str ( "~~~" ) ;
76+ }
77+
78+ writeln ! ( & mut result) . unwrap ( ) ;
79+ writeln ! ( & mut result) . unwrap ( ) ;
80+
81+ let mut msg_content = vec ! [ ] ;
82+ let mut writer = HTML :: new ( & mut msg_content) ;
83+ let mut formatter = Formatter :: new ( & mut writer) ;
84+ d. message ( & mut formatter) . unwrap ( ) ;
85+
86+ result. push_str ( String :: from_utf8 ( msg_content) . unwrap ( ) . as_str ( ) ) ;
87+
88+ result
89+ }
90+ None => String :: from ( "No Diagnostic" ) ,
91+ } ;
92+
93+ insta:: with_settings!( {
94+ prepend_module_to_snapshot => false ,
95+ } , {
96+ insta:: assert_snapshot!( self . name, content) ;
97+
98+ } ) ;
99+ }
60100}
61101
62102#[ sqlx:: test( migrator = "pgt_test_utils::MIGRATIONS" ) ]
63- async fn invalid_column ( pool : PgPool ) {
64- test (
65- "invalid_column" ,
66- "select id, unknown from contacts;" ,
67- Some (
103+ async fn invalid_column ( test_db : PgPool ) {
104+ TestSetup {
105+ name : "invalid_column" ,
106+ query : "select id, unknown from contacts;" ,
107+ setup : Some (
68108 r#"
69109 create table public.contacts (
70110 id serial primary key,
@@ -74,7 +114,66 @@ async fn invalid_column(pool: PgPool) {
74114 );
75115 "# ,
76116 ) ,
77- & pool,
78- )
117+ test_db : & test_db,
118+ typed_identifiers : vec ! [ ] ,
119+ }
120+ . test ( )
121+ . await ;
122+ }
123+
124+ #[ sqlx:: test( migrator = "pgt_test_utils::MIGRATIONS" ) ]
125+ async fn invalid_type_in_function ( test_db : PgPool ) {
126+ // create or replace function clean_up(uid uuid)
127+ // returns void
128+ // language sql
129+ // as $$
130+ // delete from public.contacts where id = uid;
131+ // $$;
132+
133+ let setup = r#"
134+ create table public.contacts (
135+ id serial primary key,
136+ name text not null,
137+ is_vegetarian bool default false,
138+ middle_name varchar(255)
139+ );
140+ "# ;
141+
142+ /* NOTE: The replaced type default value is *longer* than the param name. */
143+ TestSetup {
144+ name : "invalid_type_in_function_longer_default" ,
145+ setup : Some ( setup) ,
146+ query : r#"delete from public.contacts where id = uid;"# ,
147+ test_db : & test_db,
148+ typed_identifiers : vec ! [ TypedIdentifier {
149+ path: "clean_up" . to_string( ) ,
150+ name: Some ( "uid" . to_string( ) ) ,
151+ type_: IdentifierType {
152+ schema: None ,
153+ name: "uuid" . to_string( ) ,
154+ is_array: false ,
155+ } ,
156+ } ] ,
157+ }
158+ . test ( )
159+ . await ;
160+
161+ /* NOTE: The replaced type default value is *shorter* than the param name. */
162+ TestSetup {
163+ name : "invalid_type_in_function_shorter_default" ,
164+ setup : None ,
165+ query : r#"delete from public.contacts where id = contact_name;"# ,
166+ test_db : & test_db,
167+ typed_identifiers : vec ! [ TypedIdentifier {
168+ path: "clean_up" . to_string( ) ,
169+ name: Some ( "contact_name" . to_string( ) ) ,
170+ type_: IdentifierType {
171+ schema: None ,
172+ name: "text" . to_string( ) ,
173+ is_array: false ,
174+ } ,
175+ } ] ,
176+ }
177+ . test ( )
79178 . await ;
80179}
0 commit comments