Skip to content

Commit 633ebfa

Browse files
committed
Create and initialize the database if it does not exist
Also implement a versioning and migration scheme for future updates to the schema. It is an adaptation of the scheme used in CLN.
1 parent 42655fa commit 633ebfa

File tree

4 files changed

+181
-14
lines changed

4 files changed

+181
-14
lines changed

rust/impls/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ chrono = "0.4.38"
1010
tokio-postgres = { version = "0.7.12", features = ["with-chrono-0_4"] }
1111
bb8-postgres = "0.7"
1212
bytes = "1.4.0"
13+
tokio = { version = "1.38.0", default-features = false }
1314

1415
[dev-dependencies]
1516
tokio = { version = "1.38.0", default-features = false, features = ["rt-multi-thread", "macros"] }

rust/impls/src/postgres_store.rs

Lines changed: 172 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use chrono::Utc;
1212
use std::cmp::min;
1313
use std::io;
1414
use std::io::{Error, ErrorKind};
15-
use tokio_postgres::{NoTls, Transaction};
15+
use tokio_postgres::{error, NoTls, Transaction};
1616

1717
pub(crate) struct VssDbRecord {
1818
pub(crate) user_token: String,
@@ -27,6 +27,38 @@ const KEY_COLUMN: &str = "key";
2727
const VALUE_COLUMN: &str = "value";
2828
const VERSION_COLUMN: &str = "version";
2929

30+
const DB_VERSION_COLUMN: &str = "db_version";
31+
32+
const CHECK_DB_STMT: &str = "SELECT 1 FROM pg_database WHERE datname = $1";
33+
const INIT_DB_CMD: &str = "CREATE DATABASE";
34+
const GET_VERSION_STMT: &str = "SELECT db_version FROM vss_db_version;";
35+
const UPDATE_VERSION_STMT: &str = "UPDATE vss_db_version SET db_version=$1;";
36+
const LOG_MIGRATION_STMT: &str = "INSERT INTO vss_db_upgrades VALUES($1);";
37+
38+
// APPEND-ONLY list of migration statements
39+
//
40+
// Each statement MUST be applied in-order, and only once per database.
41+
//
42+
// We make an exception for the vss_db table creation statement, as users of VSS could have initialized the table
43+
// themselves.
44+
const MIGRATIONS: &[&str] = &[
45+
"CREATE TABLE vss_db_version (db_version INTEGER);",
46+
"INSERT INTO vss_db_version VALUES(1);",
47+
// A write-only log of all the migrations performed on this database, useful for debugging and testing
48+
"CREATE TABLE vss_db_upgrades (upgrade_from INTEGER);",
49+
// We do not complain if the table already exists, as users of VSS could have already created this table
50+
"CREATE TABLE IF NOT EXISTS vss_db (
51+
user_token character varying(120) NOT NULL CHECK (user_token <> ''),
52+
store_id character varying(120) NOT NULL CHECK (store_id <> ''),
53+
key character varying(600) NOT NULL,
54+
value bytea NULL,
55+
version bigint NOT NULL,
56+
created_at TIMESTAMP WITH TIME ZONE,
57+
last_updated_at TIMESTAMP WITH TIME ZONE,
58+
PRIMARY KEY (user_token, store_id, key)
59+
);",
60+
];
61+
3062
/// The maximum number of key versions that can be returned in a single page.
3163
///
3264
/// This constant helps control memory and bandwidth usage for list operations,
@@ -46,17 +78,149 @@ pub struct PostgresBackendImpl {
4678
pool: Pool<PostgresConnectionManager<NoTls>>,
4779
}
4880

