From 383211b6219345b91dbb02f1f3bab4bdfd5e675a Mon Sep 17 00:00:00 2001 From: maxdml Date: Wed, 5 Nov 2025 16:31:44 -0800 Subject: [PATCH 1/3] mask password in kv urls and add some tests --- cmd/dbos/utils.go | 84 +++++++++++++++++++++++++++--------- dbos/dbos_test.go | 94 +++++++++++++++++++++++++++++++++++++++++ dbos/system_database.go | 75 ++++++++++++++++++++++++-------- 3 files changed, 214 insertions(+), 39 deletions(-) diff --git a/cmd/dbos/utils.go b/cmd/dbos/utils.go index 19571ff..0d590e0 100644 --- a/cmd/dbos/utils.go +++ b/cmd/dbos/utils.go @@ -7,38 +7,76 @@ import ( "log/slog" "net/url" "os" + "strings" "github.com/dbos-inc/dbos-transact-golang/dbos" "github.com/spf13/viper" ) // maskPassword replaces the password in a database URL with asterisks -func maskPassword(dbURL string) string { +func maskPassword(dbURL string) (string, error) { parsedURL, err := url.Parse(dbURL) - if err != nil { - // If we can't parse it, return the original (shouldn't happen with valid URLs) - logger.Warn("Failed to parse database URL", "error", err) - return dbURL - } + if err == nil && parsedURL.Scheme != "" { - // Check if there is user info with a password - if parsedURL.User != nil { - username := parsedURL.User.Username() - _, hasPassword := parsedURL.User.Password() - if hasPassword { - // Manually construct the URL with masked password to avoid encoding - maskedURL := parsedURL.Scheme + "://" + username + ":********@" + parsedURL.Host + parsedURL.Path - if parsedURL.RawQuery != "" { - maskedURL += "?" + parsedURL.RawQuery - } - if parsedURL.Fragment != "" { - maskedURL += "#" + parsedURL.Fragment + // Check if there is user info with a password + if parsedURL.User != nil { + username := parsedURL.User.Username() + _, hasPassword := parsedURL.User.Password() + if hasPassword { + // Manually construct the URL with masked password to avoid encoding + maskedURL := parsedURL.Scheme + "://" + username + ":***@" + parsedURL.Host + parsedURL.Path + if parsedURL.RawQuery != "" { + maskedURL += "?" + parsedURL.RawQuery + } + if parsedURL.Fragment != "" { + maskedURL += "#" + parsedURL.Fragment + } + return maskedURL, nil } - return maskedURL } + + return parsedURL.String(), nil } - return parsedURL.String() + // If URL parsing failed or no scheme, try key-value format (libpq connection string) + return maskPasswordInKeyValueFormat(dbURL), nil +} + +// maskPasswordInKeyValueFormat masks password in libpq-style key-value connection strings +// Format: "user=foo password=bar database=db host=localhost" +// Supports all spacing variations: password=value, password =value, password= value, password = value +func maskPasswordInKeyValueFormat(connStr string) string { + // Find "password" key (case insensitive) + lowerStr := strings.ToLower(connStr) + passwordKey := "password" + passwordIdx := strings.Index(lowerStr, passwordKey) + if passwordIdx == -1 { + return connStr // No password found + } + + // Find the = sign after "password" (skip optional spaces before =) + afterKey := passwordIdx + len(passwordKey) + for afterKey < len(connStr) && connStr[afterKey] == ' ' { + afterKey++ + } + if afterKey >= len(connStr) || connStr[afterKey] != '=' { + return connStr // No = sign found + } + + // Find the start of the password value (skip = and optional spaces after =) + valueStart := afterKey + 1 + for valueStart < len(connStr) && connStr[valueStart] == ' ' { + valueStart++ + } + + // Find the end of the password value (next space or end of string) + valueEnd := valueStart + for valueEnd < len(connStr) && connStr[valueEnd] != ' ' { + valueEnd++ + } + + // Replace password value with *** + return connStr[:valueStart] + "***" + connStr[valueEnd:] } // getDBURL resolves the database URL from flag, config, or environment variable @@ -63,7 +101,11 @@ func getDBURL() (string, error) { } // Log the database URL in verbose mode with masked password - maskedURL := maskPassword(resolvedURL) + maskedURL, err := maskPassword(resolvedURL) + if err != nil { + logger.Warn("Failed to mask database URL", "error", err) + maskedURL = resolvedURL + } logger.Debug("Using database URL", "source", source, "url", maskedURL) return resolvedURL, nil diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index 43dbcee..f964044 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "strings" "testing" "time" @@ -278,6 +279,99 @@ func TestConfig(t *testing.T) { require.NotNil(t, ctx2) }) + + t.Run("KeyValueFormatConnectionString", func(t *testing.T) { + t.Setenv("DBOS__APPVERSION", "v1.0.0") + t.Setenv("DBOS__APPID", "test-keyvalue-format") + t.Setenv("DBOS__VMID", "test-executor-id") + + // Get base connection parameters + originalURL := databaseURL + parsedURL, err := pgxpool.ParseConfig(originalURL) + require.NoError(t, err) + + user := parsedURL.ConnConfig.User + database := parsedURL.ConnConfig.Database + host := parsedURL.ConnConfig.Host + port := parsedURL.ConnConfig.Port + + // Use a unique test password that won't match other connection parameters + testPassword := "TEST_PASSWORD_UNIQUE_12345!@#$%" + + // Test password masking with various spacing formats + maskingTestCases := []struct { + name string + connStr string + }{ + {"NoSpaces", fmt.Sprintf("user=%s password=%s database=%s host=%s", user, testPassword, database, host)}, + {"SpaceBeforeEquals", fmt.Sprintf("user=%s password =%s database=%s host=%s", user, testPassword, database, host)}, + {"SpaceAfterEquals", fmt.Sprintf("user=%s password= %s database=%s host=%s", user, testPassword, database, host)}, + {"SpacesBothSides", fmt.Sprintf("user=%s password = %s database=%s host=%s", user, testPassword, database, host)}, + {"UppercaseKey", fmt.Sprintf("user=%s PASSWORD=%s database=%s host=%s", user, testPassword, database, host)}, + {"MixedCaseKey", fmt.Sprintf("user=%s Password=%s database=%s host=%s", user, testPassword, database, host)}, + } + + // Add port and sslmode if needed + portSSL := "" + if port != 0 { + portSSL += fmt.Sprintf(" port=%d", port) + } + if strings.Contains(originalURL, "sslmode=disable") { + portSSL += " sslmode=disable" + } + for i := range maskingTestCases { + maskingTestCases[i].connStr += portSSL + } + + for _, tc := range maskingTestCases { + t.Run("Masking_"+tc.name, func(t *testing.T) { + masked, err := maskPassword(tc.connStr) + require.NoError(t, err) + assert.Contains(t, masked, "***", "password should be masked") + passwordPattern := fmt.Sprintf("password=%s", testPassword) + assert.NotContains(t, strings.ToLower(masked), strings.ToLower(passwordPattern), "password should not appear in plaintext") + }) + } + + // Integration test: verify DBOS context works with key-value format + t.Run("DBOSContextCreation", func(t *testing.T) { + // Use the actual password from config for integration test + actualPassword := parsedURL.ConnConfig.Password + keyValueConnStr := fmt.Sprintf("user=%s password=%s database=%s host=%s%s", user, actualPassword, database, host, portSSL) + + ctx, err := NewDBOSContext(context.Background(), Config{ + DatabaseURL: keyValueConnStr, + AppName: "test-keyvalue-format", + }) + require.NoError(t, err) + defer func() { + if ctx != nil { + Shutdown(ctx, 1*time.Minute) + } + }() + + require.NotNil(t, ctx) + + // Verify system DB is functional + dbosCtx, ok := ctx.(*dbosContext) + require.True(t, ok) + sysDB, ok := dbosCtx.systemDB.(*sysDB) + require.True(t, ok) + + var exists bool + err = sysDB.pool.QueryRow(context.Background(), "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'dbos' AND table_name = 'workflow_status')").Scan(&exists) + require.NoError(t, err) + assert.True(t, exists) + + // Verify masking works + poolConnStr := sysDB.pool.Config().ConnString() + maskedConnStr, err := maskPassword(poolConnStr) + require.NoError(t, err) + assert.Contains(t, maskedConnStr, "password=***") + assert.NotContains(t, maskedConnStr, fmt.Sprintf("password=%s", actualPassword)) + }) + }) + } func TestCustomSystemDBSchema(t *testing.T) { diff --git a/dbos/system_database.go b/dbos/system_database.go index 284e29e..35d73d0 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -2508,28 +2508,67 @@ func backoffWithJitter(retryAttempt int) time.Duration { // maskPassword replaces the password in a database URL with asterisks func maskPassword(dbURL string) (string, error) { parsedURL, err := url.Parse(dbURL) - if err != nil { - return "", err - } - - // Check if there is user info with a password - if parsedURL.User != nil { - username := parsedURL.User.Username() - _, hasPassword := parsedURL.User.Password() - if hasPassword { - // Manually construct the URL with masked password to avoid encoding - maskedURL := parsedURL.Scheme + "://" + username + ":***@" + parsedURL.Host + parsedURL.Path - if parsedURL.RawQuery != "" { - maskedURL += "?" + parsedURL.RawQuery - } - if parsedURL.Fragment != "" { - maskedURL += "#" + parsedURL.Fragment + if err == nil && parsedURL.Scheme != "" { + + // Check if there is user info with a password + if parsedURL.User != nil { + username := parsedURL.User.Username() + _, hasPassword := parsedURL.User.Password() + if hasPassword { + // Manually construct the URL with masked password to avoid encoding + maskedURL := parsedURL.Scheme + "://" + username + ":***@" + parsedURL.Host + parsedURL.Path + if parsedURL.RawQuery != "" { + maskedURL += "?" + parsedURL.RawQuery + } + if parsedURL.Fragment != "" { + maskedURL += "#" + parsedURL.Fragment + } + return maskedURL, nil } - return maskedURL, nil } + + return parsedURL.String(), nil + } + + // If URL parsing failed or no scheme, try key-value format (libpq connection string) + return maskPasswordInKeyValueFormat(dbURL), nil +} + +// maskPasswordInKeyValueFormat masks password in libpq-style key-value connection strings +// Format: "user=foo password=bar database=db host=localhost" +// Supports all spacing variations: password=value, password =value, password= value, password = value +func maskPasswordInKeyValueFormat(connStr string) string { + // Find "password" key (case insensitive) + lowerStr := strings.ToLower(connStr) + passwordKey := "password" + passwordIdx := strings.Index(lowerStr, passwordKey) + if passwordIdx == -1 { + return connStr // No password found + } + + // Find the = sign after "password" (skip optional spaces before =) + afterKey := passwordIdx + len(passwordKey) + for afterKey < len(connStr) && connStr[afterKey] == ' ' { + afterKey++ + } + if afterKey >= len(connStr) || connStr[afterKey] != '=' { + return connStr // No = sign found + } + + // Find the start of the password value (skip = and optional spaces after =) + valueStart := afterKey + 1 + for valueStart < len(connStr) && connStr[valueStart] == ' ' { + valueStart++ + } + + // Find the end of the password value (next space or end of string) + valueEnd := valueStart + for valueEnd < len(connStr) && connStr[valueEnd] != ' ' { + valueEnd++ } - return parsedURL.String(), nil + // Replace password value with *** + return connStr[:valueStart] + "***" + connStr[valueEnd:] } /*******************************/ From 17597b1288e24c36e6bbd7b329332d9fe57ee2b0 Mon Sep 17 00:00:00 2001 From: maxdml Date: Wed, 5 Nov 2025 16:36:27 -0800 Subject: [PATCH 2/3] handle k/v urls format in the CLI --- cmd/dbos/reset.go | 47 ++++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/cmd/dbos/reset.go b/cmd/dbos/reset.go index a57be6f..cf43e26 100644 --- a/cmd/dbos/reset.go +++ b/cmd/dbos/reset.go @@ -1,11 +1,11 @@ package main import ( - "database/sql" + "context" "fmt" - "net/url" - _ "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/spf13/cobra" ) @@ -39,42 +39,43 @@ func runReset(cmd *cobra.Command, args []string) error { return err } - // Parse the URL to get database name - parsedURL, err := url.Parse(dbURL) - if err != nil { - return fmt.Errorf("invalid database URL: %w", err) - } + ctx := context.Background() - // Extract database name from path - dbName := parsedURL.Path - if len(dbName) > 0 && dbName[0] == '/' { - dbName = dbName[1:] // Remove leading slash + // Parse the connection string using pgxpool.ParseConfig which handles both URL and key-value formats + config, err := pgxpool.ParseConfig(dbURL) + if err != nil { + return fmt.Errorf("failed to parse database URL: %w", err) } + // Get the database name from the config + dbName := config.ConnConfig.Database if dbName == "" { - return fmt.Errorf("database name is required in URL") + return fmt.Errorf("database name not found in connection string") } - // Connect to postgres database to drop and recreate the system database - parsedURL.Path = "/postgres" - postgresURL := parsedURL.String() + // Create a connection configuration pointing to the postgres database + postgresConfig := config.ConnConfig.Copy() + postgresConfig.Database = "postgres" - db, err := sql.Open("pgx", postgresURL) + // Connect to the postgres database + conn, err := pgx.ConnectConfig(ctx, postgresConfig) if err != nil { - return fmt.Errorf("failed to connect to postgres database: %w", err) + return fmt.Errorf("failed to connect to PostgreSQL server: %w", err) } - defer db.Close() + defer conn.Close(ctx) // Drop the system database if it exists logger.Info("Resetting system database", "database", dbName) - dropQuery := fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName) - if _, err := db.Exec(dropQuery); err != nil { + dropSQL := fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", pgx.Identifier{dbName}.Sanitize()) + _, err = conn.Exec(ctx, dropSQL) + if err != nil { return fmt.Errorf("failed to drop system database: %w", err) } // Create the database - createQuery := fmt.Sprintf("CREATE DATABASE %s", dbName) - if _, err := db.Exec(createQuery); err != nil { + createSQL := fmt.Sprintf("CREATE DATABASE %s", pgx.Identifier{dbName}.Sanitize()) + _, err = conn.Exec(ctx, createSQL) + if err != nil { return fmt.Errorf("failed to create system database: %w", err) } From 73217a5b65c7b38201f5866414210fd79c67fa44 Mon Sep 17 00:00:00 2001 From: maxdml Date: Wed, 5 Nov 2025 16:45:52 -0800 Subject: [PATCH 3/3] test --- cmd/dbos/cli_integration_test.go | 62 ++++++++++++++++++++++++++++++++ cmd/dbos/utils.go | 2 +- dbos/dbos_test.go | 2 +- 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/cmd/dbos/cli_integration_test.go b/cmd/dbos/cli_integration_test.go index a72b21d..1441fb9 100644 --- a/cmd/dbos/cli_integration_test.go +++ b/cmd/dbos/cli_integration_test.go @@ -55,6 +55,16 @@ func getDatabaseURL(dbRole string) string { return dsn.String() } +// getDatabaseURLKeyValue returns the same connection string in libpq key-value format +// This is useful for testing that commands handle both URL and key-value formats +func getDatabaseURLKeyValue(dbRole string) string { + password := os.Getenv("PGPASSWORD") + if password == "" { + password = "dbos" + } + return fmt.Sprintf("user='%s' password='%s' database=dbos host=localhost port=5432 sslmode=disable", dbRole, password) +} + // TestCLIWorkflow provides comprehensive integration testing of the DBOS CLI func TestCLIWorkflow(t *testing.T) { defer goleak.VerifyNone(t, @@ -147,6 +157,27 @@ func TestCLIWorkflow(t *testing.T) { } }) + t.Run("ResetDatabaseWithKeyValueFormat", func(t *testing.T) { + // Test reset command with key-value format connection string + args := append([]string{"reset", "-y", "--db-url", getDatabaseURLKeyValue("postgres")}, config.args...) + cmd := exec.Command(cliPath, args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Reset database command with key-value format failed: %s", string(output)) + + assert.Contains(t, string(output), "System database has been reset successfully", "Output should confirm database reset") + + // Verify the database was reset by checking schema doesn't exist + db, err := sql.Open("pgx", getDatabaseURL("postgres")) + require.NoError(t, err) + defer db.Close() + + var exists bool + err = db.QueryRow("SELECT EXISTS(SELECT 1 FROM information_schema.schemata WHERE schema_name = $1)", config.schemaName).Scan(&exists) + require.NoError(t, err) + assert.False(t, exists, fmt.Sprintf("Schema %s should not exist after reset", config.schemaName)) + }) + t.Run("ProjectInitialization", func(t *testing.T) { testProjectInitialization(t, cliPath) }) @@ -507,6 +538,22 @@ func testListWorkflows(t *testing.T, cliPath string, baseArgs []string, dbRole s } }) } + + // Test list command with key-value format connection string + t.Run("ListWithKeyValueFormat", func(t *testing.T) { + args := append([]string{"workflow", "list", "--db-url", getDatabaseURLKeyValue(dbRole)}, baseArgs...) + fmt.Println(args) + cmd := exec.Command(cliPath, args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "List command with key-value format failed: %s", string(output)) + + // Parse JSON output + var workflows []dbos.WorkflowStatus + err = json.Unmarshal(output, &workflows) + require.NoError(t, err, "JSON output should be valid") + assert.Greater(t, len(workflows), 0, "Should have workflows when using key-value format") + }) } // testGetWorkflow tests retrieving individual workflow details @@ -560,6 +607,21 @@ func testGetWorkflow(t *testing.T, cliPath string, baseArgs []string, dbRole str assert.NotEmpty(t, status2.Status, "Should have workflow status") assert.NotEmpty(t, status2.Name, "Should have workflow name") + // Test with key-value format connection string (libpq format) + argsKeyValue := append([]string{"workflow", "get", workflowID, "--db-url", getDatabaseURLKeyValue(dbRole)}, baseArgs...) + cmdKeyValue := exec.Command(cliPath, argsKeyValue...) + + outputKeyValue, errKeyValue := cmdKeyValue.CombinedOutput() + require.NoError(t, errKeyValue, "Get workflow JSON command with key-value format failed: %s", string(outputKeyValue)) + + // Verify valid JSON + var statusKeyValue dbos.WorkflowStatus + err = json.Unmarshal(outputKeyValue, &statusKeyValue) + require.NoError(t, err, "JSON output should be valid") + assert.Equal(t, workflowID, statusKeyValue.ID, "JSON should contain correct workflow ID") + assert.NotEmpty(t, statusKeyValue.Status, "Should have workflow status") + assert.NotEmpty(t, statusKeyValue.Name, "Should have workflow name") + // Test with config file containing environment variable configPath := "dbos-config.yaml" diff --git a/cmd/dbos/utils.go b/cmd/dbos/utils.go index 0d590e0..4711816 100644 --- a/cmd/dbos/utils.go +++ b/cmd/dbos/utils.go @@ -103,7 +103,7 @@ func getDBURL() (string, error) { // Log the database URL in verbose mode with masked password maskedURL, err := maskPassword(resolvedURL) if err != nil { - logger.Warn("Failed to mask database URL", "error", err) + logger.Debug("Failed to mask database URL", "error", err) maskedURL = resolvedURL } logger.Debug("Using database URL", "source", source, "url", maskedURL) diff --git a/dbos/dbos_test.go b/dbos/dbos_test.go index f964044..ee9b605 100644 --- a/dbos/dbos_test.go +++ b/dbos/dbos_test.go @@ -337,7 +337,7 @@ func TestConfig(t *testing.T) { t.Run("DBOSContextCreation", func(t *testing.T) { // Use the actual password from config for integration test actualPassword := parsedURL.ConnConfig.Password - keyValueConnStr := fmt.Sprintf("user=%s password=%s database=%s host=%s%s", user, actualPassword, database, host, portSSL) + keyValueConnStr := fmt.Sprintf("user='%s' password='%s' database=%s host=%s%s", user, actualPassword, database, host, portSSL) ctx, err := NewDBOSContext(context.Background(), Config{ DatabaseURL: keyValueConnStr,