diff --git a/pkg/database/restore.go b/pkg/database/restore.go index 37a9c20..45df73e 100644 --- a/pkg/database/restore.go +++ b/pkg/database/restore.go @@ -7,11 +7,7 @@ import ( "fmt" "io" "regexp" -) - -const ( - // used to define default max buffer size Scanner, counter part of dump - defaultMaxAllowedPacket = 4194304 + "strings" ) var ( @@ -31,20 +27,33 @@ func Restore(ctx context.Context, dbconn *Connection, databasesMap map[string]st if err != nil { return fmt.Errorf("failed to restore database: %w", err) } - scanBuf := []byte{} - scanner := bufio.NewScanner(r) - // increase the buffer size - scanner.Buffer(scanBuf, defaultMaxAllowedPacket) //TODO should be a configurable option like with dump + reader := bufio.NewReader(r) var current string - for scanner.Scan() { - line := scanner.Text() + for { + line, err := reader.ReadString('\n') + if err != nil && err != io.EOF { + _ = tx.Rollback() + return fmt.Errorf("failed to restore database: %w", err) + } + // strip CRLF/newline + line = strings.TrimRight(line, "\r\n") if line == "" { + if err == io.EOF { + break + } continue } current += line + "\n" + + // if the line does not end with a semicolon, keep accumulating if line[len(line)-1] != ';' { + if err == io.EOF { + // EOF reached but statement not terminated; we'll try to execute below + break + } continue } + // if we have the line that sets the database, and we need to replace, replace it if createRegex.MatchString(current) { dbName := createRegex.FindStringSubmatch(current)[3] @@ -64,9 +73,17 @@ func Restore(ctx context.Context, dbconn *Connection, databasesMap map[string]st return fmt.Errorf("failed to restore database: %w", err) } current = "" + + if err == io.EOF { + break + } } - if err := scanner.Err(); err != nil { - return fmt.Errorf("failed to restore database: %w", err) + // if there's any leftover SQL (for example last statement without newline), execute it + if strings.TrimSpace(current) != "" { + if _, err := tx.Exec(current); err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to restore database: %w", err) + } } if err := tx.Commit(); err != nil { return fmt.Errorf("failed to restore database: %w", err)