Skip to content

Commit 1b947fd

Browse files
authored
Merge pull request #55 from tankyleo/init
Create and initialize the database if it does not exist
2 parents 42655fa + 95f2009 commit 1b947fd

File tree

6 files changed

+330
-20
lines changed

6 files changed

+330
-20
lines changed

rust/impls/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ chrono = "0.4.38"
1010
tokio-postgres = { version = "0.7.12", features = ["with-chrono-0_4"] }
1111
bb8-postgres = "0.7"
1212
bytes = "1.4.0"
13+
tokio = { version = "1.38.0", default-features = false }
1314

1415
[dev-dependencies]
1516
tokio = { version = "1.38.0", default-features = false, features = ["rt-multi-thread", "macros"] }

rust/impls/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#![deny(rustdoc::private_intra_doc_links)]
1212
#![deny(missing_docs)]
1313

14+
mod migrations;
1415
/// Contains [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS.
1516
pub mod postgres_store;
1617

rust/impls/src/migrations.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
pub(crate) const DB_VERSION_COLUMN: &str = "db_version";
2+
#[cfg(test)]
3+
pub(crate) const MIGRATION_LOG_COLUMN: &str = "upgrade_from";
4+
5+
pub(crate) const CHECK_DB_STMT: &str = "SELECT 1 FROM pg_database WHERE datname = $1";
6+
pub(crate) const INIT_DB_CMD: &str = "CREATE DATABASE";
7+
#[cfg(test)]
8+
const DROP_DB_CMD: &str = "DROP DATABASE";
9+
pub(crate) const GET_VERSION_STMT: &str = "SELECT db_version FROM vss_db_version;";
10+
pub(crate) const UPDATE_VERSION_STMT: &str = "UPDATE vss_db_version SET db_version=$1;";
11+
pub(crate) const LOG_MIGRATION_STMT: &str = "INSERT INTO vss_db_upgrades VALUES($1);";
12+
#[cfg(test)]
13+
pub(crate) const GET_MIGRATION_LOG_STMT: &str = "SELECT upgrade_from FROM vss_db_upgrades;";
14+
15+
// APPEND-ONLY list of migration statements
16+
//
17+
// Each statement MUST be applied in-order, and only once per database.
18+
//
19+
// We make an exception for the vss_db table creation statement, as users of VSS could have initialized the table
20+
// themselves.
21+
pub(crate) const MIGRATIONS: &[&str] = &[
22+
"CREATE TABLE vss_db_version (db_version INTEGER);",
23+
"INSERT INTO vss_db_version VALUES(1);",
24+
// A write-only log of all the migrations performed on this database, useful for debugging and testing
25+
"CREATE TABLE vss_db_upgrades (upgrade_from INTEGER);",
26+
// We do not complain if the table already exists, as users of VSS could have already created this table
27+
"CREATE TABLE IF NOT EXISTS vss_db (
28+
user_token character varying(120) NOT NULL CHECK (user_token <> ''),
29+
store_id character varying(120) NOT NULL CHECK (store_id <> ''),
30+
key character varying(600) NOT NULL,
31+
value bytea NULL,
32+
version bigint NOT NULL,
33+
created_at TIMESTAMP WITH TIME ZONE,
34+
last_updated_at TIMESTAMP WITH TIME ZONE,
35+
PRIMARY KEY (user_token, store_id, key)
36+
);",
37+
];
38+
#[cfg(test)]
39+
const DUMMY_MIGRATION: &str = "SELECT 1 WHERE FALSE;";

rust/impls/src/postgres_store.rs

