Skip to content

Commit d28b3c6

Browse files
committed
refactor: improve error handling in Go function and update tests for custom output types
1 parent 74deb3a commit d28b3c6

File tree

2 files changed

+45
-19
lines changed

2 files changed

+45
-19
lines changed

dbos/workflow.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,18 +1257,29 @@ func Go[R any](ctx DBOSContext, fn Step[R], opts ...StepOption) (chan StepOutcom
12571257
return *new(chan StepOutcome[R]), err
12581258
}
12591259

1260-
// Otherwise type-check and cast the result
1260+
outcomeChan := make(chan StepOutcome[R], 1)
1261+
defer close(outcomeChan)
1262+
12611263
outcome := <-result
1264+
1265+
if outcome.err != nil {
1266+
outcomeChan <- StepOutcome[R]{
1267+
result: *new(R),
1268+
err: outcome.err,
1269+
}
1270+
return outcomeChan, nil
1271+
}
1272+
1273+
// Otherwise type-check and cast the result
12621274
typedResult, ok := outcome.result.(R)
12631275
if !ok {
12641276
return *new(chan StepOutcome[R]), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result)
12651277
}
1266-
outcomeChan := make(chan StepOutcome[R], 1)
1267-
defer close(outcomeChan)
12681278
outcomeChan <- StepOutcome[R]{
12691279
result: typedResult,
1270-
err: outcome.err,
1280+
err: nil,
12711281
}
1282+
12721283
return outcomeChan, nil
12731284
}
12741285

@@ -1292,6 +1303,8 @@ func (c *dbosContext) Go(ctx DBOSContext, fn StepFunc, opts ...StepOption) (chan
12921303
}
12931304
}()
12941305

1306+
// TODO: do I need to close the channel here?
1307+
12951308
return result, nil
12961309
}
12971310

dbos/workflows_test.go

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package dbos
22

33
import (
44
"context"
5+
"encoding/gob"
56
"errors"
67
"fmt"
78
"reflect"
@@ -42,11 +43,26 @@ func simpleStepError(_ context.Context) (string, error) {
4243
return "", fmt.Errorf("step failure")
4344
}
4445

46+
type stepWithSleepOutput struct {
47+
StepID int
48+
Result string
49+
Error error
50+
}
51+
4552
func stepWithSleep(_ context.Context, duration time.Duration) (string, error) {
4653
time.Sleep(duration)
4754
return fmt.Sprintf("from step that slept for %s", duration), nil
4855
}
4956

57+
func stepWithSleepCustomOutput(_ context.Context, duration time.Duration, stepID int) (stepWithSleepOutput, error) {
58+
time.Sleep(duration)
59+
return stepWithSleepOutput{
60+
StepID: stepID,
61+
Result: fmt.Sprintf("from step that slept for %s", duration),
62+
Error: nil,
63+
}, nil
64+
}
65+
5066
func simpleWorkflowWithStepError(dbosCtx DBOSContext, input string) (string, error) {
5167
return RunAsStep(dbosCtx, func(ctx context.Context) (string, error) {
5268
return simpleStepError(ctx)
@@ -862,6 +878,10 @@ func TestSteps(t *testing.T) {
862878

863879
func TestGoRunningStepsInsideGoRoutines(t *testing.T) {
864880
dbosCtx := setupDBOS(t, true, true)
881+
882+
// Register custom types for Gob encoding
883+
var stepOutput stepWithSleepOutput
884+
gob.Register(stepOutput)
865885
t.Run("Go must run steps inside a workflow", func(t *testing.T) {
866886
_, err := Go(dbosCtx, func(ctx context.Context) (string, error) {
867887
return stepWithSleep(ctx, 1*time.Second)
@@ -902,12 +922,12 @@ func TestGoRunningStepsInsideGoRoutines(t *testing.T) {
902922
const numSteps = 100
903923
results := make(chan string, numSteps)
904924
errors := make(chan error, numSteps)
905-
var resultChans []<-chan StepOutcome[string]
925+
var resultChans []<-chan StepOutcome[stepWithSleepOutput]
906926

907927
goWorkflow := func(dbosCtx DBOSContext, input string) (string, error) {
908-
for range numSteps {
909-
resultChan, err := Go(dbosCtx, func(ctx context.Context) (string, error) {
910-
return stepWithSleep(ctx, 20*time.Millisecond)
928+
for i := 0; i < numSteps; i++ {
929+
resultChan, err := Go(dbosCtx, func(ctx context.Context) (stepWithSleepOutput, error) {
930+
return stepWithSleepCustomOutput(ctx, 20*time.Millisecond, i)
911931
})
912932

913933
if err != nil {
@@ -916,12 +936,13 @@ func TestGoRunningStepsInsideGoRoutines(t *testing.T) {
916936
resultChans = append(resultChans, resultChan)
917937
}
918938

919-
for _, resultChan := range resultChans {
939+
for i, resultChan := range resultChans {
920940
result1 := <-resultChan
921941
if result1.err != nil {
922-
errors <- result1.err
942+
errors <- result1.result.Error
923943
}
924-
results <- result1.result
944+
assert.Equal(t, i, result1.result.StepID, "expected step ID to be %d, got %d", i, result1.result.StepID)
945+
results <- result1.result.Result
925946
}
926947
return "", nil
927948
}
@@ -933,18 +954,10 @@ func TestGoRunningStepsInsideGoRoutines(t *testing.T) {
933954

934955
close(results)
935956
close(errors)
936-
937957
require.NoError(t, err, "failed to get result from go workflow")
938958
assert.Equal(t, numSteps, len(results), "expected %d results, got %d", numSteps, len(results))
939959
assert.Equal(t, 0, len(errors), "expected no errors, got %d", len(errors))
940960

941-
// Test step IDs are deterministic and in the order of execution
942-
steps, err := GetWorkflowSteps(dbosCtx, handle.GetWorkflowID())
943-
require.NoError(t, err, "failed to get workflow steps")
944-
require.Len(t, steps, numSteps, "expected %d steps, got %d", numSteps, len(steps))
945-
for i := 0; i < numSteps; i++ {
946-
assert.Equal(t, i, steps[i].StepID, "expected step ID to be %d, got %d", i, steps[i].StepID)
947-
}
948961
})
949962
}
950963

0 commit comments

Comments
 (0)