@@ -12,7 +12,7 @@ use chrono::Utc;
1212use std:: cmp:: min;
1313use std:: io;
1414use std:: io:: { Error , ErrorKind } ;
15- use tokio_postgres:: { NoTls , Transaction } ;
15+ use tokio_postgres:: { error , NoTls , Transaction } ;
1616
1717pub ( crate ) struct VssDbRecord {
1818 pub ( crate ) user_token : String ,
@@ -27,6 +27,38 @@ const KEY_COLUMN: &str = "key";
2727const VALUE_COLUMN : & str = "value" ;
2828const VERSION_COLUMN : & str = "version" ;
2929
30+ const DB_VERSION_COLUMN : & str = "db_version" ;
31+
32+ const CHECK_DB_STMT : & str = "SELECT 1 FROM pg_database WHERE datname = $1" ;
33+ const INIT_DB_CMD : & str = "CREATE DATABASE" ;
34+ const GET_VERSION_STMT : & str = "SELECT db_version FROM vss_db_version;" ;
35+ const UPDATE_VERSION_STMT : & str = "UPDATE vss_db_version SET db_version=$1;" ;
36+ const LOG_MIGRATION_STMT : & str = "INSERT INTO vss_db_upgrades VALUES($1);" ;
37+
38+ // APPEND-ONLY list of migration statements
39+ //
40+ // Each statement MUST be applied in-order, and only once per database.
41+ //
42+ // We make an exception for the vss_db table creation statement, as users of VSS could have initialized the table
43+ // themselves.
44+ const MIGRATIONS : & [ & str ] = & [
45+ "CREATE TABLE vss_db_version (db_version INTEGER);" ,
46+ "INSERT INTO vss_db_version VALUES(1);" ,
47+ // A write-only log of all the migrations performed on this database, useful for debugging and testing
48+ "CREATE TABLE vss_db_upgrades (upgrade_from INTEGER);" ,
49+ // We do not complain if the table already exists, as users of VSS could have already created this table
50+ "CREATE TABLE IF NOT EXISTS vss_db (
51+ user_token character varying(120) NOT NULL CHECK (user_token <> ''),
52+ store_id character varying(120) NOT NULL CHECK (store_id <> ''),
53+ key character varying(600) NOT NULL,
54+ value bytea NULL,
55+ version bigint NOT NULL,
56+ created_at TIMESTAMP WITH TIME ZONE,
57+ last_updated_at TIMESTAMP WITH TIME ZONE,
58+ PRIMARY KEY (user_token, store_id, key)
59+ );" ,
60+ ] ;
61+
3062/// The maximum number of key versions that can be returned in a single page.
3163///
3264/// This constant helps control memory and bandwidth usage for list operations,
@@ -46,17 +78,149 @@ pub struct PostgresBackendImpl {
4678 pool : Pool < PostgresConnectionManager < NoTls > > ,
4779}
4880
81+ async fn initialize_vss_database ( postgres_endpoint : & str , db_name : & str ) -> Result < ( ) , Error > {
82+ let postgres_dsn = format ! ( "{}/{}" , postgres_endpoint, "postgres" ) ;
83+ let ( client, connection) = tokio_postgres:: connect ( & postgres_dsn, NoTls )
84+ . await
85+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Connection error: {}" , e) ) ) ?;
86+ // Connection must be driven on a separate task, and will resolve when the client is dropped
87+ tokio:: spawn ( async move {
88+ if let Err ( e) = connection. await {
89+ eprintln ! ( "Connection error: {}" , e) ;
90+ }
91+ } ) ;
92+
93+ let num_rows = client. execute ( CHECK_DB_STMT , & [ & db_name] ) . await . map_err ( |e| {
94+ Error :: new (
95+ ErrorKind :: Other ,
96+ format ! ( "Failed to check presence of database {}: {}" , db_name, e) ,
97+ )
98+ } ) ?;
99+
100+ if num_rows == 0 {
101+ let stmt = format ! ( "{} {};" , INIT_DB_CMD , db_name) ;
102+ client. execute ( & stmt, & [ ] ) . await . map_err ( |e| {
103+ Error :: new ( ErrorKind :: Other , format ! ( "Failed to create database {}: {}" , db_name, e) )
104+ } ) ?;
105+ println ! ( "Created database {}" , db_name) ;
106+ }
107+
108+ Ok ( ( ) )
109+ }
110+
49111impl PostgresBackendImpl {
50112 /// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
51- pub async fn new ( dsn : & str ) -> Result < Self , Error > {
52- let manager = PostgresConnectionManager :: new_from_stringlike ( dsn, NoTls ) . map_err ( |e| {
53- Error :: new ( ErrorKind :: Other , format ! ( "Connection manager error: {}" , e) )
54- } ) ?;
113+ pub async fn new ( postgres_endpoint : & str , db_name : & str ) -> Result < Self , Error > {
114+ initialize_vss_database ( postgres_endpoint, db_name) . await ?;
115+
116+ let vss_dsn = format ! ( "{}/{}" , postgres_endpoint, db_name) ;
117+ let manager =
118+ PostgresConnectionManager :: new_from_stringlike ( vss_dsn, NoTls ) . map_err ( |e| {
119+ Error :: new (
120+ ErrorKind :: Other ,
121+ format ! ( "Failed to create PostgresConnectionManager: {}" , e) ,
122+ )
123+ } ) ?;
124+ // By default, Pool maintains 0 long-running connections, so returning a pool
125+ // here is no guarantee that Pool established a connection to the database.
126+ //
127+ // See Builder::min_idle to increase the long-running connection count.
55128 let pool = Pool :: builder ( )
56129 . build ( manager)
57130 . await
58- . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Pool build error: {}" , e) ) ) ?;
59- Ok ( PostgresBackendImpl { pool } )
131+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Failed to build Pool: {}" , e) ) ) ?;
132+ let postgres_backend = PostgresBackendImpl { pool } ;
133+
134+ postgres_backend. migrate_vss_database ( ) . await ?;
135+
136+ Ok ( postgres_backend)
137+ }
138+
139+ async fn migrate_vss_database ( & self ) -> Result < ( ) , Error > {
140+ let mut conn = self . pool . get ( ) . await . map_err ( |e| {
141+ Error :: new (
142+ ErrorKind :: Other ,
143+ format ! ( "Failed to fetch a connection from Pool: {}" , e) ,
144+ )
145+ } ) ?;
146+
147+ // Get the next migration to be applied.
148+ let migration_start = match conn. query_one ( GET_VERSION_STMT , & [ ] ) . await {
149+ Ok ( row) => {
150+ let i: i32 = row. get ( DB_VERSION_COLUMN ) ;
151+ usize:: try_from ( i) . expect ( "The column should always contain unsigned integers" )
152+ } ,
153+ Err ( e) => {
154+ // If the table is not defined, start at migration 0
155+ if let Some ( & error:: SqlState :: UNDEFINED_TABLE ) = e. code ( ) {
156+ 0
157+ } else {
158+ return Err ( Error :: new (
159+ ErrorKind :: Other ,
160+ format ! ( "Failed to query the version of the database schema: {}" , e) ,
161+ ) ) ;
162+ }
163+ } ,
164+ } ;
165+
166+ let tx = conn
167+ . transaction ( )
168+ . await
169+ . map_err ( |e| Error :: new ( ErrorKind :: Other , format ! ( "Transaction start error: {}" , e) ) ) ?;
170+
171+ if migration_start == MIGRATIONS . len ( ) {
172+ // No migrations needed, we are done
173+ return Ok ( ( ) ) ;
174+ } else if migration_start > MIGRATIONS . len ( ) {
175+ panic ! ( "We do not allow downgrades" ) ;
176+ }
177+
178+ println ! ( "Applying migration(s) {} through {}" , migration_start, MIGRATIONS . len( ) - 1 ) ;
179+
180+ for ( idx, & stmt) in ( & MIGRATIONS [ migration_start..] ) . iter ( ) . enumerate ( ) {
181+ let _num_rows = tx. execute ( stmt, & [ ] ) . await . map_err ( |e| {
182+ Error :: new (
183+ ErrorKind :: Other ,
184+ format ! (
185+ "Database migration no {} with stmt {} failed: {}" ,
186+ migration_start + idx,
187+ stmt,
188+ e
189+ ) ,
190+ )
191+ } ) ?;
192+ }
193+
194+ let num_rows = tx
195+ . execute (
196+ LOG_MIGRATION_STMT ,
197+ & [ & i32:: try_from ( migration_start) . expect ( "Read from an i32 further above" ) ] ,
198+ )
199+ . await
200+ . map_err ( |e| {
201+ Error :: new ( ErrorKind :: Other , format ! ( "Failed to log database migration: {}" , e) )
202+ } ) ?;
203+ assert_eq ! ( num_rows, 1 , "LOG_MIGRATION_STMT should only add one row at a time" ) ;
204+
205+ let next_migration_start =
206+ i32:: try_from ( MIGRATIONS . len ( ) ) . expect ( "Length is definitely smaller than i32::MAX" ) ;
207+ let num_rows =
208+ tx. execute ( UPDATE_VERSION_STMT , & [ & next_migration_start] ) . await . map_err ( |e| {
209+ Error :: new (
210+ ErrorKind :: Other ,
211+ format ! ( "Failed to update the version of the schema: {}" , e) ,
212+ )
213+ } ) ?;
214+ assert_eq ! (
215+ num_rows, 1 ,
216+ "UPDATE_VERSION_STMT should only update the unique row in the version table"
217+ ) ;
218+
219+ tx. commit ( ) . await . map_err ( |e| {
220+ Error :: new ( ErrorKind :: Other , format ! ( "Transaction commit error: {}" , e) )
221+ } ) ?;
222+
223+ Ok ( ( ) )
60224 }
61225
62226 fn build_vss_record ( & self , user_token : String , store_id : String , kv : KeyValue ) -> VssDbRecord {
@@ -413,7 +577,7 @@ mod tests {
413577 define_kv_store_tests ! (
414578 PostgresKvStoreTest ,
415579 PostgresBackendImpl ,
416- PostgresBackendImpl :: new( "postgresql://postgres:postgres@localhost:5432/ postgres" )
580+ PostgresBackendImpl :: new( "postgresql://postgres:postgres@localhost:5432" , " postgres")
417581 . await
418582 . unwrap( )
419583 ) ;
0 commit comments