Skip to content

Commit 346bdc0

Browse files
af-mdmaxdml
authored andcommitted
refactor: improve error handling in Go function and update tests for custom output types
1 parent c6f4f1d commit 346bdc0

File tree

2 files changed

+45
-19
lines changed

2 files changed

+45
-19
lines changed

dbos/workflow.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,18 +1306,29 @@ func Go[R any](ctx DBOSContext, fn Step[R], opts ...StepOption) (chan StepOutcom
13061306
return *new(chan StepOutcome[R]), err
13071307
}
13081308

1309-
// Otherwise type-check and cast the result
1309+
outcomeChan := make(chan StepOutcome[R], 1)
1310+
defer close(outcomeChan)
1311+
13101312
outcome := <-result
1313+
1314+
if outcome.err != nil {
1315+
outcomeChan <- StepOutcome[R]{
1316+
result: *new(R),
1317+
err: outcome.err,
1318+
}
1319+
return outcomeChan, nil
1320+
}
1321+
1322+
// Otherwise type-check and cast the result
13111323
typedResult, ok := outcome.result.(R)
13121324
if !ok {
13131325
return *new(chan StepOutcome[R]), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result)
13141326
}
1315-
outcomeChan := make(chan StepOutcome[R], 1)
1316-
defer close(outcomeChan)
13171327
outcomeChan <- StepOutcome[R]{
13181328
result: typedResult,
1319-
err: outcome.err,
1329+
err: nil,
13201330
}
1331+
13211332
return outcomeChan, nil
13221333
}
13231334

@@ -1341,6 +1352,8 @@ func (c *dbosContext) Go(ctx DBOSContext, fn StepFunc, opts ...StepOption) (chan
13411352
}
13421353
}()
13431354

1355+
// TODO: do I need to close the channel here?
1356+
13441357
return result, nil
13451358
}
13461359

dbos/workflows_test.go

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package dbos
22

33
import (
44
"context"
5+
"encoding/gob"
56
"errors"
67
"fmt"
78
"reflect"
@@ -47,11 +48,26 @@ func simpleStepError(_ context.Context) (string, error) {
4748
return "", fmt.Errorf("step failure")
4849
}
4950

51+
type stepWithSleepOutput struct {
52+
StepID int
53+
Result string
54+
Error error
55+
}
56+
5057
func stepWithSleep(_ context.Context, duration time.Duration) (string, error) {
5158
time.Sleep(duration)
5259
return fmt.Sprintf("from step that slept for %s", duration), nil
5360
}
5461

62+
func stepWithSleepCustomOutput(_ context.Context, duration time.Duration, stepID int) (stepWithSleepOutput, error) {
63+
time.Sleep(duration)
64+
return stepWithSleepOutput{
65+
StepID: stepID,
66+
Result: fmt.Sprintf("from step that slept for %s", duration),
67+
Error: nil,
68+
}, nil
69+
}
70+
5571
func simpleWorkflowWithStepError(dbosCtx DBOSContext, input string) (string, error) {
5672
return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
5773
return simpleStepError(ctx)
@@ -867,6 +883,10 @@ func TestSteps(t *testing.T) {
867883

868884
func TestGoRunningStepsInsideGoRoutines(t *testing.T) {
869885
dbosCtx := setupDBOS(t, true, true)
886+
887+
// Register custom types for Gob encoding
888+
var stepOutput stepWithSleepOutput
889+
gob.Register(stepOutput)
870890
t.Run("Go must run steps inside a workflow", func(t *testing.T) {
871891
_, err := Go(dbosCtx, func(ctx context.Context) (string, error) {
872892
return stepWithSleep(ctx, 1*time.Second)
@@ -907,12 +927,12 @@ func TestGoRunningStepsInsideGoRoutines(t *testing.T) {
907927
const numSteps = 100
908928
results := make(chan string, numSteps)
909929
errors := make(chan error, numSteps)
910-
var resultChans []<-chan StepOutcome[string]
930+
var resultChans []<-chan StepOutcome[stepWithSleepOutput]
911931

912932
goWorkflow := func(dbosCtx DBOSContext, input string) (string, error) {
913-
for range numSteps {
914-
resultChan, err := Go(dbosCtx, func(ctx context.Context) (string, error) {
915-
return stepWithSleep(ctx, 20*time.Millisecond)
933+
for i := 0; i < numSteps; i++ {
934+
resultChan, err := Go(dbosCtx, func(ctx context.Context) (stepWithSleepOutput, error) {
935+
return stepWithSleepCustomOutput(ctx, 20*time.Millisecond, i)
916936
})
917937

918938
if err != nil {
@@ -921,12 +941,13 @@ func TestGoRunningStepsInsideGoRoutines(t *testing.T) {
921941
resultChans = append(resultChans, resultChan)
922942
}
923943

924-
for _, resultChan := range resultChans {
944+
for i, resultChan := range resultChans {
925945
result1 := <-resultChan
926946
if result1.err != nil {
927-
errors <- result1.err
947+
errors <- result1.result.Error
928948
}
929-
results <- result1.result
949+
assert.Equal(t, i, result1.result.StepID, "expected step ID to be %d, got %d", i, result1.result.StepID)
950+
results <- result1.result.Result
930951
}
931952
return "", nil
932953
}
@@ -938,18 +959,10 @@ func TestGoRunningStepsInsideGoRoutines(t *testing.T) {
938959

939960
close(results)
940961
close(errors)
941-
942962
require.NoError(t, err, "failed to get result from go workflow")
943963
assert.Equal(t, numSteps, len(results), "expected %d results, got %d", numSteps, len(results))
944964
assert.Equal(t, 0, len(errors), "expected no errors, got %d", len(errors))
945965

946-
// Test step IDs are deterministic and in the order of execution
947-
steps, err := GetWorkflowSteps(dbosCtx, handle.GetWorkflowID())
948-
require.NoError(t, err, "failed to get workflow steps")
949-
require.Len(t, steps, numSteps, "expected %d steps, got %d", numSteps, len(steps))
950-
for i := 0; i < numSteps; i++ {
951-
assert.Equal(t, i, steps[i].StepID, "expected step ID to be %d, got %d", i, steps[i].StepID)
952-
}
953966
})
954967
}
955968

0 commit comments

Comments
 (0)