Skip to content

Commit 050cfef

Browse files
authored
Fix comms durable sleeps (#116)
- Fix a bug where sleep steps ID were not always generated when `recv`/`getEvent` skipped sleeping (e.g., a value was already found, during recovery) - Fix a bug were `recv` would return an error if it recorded a timeout value - Handle context cancellation during `recv` (addresses #86)
1 parent 0afae2e commit 050cfef

File tree

2 files changed

+287
-76
lines changed

2 files changed

+287
-76
lines changed

dbos/system_database.go

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,7 +1440,8 @@ func (s *sysDB) getWorkflowSteps(ctx context.Context, workflowID string) ([]Step
14401440

14411441
type sleepInput struct {
14421442
duration time.Duration // Duration to sleep
1443-
skipSleep bool // If true, the function will not actually sleep (useful for testing)
1443+
skipSleep bool // If true, the function will not actually sleep and just return the remaining sleep duration
1444+
stepID *int // Optional step ID to use instead of generating a new one (for internal use)
14441445
}
14451446

14461447
// Sleep is a special type of step that sleeps for a specified duration
@@ -1461,7 +1462,13 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err
14611462
return 0, newStepExecutionError(wfState.workflowID, functionName, "cannot call Sleep within a step")
14621463
}
14631464

1464-
stepID := wfState.NextStepID()
1465+
// Determine step ID
1466+
var stepID int
1467+
if input.stepID != nil && *input.stepID >= 0 {
1468+
stepID = *input.stepID
1469+
} else {
1470+
stepID = wfState.NextStepID()
1471+
}
14651472

14661473
// Check if operation was already executed
14671474
checkInput := checkOperationExecutionDBInput{
@@ -1700,6 +1707,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
17001707
}
17011708

17021709
stepID := wfState.NextStepID()
1710+
sleepStepID := wfState.NextStepID() // We will use a sleep step to implement the timeout
17031711
destinationID := wfState.workflowID
17041712

17051713
// Set default topic if not provided
@@ -1719,10 +1727,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
17191727
return nil, err
17201728
}
17211729
if recordedResult != nil {
1722-
if recordedResult.output != nil {
1723-
return recordedResult.output, nil
1724-
}
1725-
return nil, fmt.Errorf("no output recorded in the last recv")
1730+
return recordedResult.output, nil
17261731
}
17271732

17281733
// First check if there's already a receiver for this workflow/topic to avoid unnecessary database load
@@ -1762,6 +1767,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
17621767
timeout, err := s.sleep(ctx, sleepInput{
17631768
duration: input.Timeout,
17641769
skipSleep: true,
1770+
stepID: &sleepStepID,
17651771
})
17661772
if err != nil {
17671773
return nil, fmt.Errorf("failed to sleep before recv timeout: %w", err)
@@ -1772,6 +1778,9 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
17721778
s.logger.Debug("Received notification on condition variable", "payload", payload)
17731779
case <-time.After(timeout):
17741780
s.logger.Warn("Recv() timeout reached", "payload", payload, "timeout", input.Timeout)
1781+
case <-ctx.Done():
1782+
s.logger.Warn("Recv() context cancelled", "payload", payload, "cause", context.Cause(ctx))
1783+
return nil, ctx.Err()
17751784
}
17761785
}
17771786

@@ -1808,7 +1817,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) {
18081817

18091818
// Deserialize the message
18101819
var message any
1811-
if messageString != nil { // nil message should never happen because they'd cause an error on the send() path
1820+
if messageString != nil { // nil message can happen on the timeout path only
18121821
message, err = deserialize(messageString)
18131822
if err != nil {
18141823
return nil, fmt.Errorf("failed to deserialize message: %w", err)
@@ -1923,6 +1932,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
19231932
// Get workflow state from context (optional for GetEvent as we can get an event from outside a workflow)
19241933
wfState, ok := ctx.Value(workflowStateKey).(*workflowState)
19251934
var stepID int
1935+
var sleepStepID int
19261936
var isInWorkflow bool
19271937

19281938
if ok && wfState != nil {
@@ -1931,6 +1941,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
19311941
return nil, newStepExecutionError(wfState.workflowID, functionName, "cannot call GetEvent within a step")
19321942
}
19331943
stepID = wfState.NextStepID()
1944+
sleepStepID = wfState.NextStepID() // We will use a sleep step to implement the timeout
19341945

19351946
// Check if operation was already executed (only if in workflow)
19361947
checkInput := checkOperationExecutionDBInput{
@@ -1967,7 +1978,6 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
19671978

19681979
// Check if the event already exists in the database
19691980
query := `SELECT value FROM dbos.workflow_events WHERE workflow_uuid = $1 AND key = $2`
1970-
var value any
19711981
var valueString *string
19721982

19731983
row := s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
@@ -1976,7 +1986,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
19761986
return nil, fmt.Errorf("failed to query workflow event: %w", err)
19771987
}
19781988

1979-
if err == pgx.ErrNoRows || valueString == nil { // valueString should never be `nil`
1989+
if err == pgx.ErrNoRows {
19801990
// Wait for notification with timeout using condition variable
19811991
done := make(chan struct{})
19821992
go func() {
@@ -1991,6 +2001,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
19912001
timeout, err = s.sleep(ctx, sleepInput{
19922002
duration: input.Timeout,
19932003
skipSleep: true,
2004+
stepID: &sleepStepID,
19942005
})
19952006
if err != nil {
19962007
return nil, fmt.Errorf("failed to sleep before getEvent timeout: %w", err)
@@ -2003,22 +2014,20 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error)
20032014
case <-time.After(timeout):
20042015
s.logger.Warn("GetEvent() timeout reached", "target_workflow_id", input.TargetWorkflowID, "key", input.Key, "timeout", input.Timeout)
20052016
case <-ctx.Done():
2006-
return nil, fmt.Errorf("context cancelled while waiting for event: %w", ctx.Err())
2017+
s.logger.Warn("GetEvent() context cancelled", "target_workflow_id", input.TargetWorkflowID, "key", input.Key, "cause", context.Cause(ctx))
2018+
return nil, ctx.Err()
20072019
}
20082020

20092021
// Query the database again after waiting
20102022
row = s.pool.QueryRow(ctx, query, input.TargetWorkflowID, input.Key)
20112023
err = row.Scan(&valueString)
2012-
if err != nil {
2013-
if err == pgx.ErrNoRows {
2014-
value = nil // Event still doesn't exist
2015-
} else {
2016-
return nil, fmt.Errorf("failed to query workflow event after wait: %w", err)
2017-
}
2024+
if err != nil && err != pgx.ErrNoRows {
2025+
return nil, fmt.Errorf("failed to query workflow event after wait: %w", err)
20182026
}
20192027
}
20202028

20212029
// Deserialize the value if it exists
2030+
var value any
20222031
if valueString != nil {
20232032
value, err = deserialize(valueString)
20242033
if err != nil {

0 commit comments

Comments
 (0)