Skip to content

Commit 0afae2e

Browse files
authored
test dbos mocks (#115)
- Add a test exercising the mocking of the DBOSContext and WorkflowHandle interfaces - We need a wrapper handle for the mock Path in `RunWorkflow` (it does return a new handle from the `DBOSContext.RunWorkflow`). This will only be used on the mocking path.
1 parent 574621c commit 0afae2e

File tree

8 files changed

+263
-6
lines changed

8 files changed

+263
-6
lines changed

.github/workflows/tests.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ jobs:
6060
- name: Install gotestsum
6161
run: go install gotest.tools/gotestsum@latest
6262

63+
- name: Install mockery
64+
run: go install github.com/vektra/mockery/v3@latest
65+
66+
- name: Generate mocks
67+
run: go generate ./...
68+
working-directory: ./dbos
69+
6370
- name: Run tests
6471
run: go vet ./... && gotestsum --format github-action -- -race -v -count=1 ./...
6572
working-directory: ./dbos

cmd/dbos/init.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ var initCmd = &cobra.Command{
1616
RunE: runInit,
1717
}
1818

19-
var (
20-
configOnly bool
21-
)
22-
2319
type templateData struct {
2420
ProjectName string
2521
}

dbos/dbos.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ func processConfig(inputConfig *Config) (*Config, error) {
9393
return dbosConfig, nil
9494
}
9595

96+
//go:generate mockery --config=mocks-tests-config.yaml
97+
9698
// DBOSContext represents a DBOS execution context that provides workflow orchestration capabilities.
9799
// It extends the standard Go context.Context and adds methods for running workflows and steps,
98100
// inter-workflow communication, and state management.

dbos/mocks-tests-config.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
all: false
2+
dir: './mocks'
3+
filename: '{{.InterfaceName}}_mock.go'
4+
force-file-write: true
5+
formatter: goimports
6+
include-auto-generated: false
7+
log-level: info
8+
structname: 'Mock{{.InterfaceName}}'
9+
pkgname: 'mocks'
10+
recursive: false
11+
require-template-schema-exists: true
12+
template: testify
13+
template-schema: '{{.Template}}.schema.json'
14+
packages:
15+
github.com/dbos-inc/dbos-transact-golang/dbos:
16+
interfaces:
17+
DBOSContext:
18+
WorkflowHandle:

dbos/mocks_test.go

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
package dbos_test
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"testing"
7+
"time"
8+
9+
"github.com/dbos-inc/dbos-transact-golang/dbos"
10+
"github.com/dbos-inc/dbos-transact-golang/dbos/mocks"
11+
"github.com/stretchr/testify/mock"
12+
)
13+
14+
func step(ctx context.Context) (int, error) {
15+
return 1, nil
16+
}
17+
18+
func childWorkflow(ctx dbos.DBOSContext, i int) (int, error) {
19+
return i + 1, nil
20+
}
21+
22+
func workflow(ctx dbos.DBOSContext, i int) (int, error) {
23+
// Test RunAsStep
24+
a, err := dbos.RunAsStep(ctx, step)
25+
if err != nil {
26+
return 0, err
27+
}
28+
29+
// Child wf
30+
ch, err := dbos.RunWorkflow(ctx, childWorkflow, i)
31+
if err != nil {
32+
return 0, err
33+
}
34+
b, err := ch.GetResult()
35+
if err != nil {
36+
return 0, err
37+
}
38+
39+
// Test messaging operations
40+
c, err := dbos.Recv[int](ctx, "chan1", 1*time.Second)
41+
if err != nil {
42+
return 0, err
43+
}
44+
d, err := dbos.GetEvent[int](ctx, "tgw", "event1", 1*time.Second)
45+
if err != nil {
46+
return 0, err
47+
}
48+
err = dbos.Send(ctx, "dst", 1, "topic")
49+
if err != nil {
50+
return 0, err
51+
}
52+
53+
// Test SetEvent
54+
err = dbos.SetEvent(ctx, "test_key", "test_value")
55+
if err != nil {
56+
return 0, err
57+
}
58+
59+
// Test Sleep
60+
_, err = dbos.Sleep(ctx, 100*time.Millisecond)
61+
if err != nil {
62+
return 0, err
63+
}
64+
65+
// Test ID retrieval methods
66+
workflowID, err := ctx.GetWorkflowID()
67+
if err != nil {
68+
return 0, err
69+
}
70+
stepID, err := ctx.GetStepID()
71+
if err != nil {
72+
return 0, err
73+
}
74+
75+
// Test workflow management
76+
_, err = dbos.RetrieveWorkflow[int](ctx, workflowID)
77+
if err != nil {
78+
return 0, err
79+
}
80+
81+
_, err = dbos.Enqueue[int, int](ctx, "test_queue", "test_workflow", 42)
82+
if err != nil {
83+
return 0, err
84+
}
85+
86+
err = dbos.CancelWorkflow(ctx, workflowID)
87+
if err != nil {
88+
return 0, err
89+
}
90+
91+
_, err = dbos.ResumeWorkflow[int](ctx, workflowID)
92+
if err != nil {
93+
return 0, err
94+
}
95+
96+
forkInput := dbos.ForkWorkflowInput{
97+
OriginalWorkflowID: workflowID,
98+
StartStep: uint(stepID),
99+
}
100+
_, err = dbos.ForkWorkflow[int](ctx, forkInput)
101+
if err != nil {
102+
return 0, err
103+
}
104+
105+
_, err = dbos.ListWorkflows(ctx)
106+
if err != nil {
107+
return 0, err
108+
}
109+
110+
_, err = dbos.GetWorkflowSteps(ctx, workflowID)
111+
if err != nil {
112+
return 0, err
113+
}
114+
115+
// Test accessor methods
116+
appVersion := ctx.GetApplicationVersion()
117+
executorID := ctx.GetExecutorID()
118+
appID := ctx.GetApplicationID()
119+
120+
// Use some values to avoid compiler warnings
121+
_ = appVersion
122+
_ = executorID
123+
_ = appID
124+
125+
return a + b + c + d, nil
126+
}
127+
128+
func aRealProgramFunction(dbosCtx dbos.DBOSContext) error {
129+
130+
dbos.RegisterWorkflow(dbosCtx, workflow)
131+
132+
err := dbosCtx.Launch()
133+
if err != nil {
134+
return err
135+
}
136+
defer dbosCtx.Shutdown(1 * time.Second)
137+
138+
res, err := workflow(dbosCtx, 2)
139+
if err != nil {
140+
return err
141+
}
142+
if res != 4 {
143+
return fmt.Errorf("unexpected result: %v", res)
144+
}
145+
return nil
146+
}
147+
148+
func TestMocks(t *testing.T) {
149+
mockCtx := mocks.NewMockDBOSContext(t)
150+
151+
// Context lifecycle
152+
mockCtx.On("Launch").Return(nil)
153+
mockCtx.On("Shutdown", mock.Anything).Return()
154+
155+
// Basic workflow operations (existing)
156+
mockCtx.On("RunAsStep", mockCtx, mock.Anything, mock.Anything).Return(1, nil)
157+
158+
// Child workflow
159+
mockChildHandle := mocks.NewMockWorkflowHandle[any](t)
160+
mockCtx.On("RunWorkflow", mockCtx, mock.Anything, 2, mock.Anything).Return(mockChildHandle, nil).Once()
161+
mockChildHandle.On("GetResult").Return(1, nil)
162+
163+
// Messaging
164+
mockCtx.On("Recv", mockCtx, "chan1", 1*time.Second).Return(1, nil)
165+
mockCtx.On("GetEvent", mockCtx, "tgw", "event1", 1*time.Second).Return(1, nil)
166+
mockCtx.On("Send", mockCtx, "dst", 1, "topic").Return(nil)
167+
mockCtx.On("SetEvent", mockCtx, "test_key", "test_value").Return(nil)
168+
169+
mockCtx.On("Sleep", mockCtx, 100*time.Millisecond).Return(100*time.Millisecond, nil)
170+
171+
// ID retrieval methods
172+
mockCtx.On("GetWorkflowID").Return("test-workflow-id", nil)
173+
mockCtx.On("GetStepID").Return(1, nil)
174+
175+
// Workflow management
176+
mockGenericHandle := mocks.NewMockWorkflowHandle[any](t)
177+
mockGenericHandle.On("GetWorkflowID").Return("generic-workflow-id").Maybe()
178+
mockGenericHandle.On("GetResult").Return(42, nil).Maybe()
179+
mockGenericHandle.On("GetStatus").Return(dbos.WorkflowStatus{}, nil).Maybe()
180+
181+
mockCtx.On("RetrieveWorkflow", mockCtx, "test-workflow-id").Return(mockGenericHandle, nil)
182+
mockCtx.On("Enqueue", mockCtx, "test_queue", "test_workflow", 42).Return(mockGenericHandle, nil)
183+
mockCtx.On("CancelWorkflow", mockCtx, "test-workflow-id").Return(nil)
184+
mockCtx.On("ResumeWorkflow", mockCtx, "test-workflow-id").Return(mockGenericHandle, nil)
185+
mockCtx.On("ForkWorkflow", mockCtx, mock.Anything).Return(mockGenericHandle, nil)
186+
mockCtx.On("ListWorkflows", mockCtx).Return([]dbos.WorkflowStatus{}, nil)
187+
mockCtx.On("GetWorkflowSteps", mockCtx, "test-workflow-id").Return([]dbos.StepInfo{}, nil)
188+
189+
// Accessor methods
190+
mockCtx.On("GetApplicationVersion").Return("test-version")
191+
mockCtx.On("GetExecutorID").Return("test-executor")
192+
mockCtx.On("GetApplicationID").Return("test-app-id")
193+
194+
err := aRealProgramFunction(mockCtx)
195+
if err != nil {
196+
t.Fatal(err)
197+
}
198+
199+
mockCtx.AssertExpectations(t)
200+
mockChildHandle.AssertExpectations(t)
201+
// mockGenericHandle.AssertExpectations(t)
202+
}

