Skip to content

Commit c1f41a6

Browse files
authored
Cancellable get results (#169)
Resolves #85
1 parent 381f713 commit c1f41a6

File tree

2 files changed

+136
-13
lines changed

2 files changed

+136
-13
lines changed

dbos/workflow.go

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,32 @@ type workflowOutcome[R any] struct {
8282
// The type parameter R represents the expected return type of the workflow.
8383
// Handles can be used to wait for workflow completion, check status, and retrieve results.
8484
type WorkflowHandle[R any] interface {
85-
GetResult() (R, error) // Wait for workflow completion and return the result
86-
GetStatus() (WorkflowStatus, error) // Get current workflow status without waiting
87-
GetWorkflowID() string // Get the unique workflow identifier
85+
GetResult(opts ...GetResultOption) (R, error) // Wait for workflow completion and return the result
86+
GetStatus() (WorkflowStatus, error) // Get current workflow status without waiting
87+
GetWorkflowID() string // Get the unique workflow identifier
8888
}
8989

9090
type baseWorkflowHandle struct {
9191
workflowID string
9292
dbosContext DBOSContext
9393
}
9494

95+
// GetResultOption is a functional option for configuring GetResult behavior.
96+
type GetResultOption func(*getResultOptions)
97+
98+
// getResultOptions holds the configuration for GetResult execution.
99+
type getResultOptions struct {
100+
timeout time.Duration
101+
}
102+
103+
// WithHandleTimeout sets a timeout for the GetResult operation.
104+
// If the timeout is reached before the workflow completes, GetResult will return a timeout error.
105+
func WithHandleTimeout(timeout time.Duration) GetResultOption {
106+
return func(opts *getResultOptions) {
107+
opts.timeout = timeout
108+
}
109+
}
110+
95111
// GetStatus returns the current status of the workflow from the database
96112
// If the DBOSContext is running in client mode, do not load input and outputs
97113
func (h *baseWorkflowHandle) GetStatus() (WorkflowStatus, error) {
@@ -162,12 +178,33 @@ type workflowHandle[R any] struct {
162178
outcomeChan chan workflowOutcome[R]
163179
}
164180

165-
func (h *workflowHandle[R]) GetResult() (R, error) {
166-
outcome, ok := <-h.outcomeChan // Blocking read
167-
if !ok {
168-
// Return an error if the channel was closed. In normal operations this would happen if GetResul() is called twice on a handler. The first call should get the buffered result, the second call find zero values (channel is empty and closed).
169-
return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?")
181+
func (h *workflowHandle[R]) GetResult(opts ...GetResultOption) (R, error) {
182+
options := &getResultOptions{}
183+
for _, opt := range opts {
184+
opt(options)
170185
}
186+
187+
var timeoutChan <-chan time.Time
188+
if options.timeout > 0 {
189+
timeoutChan = time.After(options.timeout)
190+
}
191+
192+
select {
193+
case outcome, ok := <-h.outcomeChan:
194+
if !ok {
195+
// Return error if channel closed (happens when GetResult() called twice)
196+
return *new(R), errors.New("workflow result channel is already closed. Did you call GetResult() twice on the same workflow handle?")
197+
}
198+
return h.processOutcome(outcome)
199+
case <-h.dbosContext.Done():
200+
return *new(R), context.Cause(h.dbosContext)
201+
case <-timeoutChan:
202+
return *new(R), fmt.Errorf("workflow result timeout after %v: %w", options.timeout, context.DeadlineExceeded)
203+
}
204+
}
205+
206+
// processOutcome handles the common logic for processing workflow outcomes
207+
func (h *workflowHandle[R]) processOutcome(outcome workflowOutcome[R]) (R, error) {
171208
// If we are calling GetResult inside a workflow, record the result as a step result
172209
workflowState, ok := h.dbosContext.Value(workflowStateKey).(*workflowState)
173210
isWithinWorkflow := ok && workflowState != nil
@@ -198,9 +235,22 @@ type workflowPollingHandle[R any] struct {
198235
baseWorkflowHandle
199236
}
200237

201-
func (h *workflowPollingHandle[R]) GetResult() (R, error) {
202-
result, err := retryWithResult(h.dbosContext, func() (any, error) {
203-
return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(h.dbosContext, h.workflowID)
238+
func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error) {
239+
options := &getResultOptions{}
240+
for _, opt := range opts {
241+
opt(options)
242+
}
243+
244+
// Use timeout if specified, otherwise use DBOS context directly
245+
ctx := h.dbosContext
246+
var cancel context.CancelFunc
247+
if options.timeout > 0 {
248+
ctx, cancel = WithTimeout(h.dbosContext, options.timeout)
249+
defer cancel()
250+
}
251+
252+
result, err := retryWithResult(ctx, func() (any, error) {
253+
return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(ctx, h.workflowID)
204254
}, withRetrierLogger(h.dbosContext.(*dbosContext).logger))
205255
if result != nil {
206256
typedResult, ok := result.(R)
@@ -240,8 +290,8 @@ type workflowHandleProxy[R any] struct {
240290
wrappedHandle WorkflowHandle[any]
241291
}
242292

243-
func (h *workflowHandleProxy[R]) GetResult() (R, error) {
244-
result, err := h.wrappedHandle.GetResult()
293+
func (h *workflowHandleProxy[R]) GetResult(opts ...GetResultOption) (R, error) {
294+
result, err := h.wrappedHandle.GetResult(opts...)
245295
if err != nil {
246296
var zero R
247297
return zero, err

dbos/workflows_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ func simpleWorkflowWithStep(dbosCtx DBOSContext, input string) (string, error) {
3434
})
3535
}
3636

37+
func slowWorkflow(dbosCtx DBOSContext, sleepTime time.Duration) (string, error) {
38+
Sleep(dbosCtx, sleepTime)
39+
return "done", nil
40+
}
41+
3742
func simpleStep(_ context.Context) (string, error) {
3843
return "from step", nil
3944
}
@@ -4523,3 +4528,71 @@ func TestWorkflowIdentity(t *testing.T) {
45234528
assert.Equal(t, []string{"reader", "writer"}, status.AuthenticatedRoles)
45244529
})
45254530
}
4531+
4532+
func TestWorkflowHandleTimeout(t *testing.T) {
4533+
dbosCtx := setupDBOS(t, true, true)
4534+
RegisterWorkflow(dbosCtx, slowWorkflow)
4535+
4536+
t.Run("WorkflowHandleTimeout", func(t *testing.T) {
4537+
handle, err := RunWorkflow(dbosCtx, slowWorkflow, 10*time.Second)
4538+
require.NoError(t, err, "failed to start workflow")
4539+
4540+
start := time.Now()
4541+
_, err = handle.GetResult(WithHandleTimeout(10 * time.Millisecond))
4542+
duration := time.Since(start)
4543+
4544+
require.Error(t, err, "expected timeout error")
4545+
assert.Contains(t, err.Error(), "workflow result timeout")
4546+
assert.True(t, duration < 100*time.Millisecond, "timeout should occur quickly")
4547+
assert.True(t, errors.Is(err, context.DeadlineExceeded),
4548+
"expected error to be detectable as context.DeadlineExceeded, got: %v", err)
4549+
})
4550+
4551+
t.Run("WorkflowPollingHandleTimeout", func(t *testing.T) {
4552+
// Start a workflow that will block on the first signal
4553+
originalHandle, err := RunWorkflow(dbosCtx, slowWorkflow, 10*time.Second)
4554+
require.NoError(t, err, "failed to start workflow")
4555+
4556+
pollingHandle, err := RetrieveWorkflow[string](dbosCtx, originalHandle.GetWorkflowID())
4557+
require.NoError(t, err, "failed to retrieve workflow")
4558+
4559+
_, ok := pollingHandle.(*workflowPollingHandle[string])
4560+
require.True(t, ok, "expected polling handle, got %T", pollingHandle)
4561+
4562+
_, err = pollingHandle.GetResult(WithHandleTimeout(10 * time.Millisecond))
4563+
4564+
require.Error(t, err, "expected timeout error")
4565+
assert.True(t, errors.Is(err, context.DeadlineExceeded),
4566+
"expected error to be detectable as context.DeadlineExceeded, got: %v", err)
4567+
})
4568+
}
4569+
4570+
func TestWorkflowHandleContextCancel(t *testing.T) {
4571+
dbosCtx := setupDBOS(t, true, true)
4572+
RegisterWorkflow(dbosCtx, getEventWorkflow)
4573+
4574+
t.Run("WorkflowHandleContextCancel", func(t *testing.T) {
4575+
getEventWorkflowStartedSignal.Clear()
4576+
handle, err := RunWorkflow(dbosCtx, getEventWorkflow, getEventWorkflowInput{
4577+
TargetWorkflowID: "test-workflow-id",
4578+
Key: "test-key",
4579+
})
4580+
require.NoError(t, err, "failed to start workflow")
4581+
4582+
resultChan := make(chan error)
4583+
go func() {
4584+
_, err := handle.GetResult()
4585+
resultChan <- err
4586+
}()
4587+
4588+
getEventWorkflowStartedSignal.Wait()
4589+
getEventWorkflowStartedSignal.Clear()
4590+
4591+
dbosCtx.Shutdown(1 * time.Second)
4592+
4593+
err = <-resultChan
4594+
require.Error(t, err, "expected error from cancelled context")
4595+
assert.True(t, errors.Is(err, context.Canceled),
4596+
"expected error to be detectable as context.Canceled, got: %v", err)
4597+
})
4598+
}

0 commit comments

Comments
 (0)