diff --git a/dbos/client.go b/dbos/client.go index c04130b..68bd55d 100644 --- a/dbos/client.go +++ b/dbos/client.go @@ -147,6 +147,7 @@ func (c *client) Enqueue(queueName, workflowName string, input any, opts ...Enqu if params.priority > uint(math.MaxInt) { return nil, fmt.Errorf("priority %d exceeds maximum allowed value %d", params.priority, math.MaxInt) } + status := WorkflowStatus{ Name: params.workflowName, ApplicationVersion: params.applicationVersion, @@ -155,7 +156,7 @@ func (c *client) Enqueue(queueName, workflowName string, input any, opts ...Enqu CreatedAt: time.Now(), Deadline: deadline, Timeout: params.workflowTimeout, - Input: params.workflowInput, + Input: input, QueueName: queueName, DeduplicationID: params.deduplicationID, Priority: int(params.priority), @@ -240,20 +241,15 @@ func Enqueue[P any, R any](c Client, queueName, workflowName string, input P, op return nil, errors.New("client cannot be nil") } - // Register the input and outputs for gob encoding - var logger *slog.Logger - if cl, ok := c.(*client); ok { - if ctx, ok := cl.dbosCtx.(*dbosContext); ok { - logger = ctx.logger - } + // Serialize input + serializer := newGobSerializer[P]() + encodedInput, err := serializer.Encode(input) + if err != nil { + return nil, fmt.Errorf("failed to serialize workflow input: %w", err) } - var typedInput P - safeGobRegister(typedInput, logger) - var typedOutput R - safeGobRegister(typedOutput, logger) // Call the interface method with the same signature - handle, err := c.Enqueue(queueName, workflowName, input, opts...) + handle, err := c.Enqueue(queueName, workflowName, &encodedInput, opts...) if err != nil { return nil, err } diff --git a/dbos/dbos.go b/dbos/dbos.go index f9aa829..e8fc04d 100644 --- a/dbos/dbos.go +++ b/dbos/dbos.go @@ -185,6 +185,7 @@ func (c *dbosContext) Value(key any) any { return c.ctx.Value(key) } + // WithValue returns a copy of the DBOS context with the given key-value pair. // This is similar to context.WithValue but maintains DBOS context capabilities. // No-op if the provided context is not a concrete dbos.dbosContext. @@ -354,14 +355,6 @@ func NewDBOSContext(ctx context.Context, inputConfig Config) (DBOSContext, error initExecutor.logger = config.Logger initExecutor.logger.Info("Initializing DBOS context", "app_name", config.AppName, "dbos_version", getDBOSVersion()) - // Register types we serialize with gob - var t time.Time - safeGobRegister(t, initExecutor.logger) - var ws []WorkflowStatus - safeGobRegister(ws, initExecutor.logger) - var si []StepInfo - safeGobRegister(si, initExecutor.logger) - // Initialize global variables from processed config (already handles env vars and defaults) initExecutor.applicationVersion = config.ApplicationVersion initExecutor.executorID = config.ExecutorID diff --git a/dbos/queue.go b/dbos/queue.go index f323887..d0b4dc6 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -1,10 +1,7 @@ package dbos import ( - "bytes" "context" - "encoding/base64" - "encoding/gob" "log/slog" "math" "math/rand" @@ -227,23 +224,8 @@ func (qr *queueRunner) run(ctx *dbosContext) { continue } - // Deserialize input - var input any - if len(workflow.input) > 0 { - inputBytes, err := base64.StdEncoding.DecodeString(workflow.input) - if err != nil { - qr.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err) - continue - } - buf := bytes.NewBuffer(inputBytes) - dec := gob.NewDecoder(buf) - if err := dec.Decode(&input); err != nil { - qr.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err) - continue - } - } - - _, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id)) + // Pass encoded input directly - decoding will happen in workflow wrapper when we know the target type + _, err = registeredWorkflow.wrappedFunction(ctx, workflow.input, WithWorkflowID(workflow.id)) if err != nil { qr.logger.Error("Error running queued workflow", "error", err) } diff --git a/dbos/recovery.go b/dbos/recovery.go index ba51a6d..4d1483a 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -1,9 +1,5 @@ package dbos -import ( - "strings" -) - func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]WorkflowHandle[any], error) { workflowHandles := make([]WorkflowHandle[any], 0) // List pending workflows for the executors @@ -18,13 +14,6 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow } for _, workflow := range pendingWorkflows { - if inputStr, ok := workflow.Input.(string); ok { - if strings.Contains(inputStr, "Failed to decode") { - ctx.logger.Warn("Skipping workflow recovery due to input decoding failure", "workflow_id", workflow.ID, "name", workflow.Name) - continue - } - } - if workflow.QueueName != "" { cleared, err := ctx.systemDB.clearQueueAssignment(ctx, workflow.ID) if err != nil { @@ -59,6 +48,7 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow WithWorkflowID(workflow.ID), } // Create a workflow context from the executor context + // Pass encoded input directly - decoding will happen in workflow wrapper when we know the target type handle, err := registeredWorkflow.wrappedFunction(ctx, workflow.Input, opts...) if err != nil { return nil, err diff --git a/dbos/serialization.go b/dbos/serialization.go index c2e8814..faf4913 100644 --- a/dbos/serialization.go +++ b/dbos/serialization.go @@ -5,48 +5,26 @@ import ( "encoding/base64" "encoding/gob" "fmt" - "log/slog" + "reflect" "strings" ) -func serialize(data any) (string, error) { - var inputBytes []byte - if data != nil { - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - if err := enc.Encode(&data); err != nil { - return "", fmt.Errorf("failed to encode data: %w", err) - } - inputBytes = buf.Bytes() - } - return base64.StdEncoding.EncodeToString(inputBytes), nil +type serializer[T any] interface { + Encode(data T) (string, error) + Decode(data *string) (T, error) } -func deserialize(data *string) (any, error) { - if data == nil || *data == "" { - return nil, nil - } - - dataBytes, err := base64.StdEncoding.DecodeString(*data) - if err != nil { - return nil, fmt.Errorf("failed to decode data: %w", err) - } - - var result any - buf := bytes.NewBuffer(dataBytes) - dec := gob.NewDecoder(buf) - if err := dec.Decode(&result); err != nil { - return nil, fmt.Errorf("failed to decode data: %w", err) - } - - return result, nil +// gobValue is a wrapper type for gob encoding/decoding of any value +// It prevents encoding nil values directly, and helps us differentiate nil values and empty strings +type gobValue struct { + Value any } // safeGobRegister attempts to register a type with gob, recovering only from // panics caused by duplicate type/name registrations (e.g., registering both T and *T). -// These specific conflicts don't affect encoding/decoding correctness, so they're safe to ignore. -// Other panics (like register `any`) are real errors and will propagate. -func safeGobRegister(value any, logger *slog.Logger) { +// These specific conflicts don't affect encoding/decoding correctness, so they aren't errors. +// Other panics (like registering `any`) are real errors and will propagate. +func safeGobRegister(value any) { defer func() { if r := recover(); r != nil { if errStr, ok := r.(string); ok { @@ -54,9 +32,6 @@ func safeGobRegister(value any, logger *slog.Logger) { // See https://cs.opensource.google/go/go/+/refs/tags/go1.25.1:src/encoding/gob/type.go;l=832 if strings.Contains(errStr, "gob: registering duplicate types for") || strings.Contains(errStr, "gob: registering duplicate names for") { - if logger != nil { - logger.Debug("gob registration conflict", "type", fmt.Sprintf("%T", value), "error", r) - } return } } @@ -66,3 +41,142 @@ func safeGobRegister(value any, logger *slog.Logger) { }() gob.Register(value) } + +// init registers the gobValue wrapper type with gob for gobSerializer +func init() { + // Register wrapper type - this is required for gob encoding/decoding to work + safeGobRegister(gobValue{}) +} + +type gobSerializer[T any] struct{} + +func newGobSerializer[T any]() serializer[T] { + return &gobSerializer[T]{} +} + +func (g *gobSerializer[T]) Encode(data T) (string, error) { + if isNilValue(data) { + // For nil values, encode an empty byte slice directly to base64 + return base64.StdEncoding.EncodeToString([]byte{}), nil + } + + // Register the type before encoding + safeGobRegister(data) + + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + wrapper := gobValue{Value: data} + if err := encoder.Encode(wrapper); err != nil { + return "", fmt.Errorf("failed to encode data: %w", err) + } + return base64.StdEncoding.EncodeToString(buf.Bytes()), nil +} + +func (g *gobSerializer[T]) Decode(data *string) (T, error) { + zero := *new(T) + + if data == nil || *data == "" { + return zero, nil + } + + dataBytes, err := base64.StdEncoding.DecodeString(*data) + if err != nil { + return zero, fmt.Errorf("failed to decode base64 data: %w", err) + } + + // If decoded data is empty, it represents a nil value + if len(dataBytes) == 0 { + return zero, nil + } + + // Resolve the type of T + tType := reflect.TypeOf(zero) + if tType == nil { + // zero is nil, T is likely a pointer type or interface + // Get the type from a pointer to T's zero value + tType = reflect.TypeOf(&zero).Elem() + } + + // Register type T before decoding + // This is required on the recovery path, where the process might not have been doing the encode/registering. + // This will panic if T is an non-registered interface type (which is not supported) + if tType != nil && tType.Kind() != reflect.Interface { + safeGobRegister(zero) + } + + var wrapper gobValue + decoder := gob.NewDecoder(bytes.NewReader(dataBytes)) + if err := decoder.Decode(&wrapper); err != nil { + return zero, fmt.Errorf("failed to decode gob data: %w", err) + } + + decoded := wrapper.Value + + // Gob stores pointed values directly, so we need to reconstruct the pointer type + if tType != nil && tType.Kind() == reflect.Pointer { + elemType := tType.Elem() + decodedType := reflect.TypeOf(decoded) + + // Check if decoded value matches the element type (not the pointer type) + if decodedType != nil && decodedType == elemType { + // Create a new pointer to the decoded value + elemValue := reflect.New(elemType) + elemValue.Elem().Set(reflect.ValueOf(decoded)) + return elemValue.Interface().(T), nil + } + // If decoded is already a pointer of the correct type, try direct assertion + if decodedType != nil && decodedType == tType { + typedResult, ok := decoded.(T) + if ok { + return typedResult, nil + } + } + // If decoded is nil and T is a pointer type, return nil pointer + if decoded == nil { + return zero, nil + } + } + + // Not a pointer -- direct type assertion + typedResult, ok := decoded.(T) + if !ok { + return zero, fmt.Errorf("cannot convert decoded value of type %T to %T", decoded, zero) + } + return typedResult, nil +} + +// isNilValue checks if a value is nil (for pointer types, slice, map, etc.) +func isNilValue(v any) bool { + val := reflect.ValueOf(v) + if !val.IsValid() { + return true + } + switch val.Kind() { + case reflect.Pointer, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func: + return val.IsNil() + } + return false +} + +// IsNestedPointer checks if a type is a nested pointer (e.g., **int, ***int). +// It returns false for non-pointer types and single-level pointers (*int). +// It returns true for nested pointers with depth > 1. +func IsNestedPointer(t reflect.Type) bool { + if t == nil { + return false + } + + depth := 0 + currentType := t + + // Count pointer indirection levels, break early if depth > 1 + for currentType != nil && currentType.Kind() == reflect.Pointer { + depth++ + if depth > 1 { + return true + } + currentType = currentType.Elem() + } + + return false +} diff --git a/dbos/serialization_test.go b/dbos/serialization_test.go index 22418ef..2e8af98 100644 --- a/dbos/serialization_test.go +++ b/dbos/serialization_test.go @@ -2,8 +2,8 @@ package dbos import ( "context" - "errors" "fmt" + "reflect" "testing" "time" @@ -11,304 +11,1057 @@ import ( "github.com/stretchr/testify/require" ) -// Builtin types -func encodingStepBuiltinTypes(_ context.Context, input int) (int, error) { - return input, errors.New("step error") -} +// testAllSerializationPaths tests workflow recovery and verifies all read paths. +// This is the unified test function that exercises: +// 1. Workflow recovery: starts a workflow, blocks it, recovers it, then verifies completion +// 2. All read paths: HandleGetResult, GetWorkflowSteps, ListWorkflows, RetrieveWorkflow +// This ensures recovery paths exercise all encoding/decoding scenarios that normal workflows do. +func testAllSerializationPaths[T any]( + t *testing.T, + executor DBOSContext, + recoveryWorkflow Workflow[T, T], + input T, + workflowID string, +) { + t.Helper() + + isNilExpected := isNilValue(input) + + // Setup events for recovery + startEvent := NewEvent() + blockingEvent := NewEvent() + recoveryEventRegistry[workflowID] = struct { + startEvent *Event + blockingEvent *Event + }{startEvent, blockingEvent} + defer delete(recoveryEventRegistry, workflowID) + + // Start the blocking workflow + handle, err := RunWorkflow(executor, recoveryWorkflow, input, WithWorkflowID(workflowID)) + require.NoError(t, err, "failed to start blocking workflow") + + // Wait for the workflow to reach the blocking step + startEvent.Wait() + + // Recover the pending workflow + dbosCtx, ok := executor.(*dbosContext) + require.True(t, ok, "expected dbosContext") + recoveredHandles, err := recoverPendingWorkflows(dbosCtx, []string{"local"}) + require.NoError(t, err, "failed to recover pending workflows") + + // Find our workflow in the recovered handles + var recoveredHandle WorkflowHandle[any] + for _, h := range recoveredHandles { + if h.GetWorkflowID() == handle.GetWorkflowID() { + recoveredHandle = h + break + } + } + require.NotNil(t, recoveredHandle, "expected to find recovered handle") + + // Unblock the workflow + blockingEvent.Set() + + // Expected output - workflow returns input, so output equals input + expectedOutput := input + + // Test read paths after completion + t.Run("HandleGetResult", func(t *testing.T) { + gotAny, err := handle.GetResult() + require.NoError(t, err) + if isNilExpected { + assert.Nil(t, gotAny, "Nil result should be preserved") + } else { + assert.Equal(t, expectedOutput, gotAny) + } + }) + + t.Run("RetrieveWorkflow", func(t *testing.T) { + h2, err := RetrieveWorkflow[T](executor, handle.GetWorkflowID()) + require.NoError(t, err) + gotAny, err := h2.GetResult() + require.NoError(t, err) + if isNilExpected { + assert.Nil(t, gotAny, "Retrieved workflow result should be nil") + } else { + assert.Equal(t, expectedOutput, gotAny, "Retrieved workflow result should match expected output") + } + }) + + // Check the last step output (the workflow result) + t.Run("GetWorkflowSteps", func(t *testing.T) { + steps, err := GetWorkflowSteps(executor, handle.GetWorkflowID()) + require.NoError(t, err) + require.GreaterOrEqual(t, len(steps), 1, "Should have at least one step") + if len(steps) > 0 { + lastStep := steps[len(steps)-1] + if isNilExpected { + assert.Nil(t, lastStep.Output, "Step output should be nil") + } else { + require.NotNil(t, lastStep.Output) + // GetWorkflowSteps decodes pointer types as their underlying value (not as pointers) + // So if T is a pointer type, we need to compare against the dereferenced value + zero := *new(T) + tType := reflect.TypeOf(zero) + isPointerType := tType != nil && tType.Kind() == reflect.Pointer + + if isPointerType { + // GetWorkflowSteps returns the underlying value, not the pointer + // So we compare against the dereferenced expectedOutput + expectedValue := reflect.ValueOf(expectedOutput).Elem().Interface() + assert.Equal(t, expectedValue, lastStep.Output, "Step output should match dereferenced expected output") + } else { + assert.Equal(t, expectedOutput, lastStep.Output, "Step output should match expected output") + } + } + assert.Nil(t, lastStep.Error) + } + }) + + // Verify final state via ListWorkflows + t.Run("ListWorkflows", func(t *testing.T) { + wfs, err := ListWorkflows(executor, + WithWorkflowIDs([]string{handle.GetWorkflowID()}), + WithLoadInput(true), WithLoadOutput(true)) + require.NoError(t, err) + require.Len(t, wfs, 1) + wf := wfs[0] + if isNilExpected { + assert.Nil(t, wf.Input, "Workflow input should be nil") + assert.Nil(t, wf.Output, "Workflow output should be nil") + } else { + require.NotNil(t, wf.Input) + require.NotNil(t, wf.Output) + + // ListWorkflows decodes pointer types as their underlying value (not as pointers) + // So if T is a pointer type, we need to compare against the dereferenced value + zero := *new(T) + tType := reflect.TypeOf(zero) + isPointerType := tType != nil && tType.Kind() == reflect.Pointer -func encodingWorkflowBuiltinTypes(ctx DBOSContext, input string) (string, error) { - stepResult, err := RunAsStep(ctx, func(context context.Context) (int, error) { - return encodingStepBuiltinTypes(context, 123) + if isPointerType { + // ListWorkflows returns the underlying value, not the pointer + // So we compare against the dereferenced values + expectedInputValue := reflect.ValueOf(input).Elem().Interface() + expectedOutputValue := reflect.ValueOf(expectedOutput).Elem().Interface() + assert.Equal(t, expectedInputValue, wf.Input, "Workflow input should match dereferenced input") + assert.Equal(t, expectedOutputValue, wf.Output, "Workflow output should match dereferenced expected output") + } else { + assert.Equal(t, input, wf.Input) + assert.Equal(t, expectedOutput, wf.Output) + } + } }) - return fmt.Sprintf("%d", stepResult), fmt.Errorf("workflow error: %v", err) } -// Struct types -type StepOutputStruct struct { - A StepInputStruct - B string +// Helper function to test Send/Recv communication +func testSendRecv[T any]( + t *testing.T, + executor DBOSContext, + senderWorkflow Workflow[T, T], + receiverWorkflow Workflow[T, T], + input T, + senderID string, +) { + t.Helper() + + // Start receiver workflow first (it will wait for the message) + receiverHandle, err := RunWorkflow(executor, receiverWorkflow, input, WithWorkflowID(senderID+"-receiver")) + require.NoError(t, err, "Receiver workflow execution failed") + + // Start sender workflow (it will send the message) + senderHandle, err := RunWorkflow(executor, senderWorkflow, input, WithWorkflowID(senderID)) + require.NoError(t, err, "Sender workflow execution failed") + + // Get sender result + senderResult, err := senderHandle.GetResult() + require.NoError(t, err, "Sender workflow should complete") + + // Get receiver result + receiverResult, err := receiverHandle.GetResult() + require.NoError(t, err, "Receiver workflow should complete") + + // Verify the received data matches what was sent + assert.Equal(t, input, senderResult, "Sender result should match input") + assert.Equal(t, input, receiverResult, "Received data should match sent data") } -type StepInputStruct struct { - A SimpleStruct - B string +// Helper function to test SetEvent/GetEvent communication +func testSetGetEvent[T any]( + t *testing.T, + executor DBOSContext, + setEventWorkflow Workflow[T, T], + getEventWorkflow Workflow[string, T], + input T, + setEventID string, + getEventID string, +) { + t.Helper() + + // Start setEvent workflow + setEventHandle, err := RunWorkflow(executor, setEventWorkflow, input, WithWorkflowID(setEventID)) + require.NoError(t, err, "SetEvent workflow execution failed") + + // Wait for setEvent to complete + setResult, err := setEventHandle.GetResult() + require.NoError(t, err, "SetEvent workflow should complete") + + // Start getEvent workflow (will retrieve the event) + getEventHandle, err := RunWorkflow(executor, getEventWorkflow, setEventID, WithWorkflowID(getEventID)) + require.NoError(t, err, "GetEvent workflow execution failed") + + // Get the event result + getResult, err := getEventHandle.GetResult() + require.NoError(t, err, "GetEvent workflow should complete") + + // Verify the event data matches what was set + assert.Equal(t, input, setResult, "SetEvent result should match input") + assert.Equal(t, input, getResult, "GetEvent data should match what was set") } -type WorkflowInputStruct struct { - A SimpleStruct - B int +type MyInt int +type MyString string +type IntSliceSlice [][]int + +// Test data structures for DBOS integration testing +type TestData struct { + Message string + Value int + Active bool } -type SimpleStruct struct { - A string - B int +// NestedTestData is a nested struct type for testing slices and maps of structs +type NestedTestData struct { + Key string + Count int } -func encodingWorkflowStruct(ctx DBOSContext, input WorkflowInputStruct) (StepOutputStruct, error) { - return RunAsStep(ctx, func(context context.Context) (StepOutputStruct, error) { - return encodingStepStruct(context, StepInputStruct{ - A: input.A, - B: fmt.Sprintf("%d", input.B), - }) - }) +type TestWorkflowData struct { + ID string + Message string + Value int + Active bool + Data TestData + Metadata map[string]string + NestedSlice []NestedTestData + NestedMap map[NestedTestData]MyInt + StringPtr *string + StringPtrPtr **string } -func encodingStepStruct(_ context.Context, input StepInputStruct) (StepOutputStruct, error) { - return StepOutputStruct{ - A: input, - B: "processed by encodingStepStruct", - }, nil +// Test workflows and steps +func serializerTestStep(_ context.Context, input TestWorkflowData) (TestWorkflowData, error) { + return input, nil } -func TestWorkflowEncoding(t *testing.T) { - executor := setupDBOS(t, true, true) +func serializerWorkflow(ctx DBOSContext, input TestWorkflowData) (TestWorkflowData, error) { + return RunAsStep(ctx, func(context context.Context) (TestWorkflowData, error) { + return serializerTestStep(context, input) + }) +} + +func serializerPointerValueWorkflow(ctx DBOSContext, input *TestWorkflowData) (*TestWorkflowData, error) { + return RunAsStep(ctx, func(context context.Context) (*TestWorkflowData, error) { + return input, nil + }) +} - // Register workflows with executor - RegisterWorkflow(executor, encodingWorkflowBuiltinTypes) - RegisterWorkflow(executor, encodingWorkflowStruct) +// makeTestWorkflow creates a generic workflow that simply returns the input. +func makeTestWorkflow[T any]() Workflow[T, T] { + return func(ctx DBOSContext, input T) (T, error) { + return RunAsStep(ctx, func(context context.Context) (T, error) { + return input, nil + }) + } +} - err := Launch(executor) - require.NoError(t, err) +// Typed workflow functions for testing concrete signatures +// These are now generated using makeTestWorkflow to reduce boilerplate +var ( + serializerIntWorkflow = makeTestWorkflow[int]() + serializerIntPtrWorkflow = makeTestWorkflow[*int]() + serializerIntSliceWorkflow = makeTestWorkflow[[]int]() + serializerStringIntMapWorkflow = makeTestWorkflow[map[string]int]() + serializerMyIntWorkflow = makeTestWorkflow[MyInt]() + serializerMyStringWorkflow = makeTestWorkflow[MyString]() + serializerMyStringSliceWorkflow = makeTestWorkflow[[]MyString]() + serializerStringMyIntMapWorkflow = makeTestWorkflow[map[string]MyInt]() + serializerStringWorkflow = makeTestWorkflow[string]() + serializerBoolWorkflow = makeTestWorkflow[bool]() + serializerIntArrayWorkflow = makeTestWorkflow[[3]int]() + serializerByteSliceWorkflow = makeTestWorkflow[[]byte]() + // Recovery workflows for all types - these test encoding/decoding through recovery paths + recoveryIntWorkflow = makeRecoveryWorkflow[int]() + recoveryStringWorkflow = makeRecoveryWorkflow[string]() + recoveryIntPtrWorkflow = makeRecoveryWorkflow[*int]() + recoveryIntSliceWorkflow = makeRecoveryWorkflow[[]int]() + recoveryIntArrayWorkflow = makeRecoveryWorkflow[[3]int]() + recoveryByteSliceWorkflow = makeRecoveryWorkflow[[]byte]() + recoveryStringIntMapWorkflow = makeRecoveryWorkflow[map[string]int]() + recoveryMyIntWorkflow = makeRecoveryWorkflow[MyInt]() + recoveryMyStringWorkflow = makeRecoveryWorkflow[MyString]() + recoveryMyStringSliceWorkflow = makeRecoveryWorkflow[[]MyString]() + recoveryStringMyIntMapWorkflow = makeRecoveryWorkflow[map[string]MyInt]() + // Additional types: empty struct, nested collections, slices of pointers + recoveryEmptyStructWorkflow = makeRecoveryWorkflow[struct{}]() + recoveryIntSliceSliceWorkflow = makeRecoveryWorkflow[IntSliceSlice]() + recoveryNestedMapWorkflow = makeRecoveryWorkflow[map[string]map[string]int]() + recoveryIntPtrSliceWorkflow = makeRecoveryWorkflow[[]*int]() +) - t.Run("BuiltinTypes", func(t *testing.T) { - // Test a workflow that uses a built-in type (string) - directHandle, err := RunWorkflow(executor, encodingWorkflowBuiltinTypes, "test") - require.NoError(t, err) +// makeSenderWorkflow creates a generic sender workflow that sends a message to a receiver workflow. +func makeSenderWorkflow[T any]() Workflow[T, T] { + return func(ctx DBOSContext, input T) (T, error) { + receiverWorkflowID, err := GetWorkflowID(ctx) + if err != nil { + return *new(T), fmt.Errorf("failed to get workflow ID: %w", err) + } + destID := receiverWorkflowID + "-receiver" + err = Send(ctx, destID, input, "test-topic") + if err != nil { + return *new(T), fmt.Errorf("send failed: %w", err) + } + return input, nil + } +} - // Test result and error from direct handle - directHandleResult, err := directHandle.GetResult() - assert.Equal(t, "123", directHandleResult) - require.Error(t, err) - assert.Equal(t, "workflow error: step error", err.Error()) +// makeReceiverWorkflow creates a generic receiver workflow that receives a message. +func makeReceiverWorkflow[T any]() Workflow[T, T] { + return func(ctx DBOSContext, _ T) (T, error) { + received, err := Recv[T](ctx, "test-topic", 10*time.Second) + if err != nil { + return *new(T), fmt.Errorf("recv failed: %w", err) + } + return received, nil + } +} - // Test result from polling handle - retrieveHandler, err := RetrieveWorkflow[string](executor.(*dbosContext), directHandle.GetWorkflowID()) - require.NoError(t, err) - retrievedResult, err := retrieveHandler.GetResult() - assert.Equal(t, "123", retrievedResult) - require.Error(t, err) - assert.Equal(t, "workflow error: step error", err.Error()) - - // Test results from ListWorkflows - workflows, err := ListWorkflows( - executor, - WithWorkflowIDs([]string{directHandle.GetWorkflowID()}), - WithLoadInput(true), - WithLoadOutput(true), - ) - require.NoError(t, err) - require.Len(t, workflows, 1) - workflow := workflows[0] - require.NotNil(t, workflow.Input) - workflowInput, ok := workflow.Input.(string) - require.True(t, ok, "expected workflow input to be of type string, got %T", workflow.Input) - assert.Equal(t, "test", workflowInput) - require.NotNil(t, workflow.Output) - workflowOutput, ok := workflow.Output.(string) - require.True(t, ok, "expected workflow output to be of type string, got %T", workflow.Output) - assert.Equal(t, "123", workflowOutput) - require.NotNil(t, workflow.Error) - assert.Equal(t, "workflow error: step error", workflow.Error.Error()) - - // Test results from GetWorkflowSteps - steps, err := GetWorkflowSteps(executor, directHandle.GetWorkflowID()) - require.NoError(t, err) - require.Len(t, steps, 1) - step := steps[0] - require.NotNil(t, step.Output) - stepOutput, ok := step.Output.(int) - require.True(t, ok, "expected step output to be of type int, got %T", step.Output) - assert.Equal(t, 123, stepOutput) - require.NotNil(t, step.Error) - assert.Equal(t, "step error", step.Error.Error()) - }) +// makeSetEventWorkflow creates a generic workflow that sets an event. +func makeSetEventWorkflow[T any]() Workflow[T, T] { + return func(ctx DBOSContext, input T) (T, error) { + err := SetEvent(ctx, "test-key", input) + if err != nil { + return *new(T), fmt.Errorf("set event failed: %w", err) + } + return input, nil + } +} - t.Run("StructType", func(t *testing.T) { - // Test a workflow that calls a step with struct types to verify serialization/deserialization - input := WorkflowInputStruct{ - A: SimpleStruct{A: "test", B: 123}, - B: 456, +// makeGetEventWorkflow creates a generic workflow that gets an event. +func makeGetEventWorkflow[T any]() Workflow[string, T] { + return func(ctx DBOSContext, targetWorkflowID string) (T, error) { + event, err := GetEvent[T](ctx, targetWorkflowID, "test-key", 10*time.Second) + if err != nil { + return *new(T), fmt.Errorf("get event failed: %w", err) } + return event, nil + } +} - directHandle, err := RunWorkflow(executor, encodingWorkflowStruct, input) - require.NoError(t, err) +// Typed Send/Recv workflows for various types +var ( + serializerIntSenderWorkflow = makeSenderWorkflow[int]() + serializerIntReceiverWorkflow = makeReceiverWorkflow[int]() + serializerIntPtrSenderWorkflow = makeSenderWorkflow[*int]() + serializerIntPtrReceiverWorkflow = makeReceiverWorkflow[*int]() + serializerMyIntSenderWorkflow = makeSenderWorkflow[MyInt]() + serializerMyIntReceiverWorkflow = makeReceiverWorkflow[MyInt]() +) - // Test result from direct handle - directResult, err := directHandle.GetResult() - require.NoError(t, err) - assert.Equal(t, input.A.A, directResult.A.A.A) - assert.Equal(t, input.A.B, directResult.A.A.B) - assert.Equal(t, fmt.Sprintf("%d", input.B), directResult.A.B) - assert.Equal(t, "processed by encodingStepStruct", directResult.B) +// Typed SetEvent/GetEvent workflows for various types +var ( + serializerIntSetEventWorkflow = makeSetEventWorkflow[int]() + serializerIntGetEventWorkflow = makeGetEventWorkflow[int]() + serializerIntPtrSetEventWorkflow = makeSetEventWorkflow[*int]() + serializerIntPtrGetEventWorkflow = makeGetEventWorkflow[*int]() + serializerMyIntSetEventWorkflow = makeSetEventWorkflow[MyInt]() + serializerMyIntGetEventWorkflow = makeGetEventWorkflow[MyInt]() +) - // Test result from polling handle - retrieveHandler, err := RetrieveWorkflow[StepOutputStruct](executor.(*dbosContext), directHandle.GetWorkflowID()) - require.NoError(t, err) - retrievedResult, err := retrieveHandler.GetResult() - require.NoError(t, err) - assert.Equal(t, input.A.A, retrievedResult.A.A.A) - assert.Equal(t, input.A.B, retrievedResult.A.A.B) - assert.Equal(t, fmt.Sprintf("%d", input.B), retrievedResult.A.B) - assert.Equal(t, "processed by encodingStepStruct", retrievedResult.B) - - // Test results from ListWorkflows - workflows, err := ListWorkflows(executor, - WithWorkflowIDs([]string{directHandle.GetWorkflowID()}), - WithLoadInput(true), - WithLoadOutput(true), - ) - require.Len(t, workflows, 1) - require.NoError(t, err) - workflow := workflows[0] - require.NotNil(t, workflow.Input) - workflowInput, ok := workflow.Input.(WorkflowInputStruct) - require.True(t, ok, "expected workflow input to be of type WorkflowInputStruct, got %T", workflow.Input) - assert.Equal(t, input.A.A, workflowInput.A.A) - assert.Equal(t, input.A.B, workflowInput.A.B) - assert.Equal(t, input.B, workflowInput.B) - - workflowOutput, ok := workflow.Output.(StepOutputStruct) - require.True(t, ok, "expected workflow output to be of type StepOutputStruct, got %T", workflow.Output) - assert.Equal(t, input.A.A, workflowOutput.A.A.A) - assert.Equal(t, input.A.B, workflowOutput.A.A.B) - assert.Equal(t, fmt.Sprintf("%d", input.B), workflowOutput.A.B) - assert.Equal(t, "processed by encodingStepStruct", workflowOutput.B) - - // Test results from GetWorkflowSteps - steps, err := GetWorkflowSteps(executor, directHandle.GetWorkflowID()) - require.NoError(t, err) - require.Len(t, steps, 1) - step := steps[0] - require.NotNil(t, step.Output) - stepOutput, ok := step.Output.(StepOutputStruct) - require.True(t, ok, "expected step output to be of type StepOutputStruct, got %T", step.Output) - assert.Equal(t, input.A.A, stepOutput.A.A.A) - assert.Equal(t, input.A.B, stepOutput.A.A.B) - assert.Equal(t, fmt.Sprintf("%d", input.B), stepOutput.A.B) - assert.Equal(t, "processed by encodingStepStruct", stepOutput.B) - assert.Nil(t, step.Error) - }) +func serializerErrorStep(_ context.Context, _ TestWorkflowData) (TestWorkflowData, error) { + return TestWorkflowData{}, fmt.Errorf("step error") } -type UserDefinedEventData struct { - ID int `json:"id"` - Name string `json:"name"` - Details struct { - Description string `json:"description"` - Tags []string `json:"tags"` - } `json:"details"` +func serializerErrorWorkflow(ctx DBOSContext, input TestWorkflowData) (TestWorkflowData, error) { + return RunAsStep(ctx, func(context context.Context) (TestWorkflowData, error) { + return serializerErrorStep(context, input) + }) } -func setEventUserDefinedTypeWorkflow(ctx DBOSContext, input string) (string, error) { - eventData := UserDefinedEventData{ - ID: 42, - Name: "test-event", - Details: struct { - Description string `json:"description"` - Tags []string `json:"tags"` - }{ - Description: "This is a test event with user-defined data", - Tags: []string{"test", "user-defined", "serialization"}, - }, +// Workflows for testing Send/Recv with non-basic types +func serializerSenderWorkflow(ctx DBOSContext, input TestWorkflowData) (TestWorkflowData, error) { + receiverWorkflowID, err := GetWorkflowID(ctx) + if err != nil { + return TestWorkflowData{}, fmt.Errorf("failed to get workflow ID: %w", err) } + // Add a suffix to create receiver workflow ID + destID := receiverWorkflowID + "-receiver" - err := SetEvent(ctx, input, eventData) + err = Send(ctx, destID, input, "test-topic") if err != nil { - return "", err + return TestWorkflowData{}, fmt.Errorf("send failed: %w", err) } - return "user-defined-event-set", nil + return input, nil } -func TestSetEventSerialize(t *testing.T) { - executor := setupDBOS(t, true, true) - - // Register workflow with executor - RegisterWorkflow(executor, setEventUserDefinedTypeWorkflow) - - t.Run("SetEventUserDefinedType", func(t *testing.T) { - // Start a workflow that sets an event with a user-defined type - setHandle, err := RunWorkflow(executor, setEventUserDefinedTypeWorkflow, "user-defined-key") - require.NoError(t, err) +func serializerReceiverWorkflow(ctx DBOSContext, _ TestWorkflowData) (TestWorkflowData, error) { + // Receive a message with the expected type + received, err := Recv[TestWorkflowData](ctx, "test-topic", 10*time.Second) + if err != nil { + return TestWorkflowData{}, fmt.Errorf("recv failed: %w", err) + } + return received, nil +} - // Wait for the workflow to complete - result, err := setHandle.GetResult() - require.NoError(t, err) - assert.Equal(t, "user-defined-event-set", result) +// Workflows for testing SetEvent/GetEvent with non-basic types +func serializerSetEventWorkflow(ctx DBOSContext, input TestWorkflowData) (TestWorkflowData, error) { + err := SetEvent(ctx, "test-key", input) + if err != nil { + return TestWorkflowData{}, fmt.Errorf("set event failed: %w", err) + } + return input, nil +} - // Retrieve the event to verify it was properly serialized and can be deserialized - retrievedEvent, err := GetEvent[UserDefinedEventData](executor, setHandle.GetWorkflowID(), "user-defined-key", 3*time.Second) - require.NoError(t, err) +func serializerGetEventWorkflow(ctx DBOSContext, targetWorkflowID string) (TestWorkflowData, error) { + // Get the event with the expected type + event, err := GetEvent[TestWorkflowData](ctx, targetWorkflowID, "test-key", 10*time.Second) + if err != nil { + return TestWorkflowData{}, fmt.Errorf("get event failed: %w", err) + } + return event, nil +} - // Verify the retrieved data matches what we set - assert.Equal(t, 42, retrievedEvent.ID) - assert.Equal(t, "test-event", retrievedEvent.Name) - assert.Equal(t, "This is a test event with user-defined data", retrievedEvent.Details.Description) - require.Len(t, retrievedEvent.Details.Tags, 3) - expectedTags := []string{"test", "user-defined", "serialization"} - assert.Equal(t, expectedTags, retrievedEvent.Details.Tags) +// Workflow for testing interface signature with manual gob registration +var interfaceWorkflow = func(ctx DBOSContext, input TestDataProcessor) (TestDataProcessor, error) { + return RunAsStep(ctx, func(context context.Context) (TestDataProcessor, error) { + return input, nil }) } -func sendUserDefinedTypeWorkflow(ctx DBOSContext, destinationID string) (string, error) { - // Create an instance of our user-defined type inside the workflow - sendData := UserDefinedEventData{ - ID: 42, - Name: "test-send-message", - Details: struct { - Description string `json:"description"` - Tags []string `json:"tags"` - }{ - Description: "This is a test send message with user-defined data", - Tags: []string{"test", "user-defined", "serialization", "send"}, - }, +// recoveryEventRegistry stores events for recovery workflows by workflow ID +var recoveryEventRegistry = make(map[string]struct { + startEvent *Event + blockingEvent *Event +}) + +// makeRecoveryWorkflow creates a generic recovery workflow that has an initial step +// and then a blocking step that uses the output of the first step. +// This is used to test workflow recovery with various types. +// The workflow looks up events from recoveryEventRegistry using the workflow ID. +func makeRecoveryWorkflow[T any]() Workflow[T, T] { + return func(ctx DBOSContext, input T) (T, error) { + // First step: return the input (tests encoding/decoding of type T) + firstStepOutput, err := RunAsStep(ctx, func(context context.Context) (T, error) { + return input, nil + }, WithStepName("FirstStep")) + if err != nil { + fmt.Printf("makeRecoveryWorkflow: FirstStep error: %v\n", err) + return *new(T), err + } + + // Second step: blocking step that uses the first step's output + // This tests that the first step's output is correctly decoded + // If decoding fails or is incorrect, this step will fail + return RunAsStep(ctx, func(context context.Context) (T, error) { + workflowID, err := GetWorkflowID(ctx) + if err != nil { + return *new(T), fmt.Errorf("failed to get workflow ID: %w", err) + } + events, ok := recoveryEventRegistry[workflowID] + if !ok { + return *new(T), fmt.Errorf("no events registered for workflow ID: %s", workflowID) + } + events.startEvent.Set() + events.blockingEvent.Wait() + // Return the first step's output - this verifies correct decoding + // If the type was decoded incorrectly, this assignment/return will fail + return firstStepOutput, nil + }, WithStepName("BlockingStep")) } +} - // Send should automatically register this type with gob - err := Send(ctx, destinationID, sendData, "user-defined-topic") +// serializerRecoveryWorkflow is a recovery workflow for TestWorkflowData type. +// It uses the recoveryEventRegistry to look up events by workflow ID. +func serializerRecoveryWorkflow(ctx DBOSContext, input TestWorkflowData) (TestWorkflowData, error) { + // First step: return the input (tests encoding/decoding of TestWorkflowData) + firstStepOutput, err := RunAsStep(ctx, func(context context.Context) (TestWorkflowData, error) { + return input, nil + }, WithStepName("FirstStep")) if err != nil { - return "", err + return TestWorkflowData{}, err } - return "user-defined-message-sent", nil + + // Second step: blocking step that uses the first step's output + // This tests that the first step's output is correctly decoded + return RunAsStep(ctx, func(context context.Context) (TestWorkflowData, error) { + workflowID, err := GetWorkflowID(ctx) + if err != nil { + return TestWorkflowData{}, fmt.Errorf("failed to get workflow ID: %w", err) + } + events, ok := recoveryEventRegistry[workflowID] + if !ok { + return TestWorkflowData{}, fmt.Errorf("no events registered for workflow ID: %s", workflowID) + } + events.startEvent.Set() + events.blockingEvent.Wait() + // Return the first step's output - this verifies correct decoding + return firstStepOutput, nil + }, WithStepName("BlockingStep")) } -func recvUserDefinedTypeWorkflow(ctx DBOSContext, input string) (UserDefinedEventData, error) { - // Receive the user-defined type message - result, err := Recv[UserDefinedEventData](ctx, "user-defined-topic", 3*time.Second) - return result, err +// TestDataProcessor is an interface for testing workflows with interface signatures +type TestDataProcessor interface { + Process(data string) string } -func TestSendSerialize(t *testing.T) { - executor := setupDBOS(t, true, true) +// TestStringProcessor is a concrete implementation of TestDataProcessor +type TestStringProcessor struct { + Prefix string +} - // Register workflows with executor - RegisterWorkflow(executor, sendUserDefinedTypeWorkflow) - RegisterWorkflow(executor, recvUserDefinedTypeWorkflow) +// Process implements the TestDataProcessor interface +func (p *TestStringProcessor) Process(data string) string { + return p.Prefix + data +} - t.Run("SendUserDefinedType", func(t *testing.T) { - // Start a receiver workflow first - recvHandle, err := RunWorkflow(executor, recvUserDefinedTypeWorkflow, "recv-input") - require.NoError(t, err) +// TestSerializer tests that workflows use the configured serializer for input/output. +// +// This test suite uses recovery-based testing as the primary approach. All tests exercise +// workflow recovery paths because: +// 1. Recovery paths exercise all encoding/decoding scenarios that normal workflows do +// 2. Recovery paths additionally test decoding from persisted state (database) +// 3. This ensures that serialization works correctly even when workflows are recovered +// after a process restart or failure +// +// Each test: +// - Starts a workflow with a blocking step +// - Recovers the pending workflow from the database +// - Verifies all read paths: HandleGetResult, ListWorkflows, GetWorkflowSteps, RetrieveWorkflow +// - Ensures that both original and recovered handles produce correct results +// +// The suite covers: scalars, pointers (single level only, nested pointers not supported), +// slices, arrays, byte slices, maps, and custom types. It also tests Send/Recv and +// SetEvent/GetEvent communication patterns. +func TestSerializer(t *testing.T) { + t.Run("Gob", func(t *testing.T) { + executor := setupDBOS(t, true, true) - // Start a sender workflow that sends a message with a user-defined type - sendHandle, err := RunWorkflow(executor, sendUserDefinedTypeWorkflow, recvHandle.GetWorkflowID()) - require.NoError(t, err) + // Create a test queue for queued workflow tests + testQueue := NewWorkflowQueue(executor, "serializer-test-queue") - // Wait for the sender workflow to complete - sendResult, err := sendHandle.GetResult() - require.NoError(t, err) - assert.Equal(t, "user-defined-message-sent", sendResult) + // Register workflows + RegisterWorkflow(executor, serializerWorkflow) + RegisterWorkflow(executor, serializerPointerValueWorkflow) + RegisterWorkflow(executor, serializerErrorWorkflow) + RegisterWorkflow(executor, serializerSenderWorkflow) + RegisterWorkflow(executor, serializerReceiverWorkflow) + RegisterWorkflow(executor, serializerSetEventWorkflow) + RegisterWorkflow(executor, serializerGetEventWorkflow) + RegisterWorkflow(executor, serializerRecoveryWorkflow) + // Register typed workflows for concrete signatures + RegisterWorkflow(executor, serializerIntWorkflow) + RegisterWorkflow(executor, serializerIntPtrWorkflow) + RegisterWorkflow(executor, serializerIntSliceWorkflow) + RegisterWorkflow(executor, serializerStringIntMapWorkflow) + RegisterWorkflow(executor, serializerMyIntWorkflow) + RegisterWorkflow(executor, serializerMyStringWorkflow) + RegisterWorkflow(executor, serializerMyStringSliceWorkflow) + RegisterWorkflow(executor, serializerStringMyIntMapWorkflow) + RegisterWorkflow(executor, serializerStringWorkflow) + RegisterWorkflow(executor, serializerBoolWorkflow) + RegisterWorkflow(executor, serializerIntArrayWorkflow) + RegisterWorkflow(executor, serializerByteSliceWorkflow) + // Register recovery workflows for all types + RegisterWorkflow(executor, recoveryIntWorkflow) + RegisterWorkflow(executor, recoveryStringWorkflow) + RegisterWorkflow(executor, recoveryIntPtrWorkflow) + RegisterWorkflow(executor, recoveryIntSliceWorkflow) + RegisterWorkflow(executor, recoveryIntArrayWorkflow) + RegisterWorkflow(executor, recoveryByteSliceWorkflow) + RegisterWorkflow(executor, recoveryStringIntMapWorkflow) + RegisterWorkflow(executor, recoveryMyIntWorkflow) + RegisterWorkflow(executor, recoveryMyStringWorkflow) + RegisterWorkflow(executor, recoveryMyStringSliceWorkflow) + RegisterWorkflow(executor, recoveryStringMyIntMapWorkflow) + // Register additional recovery workflows + RegisterWorkflow(executor, recoveryEmptyStructWorkflow) + RegisterWorkflow(executor, recoveryIntSliceSliceWorkflow) + RegisterWorkflow(executor, recoveryNestedMapWorkflow) + RegisterWorkflow(executor, recoveryIntPtrSliceWorkflow) + // Register typed Send/Recv workflows + RegisterWorkflow(executor, serializerIntSenderWorkflow) + RegisterWorkflow(executor, serializerIntReceiverWorkflow) + RegisterWorkflow(executor, serializerIntPtrSenderWorkflow) + RegisterWorkflow(executor, serializerIntPtrReceiverWorkflow) + RegisterWorkflow(executor, serializerMyIntSenderWorkflow) + RegisterWorkflow(executor, serializerMyIntReceiverWorkflow) + // Register typed SetEvent/GetEvent workflows + RegisterWorkflow(executor, serializerIntSetEventWorkflow) + RegisterWorkflow(executor, serializerIntGetEventWorkflow) + RegisterWorkflow(executor, serializerIntPtrSetEventWorkflow) + RegisterWorkflow(executor, serializerIntPtrGetEventWorkflow) + RegisterWorkflow(executor, serializerMyIntSetEventWorkflow) + RegisterWorkflow(executor, serializerMyIntGetEventWorkflow) + + // Register workflow with interface signature for manual gob registration test + RegisterWorkflow(executor, interfaceWorkflow) + + // Register recovery workflow for *TestWorkflowData (used in NilPointer test) + recoveryPtrWorkflow := makeRecoveryWorkflow[*TestWorkflowData]() + RegisterWorkflow(executor, recoveryPtrWorkflow) + + // Define workflows for RunAsStep nested pointer validation tests + nestedPtrStepWorkflow := func(ctx DBOSContext, input int) (int, error) { + _, err := RunAsStep(ctx, func(context context.Context) (**int, error) { + return nil, nil + }) + if err != nil { + return 0, err + } + return input, nil + } + + tripleNestedPtrStepWorkflow := func(ctx DBOSContext, input int) (int, error) { + _, err := RunAsStep(ctx, func(context context.Context) (***int, error) { + return nil, nil + }) + if err != nil { + return 0, err + } + return input, nil + } + + // Register workflows for RunAsStep nested pointer validation + RegisterWorkflow(executor, nestedPtrStepWorkflow) + RegisterWorkflow(executor, tripleNestedPtrStepWorkflow) + + // Test nested pointer validation - these tests only check registration, not execution + // so they can run before Launch using the main executor + t.Run("NestedPointerValidation_Registration", func(t *testing.T) { + // Test RegisterWorkflow with nested pointer input type + t.Run("RegisterWorkflowWithNestedPointerInput", func(t *testing.T) { + nestedPtrInputWorkflow := func(ctx DBOSContext, input **int) (int, error) { + return 0, nil + } + require.Panics(t, func() { + RegisterWorkflow(executor, nestedPtrInputWorkflow) + }, "RegisterWorkflow should panic when input type is a nested pointer") + }) + + // Test RegisterWorkflow with nested pointer return type + t.Run("RegisterWorkflowWithNestedPointerReturn", func(t *testing.T) { + nestedPtrReturnWorkflow := func(ctx DBOSContext, input int) (**int, error) { + return nil, nil + } + require.Panics(t, func() { + RegisterWorkflow(executor, nestedPtrReturnWorkflow) + }, "RegisterWorkflow should panic when return type is a nested pointer") + }) + + // Test RegisterWorkflow with triple nested pointer + t.Run("RegisterWorkflowWithTripleNestedPointer", func(t *testing.T) { + tripleNestedWorkflow := func(ctx DBOSContext, input ***int) (int, error) { + return 0, nil + } + require.Panics(t, func() { + RegisterWorkflow(executor, tripleNestedWorkflow) + }, "RegisterWorkflow should panic when input type is a triple nested pointer") + }) + }) - // Wait for the receiver workflow to complete and get the message - receivedData, err := recvHandle.GetResult() + err := Launch(executor) require.NoError(t, err) + defer Shutdown(executor, 10*time.Second) + + // Test workflow with comprehensive data structure + t.Run("ComprehensiveValues", func(t *testing.T) { + strPtr := "pointer value" + strPtrPtr := &strPtr + input := TestWorkflowData{ + ID: "test-id", + Message: "test message", + Value: 42, + Active: true, + Data: TestData{Message: "embedded", Value: 123, Active: false}, + Metadata: map[string]string{"key": "value"}, + NestedSlice: []NestedTestData{ + {Key: "nested1", Count: 10}, + {Key: "nested2", Count: 20}, + }, + NestedMap: map[NestedTestData]MyInt{ + {Key: "map-key1", Count: 1}: MyInt(100), + {Key: "map-key2", Count: 2}: MyInt(200), + }, + StringPtr: &strPtr, + StringPtrPtr: &strPtrPtr, + } + + testAllSerializationPaths(t, executor, serializerRecoveryWorkflow, input, "comprehensive-values-wf") + }) + + // Test nil values with pointer type workflow + t.Run("NilPointer", func(t *testing.T) { + testAllSerializationPaths(t, executor, recoveryPtrWorkflow, (*TestWorkflowData)(nil), "nil-pointer-wf") + }) + + // Test error values + t.Run("ErrorValues", func(t *testing.T) { + input := TestWorkflowData{ + ID: "error-test-id", + Message: "error test", + Value: 123, + Active: true, + Data: TestData{Message: "error data", Value: 456, Active: false}, + Metadata: map[string]string{"type": "error"}, + NestedSlice: []NestedTestData{ + {Key: "error-nested", Count: 99}, + }, + NestedMap: map[NestedTestData]MyInt{ + {Key: "error-key", Count: 999}: MyInt(999), + }, + StringPtr: nil, + StringPtrPtr: nil, + } + + handle, err := RunWorkflow(executor, serializerErrorWorkflow, input) + require.NoError(t, err, "Error workflow execution failed") + + // 1. Test with handle.GetResult() + t.Run("HandleGetResult", func(t *testing.T) { + _, err := handle.GetResult() + require.Error(t, err, "Should get step error") + assert.Contains(t, err.Error(), "step error", "Error message should be preserved") + }) + + // 2. Test with GetWorkflowSteps + t.Run("GetWorkflowSteps", func(t *testing.T) { + steps, err := GetWorkflowSteps(executor, handle.GetWorkflowID()) + require.NoError(t, err, "Failed to get workflow steps") + require.Len(t, steps, 1, "Expected 1 step") + + step := steps[0] + require.NotNil(t, step.Error, "Step should have error") + assert.Contains(t, step.Error.Error(), "step error", "Step error should be preserved") + }) + }) + + // Test Send/Recv with non-basic types + t.Run("SendRecv", func(t *testing.T) { + strPtr := "sendrecv pointer" + strPtrPtr := &strPtr + input := TestWorkflowData{ + ID: "sendrecv-test-id", + Message: "test message", + Value: 99, + Active: true, + Data: TestData{Message: "nested", Value: 200, Active: true}, + Metadata: map[string]string{"comm": "sendrecv"}, + NestedSlice: []NestedTestData{ + {Key: "sendrecv-nested", Count: 50}, + }, + NestedMap: map[NestedTestData]MyInt{ + {Key: "sendrecv-key", Count: 5}: MyInt(500), + }, + StringPtr: &strPtr, + StringPtrPtr: &strPtrPtr, + } + + testSendRecv(t, executor, serializerSenderWorkflow, serializerReceiverWorkflow, input, "sender-wf") + }) + + // Test SetEvent/GetEvent with non-basic types + t.Run("SetGetEvent", func(t *testing.T) { + strPtr := "event pointer" + strPtrPtr := &strPtr + input := TestWorkflowData{ + ID: "event-test-id", + Message: "event message", + Value: 77, + Active: false, + Data: TestData{Message: "event nested", Value: 333, Active: true}, + Metadata: map[string]string{"type": "event"}, + NestedSlice: []NestedTestData{ + {Key: "event-nested1", Count: 30}, + {Key: "event-nested2", Count: 40}, + }, + NestedMap: map[NestedTestData]MyInt{ + {Key: "event-key1", Count: 3}: MyInt(300), + {Key: "event-key2", Count: 4}: MyInt(400), + }, + StringPtr: &strPtr, + StringPtrPtr: &strPtrPtr, + } + + testSetGetEvent(t, executor, serializerSetEventWorkflow, serializerGetEventWorkflow, input, "setevent-wf", "getevent-wf") + }) + + // Test typed Send/Recv and SetEvent/GetEvent with various types + t.Run("TypedSendRecvAndSetGetEvent", func(t *testing.T) { + // Test int (scalar type) + t.Run("Int", func(t *testing.T) { + input := 42 + testSendRecv(t, executor, serializerIntSenderWorkflow, serializerIntReceiverWorkflow, input, "typed-int-sender-wf") + testSetGetEvent(t, executor, serializerIntSetEventWorkflow, serializerIntGetEventWorkflow, input, "typed-int-setevent-wf", "typed-int-getevent-wf") + }) + + // Test MyInt (user defined type) + t.Run("MyInt", func(t *testing.T) { + input := MyInt(73) + testSendRecv(t, executor, serializerMyIntSenderWorkflow, serializerMyIntReceiverWorkflow, input, "typed-myint-sender-wf") + testSetGetEvent(t, executor, serializerMyIntSetEventWorkflow, serializerMyIntGetEventWorkflow, input, "typed-myint-setevent-wf", "typed-myint-getevent-wf") + }) + + // Test *int (pointer type, set) + t.Run("IntPtrSet", func(t *testing.T) { + v := 99 + input := &v + testSendRecv(t, executor, serializerIntPtrSenderWorkflow, serializerIntPtrReceiverWorkflow, input, "typed-intptr-set-sender-wf") + testSetGetEvent(t, executor, serializerIntPtrSetEventWorkflow, serializerIntPtrGetEventWorkflow, input, "typed-intptr-set-setevent-wf", "typed-intptr-set-getevent-wf") + }) + + // Test *int (pointer type, nil) + t.Run("IntPtrNil", func(t *testing.T) { + var input *int = nil + testSendRecv(t, executor, serializerIntPtrSenderWorkflow, serializerIntPtrReceiverWorkflow, input, "typed-intptr-nil-sender-wf") + testSetGetEvent(t, executor, serializerIntPtrSetEventWorkflow, serializerIntPtrGetEventWorkflow, input, "typed-intptr-nil-setevent-wf", "typed-intptr-nil-getevent-wf") + }) + }) - // Verify the received data matches what we sent - assert.Equal(t, 42, receivedData.ID) - assert.Equal(t, "test-send-message", receivedData.Name) - assert.Equal(t, "This is a test send message with user-defined data", receivedData.Details.Description) + // Test queued workflow with TestWorkflowData type + t.Run("QueuedWorkflow", func(t *testing.T) { + strPtr := "queued pointer" + strPtrPtr := &strPtr + input := TestWorkflowData{ + ID: "queued-test-id", + Message: "queued test message", + Value: 456, + Active: false, + Data: TestData{Message: "queued nested", Value: 789, Active: true}, + Metadata: map[string]string{"type": "queued"}, + NestedSlice: []NestedTestData{ + {Key: "queued-nested", Count: 222}, + }, + NestedMap: map[NestedTestData]MyInt{ + {Key: "queued-key", Count: 22}: MyInt(2222), + }, + StringPtr: &strPtr, + StringPtrPtr: &strPtrPtr, + } + + // Start workflow with queue option + handle, err := RunWorkflow(executor, serializerWorkflow, input, WithWorkflowID("serializer-queued-wf"), WithQueue(testQueue.Name)) + require.NoError(t, err, "failed to start queued workflow") + + // Get result from the handle + result, err := handle.GetResult() + require.NoError(t, err, "queued workflow should complete successfully") + assert.Equal(t, input, result, "queued workflow result should match input") + }) + + t.Run("Scalars", func(t *testing.T) { + testAllSerializationPaths(t, executor, recoveryIntWorkflow, 42, "recovery-int-wf") + }) + + t.Run("EmptyString", func(t *testing.T) { + testAllSerializationPaths(t, executor, recoveryStringWorkflow, "", "recovery-empty-string-wf") + }) + + // Pointer variants (single level only, nested pointers not supported) + t.Run("Pointers", func(t *testing.T) { + t.Run("NonNil", func(t *testing.T) { + v := 123 + input := &v + testAllSerializationPaths(t, executor, recoveryIntPtrWorkflow, input, "recovery-int-ptr-wf") + }) + + t.Run("Nil", func(t *testing.T) { + var input *int = nil + testAllSerializationPaths(t, executor, recoveryIntPtrWorkflow, input, "recovery-int-ptr-nil-wf") + }) + }) + + t.Run("SlicesAndArrays", func(t *testing.T) { + t.Run("NonEmptySlice", func(t *testing.T) { + input := []int{1, 2, 3} + testAllSerializationPaths(t, executor, recoveryIntSliceWorkflow, input, "recovery-int-slice-wf") + }) + + t.Run("NilSlice", func(t *testing.T) { + var input []int = nil + testAllSerializationPaths(t, executor, recoveryIntSliceWorkflow, input, "recovery-int-slice-nil-wf") + }) + + t.Run("Array", func(t *testing.T) { + input := [3]int{1, 2, 3} + testAllSerializationPaths(t, executor, recoveryIntArrayWorkflow, input, "recovery-int-array-wf") + }) + }) + + t.Run("ByteSlices", func(t *testing.T) { + t.Run("NonEmpty", func(t *testing.T) { + input := []byte{1, 2, 3, 4, 5} + testAllSerializationPaths(t, executor, recoveryByteSliceWorkflow, input, "recovery-byte-slice-wf") + }) + + t.Run("Nil", func(t *testing.T) { + var input []byte = nil + testAllSerializationPaths(t, executor, recoveryByteSliceWorkflow, input, "recovery-byte-slice-nil-wf") + }) + }) + + t.Run("Maps", func(t *testing.T) { + t.Run("NonEmptyMap", func(t *testing.T) { + input := map[string]int{"x": 1, "y": 2} + testAllSerializationPaths(t, executor, recoveryStringIntMapWorkflow, input, "recovery-string-int-map-wf") + }) + + t.Run("NilMap", func(t *testing.T) { + var input map[string]int = nil + testAllSerializationPaths(t, executor, recoveryStringIntMapWorkflow, input, "recovery-string-int-map-nil-wf") + }) + }) + + t.Run("CustomTypes", func(t *testing.T) { + t.Run("MyInt", func(t *testing.T) { + input := MyInt(7) + testAllSerializationPaths(t, executor, recoveryMyIntWorkflow, input, "recovery-myint-wf") + }) + + t.Run("MyString", func(t *testing.T) { + input := MyString("zeta") + testAllSerializationPaths(t, executor, recoveryMyStringWorkflow, input, "recovery-mystring-wf") + }) + + t.Run("MyStringSlice", func(t *testing.T) { + input := []MyString{"a", "b"} + testAllSerializationPaths(t, executor, recoveryMyStringSliceWorkflow, input, "recovery-mystring-slice-wf") + }) + + t.Run("StringMyIntMap", func(t *testing.T) { + input := map[string]MyInt{"k": 9} + testAllSerializationPaths(t, executor, recoveryStringMyIntMapWorkflow, input, "recovery-string-myint-map-wf") + }) + }) + + // Empty struct + t.Run("EmptyStruct", func(t *testing.T) { + input := struct{}{} + testAllSerializationPaths(t, executor, recoveryEmptyStructWorkflow, input, "recovery-empty-struct-wf") + }) + + // Nested collections + t.Run("NestedCollections", func(t *testing.T) { + t.Run("SliceOfSlices", func(t *testing.T) { + input := IntSliceSlice{{1, 2}, {3, 4, 5}} + testAllSerializationPaths(t, executor, recoveryIntSliceSliceWorkflow, input, "recovery-int-slice-slice-wf") + }) + + t.Run("NestedMap", func(t *testing.T) { + input := map[string]map[string]int{ + "outer1": {"inner1": 1, "inner2": 2}, + "outer2": {"inner3": 3}, + } + testAllSerializationPaths(t, executor, recoveryNestedMapWorkflow, input, "recovery-nested-map-wf") + }) + }) + + // Slices of pointers + t.Run("SliceOfPointers", func(t *testing.T) { + t.Run("NonNil", func(t *testing.T) { + v1 := 10 + v2 := 20 + v3 := 30 + input := []*int{&v1, &v2, &v3} + testAllSerializationPaths(t, executor, recoveryIntPtrSliceWorkflow, input, "recovery-int-ptr-slice-wf") + }) + + t.Run("NilSlice", func(t *testing.T) { + var input []*int = nil + testAllSerializationPaths(t, executor, recoveryIntPtrSliceWorkflow, input, "recovery-int-ptr-slice-nil-wf") + }) + }) + + // Test workflow with interface signature and manual gob registration + t.Run("InterfaceWithManualGobRegistration", func(t *testing.T) { + // Create an instance of the concrete implementation + processor := &TestStringProcessor{Prefix: "Processed: "} + + // Run the workflow with explicit type parameters (needed because processor is *TestStringProcessor but workflow expects TestDataProcessor interface) + handle, err := RunWorkflow[TestDataProcessor](executor, interfaceWorkflow, processor) + require.NoError(t, err, "Workflow execution failed") + + // Helper function to verify TestStringProcessor + verifyProcessor := func(t *testing.T, actual any, name string) { + t.Helper() + require.NotNil(t, actual, "%s should not be nil", name) + processor, ok := actual.(*TestStringProcessor) + require.True(t, ok, "%s should be *TestStringProcessor, got %T", name, actual) + assert.Equal(t, "Processed: ", processor.Prefix, "%s Prefix should match", name) + // Verify the interface method works + processed := processor.Process("test") + assert.Equal(t, "Processed: test", processed, "%s Process method should work", name) + } + + t.Run("HandleGetResult", func(t *testing.T) { + result, err := handle.GetResult() + require.NoError(t, err, "Failed to get workflow result") + verifyProcessor(t, result, "Result") + }) + + t.Run("ListWorkflows", func(t *testing.T) { + wfs, err := ListWorkflows(executor, + WithWorkflowIDs([]string{handle.GetWorkflowID()}), + WithLoadInput(true), WithLoadOutput(true)) + require.NoError(t, err) + require.Len(t, wfs, 1) + wf := wfs[0] + require.NotNil(t, wf.Input, "Workflow input should not be nil") + require.NotNil(t, wf.Output, "Workflow output should not be nil") + verifyProcessor(t, wf.Input, "Workflow input") + verifyProcessor(t, wf.Output, "Workflow output") + }) + + t.Run("GetWorkflowSteps", func(t *testing.T) { + steps, err := GetWorkflowSteps(executor, handle.GetWorkflowID()) + require.NoError(t, err) + require.Len(t, steps, 1) + step := steps[0] + require.NotNil(t, step.Output, "Step output should not be nil") + verifyProcessor(t, step.Output, "Step output") + assert.Nil(t, step.Error) + }) + + t.Run("RetrieveWorkflow", func(t *testing.T) { + h2, err := RetrieveWorkflow[TestDataProcessor](executor, handle.GetWorkflowID()) + require.NoError(t, err) + result, err := h2.GetResult() + require.NoError(t, err, "Failed to get retrieved workflow result") + verifyProcessor(t, result, "Retrieved workflow result") + }) + }) + + // Test nested pointer validation with RunAsStep + t.Run("NestedPointerValidation_RunAsStep", func(t *testing.T) { + // Test RunAsStep with nested pointer return type + t.Run("RunAsStepWithNestedPointerReturn", func(t *testing.T) { + handle, err := RunWorkflow(executor, nestedPtrStepWorkflow, 42) + require.NoError(t, err, "Workflow should start successfully") + + _, err = handle.GetResult() + require.Error(t, err, "Step execution should fail with nested pointer return type") + assert.Contains(t, err.Error(), "nested pointer types are not supported", "Error should mention nested pointer types") + }) + + // Test RunAsStep with triple nested pointer return type + t.Run("RunAsStepWithTripleNestedPointerReturn", func(t *testing.T) { + handle, err := RunWorkflow(executor, tripleNestedPtrStepWorkflow, 42) + require.NoError(t, err, "Workflow should start successfully") + + _, err = handle.GetResult() + require.Error(t, err, "Step execution should fail with triple nested pointer return type") + assert.Contains(t, err.Error(), "nested pointer types are not supported", "Error should mention nested pointer types") + }) + }) - // Verify tags - expectedTags := []string{"test", "user-defined", "serialization", "send"} - assert.Equal(t, expectedTags, receivedData.Details.Tags) }) } diff --git a/dbos/system_database.go b/dbos/system_database.go index ac1abc4..284e29e 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -37,7 +37,7 @@ type systemDatabase interface { insertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error) listWorkflows(ctx context.Context, input listWorkflowsDBInput) ([]WorkflowStatus, error) updateWorkflowOutcome(ctx context.Context, input updateWorkflowOutcomeDBInput) error - awaitWorkflowResult(ctx context.Context, workflowID string) (any, error) + awaitWorkflowResult(ctx context.Context, workflowID string) (*string, error) cancelWorkflow(ctx context.Context, workflowID string) error cancelAllBefore(ctx context.Context, cutoffTime time.Time) error resumeWorkflow(ctx context.Context, workflowID string) error @@ -51,13 +51,13 @@ type systemDatabase interface { // Steps recordOperationResult(ctx context.Context, input recordOperationResultDBInput) error checkOperationExecution(ctx context.Context, input checkOperationExecutionDBInput) (*recordedResult, error) - getWorkflowSteps(ctx context.Context, input getWorkflowStepsInput) ([]StepInfo, error) + getWorkflowSteps(ctx context.Context, input getWorkflowStepsInput) ([]stepInfo, error) // Communication (special steps) send(ctx context.Context, input WorkflowSendInput) error - recv(ctx context.Context, input recvInput) (any, error) + recv(ctx context.Context, input recvInput) (*string, error) setEvent(ctx context.Context, input WorkflowSetEventInput) error - getEvent(ctx context.Context, input getEventInput) (any, error) + getEvent(ctx context.Context, input getEventInput) (*string, error) // Timers (special steps) sleep(ctx context.Context, input sleepInput) (time.Duration, error) @@ -440,11 +440,6 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt timeoutMs = &millis } - inputString, err := serialize(input.status.Input) - if err != nil { - return nil, fmt.Errorf("failed to serialize input: %w", err) - } - // Our DB works with NULL values var applicationVersion *string if len(input.status.ApplicationVersion) > 0 { @@ -516,7 +511,7 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt updatedAt.UnixMilli(), timeoutMs, deadline, - inputString, + input.status.Input, // encoded input (already *string) deduplicationID, input.status.Priority, WorkflowStatusEnqueued, @@ -791,18 +786,13 @@ func (s *sysDB) listWorkflows(ctx context.Context, input listWorkflowsDBInput) ( wf.Error = errors.New(*errorStr) } - wf.Output, err = deserialize(outputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize output: %w", err) - } + // Return output as encoded *string + wf.Output = outputString } - // Handle input only if loadInput is true + // Return input as encoded *string if input.loadInput { - wf.Input, err = deserialize(inputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize input: %w", err) - } + wf.Input = inputString } workflows = append(workflows, wf) @@ -818,7 +808,7 @@ func (s *sysDB) listWorkflows(ctx context.Context, input listWorkflowsDBInput) ( type updateWorkflowOutcomeDBInput struct { workflowID string status WorkflowStatusType - output any + output *string err error tx pgx.Tx } @@ -830,20 +820,17 @@ func (s *sysDB) updateWorkflowOutcome(ctx context.Context, input updateWorkflowO SET status = $1, output = $2, error = $3, updated_at = $4, deduplication_id = NULL WHERE workflow_uuid = $5 AND NOT (status = $6 AND $1 in ($7, $8))`, pgx.Identifier{s.schema}.Sanitize()) - outputString, err := serialize(input.output) - if err != nil { - return fmt.Errorf("failed to serialize output: %w", err) - } - var errorStr string if input.err != nil { errorStr = input.err.Error() } + // input.output is already a *string from the database layer + var err error if input.tx != nil { - _, err = input.tx.Exec(ctx, query, input.status, outputString, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) + _, err = input.tx.Exec(ctx, query, input.status, input.output, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) } else { - _, err = s.pool.Exec(ctx, query, input.status, outputString, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) + _, err = s.pool.Exec(ctx, query, input.status, input.output, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) } if err != nil { @@ -1105,11 +1092,6 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st recovery_attempts ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)`, pgx.Identifier{s.schema}.Sanitize()) - inputString, err := serialize(originalWorkflow.Input) - if err != nil { - return "", fmt.Errorf("failed to serialize input: %w", err) - } - // Marshal authenticated roles (slice of strings) to JSON for TEXT column authenticatedRoles, err := json.Marshal(originalWorkflow.AuthenticatedRoles) @@ -1127,7 +1109,7 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st &appVersion, originalWorkflow.ApplicationID, _DBOS_INTERNAL_QUEUE_NAME, - inputString, + originalWorkflow.Input, // encoded time.Now().UnixMilli(), time.Now().UnixMilli(), 0) @@ -1157,7 +1139,7 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st return forkedWorkflowID, nil } -func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (any, error) { +func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (*string, error) { query := fmt.Sprintf(`SELECT status, output, error FROM %s.workflow_status WHERE workflow_uuid = $1`, pgx.Identifier{s.schema}.Sanitize()) var status WorkflowStatusType for { @@ -1179,20 +1161,14 @@ func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (any return nil, fmt.Errorf("failed to query workflow status: %w", err) } - // Deserialize output from TEXT to bytes then from bytes to R using gob - output, err := deserialize(outputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize output: %w", err) - } - switch status { case WorkflowStatusSuccess, WorkflowStatusError: if errorStr == nil || len(*errorStr) == 0 { - return output, nil + return outputString, nil } - return output, errors.New(*errorStr) + return outputString, errors.New(*errorStr) case WorkflowStatusCancelled: - return output, newAwaitedWorkflowCancelledError(workflowID) + return outputString, newAwaitedWorkflowCancelledError(workflowID) default: time.Sleep(_DB_RETRY_INTERVAL) } @@ -1203,7 +1179,7 @@ type recordOperationResultDBInput struct { workflowID string stepID int stepName string - output any + output *string err error tx pgx.Tx } @@ -1219,16 +1195,12 @@ func (s *sysDB) recordOperationResult(ctx context.Context, input recordOperation errorString = &e } - outputString, err := serialize(input.output) - if err != nil { - return fmt.Errorf("failed to serialize output: %w", err) - } - + var err error if input.tx != nil { _, err = input.tx.Exec(ctx, query, input.workflowID, input.stepID, - outputString, + input.output, errorString, input.stepName, ) @@ -1236,7 +1208,7 @@ func (s *sysDB) recordOperationResult(ctx context.Context, input recordOperation _, err = s.pool.Exec(ctx, query, input.workflowID, input.stepID, - outputString, + input.output, errorString, input.stepName, ) @@ -1326,7 +1298,7 @@ type recordChildGetResultDBInput struct { parentWorkflowID string childWorkflowID string stepID int - output string + output *string err error } @@ -1361,7 +1333,7 @@ func (s *sysDB) recordChildGetResult(ctx context.Context, input recordChildGetRe /*******************************/ type recordedResult struct { - output any + output *string err error } @@ -1431,29 +1403,24 @@ func (s *sysDB) checkOperationExecution(ctx context.Context, input checkOperatio return nil, newUnexpectedStepError(input.workflowID, input.stepID, input.stepName, recordedFunctionName) } - output, err := deserialize(outputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize output: %w", err) - } - var recordedError error if errorStr != nil && *errorStr != "" { recordedError = errors.New(*errorStr) } result := &recordedResult{ - output: output, + output: outputString, err: recordedError, } return result, nil } // StepInfo contains information about a workflow step execution. -type StepInfo struct { - StepID int // The sequential ID of the step within the workflow - StepName string // The name of the step function - Output any // The output returned by the step (if any) - Error error // The error returned by the step (if any) - ChildWorkflowID string // The ID of a child workflow spawned by this step (if applicable) +type stepInfo struct { + StepID int // The sequential ID of the step within the workflow + StepName string // The name of the step function + Output *string // The output returned by the step (if any) + Error error // The error returned by the step (if any) + ChildWorkflowID string // The ID of a child workflow spawned by this step (if applicable) } type getWorkflowStepsInput struct { @@ -1461,7 +1428,7 @@ type getWorkflowStepsInput struct { loadOutput bool } -func (s *sysDB) getWorkflowSteps(ctx context.Context, input getWorkflowStepsInput) ([]StepInfo, error) { +func (s *sysDB) getWorkflowSteps(ctx context.Context, input getWorkflowStepsInput) ([]stepInfo, error) { query := fmt.Sprintf(`SELECT function_id, function_name, output, error, child_workflow_id FROM %s.operation_outputs WHERE workflow_uuid = $1 @@ -1473,9 +1440,9 @@ func (s *sysDB) getWorkflowSteps(ctx context.Context, input getWorkflowStepsInpu } defer rows.Close() - var steps []StepInfo + var steps []stepInfo for rows.Next() { - var step StepInfo + var step stepInfo var outputString *string var errorString *string var childWorkflowID *string @@ -1485,13 +1452,9 @@ func (s *sysDB) getWorkflowSteps(ctx context.Context, input getWorkflowStepsInpu return nil, fmt.Errorf("failed to scan step row: %w", err) } - // Deserialize output if present and loadOutput is true - if input.loadOutput && outputString != nil { - output, err := deserialize(outputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize output: %w", err) - } - step.Output = output + // Return output as encoded string if loadOutput is true + if input.loadOutput { + step.Output = outputString } // Convert error string to error if present @@ -1564,12 +1527,13 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err return 0, fmt.Errorf("no recorded end time for recorded sleep operation") } - // The output should be a time.Time representing the end time - endTimeInterface, ok := recordedResult.output.(time.Time) - if !ok { - return 0, fmt.Errorf("recorded output is not a time.Time: %T", recordedResult.output) + // Decode the recorded end time directly into time.Time + // recordedResult.output is an encoded *string + serializer := newGobSerializer[time.Time]() + endTime, err = serializer.Decode(recordedResult.output) + if err != nil { + return 0, fmt.Errorf("failed to decode sleep end time: %w", err) } - endTime = endTimeInterface if recordedResult.err != nil { // This should never happen return 0, recordedResult.err @@ -1578,12 +1542,20 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err // First execution: calculate and record the end time endTime = time.Now().Add(input.duration) + // Serialize the end time before recording + serializer := newGobSerializer[time.Time]() + encodedEndTimeStr, serErr := serializer.Encode(endTime) + if serErr != nil { + return 0, fmt.Errorf("failed to serialize sleep end time: %w", serErr) + } + encodedEndTime := &encodedEndTimeStr + // Record the operation result with the calculated end time recordInput := recordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, - output: endTime, + output: encodedEndTime, err: nil, } @@ -1753,6 +1725,10 @@ func (s *sysDB) send(ctx context.Context, input WorkflowSendInput) error { stepID = wfState.nextStepID() } + if _, ok := input.Message.(*string); !ok { + return fmt.Errorf("message must be a pointer to a string") + } + tx, err := s.pool.Begin(ctx) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) @@ -1783,14 +1759,8 @@ func (s *sysDB) send(ctx context.Context, input WorkflowSendInput) error { topic = input.Topic } - // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.Message) - if err != nil { - return fmt.Errorf("failed to serialize message: %w", err) - } - insertQuery := fmt.Sprintf(`INSERT INTO %s.notifications (destination_uuid, topic, message) VALUES ($1, $2, $3)`, pgx.Identifier{s.schema}.Sanitize()) - _, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, messageString) + _, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, input.Message) if err != nil { // Check for foreign key violation (destination workflow doesn't exist) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == _PG_ERROR_FOREIGN_KEY_VIOLATION { @@ -1825,7 +1795,7 @@ func (s *sysDB) send(ctx context.Context, input WorkflowSendInput) error { } // Recv is a special type of step that receives a message destined for a given workflow -func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { +func (s *sysDB) recv(ctx context.Context, input recvInput) (*string, error) { functionName := "DBOS.recv" // Get workflow state from context @@ -1885,7 +1855,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { err = s.pool.QueryRow(ctx, query, destinationID, topic).Scan(&exists) if err != nil { cond.L.Unlock() - return false, fmt.Errorf("failed to check message: %w", err) + return nil, fmt.Errorf("failed to check message: %w", err) } if !exists { done := make(chan struct{}) @@ -1939,29 +1909,17 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { var messageString *string err = tx.QueryRow(ctx, query, destinationID, topic).Scan(&messageString) if err != nil { - if err == pgx.ErrNoRows { - // No message found, record nil result - messageString = nil - } else { + if err != pgx.ErrNoRows { return nil, fmt.Errorf("failed to consume message: %w", err) } } - // Deserialize the message - var message any - if messageString != nil { // nil message can happen on the timeout path only - message, err = deserialize(messageString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize message: %w", err) - } - } - - // Record the operation result + // Record the operation result (with encoded message string) recordInput := recordOperationResultDBInput{ workflowID: destinationID, stepID: stepID, stepName: functionName, - output: message, + output: messageString, tx: tx, } err = s.recordOperationResult(ctx, recordInput) @@ -1973,7 +1931,8 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { return nil, fmt.Errorf("failed to commit transaction: %w", err) } - return message, nil + // Return the message string pointer + return messageString, nil } type WorkflowSetEventInput struct { @@ -1990,6 +1949,10 @@ func (s *sysDB) setEvent(ctx context.Context, input WorkflowSetEventInput) error return newStepExecutionError("", functionName, fmt.Errorf("workflow state not found in context: are you running this step within a workflow?")) } + if _, ok := input.Message.(*string); !ok { + return fmt.Errorf("message must be a pointer to a string") + } + if wfState.isWithinStep { return newStepExecutionError(wfState.workflowID, functionName, fmt.Errorf("cannot call SetEvent within a step")) } @@ -2018,19 +1981,14 @@ func (s *sysDB) setEvent(ctx context.Context, input WorkflowSetEventInput) error return nil } - // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.Message) - if err != nil { - return fmt.Errorf("failed to serialize message: %w", err) - } - + // input.Message is already encoded *string from the typed layer // Insert or update the event using UPSERT insertQuery := fmt.Sprintf(`INSERT INTO %s.workflow_events (workflow_uuid, key, value) VALUES ($1, $2, $3) ON CONFLICT (workflow_uuid, key) DO UPDATE SET value = EXCLUDED.value`, pgx.Identifier{s.schema}.Sanitize()) - _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.Key, messageString) + _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.Key, input.Message) if err != nil { return fmt.Errorf("failed to insert/update workflow event: %w", err) } @@ -2058,7 +2016,7 @@ func (s *sysDB) setEvent(ctx context.Context, input WorkflowSetEventInput) error return nil } -func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) { +func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (*string, error) { functionName := "DBOS.getEvent" // Get workflow state from context (optional for GetEvent as we can get an event from outside a workflow) @@ -2160,22 +2118,13 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) } } - // Deserialize the value if it exists - var value any - if valueString != nil { - value, err = deserialize(valueString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize event value: %w", err) - } - } - // Record the operation result if this is called within a workflow if isInWorkflow { recordInput := recordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, - output: value, + output: valueString, err: nil, } @@ -2185,7 +2134,8 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) } } - return value, nil + // Return the value string pointer + return valueString, nil } /*******************************/ @@ -2195,7 +2145,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) type dequeuedWorkflow struct { id string name string - input string + input *string } type dequeueWorkflowsInput struct { @@ -2397,21 +2347,16 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, input dequeueWorkflowsInpu WHERE workflow_uuid = $5 RETURNING name, inputs`, pgx.Identifier{s.schema}.Sanitize()) - var inputString *string err := tx.QueryRow(ctx, updateQuery, WorkflowStatusPending, input.applicationVersion, input.executorID, time.Now().UnixMilli(), - id).Scan(&retWorkflow.name, &inputString) + id).Scan(&retWorkflow.name, &retWorkflow.input) if err != nil { return nil, fmt.Errorf("failed to update workflow %s during dequeue: %w", id, err) } - if inputString != nil && len(*inputString) > 0 { - retWorkflow.input = *inputString - } - retWorkflows = append(retWorkflows, retWorkflow) } diff --git a/dbos/workflow.go b/dbos/workflow.go index 5117933..a9da430 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "log/slog" "math" "reflect" "runtime" @@ -74,8 +73,9 @@ func (ws *workflowState) nextStepID() int { // workflowOutcome holds the result and error from workflow execution type workflowOutcome[R any] struct { - result R - err error + result R + err error + needsDecoding bool // true if result came from awaitWorkflowResult (ID conflict path) and needs decoding } // WorkflowHandle provides methods to interact with a running or completed workflow. @@ -205,11 +205,16 @@ func (h *workflowHandle[R]) GetResult(opts ...GetResultOption) (R, error) { // processOutcome handles the common logic for processing workflow outcomes func (h *workflowHandle[R]) processOutcome(outcome workflowOutcome[R]) (R, error) { + decodedResult := outcome.result // If we are calling GetResult inside a workflow, record the result as a step result workflowState, ok := h.dbosContext.Value(workflowStateKey).(*workflowState) isWithinWorkflow := ok && workflowState != nil if isWithinWorkflow { - encodedOutput, encErr := serialize(outcome.result) + if _, ok := h.dbosContext.(*dbosContext); !ok { + return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("invalid DBOSContext: expected *dbosContext")) + } + serializer := newGobSerializer[R]() + encodedOutput, encErr := serializer.Encode(decodedResult) if encErr != nil { return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("serializing child workflow result: %w", encErr)) } @@ -217,7 +222,7 @@ func (h *workflowHandle[R]) processOutcome(outcome workflowOutcome[R]) (R, error parentWorkflowID: workflowState.workflowID, childWorkflowID: h.workflowID, stepID: workflowState.nextStepID(), - output: encodedOutput, + output: &encodedOutput, err: outcome.err, } recordResultErr := retry(h.dbosContext, func() error { @@ -228,7 +233,7 @@ func (h *workflowHandle[R]) processOutcome(outcome workflowOutcome[R]) (R, error return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("recording child workflow result: %w", recordResultErr)) } } - return outcome.result, outcome.err + return decodedResult, outcome.err } type workflowPollingHandle[R any] struct { @@ -249,27 +254,37 @@ func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error) defer cancel() } - result, err := retryWithResult(ctx, func() (any, error) { + encodedResult, err := retryWithResult(ctx, func() (any, error) { return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(ctx, h.workflowID) }, withRetrierLogger(h.dbosContext.(*dbosContext).logger)) - if result != nil { - typedResult, ok := result.(R) - if !ok { - return *new(R), newWorkflowUnexpectedResultType(h.workflowID, fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", result)) + + // Deserialize the result directly into the target type + var typedResult R + if encodedResult != nil { + encodedStr, ok := encodedResult.(*string) + if !ok { // Should never happen + return *new(R), newWorkflowUnexpectedResultType(h.workflowID, "string (encoded)", fmt.Sprintf("%T", encodedResult)) + } + serializer := newGobSerializer[R]() + var deserErr error + typedResult, deserErr = serializer.Decode(encodedStr) + if deserErr != nil { + return *new(R), fmt.Errorf("failed to deserialize workflow result: %w", deserErr) } + // If we are calling GetResult inside a workflow, record the result as a step result workflowState, ok := h.dbosContext.Value(workflowStateKey).(*workflowState) isWithinWorkflow := ok && workflowState != nil if isWithinWorkflow { - encodedOutput, encErr := serialize(typedResult) - if encErr != nil { - return *new(R), newWorkflowExecutionError(workflowState.workflowID, fmt.Errorf("serializing child workflow result: %w", encErr)) + encodedResultStr, ok := encodedResult.(*string) + if !ok { // Should never happen + return *new(R), newWorkflowUnexpectedResultType(h.workflowID, "string (encoded)", fmt.Sprintf("%T", encodedResult)) } recordGetResultInput := recordChildGetResultDBInput{ parentWorkflowID: workflowState.workflowID, childWorkflowID: h.workflowID, stepID: workflowState.nextStepID(), - output: encodedOutput, + output: encodedResultStr, err: err, } recordResultErr := retry(h.dbosContext, func() error { @@ -489,6 +504,28 @@ func RegisterWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], opts ... panic("workflow function cannot be nil") } + // Check for nested pointer types + var p P + var r R + pType := reflect.TypeOf(p) + rType := reflect.TypeOf(r) + + if IsNestedPointer(pType) { + // Log error if we have a concrete dbosContext + if c, ok := ctx.(*dbosContext); ok { + c.logger.Error("nested pointer types are not supported", "workflow_input_type", pType) + } + panic(fmt.Sprintf("nested pointer types are not supported: workflow input type %v is a nested pointer", pType)) + } + + if IsNestedPointer(rType) { + // Log error if we have a concrete dbosContext + if c, ok := ctx.(*dbosContext); ok { + c.logger.Error("nested pointer types are not supported", "workflow_return_type", rType) + } + panic(fmt.Sprintf("nested pointer types are not supported: workflow return type %v is a nested pointer", rType)) + } + registrationParams := workflowRegistrationOptions{ maxRetries: _DEFAULT_MAX_RECOVERY_ATTEMPTS, } @@ -499,21 +536,22 @@ func RegisterWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], opts ... fqn := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() - // Registry the input/output types for gob encoding - var logger *slog.Logger - if c, ok := ctx.(*dbosContext); ok { - logger = c.logger - } - var p P - var r R - safeGobRegister(p, logger) - safeGobRegister(r, logger) - - // Register a type-erased version of the durable workflow for recovery + // Register a type-erased version of the durable workflow for recovery and queue runner + // Input will always come from the database and encoded as *string, so we decode it into the target type (captured by this wrapped closure) typedErasedWorkflow := WorkflowFunc(func(ctx DBOSContext, input any) (any, error) { - typedInput, ok := input.(P) + workflowID, err := GetWorkflowID(ctx) + if err != nil { + return *new(R), newWorkflowExecutionError("", fmt.Errorf("getting workflow ID: %w", err)) + } + encodedInput, ok := input.(*string) if !ok { - return nil, newWorkflowUnexpectedInputType(fqn, fmt.Sprintf("%T", typedInput), fmt.Sprintf("%T", input)) + return *new(R), newWorkflowUnexpectedInputType(fqn, "*string (encoded)", fmt.Sprintf("%T", input)) + } + // Decode directly into the target type + serializer := newGobSerializer[P]() + typedInput, err := serializer.Decode(encodedInput) + if err != nil { + return *new(R), newWorkflowExecutionError(workflowID, err) } return fn(ctx, typedInput) }) @@ -688,11 +726,46 @@ func RunWorkflow[P any, R any](ctx DBOSContext, fn Workflow[P, R], input P, opts resultErr := outcome.err var typedResult R - if typedRes, ok := outcome.result.(R); ok { + + // Handle nil results - nil cannot be type-asserted to any interface + if outcome.result == nil { + typedOutcomeChan <- workflowOutcome[R]{ + result: typedResult, + err: resultErr, + } + return + } + + // Check if this is a mocked path + if _, ok := handle.dbosContext.(*dbosContext); !ok { + typedOutcomeChan <- workflowOutcome[R]{ + result: outcome.result.(R), + err: resultErr, + } + return + } + + // Convert result to expected type R + // Result can be either an encoded *string (from ID conflict path) or already decoded + if outcome.needsDecoding { + encodedResult, ok := outcome.result.(*string) + if !ok { // Should never happen + resultErr = errors.Join(resultErr, newWorkflowUnexpectedResultType(handle.workflowID, "string (encoded)", fmt.Sprintf("%T", outcome.result))) + } else { + // Result is encoded, decode directly into target type + serializer := newGobSerializer[R]() + var decodeErr error + typedResult, decodeErr = serializer.Decode(encodedResult) + if decodeErr != nil { + resultErr = errors.Join(resultErr, newWorkflowExecutionError(handle.workflowID, fmt.Errorf("decoding workflow result to type %T: %w", *new(R), decodeErr))) + } + } + } else if typedRes, ok := outcome.result.(R); ok { + // Normal path - result already has the correct type typedResult = typedRes - } else { // This should never happen - typedResult = *new(R) - typeErr := fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), outcome.result) + } else { + // Type assertion failed + typeErr := newWorkflowUnexpectedResultType(handle.workflowID, fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", outcome.result)) resultErr = errors.Join(resultErr, typeErr) } @@ -806,6 +879,14 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt if params.priority > uint(math.MaxInt) { return nil, fmt.Errorf("priority %d exceeds maximum allowed value %d", params.priority, math.MaxInt) } + + // Serialize input before storing in workflow status + serializer := newGobSerializer[any]() + encodedInput, serErr := serializer.Encode(input) + if serErr != nil { + return nil, newWorkflowExecutionError(workflowID, fmt.Errorf("failed to serialize workflow input: %w", serErr)) + } + workflowStatus := WorkflowStatus{ Name: params.workflowName, ApplicationVersion: params.applicationVersion, @@ -815,7 +896,7 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt CreatedAt: time.Now(), Deadline: deadline, Timeout: timeout, - Input: input, + Input: &encodedInput, ApplicationID: c.GetApplicationID(), QueueName: params.queueName, DeduplicationID: params.deduplicationID, @@ -937,9 +1018,14 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt // Handle DBOS ID conflict errors by waiting workflow result if errors.Is(err, &DBOSError{Code: ConflictingIDError}) { c.logger.Warn("Workflow ID conflict detected. Waiting for existing workflow to complete", "workflow_id", workflowID) - result, err = retryWithResult(c, func() (any, error) { + var encodedResult any + encodedResult, err = retryWithResult(c, func() (any, error) { return c.systemDB.awaitWorkflowResult(uncancellableCtx, workflowID) }, withRetrierLogger(c.logger)) + // Keep the encoded result - decoding will happen in RunWorkflow[P,R] when we know the target type + outcomeChan <- workflowOutcome[any]{result: encodedResult, err: err, needsDecoding: true} + close(outcomeChan) + return } else { status := WorkflowStatusSuccess @@ -955,12 +1041,22 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt status = WorkflowStatusCancelled } + // Serialize the output before recording + serializer := newGobSerializer[any]() + encodedOutput, serErr := serializer.Encode(result) + if serErr != nil { + c.logger.Error("Failed to serialize workflow output", "workflow_id", workflowID, "error", serErr) + outcomeChan <- workflowOutcome[any]{result: nil, err: fmt.Errorf("failed to serialize output: %w", serErr)} + close(outcomeChan) + return + } + recordErr := retry(c, func() error { return c.systemDB.updateWorkflowOutcome(uncancellableCtx, updateWorkflowOutcomeDBInput{ workflowID: workflowID, status: status, err: err, - output: result, + output: &encodedOutput, }) }, withRetrierLogger(c.logger)) if recordErr != nil { @@ -1104,13 +1200,17 @@ func RunAsStep[R any](ctx DBOSContext, fn Step[R], opts ...StepOption) (R, error return *new(R), newStepExecutionError("", "", fmt.Errorf("step function cannot be nil")) } - // Register the output type for gob encoding - var logger *slog.Logger - if c, ok := ctx.(*dbosContext); ok { - logger = c.logger - } + // Check for nested pointer types var r R - safeGobRegister(r, logger) + rType := reflect.TypeOf(r) + if IsNestedPointer(rType) { + workflowID, _ := GetWorkflowID(ctx) // Best effort to get workflow ID for error context + // Log error if we have a concrete dbosContext + if c, ok := ctx.(*dbosContext); ok { + c.logger.Error("nested pointer types are not supported", "step_return_type", rType, "workflow_id", workflowID) + } + return *new(R), newStepExecutionError(workflowID, "", fmt.Errorf("nested pointer types are not supported: step return type %v is a nested pointer", rType)) + } // Append WithStepName option to ensure the step name is set. This will not erase a user-provided step name stepName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() @@ -1124,10 +1224,22 @@ func RunAsStep[R any](ctx DBOSContext, fn Step[R], opts ...StepOption) (R, error if result == nil { return *new(R), err } - // Otherwise type-check and cast the result - typedResult, ok := result.(R) - if !ok { - return *new(R), fmt.Errorf("unexpected result type: expected %T, got %T", *new(R), result) + var typedResult R + // When the step is executed, the result is already decoded and should be directly convertible + if typedRes, ok := result.(R); ok { + typedResult = typedRes + } else if encodedOutput, ok := result.(*string); ok { + // If not it should be an encoded *string + serializer := newGobSerializer[R]() + var decodeErr error + typedResult, decodeErr = serializer.Decode(encodedOutput) + if decodeErr != nil { + workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error + return *new(R), newWorkflowExecutionError(workflowID, fmt.Errorf("decoding step result to expected type %T: %w", *new(R), decodeErr)) + } + } else { + workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error + return *new(R), newWorkflowUnexpectedResultType(workflowID, fmt.Sprintf("%T", *new(R)), fmt.Sprintf("%T", result)) } return typedResult, err } @@ -1178,6 +1290,7 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) return nil, newStepExecutionError(stepState.workflowID, stepOpts.stepName, fmt.Errorf("checking operation execution: %w", err)) } if recordedOutput != nil { + // Return the encoded output - decoding will happen in RunAsStep[R] when we know the target type return recordedOutput.output, recordedOutput.err } @@ -1228,13 +1341,20 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) } } + // Serialize step output before recording + serializer := newGobSerializer[any]() + encodedStepOutput, serErr := serializer.Encode(stepOutput) + if serErr != nil { + return nil, newStepExecutionError(stepState.workflowID, stepOpts.stepName, fmt.Errorf("failed to serialize step output: %w", serErr)) + } + // Record the final result dbInput := recordOperationResultDBInput{ workflowID: stepState.workflowID, stepName: stepOpts.stepName, stepID: stepState.stepID, err: stepError, - output: stepOutput, + output: &encodedStepOutput, } recErr := retry(c, func() error { return c.systemDB.recordOperationResult(uncancellableCtx, dbInput) @@ -1251,10 +1371,16 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) /****************************************/ func (c *dbosContext) Send(_ DBOSContext, destinationID string, message any, topic string) error { + // Serialize the message before sending + serializer := newGobSerializer[any]() + encodedMessage, err := serializer.Encode(message) + if err != nil { + return fmt.Errorf("failed to serialize message: %w", err) + } return retry(c, func() error { return c.systemDB.send(c, WorkflowSendInput{ DestinationID: destinationID, - Message: message, + Message: &encodedMessage, Topic: topic, }) }, withRetrierLogger(c.logger)) @@ -1273,12 +1399,6 @@ func Send[P any](ctx DBOSContext, destinationID string, message P, topic string) if ctx == nil { return errors.New("ctx cannot be nil") } - var logger *slog.Logger - if c, ok := ctx.(*dbosContext); ok { - logger = c.logger - } - var typedMessage P - safeGobRegister(typedMessage, logger) return ctx.Send(ctx, destinationID, message, topic) } @@ -1292,7 +1412,7 @@ func (c *dbosContext) Recv(_ DBOSContext, topic string, timeout time.Duration) ( Topic: topic, Timeout: timeout, } - return retryWithResult(c, func() (any, error) { + return retryWithResult(c, func() (*string, error) { return c.systemDB.recv(c, input) }, withRetrierLogger(c.logger)) } @@ -1319,23 +1439,51 @@ func Recv[R any](ctx DBOSContext, topic string, timeout time.Duration) (R, error if err != nil { return *new(R), err } - // Type check + + // Handle nil message + if msg == nil { + return *new(R), nil + } + var typedMessage R - if msg != nil { + // Check if we're in a real DBOS context (not a mock) + if _, ok := ctx.(*dbosContext); ok { + encodedMsg, ok := msg.(*string) + if !ok { + workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error + return *new(R), newWorkflowUnexpectedResultType(workflowID, "string (encoded)", fmt.Sprintf("%T", msg)) + } + serializer := newGobSerializer[R]() + var decodeErr error + typedMessage, decodeErr = serializer.Decode(encodedMsg) + if decodeErr != nil { + return *new(R), fmt.Errorf("decoding received message to type %T: %w", *new(R), decodeErr) + } + return typedMessage, nil + } else { + // Fallback for testing/mocking scenarios where serializer is nil var ok bool typedMessage, ok = msg.(R) if !ok { - return *new(R), newWorkflowUnexpectedResultType("", fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", msg)) + workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error + return *new(R), newWorkflowUnexpectedResultType(workflowID, fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", msg)) } } return typedMessage, nil } func (c *dbosContext) SetEvent(_ DBOSContext, key string, message any) error { + // Serialize the event value before storing + serializer := newGobSerializer[any]() + encodedMessage, err := serializer.Encode(message) + if err != nil { + return fmt.Errorf("failed to serialize event value: %w", err) + } + return retry(c, func() error { return c.systemDB.setEvent(c, WorkflowSetEventInput{ Key: key, - Message: message, + Message: &encodedMessage, }) }, withRetrierLogger(c.logger)) } @@ -1354,12 +1502,6 @@ func SetEvent[P any](ctx DBOSContext, key string, message P) error { if ctx == nil { return errors.New("ctx cannot be nil") } - var logger *slog.Logger - if c, ok := ctx.(*dbosContext); ok { - logger = c.logger - } - var typedMessage P - safeGobRegister(typedMessage, logger) return ctx.SetEvent(ctx, key, message) } @@ -1405,10 +1547,30 @@ func GetEvent[R any](ctx DBOSContext, targetWorkflowID, key string, timeout time if value == nil { return *new(R), nil } - // Type check - typedValue, ok := value.(R) - if !ok { - return *new(R), newWorkflowUnexpectedResultType("", fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", value)) + + var typedValue R + // Check if we're in a real DBOS context (not a mock) + if _, ok := ctx.(*dbosContext); ok { + encodedValue, ok := value.(*string) + if !ok { + workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error + return *new(R), newWorkflowUnexpectedResultType(workflowID, "string (encoded)", fmt.Sprintf("%T", value)) + } + + serializer := newGobSerializer[R]() + var decodeErr error + typedValue, decodeErr = serializer.Decode(encodedValue) + if decodeErr != nil { + return *new(R), fmt.Errorf("decoding event value to type %T: %w", *new(R), decodeErr) + } + return typedValue, nil + } else { + var ok bool + typedValue, ok = value.(R) + if !ok { + workflowID, _ := GetWorkflowID(ctx) // Must be within a workflow so we can ignore the error + return *new(R), newWorkflowUnexpectedResultType(workflowID, fmt.Sprintf("%T", new(R)), fmt.Sprintf("%T", value)) + } } return typedValue, nil } @@ -1551,14 +1713,6 @@ func RetrieveWorkflow[R any](ctx DBOSContext, workflowID string) (WorkflowHandle return nil, errors.New("dbosCtx cannot be nil") } - // Register the output for gob encoding - var logger *slog.Logger - if c, ok := ctx.(*dbosContext); ok { - logger = c.logger - } - var r R - safeGobRegister(r, logger) - // Call the interface method handle, err := ctx.RetrieveWorkflow(ctx, workflowID) if err != nil { @@ -1649,14 +1803,6 @@ func ResumeWorkflow[R any](ctx DBOSContext, workflowID string) (WorkflowHandle[R return nil, errors.New("ctx cannot be nil") } - // Register the output for gob encoding - var logger *slog.Logger - if c, ok := ctx.(*dbosContext); ok { - logger = c.logger - } - var r R - safeGobRegister(r, logger) - _, err := ctx.ResumeWorkflow(ctx, workflowID) if err != nil { return nil, err @@ -1745,14 +1891,6 @@ func ForkWorkflow[R any](ctx DBOSContext, input ForkWorkflowInput) (WorkflowHand return nil, errors.New("ctx cannot be nil") } - // Register the output for gob encoding - var logger *slog.Logger - if c, ok := ctx.(*dbosContext); ok { - logger = c.logger - } - var r R - safeGobRegister(r, logger) - handle, err := ctx.ForkWorkflow(ctx, input) if err != nil { return nil, err @@ -1940,15 +2078,51 @@ func (c *dbosContext) ListWorkflows(_ DBOSContext, opts ...ListWorkflowsOption) } // Call the context method to list workflows + var workflows []WorkflowStatus + var err error workflowState, ok := c.Value(workflowStateKey).(*workflowState) isWithinWorkflow := ok && workflowState != nil if isWithinWorkflow { - return RunAsStep(c, func(ctx context.Context) ([]WorkflowStatus, error) { + workflows, err = RunAsStep(c, func(ctx context.Context) ([]WorkflowStatus, error) { return c.systemDB.listWorkflows(ctx, dbInput) }, WithStepName("DBOS.listWorkflows")) } else { - return c.systemDB.listWorkflows(c, dbInput) + workflows, err = c.systemDB.listWorkflows(c, dbInput) } + if err != nil { + return nil, err + } + + // Deserialize Input and Output fields if they were loaded + if params.loadInput || params.loadOutput { + serializer := newGobSerializer[any]() + for i := range workflows { + if params.loadInput && workflows[i].Input != nil { + encodedInput, ok := workflows[i].Input.(*string) + if !ok { + return nil, fmt.Errorf("workflow input must be encoded string, got %T", workflows[i].Input) + } + decodedInput, err := serializer.Decode(encodedInput) + if err != nil { + return nil, fmt.Errorf("failed to deserialize workflow input for %s: %w", workflows[i].ID, err) + } + workflows[i].Input = decodedInput + } + if params.loadOutput && workflows[i].Output != nil { + encodedOutput, ok := workflows[i].Output.(*string) + if !ok { + return nil, fmt.Errorf("workflow output must be encoded string, got %T", workflows[i].Output) + } + decodedOutput, err := serializer.Decode(encodedOutput) + if err != nil { + return nil, fmt.Errorf("failed to deserialize workflow output for %s: %w", workflows[i].ID, err) + } + workflows[i].Output = decodedOutput + } + } + } + + return workflows, nil } // ListWorkflows retrieves a list of workflows based on the provided filters. @@ -2002,6 +2176,14 @@ func ListWorkflows(ctx DBOSContext, opts ...ListWorkflowsOption) ([]WorkflowStat return ctx.ListWorkflows(ctx, opts...) } +type StepInfo struct { + StepID int // The sequential ID of the step within the workflow + StepName string // The name of the step function + Output any // The output returned by the step (if any) + Error error // The error returned by the step (if any) + ChildWorkflowID string // The ID of a child workflow spawned by this step (if applicable) +} + func (c *dbosContext) GetWorkflowSteps(_ DBOSContext, workflowID string) ([]StepInfo, error) { var loadOutput bool if c.launched.Load() { @@ -2014,15 +2196,45 @@ func (c *dbosContext) GetWorkflowSteps(_ DBOSContext, workflowID string) ([]Step loadOutput: loadOutput, } + var steps []stepInfo + var err error workflowState, ok := c.Value(workflowStateKey).(*workflowState) isWithinWorkflow := ok && workflowState != nil if isWithinWorkflow { - return RunAsStep(c, func(ctx context.Context) ([]StepInfo, error) { + steps, err = RunAsStep(c, func(ctx context.Context) ([]stepInfo, error) { return c.systemDB.getWorkflowSteps(ctx, getWorkflowStepsInput) }, WithStepName("DBOS.getWorkflowSteps")) } else { - return c.systemDB.getWorkflowSteps(c, getWorkflowStepsInput) + steps, err = c.systemDB.getWorkflowSteps(c, getWorkflowStepsInput) + } + if err != nil { + return nil, err + } + stepInfos := make([]StepInfo, len(steps)) + for i, step := range steps { + stepInfos[i] = StepInfo{ + StepID: step.StepID, + StepName: step.StepName, + Output: step.Output, + Error: step.Error, + ChildWorkflowID: step.ChildWorkflowID, + } } + + // Deserialize outputs if asked to + if loadOutput { + serializer := newGobSerializer[any]() + for i := range steps { + encodedOutput := steps[i].Output + decodedOutput, err := serializer.Decode(encodedOutput) + if err != nil { + return nil, fmt.Errorf("failed to deserialize step output for step %d: %w", steps[i].StepID, err) + } + stepInfos[i].Output = decodedOutput + } + } + + return stepInfos, nil } // GetWorkflowSteps retrieves the execution steps of a workflow. diff --git a/dbos/workflows_test.go b/dbos/workflows_test.go index f5672ec..589638c 100644 --- a/dbos/workflows_test.go +++ b/dbos/workflows_test.go @@ -6,7 +6,6 @@ import ( "fmt" "reflect" "runtime" - "strings" "sync" "sync/atomic" "testing" @@ -363,177 +362,6 @@ func TestWorkflowsRegistration(t *testing.T) { }() RegisterWorkflow(freshCtx, simpleWorkflow) }) - - t.Run("SafeGobRegister", func(t *testing.T) { - // Create a fresh DBOS context for this test - freshCtx := setupDBOS(t, false, true) // Don't reset DB but do check for leaks - - // Test 1: Basic type vs pointer conflicts - type TestType struct { - Value string - } - - // Register workflows that use the same type to trigger potential gob conflicts - // The safeGobRegister calls within RegisterWorkflow should handle the conflicts - workflow1 := func(ctx DBOSContext, input TestType) (TestType, error) { - return input, nil - } - workflow2 := func(ctx DBOSContext, input *TestType) (*TestType, error) { - return input, nil - } - - // Both registrations should succeed despite using conflicting types (T and *T) - RegisterWorkflow(freshCtx, workflow1) - RegisterWorkflow(freshCtx, workflow2) - - // Test 2: Multiple workflows with the same types (duplicate registrations) - workflow3 := func(ctx DBOSContext, input TestType) (TestType, error) { - return TestType{Value: input.Value + "-modified"}, nil - } - workflow4 := func(ctx DBOSContext, input TestType) (TestType, error) { - return TestType{Value: input.Value + "-another"}, nil - } - - // These should succeed even though TestType is already registered - RegisterWorkflow(freshCtx, workflow3) - RegisterWorkflow(freshCtx, workflow4) - - // Test 3: Nested structs - type InnerType struct { - ID int - } - type OuterType struct { - Inner InnerType - Name string - } - - workflow5 := func(ctx DBOSContext, input OuterType) (OuterType, error) { - return input, nil - } - workflow6 := func(ctx DBOSContext, input *OuterType) (*OuterType, error) { - return input, nil - } - - RegisterWorkflow(freshCtx, workflow5) - RegisterWorkflow(freshCtx, workflow6) - - // Test 4: Slice and map types - workflow7 := func(ctx DBOSContext, input []TestType) ([]TestType, error) { - return input, nil - } - workflow8 := func(ctx DBOSContext, input []*TestType) ([]*TestType, error) { - return input, nil - } - workflow9 := func(ctx DBOSContext, input map[string]TestType) (map[string]TestType, error) { - return input, nil - } - workflow10 := func(ctx DBOSContext, input map[string]*TestType) (map[string]*TestType, error) { - return input, nil - } - - RegisterWorkflow(freshCtx, workflow7) - RegisterWorkflow(freshCtx, workflow8) - RegisterWorkflow(freshCtx, workflow9) - RegisterWorkflow(freshCtx, workflow10) - - // Launch and verify the system still works - err := Launch(freshCtx) - require.NoError(t, err, "failed to launch DBOS after gob conflict handling") - defer Shutdown(freshCtx, 10*time.Second) - - // Test all registered workflows to ensure they work correctly - - // Run workflow1 with value type - testValue := TestType{Value: "test"} - handle1, err := RunWorkflow(freshCtx, workflow1, testValue) - require.NoError(t, err, "failed to run workflow1") - result1, err := handle1.GetResult() - require.NoError(t, err, "failed to get result from workflow1") - assert.Equal(t, testValue, result1, "unexpected result from workflow1") - - // Run workflow2 with pointer type - testPointer := &TestType{Value: "pointer"} - handle2, err := RunWorkflow(freshCtx, workflow2, testPointer) - require.NoError(t, err, "failed to run workflow2") - result2, err := handle2.GetResult() - require.NoError(t, err, "failed to get result from workflow2") - assert.Equal(t, testPointer, result2, "unexpected result from workflow2") - - // Run workflow3 with modified output - handle3, err := RunWorkflow(freshCtx, workflow3, testValue) - require.NoError(t, err, "failed to run workflow3") - result3, err := handle3.GetResult() - require.NoError(t, err, "failed to get result from workflow3") - assert.Equal(t, TestType{Value: "test-modified"}, result3, "unexpected result from workflow3") - - // Run workflow5 with nested struct - testOuter := OuterType{Inner: InnerType{ID: 42}, Name: "test"} - handle5, err := RunWorkflow(freshCtx, workflow5, testOuter) - require.NoError(t, err, "failed to run workflow5") - result5, err := handle5.GetResult() - require.NoError(t, err, "failed to get result from workflow5") - assert.Equal(t, testOuter, result5, "unexpected result from workflow5") - - // Run workflow6 with nested struct pointer - testOuterPtr := &OuterType{Inner: InnerType{ID: 43}, Name: "test-ptr"} - handle6, err := RunWorkflow(freshCtx, workflow6, testOuterPtr) - require.NoError(t, err, "failed to run workflow6") - result6, err := handle6.GetResult() - require.NoError(t, err, "failed to get result from workflow6") - assert.Equal(t, testOuterPtr, result6, "unexpected result from workflow6") - - // Run workflow7 with slice type - testSlice := []TestType{{Value: "a"}, {Value: "b"}} - handle7, err := RunWorkflow(freshCtx, workflow7, testSlice) - require.NoError(t, err, "failed to run workflow7") - result7, err := handle7.GetResult() - require.NoError(t, err, "failed to get result from workflow7") - assert.Equal(t, testSlice, result7, "unexpected result from workflow7") - - // Run workflow8 with pointer slice type - testPtrSlice := []*TestType{{Value: "a"}, {Value: "b"}} - handle8, err := RunWorkflow(freshCtx, workflow8, testPtrSlice) - require.NoError(t, err, "failed to run workflow8") - result8, err := handle8.GetResult() - require.NoError(t, err, "failed to get result from workflow8") - assert.Equal(t, testPtrSlice, result8, "unexpected result from workflow8") - - // Run workflow9 with map type - testMap := map[string]TestType{"key1": {Value: "value1"}} - handle9, err := RunWorkflow(freshCtx, workflow9, testMap) - require.NoError(t, err, "failed to run workflow9") - result9, err := handle9.GetResult() - require.NoError(t, err, "failed to get result from workflow9") - assert.Equal(t, testMap, result9, "unexpected result from workflow9") - - // Run workflow10 with pointer map type - testPtrMap := map[string]*TestType{"key1": {Value: "value1"}} - handle10, err := RunWorkflow(freshCtx, workflow10, testPtrMap) - require.NoError(t, err, "failed to run workflow10") - result10, err := handle10.GetResult() - require.NoError(t, err, "failed to get result from workflow10") - assert.Equal(t, testPtrMap, result10, "unexpected result from workflow10") - - t.Run("validPanic", func(t *testing.T) { - // Verify that non-duplicate registration panics are still propagated - workflow11 := func(ctx DBOSContext, input any) (any, error) { - return input, nil - } - - // This should panic during registration because interface{} creates a nil value - // which gob.Register cannot handle - defer func() { - r := recover() - require.NotNil(t, r, "expected panic from interface{} registration but got none") - // Verify it's not a duplicate registration error (which would be caught) - if errStr, ok := r.(string); ok { - assert.False(t, strings.Contains(errStr, "gob: registering duplicate"), - "panic should not be a duplicate registration error, got: %v", r) - } - }() - RegisterWorkflow(freshCtx, workflow11) // This should panic - }) - }) } func stepWithinAStep(ctx context.Context) (string, error) { @@ -827,7 +655,6 @@ func TestSteps(t *testing.T) { }) t.Run("stepsOutputEncoding", func(t *testing.T) { - // Execute the workflow handle, err := RunWorkflow(dbosCtx, userObjectWorkflow, "TestObject") require.NoError(t, err, "failed to run workflow with user-defined objects") @@ -1136,7 +963,6 @@ func TestChildWorkflow(t *testing.T) { require.NoError(t, err, "failed to launch DBOS") t.Run("ChildWorkflowIDGeneration", func(t *testing.T) { - r := 3 h, err := RunWorkflow(dbosCtx, grandParentWf, r) require.NoError(t, err, "failed to execute grand parent workflow") @@ -1282,10 +1108,11 @@ func TestWorkflowRecovery(t *testing.T) { dbosCtx := setupDBOS(t, true, true) var ( - recoveryCounters []int64 - recoveryEvents []*Event - blockingEvents []*Event - secondStepErrors []error + recoveryCounters []int64 + recoveryEvents []*Event + blockingEvents []*Event + secondStepErrors []error + secondStepErrorsMu sync.Mutex ) recoveryWorkflow := func(dbosCtx DBOSContext, index int) (int64, error) { @@ -1307,7 +1134,9 @@ func TestWorkflowRecovery(t *testing.T) { return fmt.Sprintf("completed-%d", index), nil }, WithStepName(fmt.Sprintf("BlockingStep-%d", index))) if err != nil { + secondStepErrorsMu.Lock() secondStepErrors = append(secondStepErrors, err) + secondStepErrorsMu.Unlock() return 0, err } @@ -1431,8 +1260,15 @@ func TestWorkflowRecovery(t *testing.T) { // At least 5 of the 2nd steps should have errored due to execution race // Check they are DBOSErrors with StepExecutionError wrapping a ConflictingIDError - require.GreaterOrEqual(t, len(secondStepErrors), 5, "expected at least 5 errors from second steps due to recovery race, got %d", len(secondStepErrors)) - for _, err := range secondStepErrors { + var errorsCopy []error + require.Eventually(t, func() bool { + secondStepErrorsMu.Lock() + errorsCopy := make([]error, len(secondStepErrors)) + copy(errorsCopy, secondStepErrors) + secondStepErrorsMu.Unlock() + return len(errorsCopy) >= 5 + }, 10*time.Second, 100*time.Millisecond) + for _, err := range errorsCopy { dbosErr, ok := err.(*DBOSError) require.True(t, ok, "expected error to be of type *DBOSError, got %T", err) require.Equal(t, StepExecutionError, dbosErr.Code, "expected error code to be StepExecutionError, got %v", dbosErr.Code) @@ -1658,6 +1494,7 @@ var ( sendIdempotencyEvent = NewEvent() receiveIdempotencyStartEvent = NewEvent() receiveIdempotencyStopEvent = NewEvent() + sendRecvSyncEvent = NewEvent() // Event to synchronize send/recv in tests numConcurrentRecvWfs = 5 concurrentRecvReadyEvents = make([]*Event, numConcurrentRecvWfs) concurrentRecvStartEvent = NewEvent() @@ -1685,6 +1522,9 @@ func sendWorkflow(ctx DBOSContext, input sendWorkflowInput) (string, error) { } func receiveWorkflow(ctx DBOSContext, topic string) (string, error) { + // Wait for the test to signal it's ready + sendRecvSyncEvent.Wait() + msg1, err := Recv[string](ctx, topic, 2*time.Second) if err != nil { return "", err @@ -1818,27 +1658,34 @@ func TestSendRecv(t *testing.T) { Launch(dbosCtx) t.Run("SendRecvSuccess", func(t *testing.T) { - // Start the receive workflow + // Clear the sync event before starting + sendRecvSyncEvent.Clear() + + // Start the receive workflow - it will wait for sendRecvSyncEvent before calling Recv receiveHandle, err := RunWorkflow(dbosCtx, receiveWorkflow, "test-topic") require.NoError(t, err, "failed to start receive workflow") - time.Sleep(500 * time.Millisecond) // Ensure receive workflow is waiting so we don't miss the notification - - // Send a message to the receive workflow - handle, err := RunWorkflow(dbosCtx, sendWorkflow, sendWorkflowInput{ + // Send messages to the receive workflow + sendHandle, err := RunWorkflow(dbosCtx, sendWorkflow, sendWorkflowInput{ DestinationID: receiveHandle.GetWorkflowID(), Topic: "test-topic", }) require.NoError(t, err, "failed to send message") - _, err = handle.GetResult() + + // Wait for send workflow to complete + _, err = sendHandle.GetResult() require.NoError(t, err, "failed to get result from send workflow") + // Now that the send workflow has completed, signal the receive workflow to proceed + sendRecvSyncEvent.Set() + + // Wait for receive workflow to complete result, err := receiveHandle.GetResult() require.NoError(t, err, "failed to get result from receive workflow") require.Equal(t, "message1-message2-message3", result) // Verify step counting for send workflow (sendWorkflow calls Send 3 times) - sendSteps, err := GetWorkflowSteps(dbosCtx, handle.GetWorkflowID()) + sendSteps, err := GetWorkflowSteps(dbosCtx, sendHandle.GetWorkflowID()) require.NoError(t, err, "failed to get workflow steps for send workflow") require.Len(t, sendSteps, 3, "expected 3 steps in send workflow (3 Send calls), got %d", len(sendSteps)) for i, step := range sendSteps { @@ -1849,12 +1696,10 @@ func TestSendRecv(t *testing.T) { // Verify step counting for receive workflow (receiveWorkflow calls Recv 3 times) receiveSteps, err := GetWorkflowSteps(dbosCtx, receiveHandle.GetWorkflowID()) require.NoError(t, err, "failed to get workflow steps for receive workflow") - require.Len(t, receiveSteps, 4, "expected 4 steps in receive workflow (3 Recv calls + 1 sleep call during the first recv), got %d", len(receiveSteps)) - // Steps 0, 2 and 4 are recv + require.Len(t, receiveSteps, 3, "expected 3 steps in receive workflow (3 Recv calls), got %d", len(receiveSteps)) require.Equal(t, "DBOS.recv", receiveSteps[0].StepName, "expected step 0 to have StepName 'DBOS.recv'") - require.Equal(t, "DBOS.sleep", receiveSteps[1].StepName, "expected step 1 to have StepName 'DBOS.sleep'") + require.Equal(t, "DBOS.recv", receiveSteps[1].StepName, "expected step 1 to have StepName 'DBOS.recv'") require.Equal(t, "DBOS.recv", receiveSteps[2].StepName, "expected step 2 to have StepName 'DBOS.recv'") - require.Equal(t, "DBOS.recv", receiveSteps[3].StepName, "expected step 3 to have StepName 'DBOS.recv'") }) t.Run("SendRecvCustomStruct", func(t *testing.T) { @@ -1921,12 +1766,25 @@ func TestSendRecv(t *testing.T) { }) t.Run("RecvTimeout", func(t *testing.T) { + // Set the event so the receive workflow can proceed immediately + sendRecvSyncEvent.Set() + // Create a receive workflow that tries to receive a message but no send happens receiveHandle, err := RunWorkflow(dbosCtx, receiveWorkflow, "timeout-test-topic") require.NoError(t, err, "failed to start receive workflow") result, err := receiveHandle.GetResult() require.NoError(t, err, "expected no error on timeout") assert.Equal(t, "--", result, "expected -- result on timeout") + // Check that six steps were recorded: recv, sleep, recv, sleep, recv, sleep + steps, err := GetWorkflowSteps(dbosCtx, receiveHandle.GetWorkflowID()) + require.NoError(t, err, "failed to get workflow steps") + require.Len(t, steps, 6, "expected 6 steps in receive workflow, got %d", len(steps)) + require.Equal(t, "DBOS.recv", steps[0].StepName, "expected step 0 to have StepName 'DBOS.recv'") + require.Equal(t, "DBOS.sleep", steps[1].StepName, "expected step 1 to have StepName 'DBOS.sleep'") + require.Equal(t, "DBOS.recv", steps[2].StepName, "expected step 2 to have StepName 'DBOS.recv'") + require.Equal(t, "DBOS.sleep", steps[3].StepName, "expected step 3 to have StepName 'DBOS.sleep'") + require.Equal(t, "DBOS.recv", steps[4].StepName, "expected step 4 to have StepName 'DBOS.recv'") + require.Equal(t, "DBOS.sleep", steps[5].StepName, "expected step 5 to have StepName 'DBOS.sleep'") }) t.Run("RecvMustRunInsideWorkflows", func(t *testing.T) { @@ -1945,6 +1803,9 @@ func TestSendRecv(t *testing.T) { }) t.Run("SendOutsideWorkflow", func(t *testing.T) { + // Set the event so the receive workflow can proceed immediately + sendRecvSyncEvent.Set() + // Start a receive workflow to have a valid destination receiveHandle, err := RunWorkflow(dbosCtx, receiveWorkflow, "outside-workflow-topic") require.NoError(t, err, "failed to start receive workflow") @@ -2019,6 +1880,9 @@ func TestSendRecv(t *testing.T) { }) t.Run("SendCannotBeCalledWithinStep", func(t *testing.T) { + // Set the event so the receive workflow can proceed immediately + sendRecvSyncEvent.Set() + // Start a receive workflow to have a valid destination receiveHandle, err := RunWorkflow(dbosCtx, receiveWorkflow, "send-within-step-topic") require.NoError(t, err, "failed to start receive workflow")