dbos/workflow.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,35 @@ func (h *workflowPollingHandle[R]) GetResult() (R, error) {
216216
return *new(R), err
217217
}
218218

219+
// Wrapper handle -- useful for handling mocks in RunWorkflow
220+
type workflowHandleProxy[R any] struct {
221+
wrappedHandle WorkflowHandle[any]
222+
}
223+
224+
func (h *workflowHandleProxy[R]) GetResult() (R, error) {
225+
result, err := h.wrappedHandle.GetResult()
226+
if err != nil {
227+
var zero R
228+
return zero, err
229+
}
230+
231+
// Convert from any to R
232+
if typed, ok := result.(R); ok {
233+
return typed, nil
234+
}
235+
236+
var zero R
237+
return zero, fmt.Errorf("cannot convert result of type %T to %T", result, zero)
238+
}
239+
240+
func (h *workflowHandleProxy[R]) GetStatus() (WorkflowStatus, error) {
241+
return h.wrappedHandle.GetStatus()
242+
}
243+
244+
func (h *workflowHandleProxy[R]) GetWorkflowID() string {
245+
return h.wrappedHandle.GetWorkflowID()
246+
}
247+
219248
/**********************************/
220249
/******* WORKFLOW REGISTRY *******/
221250
/**********************************/
@@ -583,8 +612,8 @@ func RunWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], input P, opts
583612
return typedHandle, nil
584613
}
585614

586-
// Should never happen
587-
return nil, fmt.Errorf("unexpected workflow handle type: %T", handle)
615+
// Usually on a mocked path
616+
return &workflowHandleProxy[R]{wrappedHandle: handle}, nil
588617
}
589618

590619
func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opts ...WorkflowOption) (WorkflowHandle[any], error) {

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ require (
5050
github.com/spf13/afero v1.12.0 // indirect
5151
github.com/spf13/cast v1.7.1 // indirect
5252
github.com/spf13/pflag v1.0.6 // indirect
53+
github.com/stretchr/objx v0.5.2 // indirect
5354
github.com/subosito/gotenv v1.6.0 // indirect
5455
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
5556
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
105105
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
106106
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
107107
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
108+
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
109+
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
108110
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
109111
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
110112
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=

0 commit comments

Comments
 (0)