1+ use crate :: migrations:: * ;
2+
13use api:: error:: VssError ;
24use api:: kv_store:: { KvStore , GLOBAL_VERSION_KEY , INITIAL_RECORD_VERSION } ;
35use api:: types:: {
@@ -12,7 +14,7 @@ use chrono::Utc;
1214use std:: cmp:: min;
1315use std:: io;
1416use std:: io:: { Error , ErrorKind } ;
15- use tokio_postgres:: { NoTls , Transaction } ;
17+ use tokio_postgres:: { error , NoTls , Transaction } ;
1618
1719pub ( crate ) struct VssDbRecord {
1820 pub ( crate ) user_token : String ,
@@ -46,17 +48,189 @@ pub struct PostgresBackendImpl {
4648 pool : Pool < PostgresConnectionManager < NoTls > > ,
4749}
4850
51+ async fn initialize_vss_database ( postgres_endpoint : & str , db_name : & str ) -> Result < ( ) , Error > {
52+ let postgres_dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
53+ let ( client, connection) = tokio_postgres:: connect ( & postgres_dsn, NoTls )
54+ . await
55+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
56+ // Connection must be driven on a separate task, and will resolve when the client is dropped
57+ tokio:: spawn ( async move {
58+ if let Err ( e) = connection. await {
59+ eprintln ! ( "Connection error: {}" , e) ;
60+ }
61+ } ) ;
62+
63+ let num_rows = client. execute ( CHECK_DB_STMT , & [ & db_name] ) . await . map_err ( |e| {
64+ Error :: new (
65+ ErrorKind :: Other ,
66+ format ! ( "Failed to check presence of database {}: {}" , db_name, e) ,
67+ )
68+ } ) ?;
69+
70+ if num_rows == 0 {
71+ let stmt = format ! ( "{} {};" , INIT_DB_CMD , db_name) ;
72+ client. execute ( & stmt, & [ ] ) . await . map_err ( |e| {
73+ Error :: new ( ErrorKind :: Other , format ! ( "Failed to create database {}: {}" , db_name, e) )
74+ } ) ?;
75+ println ! ( "Created database {}" , db_name) ;
76+ }
77+
78+ Ok ( ( ) )
79+ }
80+
81+ #[ cfg( test) ]
82+ async fn drop_database ( postgres_endpoint : & str , db_name : & str ) -> Result < ( ) , Error > {
83+ let postgres_dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
84+ let ( client, connection) = tokio_postgres:: connect ( & postgres_dsn, NoTls )
85+ . await
86+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
87+ // Connection must be driven on a separate task, and will resolve when the client is dropped
88+ tokio:: spawn ( async move {
89+ if let Err ( e) = connection. await {
90+ eprintln ! ( "Connection error: {}" , e) ;
91+ }
92+ } ) ;
93+
94+ let drop_database_statement = format ! ( "{} {};" , DROP_DB_CMD , db_name) ;
95+ let num_rows = client. execute ( & drop_database_statement, & [ ] ) . await . map_err ( |e| {
96+ Error :: new (
97+ ErrorKind :: Other ,
98+ format ! ( "Failed to drop database {}: {}" , db_name, e) ,
99+ )
100+ } ) ?;
101+ assert_eq ! ( num_rows, 0 ) ;
102+
103+ Ok ( ( ) )
104+ }
105+
49106impl PostgresBackendImpl {
50107 /// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
51- pub async fn new ( dsn : & str ) -> Result < Self , Error > {
52- let manager = PostgresConnectionManager :: new_from_stringlike ( dsn, NoTls ) . map_err ( |e| {
53- Error :: new ( ErrorKind :: Other , format ! ( "Connection manager error: {}" , e) )
54- } ) ?;
108+ pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
109+ initialize_vss_database ( postgres_endpoint, db_name) . await ?;
110+
111+ let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, db_name) ;
112+ let manager =
113+ PostgresConnectionManager :: new_from_stringlike ( vss_dsn, NoTls ) . map_err ( |e| {
114+ Error :: new (
115+ ErrorKind :: Other ,
116+ format ! ( "Failed to create PostgresConnectionManager: {}" , e) ,
117+ )
118+ } ) ?;
119+ // By default, Pool maintains 0 long-running connections, so returning a pool
120+ // here is no guarantee that Pool established a connection to the database.
121+ //
122+ // See Builder::min_idle to increase the long-running connection count.
55123 let pool = Pool :: builder ( )
56124 . build ( manager)
57125 . await
58- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Pool build error: {}" , e) ) ) ?;
59- Ok ( PostgresBackendImpl { pool } )
126+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Failed to build Pool: {}" , e) ) ) ?;
127+ let postgres_backend = PostgresBackendImpl { pool } ;
128+
129+ #[ cfg( not( test) ) ]
130+ postgres_backend. migrate_vss_database ( MIGRATIONS ) . await ?;
131+
132+ Ok ( postgres_backend)
133+ }
134+
135+ async fn migrate_vss_database ( & self , migrations : & [ & str ] ) -> Result < ( usize , usize ) , Error > {
136+ let mut conn = self . pool . get ( ) . await . map_err ( |e| {
137+ Error :: new (
138+ ErrorKind :: Other ,
139+ format ! ( "Failed to fetch a connection from Pool: {}" , e) ,
140+ )
141+ } ) ?;
142+
143+ // Get the next migration to be applied.
144+ let migration_start = match conn. query_one ( GET_VERSION_STMT , & [ ] ) . await {
145+ Ok ( row) => {
146+ let i: i32 = row. get ( DB_VERSION_COLUMN ) ;
147+ usize:: try_from ( i) . expect ( "The column should always contain unsigned integers" )
148+ } ,
149+ Err ( e) => {
150+ // If the table is not defined, start at migration 0
151+ if let Some ( & error:: SqlState :: UNDEFINED_TABLE ) = e. code ( ) {
152+ 0
153+ } else {
154+ return Err ( Error :: new (
155+ ErrorKind :: Other ,
156+ format ! ( "Failed to query the version of the database schema: {}" , e) ,
157+ ) ) ;
158+ }
159+ } ,
160+ } ;
161+
162+ let tx = conn
163+ . transaction ( )
164+ . await
165+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Transaction start error: {}" , e) ) ) ?;
166+
167+ if migration_start == migrations. len ( ) {
168+ // No migrations needed, we are done
169+ return Ok ( ( migration_start, migrations. len ( ) ) ) ;
170+ } else if migration_start > migrations. len ( ) {
171+ panic ! ( "We do not allow downgrades" ) ;
172+ }
173+
174+ println ! ( "Applying migration(s) {} through {}" , migration_start, migrations. len( ) - 1 ) ;
175+
176+ for ( idx, & stmt) in ( & migrations[ migration_start..] ) . iter ( ) . enumerate ( ) {
177+ let _num_rows = tx. execute ( stmt, & [ ] ) . await . map_err ( |e| {
178+ Error :: new (
179+ ErrorKind :: Other ,
180+ format ! (
181+ "Database migration no {} with stmt {} failed: {}" ,
182+ migration_start + idx,
183+ stmt,
184+ e
185+ ) ,
186+ )
187+ } ) ?;
188+ }
189+
190+ let num_rows = tx
191+ . execute (
192+ LOG_MIGRATION_STMT ,
193+ & [ & i32:: try_from ( migration_start) . expect ( "Read from an i32 further above" ) ] ,
194+ )
195+ . await
196+ . map_err ( |e| {
197+ Error :: new ( ErrorKind :: Other , format ! ( "Failed to log database migration: {}" , e) )
198+ } ) ?;
199+ assert_eq ! ( num_rows, 1 , "LOG_MIGRATION_STMT should only add one row at a time" ) ;
200+
201+ let next_migration_start =
202+ i32:: try_from ( migrations. len ( ) ) . expect ( "Length is definitely smaller than i32::MAX" ) ;
203+ let num_rows =
204+ tx. execute ( UPDATE_VERSION_STMT , & [ & next_migration_start] ) . await . map_err ( |e| {
205+ Error :: new (
206+ ErrorKind :: Other ,
207+ format ! ( "Failed to update the version of the schema: {}" , e) ,
208+ )
209+ } ) ?;
210+ assert_eq ! (
211+ num_rows, 1 ,
212+ "UPDATE_VERSION_STMT should only update the unique row in the version table"
213+ ) ;
214+
215+ tx. commit ( ) . await . map_err ( |e| {
216+ Error :: new ( ErrorKind :: Other , format ! ( "Transaction commit error: {}" , e) )
217+ } ) ?;
218+
219+ Ok ( ( migration_start, migrations. len ( ) ) )
220+ }
221+
222+ #[ cfg( test) ]
223+ async fn get_schema_version ( & self ) -> usize {
224+ let conn = self . pool . get ( ) . await . unwrap ( ) ;
225+ let row = conn. query_one ( GET_VERSION_STMT , & [ ] ) . await . unwrap ( ) ;
226+ usize:: try_from ( row. get :: < & str , i32 > ( DB_VERSION_COLUMN ) ) . unwrap ( )
227+ }
228+
229+ #[ cfg( test) ]
230+ async fn get_upgrades_list ( & self ) -> Vec < usize > {
231+ let conn = self . pool . get ( ) . await . unwrap ( ) ;
232+ let rows = conn. query ( GET_MIGRATION_LOG_STMT , & [ ] ) . await . unwrap ( ) ;
233+ rows. iter ( ) . map ( |row| usize:: try_from ( row. get :: < & str , i32 > ( MIGRATION_LOG_COLUMN ) ) . unwrap ( ) ) . collect ( )
60234 }
61235
62236 fn build_vss_record ( & self , user_token : String , store_id : String , kv : KeyValue ) -> VssDbRecord {
@@ -409,12 +583,105 @@ impl KvStore for PostgresBackendImpl {
409583mod tests {
410584 use crate :: postgres_store:: PostgresBackendImpl ;
411585 use api:: define_kv_store_tests;
586+ use tokio:: sync:: OnceCell ;
587+ use super :: { MIGRATIONS , DUMMY_MIGRATION , drop_database} ;
588+
589+ const POSTGRES_ENDPOINT : & str = "postgresql://postgres:postgres@localhost:5432" ;
590+ const MIGRATIONS_START : usize = 0 ;
591+ const MIGRATIONS_END : usize = MIGRATIONS . len ( ) ;
592+
593+ static START : OnceCell < ( ) > = OnceCell :: const_new ( ) ;
594+
595+ define_kv_store_tests ! ( PostgresKvStoreTest , PostgresBackendImpl , {
596+ let db_name = "postgres_kv_store_tests" ;
597+ START
598+ . get_or_init( || async {
599+ let _ = drop_database( POSTGRES_ENDPOINT , db_name) . await ;
600+ let store = PostgresBackendImpl :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
601+ let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
602+ assert_eq!( start, MIGRATIONS_START ) ;
603+ assert_eq!( end, MIGRATIONS_END ) ;
604+ } )
605+ . await ;
606+ let store = PostgresBackendImpl :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
607+ let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
608+ assert_eq!( start, MIGRATIONS_END ) ;
609+ assert_eq!( end, MIGRATIONS_END ) ;
610+ assert_eq!( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
611+ assert_eq!( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
612+ store
613+ } ) ;
614+
615+ #[ tokio:: test]
616+ #[ should_panic( expected = "We do not allow downgrades" ) ]
617+ async fn panic_on_downgrade ( ) {
618+ let db_name = "panic_on_downgrade_test" ;
619+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name) . await ;
620+ {
621+ let mut migrations = MIGRATIONS . to_vec ( ) ;
622+ migrations. push ( DUMMY_MIGRATION ) ;
623+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
624+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
625+ assert_eq ! ( start, MIGRATIONS_START ) ;
626+ assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
627+ } ;
628+ {
629+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
630+ let _ = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
631+ } ;
632+ }
412633
413- define_kv_store_tests ! (
414- PostgresKvStoreTest ,
415- PostgresBackendImpl ,
416- PostgresBackendImpl :: new( "postgresql://postgres:postgres@localhost:5432/postgres" )
417- . await
418- . unwrap( )
419- ) ;
634+ #[ tokio:: test]
635+ async fn new_migrations_increments_upgrades ( ) {
636+ let db_name = "new_migrations_increments_upgrades_test" ;
637+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name) . await ;
638+ {
639+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
640+ let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
641+ assert_eq ! ( start, MIGRATIONS_START ) ;
642+ assert_eq ! ( end, MIGRATIONS_END ) ;
643+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
644+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
645+ } ;
646+ {
647+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
648+ let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
649+ assert_eq ! ( start, MIGRATIONS_END ) ;
650+ assert_eq ! ( end, MIGRATIONS_END ) ;
651+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
652+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
653+ } ;
654+
655+ let mut migrations = MIGRATIONS . to_vec ( ) ;
656+ migrations. push ( DUMMY_MIGRATION ) ;
657+ {
658+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
659+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
660+ assert_eq ! ( start, MIGRATIONS_END ) ;
661+ assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
662+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START , MIGRATIONS_END ] ) ;
663+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END + 1 ) ;
664+ } ;
665+
666+ migrations. push ( DUMMY_MIGRATION ) ;
667+ migrations. push ( DUMMY_MIGRATION ) ;
668+ {
669+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
670+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
671+ assert_eq ! ( start, MIGRATIONS_END + 1 ) ;
672+ assert_eq ! ( end, MIGRATIONS_END + 3 ) ;
673+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
674+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END + 3 ) ;
675+ } ;
676+
677+ {
678+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
679+ let list = store. get_upgrades_list ( ) . await ;
680+ assert_eq ! ( list, [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
681+ let version = store. get_schema_version ( ) . await ;
682+ assert_eq ! ( version, MIGRATIONS_END + 3 ) ;
683+ }
684+
685+ drop_database ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
686+ }
420687}
0 commit comments