Lines changed: 281 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::migrations::*;
2+
13
use api::error::VssError;
24
use api::kv_store::{KvStore, GLOBAL_VERSION_KEY, INITIAL_RECORD_VERSION};
35
use api::types::{
@@ -12,7 +14,7 @@ use chrono::Utc;
1214
use std::cmp::min;
1315
use std::io;
1416
use std::io::{Error, ErrorKind};
15-
use tokio_postgres::{NoTls, Transaction};
17+
use tokio_postgres::{error, NoTls, Transaction};
1618

1719
pub(crate) struct VssDbRecord {
1820
pub(crate) user_token: String,
@@ -46,17 +48,189 @@ pub struct PostgresBackendImpl {
4648
pool: Pool<PostgresConnectionManager<NoTls>>,
4749
}
4850

51+
async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> {
52+
let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres");
53+
let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls)
54+
.await
55+
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
56+
// Connection must be driven on a separate task, and will resolve when the client is dropped
57+
tokio::spawn(async move {
58+
if let Err(e) = connection.await {
59+
eprintln!("Connection error: {}", e);
60+
}
61+
});
62+
63+
let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| {
64+
Error::new(
65+
ErrorKind::Other,
66+
format!("Failed to check presence of database {}: {}", db_name, e),
67+
)
68+
})?;
69+
70+
if num_rows == 0 {
71+
let stmt = format!("{} {};", INIT_DB_CMD, db_name);
72+
client.execute(&stmt, &[]).await.map_err(|e| {
73+
Error::new(ErrorKind::Other, format!("Failed to create database {}: {}", db_name, e))
74+
})?;
75+
println!("Created database {}", db_name);
76+
}
77+
78+
Ok(())
79+
}
80+
81+
#[cfg(test)]
82+
async fn drop_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> {
83+
let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres");
84+
let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls)
85+
.await
86+
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
87+
// Connection must be driven on a separate task, and will resolve when the client is dropped
88+
tokio::spawn(async move {
89+
if let Err(e) = connection.await {
90+
eprintln!("Connection error: {}", e);
91+
}
92+
});
93+
94+
let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name);
95+
let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| {
96+
Error::new(
97+
ErrorKind::Other,
98+
format!("Failed to drop database {}: {}", db_name, e),
99+
)
100+
})?;
101+
assert_eq!(num_rows, 0);
102+
103+
Ok(())
104+
}
105+
49106
impl PostgresBackendImpl {
50107
/// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
51-
pub async fn new(dsn: &str) -> Result<Self, Error> {
52-
let manager = PostgresConnectionManager::new_from_stringlike(dsn, NoTls).map_err(|e| {
53-
Error::new(ErrorKind::Other, format!("Connection manager error: {}", e))
54-
})?;
108+
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
109+
initialize_vss_database(postgres_endpoint, db_name).await?;
110+
111+
let vss_dsn = format!("{}/{}", postgres_endpoint, db_name);
112+
let manager =
113+
PostgresConnectionManager::new_from_stringlike(vss_dsn, NoTls).map_err(|e| {
114+
Error::new(
115+
ErrorKind::Other,
116+
format!("Failed to create PostgresConnectionManager: {}", e),
117+
)
118+
})?;
119+
// By default, Pool maintains 0 long-running connections, so returning a pool
120+
// here is no guarantee that Pool established a connection to the database.
121+
//
122+
// See Builder::min_idle to increase the long-running connection count.
55123
let pool = Pool::builder()
56124
.build(manager)
57125
.await
58-
.map_err(|e| Error::new(ErrorKind::Other, format!("Pool build error: {}", e)))?;
59-
Ok(PostgresBackendImpl { pool })
126+
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;
127+
let postgres_backend = PostgresBackendImpl { pool };
128+
129+
#[cfg(not(test))]
130+
postgres_backend.migrate_vss_database(MIGRATIONS).await?;
131+
132+
Ok(postgres_backend)
133+
}
134+
135+
async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> {
136+
let mut conn = self.pool.get().await.map_err(|e| {
137+
Error::new(
138+
ErrorKind::Other,
139+
format!("Failed to fetch a connection from Pool: {}", e),
140+
)
141+
})?;
142+
143+
// Get the next migration to be applied.
144+
let migration_start = match conn.query_one(GET_VERSION_STMT, &[]).await {
145+
Ok(row) => {
146+
let i: i32 = row.get(DB_VERSION_COLUMN);
147+
usize::try_from(i).expect("The column should always contain unsigned integers")
148+
},
149+
Err(e) => {
150+
// If the table is not defined, start at migration 0
151+
if let Some(&error::SqlState::UNDEFINED_TABLE) = e.code() {
152+
0
153+
} else {
154+
return Err(Error::new(
155+
ErrorKind::Other,
156+
format!("Failed to query the version of the database schema: {}", e),
157+
));
158+
}
159+
},
160+
};
161+
162+
let tx = conn
163+
.transaction()
164+
.await
165+
.map_err(|e| Error::new(ErrorKind::Other, format!("Transaction start error: {}", e)))?;
166+
167+
if migration_start == migrations.len() {
168+
// No migrations needed, we are done
169+
return Ok((migration_start, migrations.len()));
170+
} else if migration_start > migrations.len() {
171+
panic!("We do not allow downgrades");
172+
}
173+
174+
println!("Applying migration(s) {} through {}", migration_start, migrations.len() - 1);
175+
176+
for (idx, &stmt) in (&migrations[migration_start..]).iter().enumerate() {
177+
let _num_rows = tx.execute(stmt, &[]).await.map_err(|e| {
178+
Error::new(
179+
ErrorKind::Other,
180+
format!(
181+
"Database migration no {} with stmt {} failed: {}",
182+
migration_start + idx,
183+
stmt,
184+
e
185+
),
186+
)
187+
})?;
188+
}
189+
190+
let num_rows = tx
191+
.execute(
192+
LOG_MIGRATION_STMT,
193+
&[&i32::try_from(migration_start).expect("Read from an i32 further above")],
194+
)
195+
.await
196+
.map_err(|e| {
197+
Error::new(ErrorKind::Other, format!("Failed to log database migration: {}", e))
198+
})?;
199+
assert_eq!(num_rows, 1, "LOG_MIGRATION_STMT should only add one row at a time");
200+
201+
let next_migration_start =
202+
i32::try_from(migrations.len()).expect("Length is definitely smaller than i32::MAX");
203+
let num_rows =
204+
tx.execute(UPDATE_VERSION_STMT, &[&next_migration_start]).await.map_err(|e| {
205+
Error::new(
206+
ErrorKind::Other,
207+
format!("Failed to update the version of the schema: {}", e),
208+
)
209+
})?;
210+
assert_eq!(
211+
num_rows, 1,
212+
"UPDATE_VERSION_STMT should only update the unique row in the version table"
213+
);
214+
215+
tx.commit().await.map_err(|e| {
216+
Error::new(ErrorKind::Other, format!("Transaction commit error: {}", e))
217+
})?;
218+
219+
Ok((migration_start, migrations.len()))
220+
}
221+
222+
#[cfg(test)]
223+
async fn get_schema_version(&self) -> usize {
224+
let conn = self.pool.get().await.unwrap();
225+
let row = conn.query_one(GET_VERSION_STMT, &[]).await.unwrap();
226+
usize::try_from(row.get::<&str, i32>(DB_VERSION_COLUMN)).unwrap()
227+
}
228+
229+
#[cfg(test)]
230+
async fn get_upgrades_list(&self) -> Vec<usize> {
231+
let conn = self.pool.get().await.unwrap();
232+
let rows = conn.query(GET_MIGRATION_LOG_STMT, &[]).await.unwrap();
233+
rows.iter().map(|row| usize::try_from(row.get::<&str, i32>(MIGRATION_LOG_COLUMN)).unwrap()).collect()
60234
}
61235

62236
fn build_vss_record(&self, user_token: String, store_id: String, kv: KeyValue) -> VssDbRecord {
@@ -409,12 +583,105 @@ impl KvStore for PostgresBackendImpl {
409583
mod tests {
410584
use crate::postgres_store::PostgresBackendImpl;
411585
use api::define_kv_store_tests;
586+
use tokio::sync::OnceCell;
587+
use super::{MIGRATIONS, DUMMY_MIGRATION, drop_database};
588+
589+
const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432";
590+
const MIGRATIONS_START: usize = 0;
591+
const MIGRATIONS_END: usize = MIGRATIONS.len();
592+
593+
static START: OnceCell<()> = OnceCell::const_new();
594+
595+
define_kv_store_tests!(PostgresKvStoreTest, PostgresBackendImpl, {
596+
let db_name = "postgres_kv_store_tests";
597+
START
598+
.get_or_init(|| async {
599+
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
600+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
601+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
602+
assert_eq!(start, MIGRATIONS_START);
603+
assert_eq!(end, MIGRATIONS_END);
604+
})
605+
.await;
606+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
607+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
608+
assert_eq!(start, MIGRATIONS_END);
609+
assert_eq!(end, MIGRATIONS_END);
610+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
611+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
612+
store
613+
});
614+
615+
#[tokio::test]
616+
#[should_panic(expected = "We do not allow downgrades")]
617+
async fn panic_on_downgrade() {
618+
let db_name = "panic_on_downgrade_test";
619+
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
620+
{
621+
let mut migrations = MIGRATIONS.to_vec();
622+
migrations.push(DUMMY_MIGRATION);
623+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
624+
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
625+
assert_eq!(start, MIGRATIONS_START);
626+
assert_eq!(end, MIGRATIONS_END + 1);
627+
};
628+
{
629+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
630+
let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap();
631+
};
632+
}
412633

413-
define_kv_store_tests!(
414-
PostgresKvStoreTest,
415-
PostgresBackendImpl,
416-
PostgresBackendImpl::new("postgresql://postgres:postgres@localhost:5432/postgres")
417-
.await
418-
.unwrap()
419-
);
634+
#[tokio::test]
635+
async fn new_migrations_increments_upgrades() {
636+
let db_name = "new_migrations_increments_upgrades_test";
637+
let _ = drop_database(POSTGRES_ENDPOINT, db_name).await;
638+
{
639+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
640+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
641+
assert_eq!(start, MIGRATIONS_START);
642+
assert_eq!(end, MIGRATIONS_END);
643+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
644+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
645+
};
646+
{
647+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
648+
let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap();
649+
assert_eq!(start, MIGRATIONS_END);
650+
assert_eq!(end, MIGRATIONS_END);
651+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START]);
652+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END);
653+
};
654+
655+
let mut migrations = MIGRATIONS.to_vec();
656+
migrations.push(DUMMY_MIGRATION);
657+
{
658+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
659+
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
660+
assert_eq!(start, MIGRATIONS_END);
661+
assert_eq!(end, MIGRATIONS_END + 1);
662+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END]);
663+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 1);
664+
};
665+
666+
migrations.push(DUMMY_MIGRATION);
667+
migrations.push(DUMMY_MIGRATION);
668+
{
669+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
670+
let (start, end) = store.migrate_vss_database(&migrations).await.unwrap();
671+
assert_eq!(start, MIGRATIONS_END + 1);
672+
assert_eq!(end, MIGRATIONS_END + 3);
673+
assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
674+
assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 3);
675+
};
676+
677+
{
678+
let store = PostgresBackendImpl::new(POSTGRES_ENDPOINT, db_name).await.unwrap();
679+
let list = store.get_upgrades_list().await;
680+
assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]);
681+
let version = store.get_schema_version().await;
682+
assert_eq!(version, MIGRATIONS_END + 3);
683+
}
684+
685+
drop_database(POSTGRES_ENDPOINT, db_name).await.unwrap();
686+
}
420687
}

rust/server/src/main.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,18 @@ fn main() {
6767
},
6868
};
6969
let authorizer = Arc::new(NoopAuthorizer {});
70+
let postgresql_config = config.postgresql_config.expect("PostgreSQLConfig must be defined in config file.");
71+
let endpoint = postgresql_config.to_postgresql_endpoint();
72+
let db_name = postgresql_config.database;
7073
let store = Arc::new(
71-
PostgresBackendImpl::new(&config.postgresql_config.expect("PostgreSQLConfig must be defined in config file.").to_connection_string())
74+
PostgresBackendImpl::new(&endpoint, &db_name)
7275
.await
7376
.unwrap(),
7477
);
78+
println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name);
7579
let rest_svc_listener =
7680
TcpListener::bind(&addr).await.expect("Failed to bind listening port");
81+
println!("Listening for incoming connections on {}", addr);
7782
loop {
7883
tokio::select! {
7984
res = rest_svc_listener.accept() => {

0 commit comments

Comments
 (0)