81+
async fn initialize_vss_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Error> {
82+
let postgres_dsn = format!("{}/{}", postgres_endpoint, "postgres");
83+
let (client, connection) = tokio_postgres::connect(&postgres_dsn, NoTls)
84+
.await
85+
.map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?;
86+
// Connection must be driven on a separate task, and will resolve when the client is dropped
87+
tokio::spawn(async move {
88+
if let Err(e) = connection.await {
89+
eprintln!("Connection error: {}", e);
90+
}
91+
});
92+
93+
let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| {
94+
Error::new(
95+
ErrorKind::Other,
96+
format!("Failed to check presence of database {}: {}", db_name, e),
97+
)
98+
})?;
99+
100+
if num_rows == 0 {
101+
let stmt = format!("{} {};", INIT_DB_CMD, db_name);
102+
client.execute(&stmt, &[]).await.map_err(|e| {
103+
Error::new(ErrorKind::Other, format!("Failed to create database {}: {}", db_name, e))
104+
})?;
105+
println!("Created database {}", db_name);
106+
}
107+
108+
Ok(())
109+
}
110+
49111
impl PostgresBackendImpl {
50112
/// Constructs a [`PostgresBackendImpl`] using `dsn` for PostgreSQL connection information.
51-
pub async fn new(dsn: &str) -> Result<Self, Error> {
52-
let manager = PostgresConnectionManager::new_from_stringlike(dsn, NoTls).map_err(|e| {
53-
Error::new(ErrorKind::Other, format!("Connection manager error: {}", e))
54-
})?;
113+
pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result<Self, Error> {
114+
initialize_vss_database(postgres_endpoint, db_name).await?;
115+
116+
let vss_dsn = format!("{}/{}", postgres_endpoint, db_name);
117+
let manager =
118+
PostgresConnectionManager::new_from_stringlike(vss_dsn, NoTls).map_err(|e| {
119+
Error::new(
120+
ErrorKind::Other,
121+
format!("Failed to create PostgresConnectionManager: {}", e),
122+
)
123+
})?;
124+
// By default, Pool maintains 0 long-running connections, so returning a pool
125+
// here is no guarantee that Pool established a connection to the database.
126+
//
127+
// See Builder::min_idle to increase the long-running connection count.
55128
let pool = Pool::builder()
56129
.build(manager)
57130
.await
58-
.map_err(|e| Error::new(ErrorKind::Other, format!("Pool build error: {}", e)))?;
59-
Ok(PostgresBackendImpl { pool })
131+
.map_err(|e| Error::new(ErrorKind::Other, format!("Failed to build Pool: {}", e)))?;
132+
let postgres_backend = PostgresBackendImpl { pool };
133+
134+
postgres_backend.migrate_vss_database().await?;
135+
136+
Ok(postgres_backend)
137+
}
138+
139+
async fn migrate_vss_database(&self) -> Result<(), Error> {
140+
let mut conn = self.pool.get().await.map_err(|e| {
141+
Error::new(
142+
ErrorKind::Other,
143+
format!("Failed to fetch a connection from Pool: {}", e),
144+
)
145+
})?;
146+
147+
// Get the next migration to be applied.
148+
let migration_start = match conn.query_one(GET_VERSION_STMT, &[]).await {
149+
Ok(row) => {
150+
let i: i32 = row.get(DB_VERSION_COLUMN);
151+
usize::try_from(i).expect("The column should always contain unsigned integers")
152+
},
153+
Err(e) => {
154+
// If the table is not defined, start at migration 0
155+
if let Some(&error::SqlState::UNDEFINED_TABLE) = e.code() {
156+
0
157+
} else {
158+
return Err(Error::new(
159+
ErrorKind::Other,
160+
format!("Failed to query the version of the database schema: {}", e),
161+
));
162+
}
163+
},
164+
};
165+
166+
let tx = conn
167+
.transaction()
168+
.await
169+
.map_err(|e| Error::new(ErrorKind::Other, format!("Transaction start error: {}", e)))?;
170+
171+
if migration_start == MIGRATIONS.len() {
172+
// No migrations needed, we are done
173+
return Ok(());
174+
} else if migration_start > MIGRATIONS.len() {
175+
panic!("We do not allow downgrades");
176+
}
177+
178+
println!("Applying migration(s) {} through {}", migration_start, MIGRATIONS.len() - 1);
179+
180+
for (idx, &stmt) in (&MIGRATIONS[migration_start..]).iter().enumerate() {
181+
let _num_rows = tx.execute(stmt, &[]).await.map_err(|e| {
182+
Error::new(
183+
ErrorKind::Other,
184+
format!(
185+
"Database migration no {} with stmt {} failed: {}",
186+
migration_start + idx,
187+
stmt,
188+
e
189+
),
190+
)
191+
})?;
192+
}
193+
194+
let num_rows = tx
195+
.execute(
196+
LOG_MIGRATION_STMT,
197+
&[&i32::try_from(migration_start).expect("Read from an i32 further above")],
198+
)
199+
.await
200+
.map_err(|e| {
201+
Error::new(ErrorKind::Other, format!("Failed to log database migration: {}", e))
202+
})?;
203+
assert_eq!(num_rows, 1, "LOG_MIGRATION_STMT should only add one row at a time");
204+
205+
let next_migration_start =
206+
i32::try_from(MIGRATIONS.len()).expect("Length is definitely smaller than i32::MAX");
207+
let num_rows =
208+
tx.execute(UPDATE_VERSION_STMT, &[&next_migration_start]).await.map_err(|e| {
209+
Error::new(
210+
ErrorKind::Other,
211+
format!("Failed to update the version of the schema: {}", e),
212+
)
213+
})?;
214+
assert_eq!(
215+
num_rows, 1,
216+
"UPDATE_VERSION_STMT should only update the unique row in the version table"
217+
);
218+
219+
tx.commit().await.map_err(|e| {
220+
Error::new(ErrorKind::Other, format!("Transaction commit error: {}", e))
221+
})?;
222+
223+
Ok(())
60224
}
61225

62226
fn build_vss_record(&self, user_token: String, store_id: String, kv: KeyValue) -> VssDbRecord {
@@ -413,7 +577,7 @@ mod tests {
413577
define_kv_store_tests!(
414578
PostgresKvStoreTest,
415579
PostgresBackendImpl,
416-
PostgresBackendImpl::new("postgresql://postgres:postgres@localhost:5432/postgres")
580+
PostgresBackendImpl::new("postgresql://postgres:postgres@localhost:5432", "postgres")
417581
.await
418582
.unwrap()
419583
);

rust/server/src/main.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,18 @@ fn main() {
6767
},
6868
};
6969
let authorizer = Arc::new(NoopAuthorizer {});
70+
let postgresql_config = config.postgresql_config.expect("PostgreSQLConfig must be defined in config file.");
71+
let endpoint = postgresql_config.to_postgresql_endpoint();
72+
let db_name = postgresql_config.database;
7073
let store = Arc::new(
71-
PostgresBackendImpl::new(&config.postgresql_config.expect("PostgreSQLConfig must be defined in config file.").to_connection_string())
74+
PostgresBackendImpl::new(&endpoint, &db_name)
7275
.await
7376
.unwrap(),
7477
);
78+
println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name);
7579
let rest_svc_listener =
7680
TcpListener::bind(&addr).await.expect("Failed to bind listening port");
81+
println!("Listening for incoming connections on {}", addr);
7782
loop {
7883
tokio::select! {
7984
res = rest_svc_listener.accept() => {

rust/server/src/util/config.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub(crate) struct PostgreSQLConfig {
2222
}
2323

2424
impl PostgreSQLConfig {
25-
pub(crate) fn to_connection_string(&self) -> String {
25+
pub(crate) fn to_postgresql_endpoint(&self) -> String {
2626
let username_env = std::env::var("VSS_POSTGRESQL_USERNAME");
2727
let username = username_env.as_ref()
2828
.ok()
@@ -34,10 +34,7 @@ impl PostgreSQLConfig {
3434
.or_else(|| self.password.as_ref())
3535
.expect("PostgreSQL database password must be provided in config or env var VSS_POSTGRESQL_PASSWORD must be set.");
3636

37-
format!(
38-
"postgresql://{}:{}@{}:{}/{}",
39-
username, password, self.host, self.port, self.database
40-
)
37+
format!("postgresql://{}:{}@{}:{}", username, password, self.host, self.port)
4138
}
4239
}
4340

0 commit comments

Comments
 (0)