@@ -28,12 +28,18 @@ const VALUE_COLUMN: &str = "value";
2828const VERSION_COLUMN : & str = "version" ;
2929
3030const DB_VERSION_COLUMN : & str = "db_version" ;
31+ #[ cfg( test) ]
32+ const MIGRATION_LOG_COLUMN : & str = "upgrade_from" ;
3133
3234const CHECK_DB_STMT : & str = "SELECT 1 FROM pg_database WHERE datname = $1" ;
3335const INIT_DB_CMD : & str = "CREATE DATABASE" ;
36+ #[ cfg( test) ]
37+ const DROP_DB_CMD : & str = "DROP DATABASE" ;
3438const GET_VERSION_STMT : & str = "SELECT db_version FROM vss_db_version;" ;
3539const UPDATE_VERSION_STMT : & str = "UPDATE vss_db_version SET db_version=$1;" ;
3640const LOG_MIGRATION_STMT : & str = "INSERT INTO vss_db_upgrades VALUES($1);" ;
41+ #[ cfg( test) ]
42+ const GET_MIGRATION_LOG_STMT : & str = "SELECT upgrade_from FROM vss_db_upgrades;" ;
3743
3844// APPEND-ONLY list of migration statements
3945//
@@ -58,6 +64,8 @@ const MIGRATIONS: &[&str] = &[
5864 PRIMARY KEY (user_token, store_id, key)
5965 );" ,
6066] ;
67+ #[ cfg( test) ]
68+ const DUMMY_MIGRATION : & str = "SELECT 1 WHERE FALSE;" ;
6169
6270/// The maximum number of key versions that can be returned in a single page.
6371///
@@ -108,6 +116,31 @@ async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Resu
108116 Ok ( ( ) )
109117}
110118
119+ #[ cfg( test) ]
120+ async fn drop_database ( postgres_endpoint : & str , db_name : & str ) -> Result < ( ) , Error > {
121+ let postgres_dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
122+ let ( client, connection) = tokio_postgres:: connect ( & postgres_dsn, NoTls )
123+ . await
124+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
125+ // Connection must be driven on a separate task, and will resolve when the client is dropped
126+ tokio:: spawn ( async move {
127+ if let Err ( e) = connection. await {
128+ eprintln ! ( "Connection error: {}" , e) ;
129+ }
130+ } ) ;
131+
132+ let drop_database_statement = format ! ( "{} {};" , DROP_DB_CMD , db_name) ;
133+ let num_rows = client. execute ( & drop_database_statement, & [ ] ) . await . map_err ( |e| {
134+ Error :: new (
135+ ErrorKind :: Other ,
136+ format ! ( "Failed to drop database {}: {}" , db_name, e) ,
137+ )
138+ } ) ?;
139+ assert_eq ! ( num_rows, 0 ) ;
140+
141+ Ok ( ( ) )
142+ }
143+
111144impl PostgresBackendImpl {
112145 /// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
113146 pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
@@ -131,12 +164,13 @@ impl PostgresBackendImpl {
131164 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Failed to build Pool: {}" , e) ) ) ?;
132165 let postgres_backend = PostgresBackendImpl { pool } ;
133166
134- postgres_backend. migrate_vss_database ( ) . await ?;
167+ #[ cfg( not( test) ) ]
168+ postgres_backend. migrate_vss_database ( MIGRATIONS ) . await ?;
135169
136170 Ok ( postgres_backend)
137171 }
138172
139- async fn migrate_vss_database ( & self ) -> Result < ( ) , Error > {
173+ async fn migrate_vss_database ( & self , migrations : & [ & str ] ) -> Result < ( usize , usize ) , Error > {
140174 let mut conn = self . pool . get ( ) . await . map_err ( |e| {
141175 Error :: new (
142176 ErrorKind :: Other ,
@@ -168,16 +202,16 @@ impl PostgresBackendImpl {
168202 . await
169203 . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Transaction start error: {}" , e) ) ) ?;
170204
171- if migration_start == MIGRATIONS . len ( ) {
205+ if migration_start == migrations . len ( ) {
172206 // No migrations needed, we are done
173- return Ok ( ( ) ) ;
174- } else if migration_start > MIGRATIONS . len ( ) {
207+ return Ok ( ( migration_start , migrations . len ( ) ) ) ;
208+ } else if migration_start > migrations . len ( ) {
175209 panic ! ( "We do not allow downgrades" ) ;
176210 }
177211
178- println ! ( "Applying migration(s) {} through {}" , migration_start, MIGRATIONS . len( ) - 1 ) ;
212+ println ! ( "Applying migration(s) {} through {}" , migration_start, migrations . len( ) - 1 ) ;
179213
180- for ( idx, & stmt) in ( & MIGRATIONS [ migration_start..] ) . iter ( ) . enumerate ( ) {
214+ for ( idx, & stmt) in ( & migrations [ migration_start..] ) . iter ( ) . enumerate ( ) {
181215 let _num_rows = tx. execute ( stmt, & [ ] ) . await . map_err ( |e| {
182216 Error :: new (
183217 ErrorKind :: Other ,
@@ -203,7 +237,7 @@ impl PostgresBackendImpl {
203237 assert_eq ! ( num_rows, 1 , "LOG_MIGRATION_STMT should only add one row at a time" ) ;
204238
205239 let next_migration_start =
206- i32:: try_from ( MIGRATIONS . len ( ) ) . expect ( "Length is definitely smaller than i32::MAX" ) ;
240+ i32:: try_from ( migrations . len ( ) ) . expect ( "Length is definitely smaller than i32::MAX" ) ;
207241 let num_rows =
208242 tx. execute ( UPDATE_VERSION_STMT , & [ & next_migration_start] ) . await . map_err ( |e| {
209243 Error :: new (
@@ -220,7 +254,21 @@ impl PostgresBackendImpl {
220254 Error :: new ( ErrorKind :: Other , format ! ( "Transaction commit error: {}" , e) )
221255 } ) ?;
222256
223- Ok ( ( ) )
257+ Ok ( ( migration_start, migrations. len ( ) ) )
258+ }
259+
260+ #[ cfg( test) ]
261+ async fn get_schema_version ( & self ) -> usize {
262+ let conn = self . pool . get ( ) . await . unwrap ( ) ;
263+ let row = conn. query_one ( GET_VERSION_STMT , & [ ] ) . await . unwrap ( ) ;
264+ usize:: try_from ( row. get :: < & str , i32 > ( DB_VERSION_COLUMN ) ) . unwrap ( )
265+ }
266+
267+ #[ cfg( test) ]
268+ async fn get_upgrades_list ( & self ) -> Vec < usize > {
269+ let conn = self . pool . get ( ) . await . unwrap ( ) ;
270+ let rows = conn. query ( GET_MIGRATION_LOG_STMT , & [ ] ) . await . unwrap ( ) ;
271+ rows. iter ( ) . map ( |row| usize:: try_from ( row. get :: < & str , i32 > ( MIGRATION_LOG_COLUMN ) ) . unwrap ( ) ) . collect ( )
224272 }
225273
226274 fn build_vss_record ( & self , user_token : String , store_id : String , kv : KeyValue ) -> VssDbRecord {
@@ -574,23 +622,104 @@ mod tests {
574622 use crate :: postgres_store:: PostgresBackendImpl ;
575623 use api:: define_kv_store_tests;
576624 use tokio:: sync:: OnceCell ;
625+ use super :: { MIGRATIONS , DUMMY_MIGRATION , drop_database} ;
626+
627+ const POSTGRES_ENDPOINT : & str = "postgresql://postgres:postgres@localhost:5432" ;
628+ const MIGRATIONS_START : usize = 0 ;
629+ const MIGRATIONS_END : usize = MIGRATIONS . len ( ) ;
577630
578631 static START : OnceCell < ( ) > = OnceCell :: const_new ( ) ;
579632
580633 define_kv_store_tests ! ( PostgresKvStoreTest , PostgresBackendImpl , {
634+ let db_name = "postgres_kv_store_tests" ;
581635 START
582636 . get_or_init( || async {
583- // Initialize the database once, and have other threads wait
584- PostgresBackendImpl :: new(
585- "postgresql://postgres:postgres@localhost:5432" ,
586- "postgres" ,
587- )
588- . await
589- . unwrap( ) ;
637+ let _ = drop_database( POSTGRES_ENDPOINT , db_name) . await ;
638+ let store = PostgresBackendImpl :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
639+ let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
640+ assert_eq!( start, MIGRATIONS_START ) ;
641+ assert_eq!( end, MIGRATIONS_END ) ;
590642 } )
591643 . await ;
592- PostgresBackendImpl :: new( "postgresql://postgres:postgres@localhost:5432" , "postgres" )
593- . await
594- . unwrap( )
644+ let store = PostgresBackendImpl :: new( POSTGRES_ENDPOINT , db_name) . await . unwrap( ) ;
645+ let ( start, end) = store. migrate_vss_database( MIGRATIONS ) . await . unwrap( ) ;
646+ assert_eq!( start, MIGRATIONS_END ) ;
647+ assert_eq!( end, MIGRATIONS_END ) ;
648+ assert_eq!( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
649+ assert_eq!( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
650+ store
595651 } ) ;
652+
653+ #[ tokio:: test]
654+ #[ should_panic( expected = "We do not allow downgrades" ) ]
655+ async fn panic_on_downgrade ( ) {
656+ let db_name = "panic_on_downgrade_test" ;
657+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name) . await ;
658+ {
659+ let mut migrations = MIGRATIONS . to_vec ( ) ;
660+ migrations. push ( DUMMY_MIGRATION ) ;
661+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
662+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
663+ assert_eq ! ( start, MIGRATIONS_START ) ;
664+ assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
665+ } ;
666+ {
667+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
668+ let _ = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
669+ } ;
670+ }
671+
672+ #[ tokio:: test]
673+ async fn new_migrations_increments_upgrades ( ) {
674+ let db_name = "new_migrations_increments_upgrades_test" ;
675+ let _ = drop_database ( POSTGRES_ENDPOINT , db_name) . await ;
676+ {
677+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
678+ let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
679+ assert_eq ! ( start, MIGRATIONS_START ) ;
680+ assert_eq ! ( end, MIGRATIONS_END ) ;
681+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
682+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
683+ } ;
684+ {
685+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
686+ let ( start, end) = store. migrate_vss_database ( MIGRATIONS ) . await . unwrap ( ) ;
687+ assert_eq ! ( start, MIGRATIONS_END ) ;
688+ assert_eq ! ( end, MIGRATIONS_END ) ;
689+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START ] ) ;
690+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END ) ;
691+ } ;
692+
693+ let mut migrations = MIGRATIONS . to_vec ( ) ;
694+ migrations. push ( DUMMY_MIGRATION ) ;
695+ {
696+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
697+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
698+ assert_eq ! ( start, MIGRATIONS_END ) ;
699+ assert_eq ! ( end, MIGRATIONS_END + 1 ) ;
700+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START , MIGRATIONS_END ] ) ;
701+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END + 1 ) ;
702+ } ;
703+
704+ migrations. push ( DUMMY_MIGRATION ) ;
705+ migrations. push ( DUMMY_MIGRATION ) ;
706+ {
707+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
708+ let ( start, end) = store. migrate_vss_database ( & migrations) . await . unwrap ( ) ;
709+ assert_eq ! ( start, MIGRATIONS_END + 1 ) ;
710+ assert_eq ! ( end, MIGRATIONS_END + 3 ) ;
711+ assert_eq ! ( store. get_upgrades_list( ) . await , [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
712+ assert_eq ! ( store. get_schema_version( ) . await , MIGRATIONS_END + 3 ) ;
713+ } ;
714+
715+ {
716+ let store = PostgresBackendImpl :: new ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
717+ let list = store. get_upgrades_list ( ) . await ;
718+ assert_eq ! ( list, [ MIGRATIONS_START , MIGRATIONS_END , MIGRATIONS_END + 1 ] ) ;
719+ let version = store. get_schema_version ( ) . await ;
720+ assert_eq ! ( version, MIGRATIONS_END + 3 ) ;
721+ }
722+
723+ drop_database ( POSTGRES_ENDPOINT , db_name) . await . unwrap ( ) ;
724+ }
596725}
0 commit comments