Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions pkg/database/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@ import (
"fmt"
"io"
"regexp"
)

const (
// used to define default max buffer size Scanner, counter part of dump
defaultMaxAllowedPacket = 4194304
"strings"
)

var (
Expand All @@ -31,20 +27,33 @@ func Restore(ctx context.Context, dbconn *Connection, databasesMap map[string]st
if err != nil {
return fmt.Errorf("failed to restore database: %w", err)
}
scanBuf := []byte{}
scanner := bufio.NewScanner(r)
// increase the buffer size
scanner.Buffer(scanBuf, defaultMaxAllowedPacket) //TODO should be a configurable option like with dump
reader := bufio.NewReader(r)
var current string
for scanner.Scan() {
line := scanner.Text()
for {
line, err := reader.ReadString('\n')
if err != nil && err != io.EOF {
_ = tx.Rollback()
return fmt.Errorf("failed to restore database: %w", err)
}
// strip CRLF/newline
line = strings.TrimRight(line, "\r\n")
if line == "" {
if err == io.EOF {
break
}
continue
}
current += line + "\n"

// if the line does not end with a semicolon, keep accumulating
if line[len(line)-1] != ';' {
if err == io.EOF {
// EOF reached but statement not terminated; we'll try to execute below
break
}
continue
}

// if we have the line that sets the database, and we need to replace, replace it
if createRegex.MatchString(current) {
dbName := createRegex.FindStringSubmatch(current)[3]
Expand All @@ -64,9 +73,17 @@ func Restore(ctx context.Context, dbconn *Connection, databasesMap map[string]st
return fmt.Errorf("failed to restore database: %w", err)
}
current = ""

if err == io.EOF {
break
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("failed to restore database: %w", err)
// if there's any leftover SQL (for example last statement without newline), execute it
if strings.TrimSpace(current) != "" {
if _, err := tx.Exec(current); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to restore database: %w", err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to restore database: %w", err)
Expand Down