diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 10cb2e1ea..53d6dd638 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -20,7 +20,6 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/discovery" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server" - "github.com/stacklok/toolhive/pkg/workloads" ) var rootCmd = &cobra.Command{ @@ -227,32 +226,22 @@ func discoverBackends(ctx context.Context, cfg *config.Config) ([]vmcp.Backend, } // Initialize managers for backend discovery - logger.Info("Initializing workload and group managers") - workloadsManager, err := workloads.NewManager(ctx) + logger.Info("Initializing group manager") + groupsManager, err := groups.NewManager() if err != nil { - logger.Warnf("Failed to create workloads manager (expected in Kubernetes): %v", err) - logger.Warnf("Backend discovery will be skipped - continuing with empty backend list") - return []vmcp.Backend{}, backendClient, nil + return nil, nil, fmt.Errorf("failed to create groups manager: %w", err) } - groupsManager, err := groups.NewManager() + // Create backend discoverer based on runtime environment + discoverer, err := aggregator.NewBackendDiscoverer(ctx, groupsManager, cfg.OutgoingAuth) if err != nil { - logger.Warnf("Failed to create groups manager (expected in Kubernetes): %v", err) - logger.Warnf("Backend discovery will be skipped - continuing with empty backend list") - return []vmcp.Backend{}, backendClient, nil + return nil, nil, fmt.Errorf("failed to create backend discoverer: %w", err) } - // Create backend discoverer and discover backends - discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth) - logger.Infof("Discovering backends in group: %s", cfg.Group) backends, err := discoverer.Discover(ctx, cfg.Group) if err != nil { - // Handle discovery errors gracefully - this is expected in Kubernetes - logger.Warnf("CLI backend discovery failed (likely running in Kubernetes): %v", err) - logger.Warnf("Kubernetes backend discovery is not yet implemented - continuing with empty backend list") - logger.Warnf("The vmcp server will start but won't proxy any backends until this feature is implemented") - return []vmcp.Backend{}, backendClient, nil + return nil, nil, fmt.Errorf("failed to discover backends: %w", err) } if len(backends) == 0 { diff --git a/pkg/vmcp/aggregator/cli_discoverer.go b/pkg/vmcp/aggregator/cli_discoverer.go deleted file mode 100644 index c577a93be..000000000 --- a/pkg/vmcp/aggregator/cli_discoverer.go +++ /dev/null @@ -1,153 +0,0 @@ -package aggregator - -import ( - "context" - "fmt" - - rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/pkg/workloads" -) - -// cliBackendDiscoverer discovers backend MCP servers from Docker/Podman workloads in a group. -// This is the CLI version of BackendDiscoverer that uses the workloads.Manager. -type cliBackendDiscoverer struct { - workloadsManager workloads.Manager - groupsManager groups.Manager - authConfig *config.OutgoingAuthConfig -} - -// NewCLIBackendDiscoverer creates a new CLI-based backend discoverer. -// It discovers workloads from Docker/Podman containers managed by ToolHive. -// -// The authConfig parameter configures authentication for discovered backends. -// If nil, backends will have no authentication configured. -func NewCLIBackendDiscoverer( - workloadsManager workloads.Manager, - groupsManager groups.Manager, - authConfig *config.OutgoingAuthConfig, -) BackendDiscoverer { - return &cliBackendDiscoverer{ - workloadsManager: workloadsManager, - groupsManager: groupsManager, - authConfig: authConfig, - } -} - -// Discover finds all backend workloads in the specified group. -// Returns all accessible backends with their health status marked based on workload status. -// The groupRef is the group name (e.g., "engineering-team"). -func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { - logger.Infof("Discovering backends in group %s", groupRef) - - // Verify that the group exists - exists, err := d.groupsManager.Exists(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to check if group exists: %w", err) - } - if !exists { - return nil, fmt.Errorf("group %s not found", groupRef) - } - - // Get all workload names in the group - workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to list workloads in group: %w", err) - } - - if len(workloadNames) == 0 { - logger.Infof("No workloads found in group %s", groupRef) - return []vmcp.Backend{}, nil - } - - logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) - - // Query each workload and convert to backend - var backends []vmcp.Backend - for _, name := range workloadNames { - workload, err := d.workloadsManager.GetWorkload(ctx, name) - if err != nil { - logger.Warnf("Failed to get workload %s: %v, skipping", name, err) - continue - } - - // Skip workloads without a URL (not accessible) - if workload.URL == "" { - logger.Debugf("Skipping workload %s without URL", name) - continue - } - - // Map workload status to backend health status - healthStatus := mapWorkloadStatusToHealth(workload.Status) - - // Convert core.Workload to vmcp.Backend - // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. - // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. - // ProxyMode tells us which transport the vmcp client should use. - transportType := workload.ProxyMode - if transportType == "" { - // Fallback to TransportType if ProxyMode is not set (for direct transports) - transportType = workload.TransportType.String() - } - - backend := vmcp.Backend{ - ID: name, - Name: name, - BaseURL: workload.URL, - TransportType: transportType, - HealthStatus: healthStatus, - Metadata: make(map[string]string), - } - - // Apply authentication configuration if provided - authStrategy, authMetadata := d.authConfig.ResolveForBackend(name) - backend.AuthStrategy = authStrategy - backend.AuthMetadata = authMetadata - if authStrategy != "" { - logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) - } - - // Copy user labels to metadata first - for k, v := range workload.Labels { - backend.Metadata[k] = v - } - - // Set system metadata (these override user labels to prevent conflicts) - backend.Metadata["group"] = groupRef - backend.Metadata["tool_type"] = workload.ToolType - backend.Metadata["workload_status"] = string(workload.Status) - - backends = append(backends, backend) - logger.Debugf("Discovered backend %s: %s (%s) with health status %s", - backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) - } - - if len(backends) == 0 { - logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) - return []vmcp.Backend{}, nil - } - - logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) - return backends, nil -} - -// mapWorkloadStatusToHealth converts a workload status to a backend health status. -func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { - switch status { - case rt.WorkloadStatusRunning: - return vmcp.BackendHealthy - case rt.WorkloadStatusUnhealthy: - return vmcp.BackendUnhealthy - case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving: - return vmcp.BackendUnhealthy - case rt.WorkloadStatusStarting, rt.WorkloadStatusUnknown: - return vmcp.BackendUnknown - case rt.WorkloadStatusUnauthenticated: - return vmcp.BackendUnauthenticated - default: - return vmcp.BackendUnknown - } -} diff --git a/pkg/vmcp/aggregator/cli_discoverer_test.go b/pkg/vmcp/aggregator/cli_discoverer_test.go deleted file mode 100644 index 9c3402fad..000000000 --- a/pkg/vmcp/aggregator/cli_discoverer_test.go +++ /dev/null @@ -1,250 +0,0 @@ -package aggregator - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/groups/mocks" - "github.com/stacklok/toolhive/pkg/transport/types" - "github.com/stacklok/toolhive/pkg/vmcp" - workloadmocks "github.com/stacklok/toolhive/pkg/workloads/mocks" -) - -const testGroupName = "test-group" - -func TestCLIBackendDiscoverer_Discover(t *testing.T) { - t.Parallel() - - t.Run("successful discovery with multiple backends", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workload1 := newTestWorkload("workload1", - withToolType("github"), - withLabels(map[string]string{"env": "prod"})) - - workload2 := newTestWorkload("workload2", - withURL("http://localhost:8081/mcp"), - withTransport(types.TransportTypeSSE), - withToolType("jira")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1", "workload2"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 2) - assert.Equal(t, "workload1", backends[0].ID) - assert.Equal(t, "http://localhost:8080/mcp", backends[0].BaseURL) - assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) - assert.Equal(t, "github", backends[0].Metadata["tool_type"]) - assert.Equal(t, "prod", backends[0].Metadata["env"]) - assert.Equal(t, "workload2", backends[1].ID) - assert.Equal(t, "sse", backends[1].TransportType) - }) - - t.Run("discovers workloads with different statuses", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - runningWorkload := newTestWorkload("running-workload") - stoppedWorkload := newTestWorkload("stopped-workload", - withStatus(runtime.WorkloadStatusStopped), - withURL("http://localhost:8081/mcp"), - withTransport(types.TransportTypeSSE)) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"running-workload", "stopped-workload"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 2) - assert.Equal(t, "running-workload", backends[0].ID) - assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) - assert.Equal(t, "stopped-workload", backends[1].ID) - assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) - assert.Equal(t, "stopped", backends[1].Metadata["workload_status"]) - }) - - t.Run("filters out workloads without URL", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workloadWithURL := newTestWorkload("workload1") - workloadWithoutURL := newTestWorkload("workload2", withURL("")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1", "workload2"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 1) - assert.Equal(t, "workload1", backends[0].ID) - }) - - t.Run("returns empty list when all workloads lack URLs", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workload1 := newTestWorkload("workload1", withURL("")) - workload2 := newTestWorkload("workload2", withStatus(runtime.WorkloadStatusStopped), withURL("")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1", "workload2"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - assert.Empty(t, backends) - }) - - t.Run("returns error when group does not exist", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), "nonexistent-group") - - require.Error(t, err) - assert.Nil(t, backends) - assert.Contains(t, err.Error(), "not found") - }) - - t.Run("returns error when group check fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.Error(t, err) - assert.Nil(t, backends) - assert.Contains(t, err.Error(), "failed to check if group exists") - }) - - t.Run("returns empty list when group is empty", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), "empty-group") - - require.NoError(t, err) - assert.Empty(t, backends) - }) - - t.Run("discovers all workloads regardless of health status", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - stoppedWorkload := newTestWorkload("stopped1", withStatus(runtime.WorkloadStatusStopped)) - errorWorkload := newTestWorkload("error1", - withStatus(runtime.WorkloadStatusError), - withURL("http://localhost:8081/mcp"), - withTransport(types.TransportTypeSSE)) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"stopped1", "error1"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 2) - assert.Equal(t, vmcp.BackendUnhealthy, backends[0].HealthStatus) - assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) - }) - - t.Run("gracefully handles workload get failures", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - goodWorkload := newTestWorkload("good-workload") - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"good-workload", "failing-workload"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "good-workload").Return(goodWorkload, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). - Return(core.Workload{}, errors.New("workload query failed")) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 1) - assert.Equal(t, "good-workload", backends[0].ID) - }) -} diff --git a/pkg/vmcp/aggregator/discoverer.go b/pkg/vmcp/aggregator/discoverer.go index b4cff44d4..4286eae0b 100644 --- a/pkg/vmcp/aggregator/discoverer.go +++ b/pkg/vmcp/aggregator/discoverer.go @@ -1,8 +1,167 @@ // Package aggregator provides platform-specific backend discovery implementations. // -// This file serves as a navigation reference for backend discovery implementations: -// - CLI (Docker/Podman): see cli_discoverer.go -// - Kubernetes: see k8s_discoverer.go +// This file contains: +// - Unified backend discoverer implementation (works with both CLI and Kubernetes) +// - Factory function to create BackendDiscoverer based on runtime environment +// - WorkloadDiscoverer interface and implementations are in pkg/vmcp/workloads // // The BackendDiscoverer interface is defined in aggregator.go. package aggregator + +import ( + "context" + "fmt" + + rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/groups" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/vmcp/workloads" + workloadsmgr "github.com/stacklok/toolhive/pkg/workloads" +) + +// backendDiscoverer discovers backend MCP servers using a WorkloadDiscoverer. +// This is a unified discoverer that works with both CLI and Kubernetes workloads. +type backendDiscoverer struct { + workloadsManager workloads.Discoverer + groupsManager groups.Manager + authConfig *config.OutgoingAuthConfig +} + +// NewUnifiedBackendDiscoverer creates a unified backend discoverer that works with both +// CLI and Kubernetes workloads through the WorkloadDiscoverer interface. +// +// The authConfig parameter configures authentication for discovered backends. +// If nil, backends will have no authentication configured. +func NewUnifiedBackendDiscoverer( + workloadsManager workloads.Discoverer, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { + return &backendDiscoverer{ + workloadsManager: workloadsManager, + groupsManager: groupsManager, + authConfig: authConfig, + } +} + +// NewBackendDiscoverer creates a unified BackendDiscoverer based on the runtime environment. +// It automatically detects whether to use CLI (Docker/Podman) or Kubernetes workloads +// and creates the appropriate WorkloadDiscoverer implementation. +// +// Parameters: +// - ctx: Context for creating managers +// - groupsManager: Manager for group operations (must already be initialized) +// - authConfig: Outgoing authentication configuration for discovered backends +// +// Returns: +// - BackendDiscoverer: A unified discoverer that works with both CLI and Kubernetes workloads +// - error: If manager creation fails +func NewBackendDiscoverer( + ctx context.Context, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) (BackendDiscoverer, error) { + var workloadDiscoverer workloads.Discoverer + + if rt.IsKubernetesRuntime() { + k8sDiscoverer, err := workloads.NewK8SDiscoverer() + if err != nil { + return nil, fmt.Errorf("failed to create Kubernetes workload discoverer: %w", err) + } + workloadDiscoverer = k8sDiscoverer + } else { + manager, err := workloadsmgr.NewManager(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create workload manager: %w", err) + } + workloadDiscoverer = manager + } + return NewUnifiedBackendDiscoverer(workloadDiscoverer, groupsManager, authConfig), nil +} + +// NewBackendDiscovererWithManager creates a unified BackendDiscoverer with a pre-configured +// WorkloadDiscoverer. This is useful for testing or when you already have a workload manager. +func NewBackendDiscovererWithManager( + workloadManager workloads.Discoverer, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { + return NewUnifiedBackendDiscoverer(workloadManager, groupsManager, authConfig) +} + +// Discover finds all backend workloads in the specified group. +// Returns all accessible backends with their health status marked based on workload status. +// The groupRef is the group name (e.g., "engineering-team"). +func (d *backendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { + logger.Infof("Discovering backends in group %s", groupRef) + + // Verify that the group exists + exists, err := d.groupsManager.Exists(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to check if group exists: %w", err) + } + if !exists { + return nil, fmt.Errorf("group %s not found", groupRef) + } + + // Get all workload names in the group + workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to list workloads in group: %w", err) + } + + if len(workloadNames) == 0 { + logger.Infof("No workloads found in group %s", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) + + // Query each workload and convert to backend + var backends []vmcp.Backend + for _, name := range workloadNames { + backend, err := d.workloadsManager.GetWorkloadAsVMCPBackend(ctx, name) + if err != nil { + logger.Warnf("Failed to get workload %s: %v, skipping", name, err) + continue + } + + // Skip workloads that are not accessible (GetWorkload returns nil) + if backend == nil { + continue + } + + // Apply authentication configuration if provided + var authStrategy string + var authMetadata map[string]any + if d.authConfig != nil { + authStrategy, authMetadata = d.authConfig.ResolveForBackend(name) + backend.AuthStrategy = authStrategy + backend.AuthMetadata = authMetadata + if authStrategy != "" { + logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) + } + } + + // Set group metadata (override user labels to prevent conflicts) + if backend.Metadata == nil { + backend.Metadata = make(map[string]string) + } + backend.Metadata["group"] = groupRef + + logger.Debugf("Discovered backend %s: %s (%s) with health status %s", + backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) + + backends = append(backends, *backend) + } + + if len(backends) == 0 { + logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) + return backends, nil +} diff --git a/pkg/vmcp/aggregator/discoverer_test.go b/pkg/vmcp/aggregator/discoverer_test.go new file mode 100644 index 000000000..70311fa2a --- /dev/null +++ b/pkg/vmcp/aggregator/discoverer_test.go @@ -0,0 +1,410 @@ +package aggregator + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/groups/mocks" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" + discoverermocks "github.com/stacklok/toolhive/pkg/vmcp/workloads/mocks" +) + +const testGroupName = "test-group" + +func TestBackendDiscoverer_Discover(t *testing.T) { + t.Parallel() + + t.Run("successful discovery with multiple backends", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + backend1 := &vmcp.Backend{ + ID: "workload1", + Name: "workload1", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{ + "tool_type": "github", + "workload_status": "running", + "env": "prod", + }, + } + backend2 := &vmcp.Backend{ + ID: "workload2", + Name: "workload2", + BaseURL: "http://localhost:8081/mcp", + TransportType: "sse", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{ + "tool_type": "jira", + "workload_status": "running", + }, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(backend1, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload2").Return(backend2, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, "workload1", backends[0].ID) + assert.Equal(t, "http://localhost:8080/mcp", backends[0].BaseURL) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "github", backends[0].Metadata["tool_type"]) + assert.Equal(t, "prod", backends[0].Metadata["env"]) + assert.Equal(t, "workload2", backends[1].ID) + assert.Equal(t, "sse", backends[1].TransportType) + }) + + t.Run("discovers workloads with different health statuses", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + healthyBackend := &vmcp.Backend{ + ID: "healthy-workload", + Name: "healthy-workload", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{"workload_status": "running"}, + } + unhealthyBackend := &vmcp.Backend{ + ID: "unhealthy-workload", + Name: "unhealthy-workload", + BaseURL: "http://localhost:8081/mcp", + TransportType: "sse", + HealthStatus: vmcp.BackendUnhealthy, + Metadata: map[string]string{"workload_status": "stopped"}, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"healthy-workload", "unhealthy-workload"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "healthy-workload").Return(healthyBackend, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "unhealthy-workload").Return(unhealthyBackend, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, "healthy-workload", backends[0].ID) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "unhealthy-workload", backends[1].ID) + assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) + assert.Equal(t, "stopped", backends[1].Metadata["workload_status"]) + }) + + t.Run("filters out workloads without URL", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + backendWithURL := &vmcp.Backend{ + ID: "workload1", + Name: "workload1", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{}, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(backendWithURL, nil) + // workload2 has no URL, so GetWorkload returns nil + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload2").Return(nil, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "workload1", backends[0].ID) + }) + + t.Run("returns empty list when all workloads lack URLs", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(nil, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload2").Return(nil, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + assert.Empty(t, backends) + }) + + t.Run("returns error when group does not exist", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), "nonexistent-group") + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("returns error when group check fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "failed to check if group exists") + }) + + t.Run("returns empty list when group is empty", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), "empty-group") + + require.NoError(t, err) + assert.Empty(t, backends) + }) + + t.Run("gracefully handles workload get failures", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + goodBackend := &vmcp.Backend{ + ID: "good-workload", + Name: "good-workload", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{}, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"good-workload", "failing-workload"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "good-workload").Return(goodBackend, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "failing-workload"). + Return(nil, errors.New("workload query failed")) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "good-workload", backends[0].ID) + }) + + t.Run("returns error when list workloads fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return(nil, errors.New("failed to list workloads")) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "failed to list workloads in group") + }) + + t.Run("applies authentication configuration", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + backend := &vmcp.Backend{ + ID: "workload1", + Name: "workload1", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{}, + } + + authConfig := &config.OutgoingAuthConfig{ + Backends: map[string]*config.BackendAuthStrategy{ + "workload1": { + Type: "bearer", + Metadata: map[string]any{ + "token": "test-token", + }, + }, + }, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(backend, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, authConfig) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "bearer", backends[0].AuthStrategy) + assert.Equal(t, "test-token", backends[0].AuthMetadata["token"]) + }) +} + +// TestCLIWorkloadDiscoverer tests the CLI workload discoverer implementation +// to ensure it correctly converts CLI workloads to backends. +func TestCLIWorkloadDiscoverer(t *testing.T) { + t.Parallel() + + t.Run("converts CLI workload to backend correctly", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockManager := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + backend := &vmcp.Backend{ + ID: "workload1", + Name: "workload1", + BaseURL: "http://localhost:8080/mcp", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{"tool_type": "github", "env": "prod"}, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockManager.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1"}, nil) + mockManager.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(backend, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockManager, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "workload1", backends[0].ID) + assert.Equal(t, "http://localhost:8080/mcp", backends[0].BaseURL) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "github", backends[0].Metadata["tool_type"]) + assert.Equal(t, "prod", backends[0].Metadata["env"]) + }) + + t.Run("maps CLI workload statuses to health correctly", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + runningBackend := &vmcp.Backend{ + ID: "running-workload", + Name: "running-workload", + BaseURL: "http://localhost:8080/mcp", + HealthStatus: vmcp.BackendHealthy, + } + stoppedBackend := &vmcp.Backend{ + ID: "stopped-workload", + Name: "stopped-workload", + BaseURL: "http://localhost:8081/mcp", + HealthStatus: vmcp.BackendUnhealthy, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"running-workload", "stopped-workload"}, nil) + // The discoverer iterates through all workloads in order + mockDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "running-workload").Return(runningBackend, nil) + mockDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "stopped-workload").Return(stoppedBackend, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + // Sort backends by name to ensure consistent ordering for assertions + if backends[0].Name > backends[1].Name { + backends[0], backends[1] = backends[1], backends[0] + } + // Find the correct backend by name + var running, stopped *vmcp.Backend + for i := range backends { + if backends[i].Name == "running-workload" { + running = &backends[i] + } + if backends[i].Name == "stopped-workload" { + stopped = &backends[i] + } + } + require.NotNil(t, running, "running-workload should be found") + require.NotNil(t, stopped, "stopped-workload should be found") + assert.Equal(t, vmcp.BackendHealthy, running.HealthStatus) + assert.Equal(t, vmcp.BackendUnhealthy, stopped.HealthStatus) + }) +} diff --git a/pkg/vmcp/aggregator/k8s_discoverer.go b/pkg/vmcp/aggregator/k8s_discoverer.go deleted file mode 100644 index b9f61fbc0..000000000 --- a/pkg/vmcp/aggregator/k8s_discoverer.go +++ /dev/null @@ -1,33 +0,0 @@ -package aggregator - -import ( - "context" - "fmt" - - "github.com/stacklok/toolhive/pkg/vmcp" -) - -// k8sBackendDiscoverer discovers backend MCP servers from Kubernetes pods/services in a group. -// This is the Kubernetes version of BackendDiscoverer (not implemented yet). -type k8sBackendDiscoverer struct { - // TODO: Add Kubernetes client and group CRD interfaces -} - -// NewK8sBackendDiscoverer creates a new Kubernetes-based backend discoverer. -// It discovers workloads from Kubernetes MCPServer resources managed by the operator. -func NewK8sBackendDiscoverer() BackendDiscoverer { - return &k8sBackendDiscoverer{} -} - -// Discover finds all backend workloads in the specified Kubernetes group. -// The groupRef is the MCPGroup name. -func (*k8sBackendDiscoverer) Discover(_ context.Context, _ string) ([]vmcp.Backend, error) { - // TODO: Implement Kubernetes backend discovery - // 1. Query MCPGroup CRD by name - // 2. List MCPServer resources with matching group label - // 3. Filter for ready/running MCPServers - // 4. Build service URLs (http://service-name.namespace.svc.cluster.local:port) - // 5. Extract transport type from MCPServer spec - // 6. Return vmcp.Backend list - return nil, fmt.Errorf("kubernetes backend discovery not yet implemented") -} diff --git a/pkg/vmcp/aggregator/testhelpers_test.go b/pkg/vmcp/aggregator/testhelpers_test.go index 0b766c508..6aeafd327 100644 --- a/pkg/vmcp/aggregator/testhelpers_test.go +++ b/pkg/vmcp/aggregator/testhelpers_test.go @@ -1,58 +1,9 @@ package aggregator import ( - "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/vmcp" ) -// Test fixture builders to reduce verbosity in tests - -func newTestWorkload(name string, opts ...func(*core.Workload)) core.Workload { - w := core.Workload{ - Name: name, - Status: runtime.WorkloadStatusRunning, - URL: "http://localhost:8080/mcp", - TransportType: types.TransportTypeStreamableHTTP, - Group: testGroupName, - } - for _, opt := range opts { - opt(&w) - } - return w -} - -func withStatus(status runtime.WorkloadStatus) func(*core.Workload) { - return func(w *core.Workload) { - w.Status = status - } -} - -func withURL(url string) func(*core.Workload) { - return func(w *core.Workload) { - w.URL = url - } -} - -func withTransport(transport types.TransportType) func(*core.Workload) { - return func(w *core.Workload) { - w.TransportType = transport - } -} - -func withToolType(toolType string) func(*core.Workload) { - return func(w *core.Workload) { - w.ToolType = toolType - } -} - -func withLabels(labels map[string]string) func(*core.Workload) { - return func(w *core.Workload) { - w.Labels = labels - } -} - func newTestBackend(id string, opts ...func(*vmcp.Backend)) vmcp.Backend { b := vmcp.Backend{ ID: id, diff --git a/pkg/vmcp/workloads/discoverer.go b/pkg/vmcp/workloads/discoverer.go new file mode 100644 index 000000000..fb11f947e --- /dev/null +++ b/pkg/vmcp/workloads/discoverer.go @@ -0,0 +1,25 @@ +// Package workloads provides the WorkloadDiscoverer interface for discovering +// backend workloads in both CLI and Kubernetes environments. +package workloads + +import ( + "context" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// Discoverer is the interface for workload managers used by vmcp. +// This interface contains only the methods needed for backend discovery, +// allowing both CLI and Kubernetes managers to implement it. +// +//go:generate mockgen -destination=mocks/mock_discoverer.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/workloads Discoverer +type Discoverer interface { + // ListWorkloadsInGroup returns all workload names that belong to the specified group + ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) + + // GetWorkloadAsVMCPBackend retrieves workload details by name and converts it to a vmcp.Backend. + // The returned Backend should have all fields populated except AuthStrategy and AuthMetadata, + // which will be set by the discoverer based on the auth configuration. + // Returns nil if the workload exists but is not accessible (e.g., no URL). + GetWorkloadAsVMCPBackend(ctx context.Context, workloadName string) (*vmcp.Backend, error) +} diff --git a/pkg/vmcp/workloads/k8s.go b/pkg/vmcp/workloads/k8s.go new file mode 100644 index 000000000..90dd2b5c0 --- /dev/null +++ b/pkg/vmcp/workloads/k8s.go @@ -0,0 +1,220 @@ +package workloads + +import ( + "context" + "fmt" + "strings" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/k8s" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/transport" + transporttypes "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// k8sDiscoverer is a direct implementation of Discoverer for Kubernetes workloads. +// It uses the Kubernetes client directly to query MCPServer CRDs instead of going through k8s.Manager. +type k8sDiscoverer struct { + k8sClient client.Client + namespace string +} + +// NewK8SDiscoverer creates a new Kubernetes workload discoverer that directly uses +// the Kubernetes client to discover MCPServer CRDs. +func NewK8SDiscoverer() (Discoverer, error) { + // Create a scheme for controller-runtime client + scheme := runtime.NewScheme() + utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(mcpv1alpha1.AddToScheme(scheme)) + + // Create controller-runtime client + k8sClient, err := k8s.NewControllerRuntimeClient(scheme) + if err != nil { + return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) + } + + // Detect namespace + namespace := k8s.GetCurrentNamespace() + + return &k8sDiscoverer{ + k8sClient: k8sClient, + namespace: namespace, + }, nil +} + +// ListWorkloadsInGroup returns all workload names that belong to the specified group. +func (d *k8sDiscoverer) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + mcpServerList := &mcpv1alpha1.MCPServerList{} + listOpts := []client.ListOption{ + client.InNamespace(d.namespace), + } + + if err := d.k8sClient.List(ctx, mcpServerList, listOpts...); err != nil { + return nil, fmt.Errorf("failed to list MCPServers: %w", err) + } + + var groupWorkloads []string + for i := range mcpServerList.Items { + mcpServer := &mcpServerList.Items[i] + if mcpServer.Spec.GroupRef == groupName { + groupWorkloads = append(groupWorkloads, mcpServer.Name) + } + } + + return groupWorkloads, nil +} + +// GetWorkloadAsVMCPBackend retrieves workload details by name and converts it to a vmcp.Backend. +func (d *k8sDiscoverer) GetWorkloadAsVMCPBackend(ctx context.Context, workloadName string) (*vmcp.Backend, error) { + mcpServer := &mcpv1alpha1.MCPServer{} + key := client.ObjectKey{Name: workloadName, Namespace: d.namespace} + if err := d.k8sClient.Get(ctx, key, mcpServer); err != nil { + if errors.IsNotFound(err) { + return nil, fmt.Errorf("MCPServer %s not found", workloadName) + } + return nil, fmt.Errorf("failed to get MCPServer: %w", err) + } + + // Convert MCPServer to Backend + backend := d.mcpServerToBackend(mcpServer) + + // Skip workloads without a URL (not accessible) + if backend.BaseURL == "" { + logger.Debugf("Skipping workload %s without URL", workloadName) + return nil, nil + } + + return backend, nil +} + +// mcpServerToBackend converts an MCPServer CRD to a vmcp.Backend. +func (*k8sDiscoverer) mcpServerToBackend(mcpServer *mcpv1alpha1.MCPServer) *vmcp.Backend { + // Parse transport type + transportType, err := transporttypes.ParseTransportType(mcpServer.Spec.Transport) + if err != nil { + logger.Warnf("Failed to parse transport type %s for MCPServer %s: %v", mcpServer.Spec.Transport, mcpServer.Name, err) + transportType = transporttypes.TransportTypeStreamableHTTP + } + + // Calculate effective proxy mode + effectiveProxyMode := getEffectiveProxyMode(transportType, mcpServer.Spec.ProxyMode) + + // Generate URL from status or reconstruct from spec + url := mcpServer.Status.URL + if url == "" { + port := int(mcpServer.Spec.ProxyPort) + if port == 0 { + port = int(mcpServer.Spec.Port) // Fallback to deprecated Port field + } + if port > 0 { + url = transport.GenerateMCPServerURL(mcpServer.Spec.Transport, transport.LocalhostIPv4, port, mcpServer.Name, "") + } + } + + // Map workload phase to backend health status + healthStatus := mapK8SWorkloadPhaseToHealth(mcpServer.Status.Phase) + + // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. + // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. + // ProxyMode tells us which transport the vmcp client should use. + transportTypeStr := effectiveProxyMode + if transportTypeStr == "" { + // Fallback to TransportType if ProxyMode is not set (for direct transports) + transportTypeStr = transportType.String() + if transportTypeStr == "" { + transportTypeStr = "unknown" + } + } + + // Extract user labels from annotations (Kubernetes doesn't have container labels like Docker) + userLabels := make(map[string]string) + if mcpServer.Annotations != nil { + // Filter out standard Kubernetes annotations + for key, value := range mcpServer.Annotations { + if !isStandardK8sAnnotation(key) { + userLabels[key] = value + } + } + } + + backend := &vmcp.Backend{ + ID: mcpServer.Name, + Name: mcpServer.Name, + BaseURL: url, + TransportType: transportTypeStr, + HealthStatus: healthStatus, + Metadata: make(map[string]string), + } + + // Copy user labels to metadata first + for k, v := range userLabels { + backend.Metadata[k] = v + } + + // Set system metadata (these override user labels to prevent conflicts) + backend.Metadata["tool_type"] = "mcp" + backend.Metadata["workload_status"] = string(mcpServer.Status.Phase) + if mcpServer.Namespace != "" { + backend.Metadata["namespace"] = mcpServer.Namespace + } + + return backend +} + +// mapK8SWorkloadPhaseToHealth converts a MCPServerPhase to a backend health status. +func mapK8SWorkloadPhaseToHealth(phase mcpv1alpha1.MCPServerPhase) vmcp.BackendHealthStatus { + switch phase { + case mcpv1alpha1.MCPServerPhaseRunning: + return vmcp.BackendHealthy + case mcpv1alpha1.MCPServerPhaseFailed: + return vmcp.BackendUnhealthy + case mcpv1alpha1.MCPServerPhaseTerminating: + return vmcp.BackendUnhealthy + case mcpv1alpha1.MCPServerPhasePending: + return vmcp.BackendUnknown + default: + return vmcp.BackendUnknown + } +} + +// getEffectiveProxyMode calculates the effective proxy mode based on transport type and configured proxy mode. +// This replicates the logic from pkg/workloads/types/proxy_mode.go +func getEffectiveProxyMode(transportType transporttypes.TransportType, configuredProxyMode string) string { + // If proxy mode is explicitly configured, use it + if configuredProxyMode != "" { + return configuredProxyMode + } + + // For stdio transports, default to streamable-http proxy mode + if transportType == transporttypes.TransportTypeStdio { + return transporttypes.ProxyModeStreamableHTTP.String() + } + + // For direct transports (SSE, streamable-http), use the transport type as proxy mode + return transportType.String() +} + +// isStandardK8sAnnotation checks if an annotation key is a standard Kubernetes annotation. +func isStandardK8sAnnotation(key string) bool { + // Common Kubernetes annotation prefixes + standardPrefixes := []string{ + "kubectl.kubernetes.io/", + "kubernetes.io/", + "deployment.kubernetes.io/", + "k8s.io/", + } + + for _, prefix := range standardPrefixes { + if strings.HasPrefix(key, prefix) { + return true + } + } + return false +} diff --git a/pkg/vmcp/workloads/mocks/mock_discoverer.go b/pkg/vmcp/workloads/mocks/mock_discoverer.go new file mode 100644 index 000000000..daed3e8d9 --- /dev/null +++ b/pkg/vmcp/workloads/mocks/mock_discoverer.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/vmcp/workloads (interfaces: Discoverer) +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_discoverer.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/workloads Discoverer +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + vmcp "github.com/stacklok/toolhive/pkg/vmcp" + gomock "go.uber.org/mock/gomock" +) + +// MockDiscoverer is a mock of Discoverer interface. +type MockDiscoverer struct { + ctrl *gomock.Controller + recorder *MockDiscovererMockRecorder + isgomock struct{} +} + +// MockDiscovererMockRecorder is the mock recorder for MockDiscoverer. +type MockDiscovererMockRecorder struct { + mock *MockDiscoverer +} + +// NewMockDiscoverer creates a new mock instance. +func NewMockDiscoverer(ctrl *gomock.Controller) *MockDiscoverer { + mock := &MockDiscoverer{ctrl: ctrl} + mock.recorder = &MockDiscovererMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDiscoverer) EXPECT() *MockDiscovererMockRecorder { + return m.recorder +} + +// GetWorkloadAsVMCPBackend mocks base method. +func (m *MockDiscoverer) GetWorkloadAsVMCPBackend(ctx context.Context, workloadName string) (*vmcp.Backend, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWorkloadAsVMCPBackend", ctx, workloadName) + ret0, _ := ret[0].(*vmcp.Backend) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetWorkloadAsVMCPBackend indicates an expected call of GetWorkloadAsVMCPBackend. +func (mr *MockDiscovererMockRecorder) GetWorkloadAsVMCPBackend(ctx, workloadName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkloadAsVMCPBackend", reflect.TypeOf((*MockDiscoverer)(nil).GetWorkloadAsVMCPBackend), ctx, workloadName) +} + +// ListWorkloadsInGroup mocks base method. +func (m *MockDiscoverer) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListWorkloadsInGroup", ctx, groupName) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListWorkloadsInGroup indicates an expected call of ListWorkloadsInGroup. +func (mr *MockDiscovererMockRecorder) ListWorkloadsInGroup(ctx, groupName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloadsInGroup", reflect.TypeOf((*MockDiscoverer)(nil).ListWorkloadsInGroup), ctx, groupName) +} diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index 7aac9e5d6..6389a037f 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -27,6 +27,7 @@ import ( "github.com/stacklok/toolhive/pkg/secrets" "github.com/stacklok/toolhive/pkg/state" "github.com/stacklok/toolhive/pkg/transport" + "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/workloads/statuses" "github.com/stacklok/toolhive/pkg/workloads/types" ) @@ -71,7 +72,8 @@ type Manager interface { DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) } -type defaultManager struct { +// DefaultManager is the default implementation of the Manager interface. +type DefaultManager struct { runtime rt.Runtime statuses statuses.StatusManager configProvider config.Provider @@ -86,7 +88,7 @@ const ( ) // NewManager creates a new container manager instance. -func NewManager(ctx context.Context) (Manager, error) { +func NewManager(ctx context.Context) (*DefaultManager, error) { runtime, err := ct.NewFactory().Create(ctx) if err != nil { return nil, err @@ -97,7 +99,7 @@ func NewManager(ctx context.Context) (Manager, error) { return nil, fmt.Errorf("failed to create status manager: %w", err) } - return &defaultManager{ + return &DefaultManager{ runtime: runtime, statuses: statusManager, configProvider: config.NewDefaultProvider(), @@ -116,7 +118,7 @@ func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) return nil, fmt.Errorf("failed to create status manager: %w", err) } - return &defaultManager{ + return &DefaultManager{ runtime: runtime, statuses: statusManager, configProvider: configProvider, @@ -130,7 +132,7 @@ func NewManagerFromRuntime(runtime rt.Runtime) (Manager, error) { return nil, fmt.Errorf("failed to create status manager: %w", err) } - return &defaultManager{ + return &DefaultManager{ runtime: runtime, statuses: statusManager, configProvider: config.NewDefaultProvider(), @@ -145,20 +147,90 @@ func NewManagerFromRuntimeWithProvider(runtime rt.Runtime, configProvider config return nil, fmt.Errorf("failed to create status manager: %w", err) } - return &defaultManager{ + return &DefaultManager{ runtime: runtime, statuses: statusManager, configProvider: configProvider, }, nil } -func (d *defaultManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { +// GetWorkload retrieves details of the named workload including its status. +func (d *DefaultManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { // For the sake of minimizing changes, delegate to the status manager. // Whether this method should still belong to the workload manager is TBD. return d.statuses.GetWorkload(ctx, workloadName) } -func (d *defaultManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { +// GetWorkloadAsVMCPBackend retrieves a workload and converts it to a vmcp.Backend. +// This method eliminates indirection by directly returning the vmcp.Backend type +// needed by vmcp workload discovery, avoiding the need for callers to convert +// from core.Workload to vmcp.Backend. +// Returns nil if the workload exists but is not accessible (e.g., no URL). +func (d *DefaultManager) GetWorkloadAsVMCPBackend(ctx context.Context, workloadName string) (*vmcp.Backend, error) { + workload, err := d.statuses.GetWorkload(ctx, workloadName) + if err != nil { + return nil, err + } + + // Skip workloads without a URL (not accessible) + if workload.URL == "" { + logger.Debugf("Skipping workload %s without URL", workloadName) + return nil, nil + } + + // Map workload status to backend health status + healthStatus := mapWorkloadStatusToVMCPHealth(workload.Status) + + // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. + // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. + // ProxyMode tells us which transport the vmcp client should use. + transportType := workload.ProxyMode + if transportType == "" { + // Fallback to TransportType if ProxyMode is not set (for direct transports) + transportType = workload.TransportType.String() + } + + backend := &vmcp.Backend{ + ID: workload.Name, + Name: workload.Name, + BaseURL: workload.URL, + TransportType: transportType, + HealthStatus: healthStatus, + Metadata: make(map[string]string), + } + + // Copy user labels to metadata first + for k, v := range workload.Labels { + backend.Metadata[k] = v + } + + // Set system metadata (these override user labels to prevent conflicts) + backend.Metadata["tool_type"] = workload.ToolType + backend.Metadata["workload_status"] = string(workload.Status) + + return backend, nil +} + +// mapWorkloadStatusToVMCPHealth converts a WorkloadStatus to a vmcp BackendHealthStatus. +func mapWorkloadStatusToVMCPHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { + switch status { + case rt.WorkloadStatusRunning: + return vmcp.BackendHealthy + case rt.WorkloadStatusUnhealthy: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStarting, rt.WorkloadStatusUnknown: + return vmcp.BackendUnknown + case rt.WorkloadStatusUnauthenticated: + return vmcp.BackendUnauthenticated + default: + return vmcp.BackendUnknown + } +} + +// DoesWorkloadExist checks if a workload with the given name exists. +func (d *DefaultManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { // check if workload exists by trying to get it workload, err := d.statuses.GetWorkload(ctx, workloadName) if err != nil { @@ -175,7 +247,8 @@ func (d *defaultManager) DoesWorkloadExist(ctx context.Context, workloadName str return true, nil } -func (d *defaultManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { +// ListWorkloads retrieves the states of all workloads. +func (d *DefaultManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { // For the sake of minimizing changes, delegate to the status manager. // Whether this method should still belong to the workload manager is TBD. containerWorkloads, err := d.statuses.ListWorkloads(ctx, listAll, labelFilters) @@ -196,7 +269,8 @@ func (d *defaultManager) ListWorkloads(ctx context.Context, listAll bool, labelF return containerWorkloads, nil } -func (d *defaultManager) StopWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { +// StopWorkloads stops the specified workloads by name. +func (d *DefaultManager) StopWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { // Validate all workload names to prevent path traversal attacks for _, name := range names { if err := types.ValidateWorkloadName(name); err != nil { @@ -220,7 +294,7 @@ func (d *defaultManager) StopWorkloads(_ context.Context, names []string) (*errg } // stopSingleWorkload stops a single workload (container or remote) -func (d *defaultManager) stopSingleWorkload(name string) error { +func (d *DefaultManager) stopSingleWorkload(name string) error { // Create a child context with a longer timeout childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -243,7 +317,7 @@ func (d *defaultManager) stopSingleWorkload(name string) error { } // stopRemoteWorkload stops a remote workload -func (d *defaultManager) stopRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { +func (d *DefaultManager) stopRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { logger.Infof("Stopping remote workload %s...", name) // Check if the workload is running by checking its status @@ -289,7 +363,7 @@ func (d *defaultManager) stopRemoteWorkload(ctx context.Context, name string, ru } // stopContainerWorkload stops a container-based workload -func (d *defaultManager) stopContainerWorkload(ctx context.Context, name string) error { +func (d *DefaultManager) stopContainerWorkload(ctx context.Context, name string) error { container, err := d.runtime.GetWorkloadInfo(ctx, name) if err != nil { if errors.Is(err, rt.ErrWorkloadNotFound) { @@ -316,7 +390,8 @@ func (d *defaultManager) stopContainerWorkload(ctx context.Context, name string) return d.stopSingleContainerWorkload(ctx, &container) } -func (d *defaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error { +// RunWorkload runs a workload in the foreground. +func (d *DefaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error { // Ensure that the workload has a status entry before starting the process. if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { // Failure to create the initial state is a fatal error. @@ -334,7 +409,8 @@ func (d *defaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunC return err } -func (d *defaultManager) validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { +// validateSecretParameters validates the secret parameters for a workload. +func (d *DefaultManager) validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { // If there are run secrets, validate them hasRegularSecrets := len(runConfig.Secrets) > 0 @@ -361,7 +437,8 @@ func (d *defaultManager) validateSecretParameters(ctx context.Context, runConfig return nil } -func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { +// RunWorkloadDetached runs a workload in the background. +func (d *DefaultManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { // before running, validate the parameters for the workload err := d.validateSecretParameters(ctx, runConfig) if err != nil { @@ -455,7 +532,8 @@ func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *run return nil } -func (d *defaultManager) GetLogs(ctx context.Context, workloadName string, follow bool) (string, error) { +// GetLogs retrieves the logs of a container. +func (d *DefaultManager) GetLogs(ctx context.Context, workloadName string, follow bool) (string, error) { // Get the logs from the runtime logs, err := d.runtime.GetWorkloadLogs(ctx, workloadName, follow) if err != nil { @@ -470,7 +548,7 @@ func (d *defaultManager) GetLogs(ctx context.Context, workloadName string, follo } // GetProxyLogs retrieves proxy logs from the filesystem -func (*defaultManager) GetProxyLogs(_ context.Context, workloadName string) (string, error) { +func (*DefaultManager) GetProxyLogs(_ context.Context, workloadName string) (string, error) { // Get the proxy log file path logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", workloadName)) if err != nil { @@ -495,7 +573,7 @@ func (*defaultManager) GetProxyLogs(_ context.Context, workloadName string) (str } // deleteWorkload handles deletion of a single workload -func (d *defaultManager) deleteWorkload(name string) error { +func (d *DefaultManager) deleteWorkload(name string) error { // Create a child context with a longer timeout childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -518,7 +596,7 @@ func (d *defaultManager) deleteWorkload(name string) error { } // deleteRemoteWorkload handles deletion of a remote workload -func (d *defaultManager) deleteRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { +func (d *DefaultManager) deleteRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { logger.Infof("Removing remote workload %s...", name) // Set status to removing @@ -545,7 +623,7 @@ func (d *defaultManager) deleteRemoteWorkload(ctx context.Context, name string, } // deleteContainerWorkload handles deletion of a container-based workload (existing logic) -func (d *defaultManager) deleteContainerWorkload(ctx context.Context, name string) error { +func (d *DefaultManager) deleteContainerWorkload(ctx context.Context, name string) error { // Find and validate the container container, err := d.getWorkloadContainer(ctx, name) @@ -590,7 +668,7 @@ func (d *defaultManager) deleteContainerWorkload(ctx context.Context, name strin } // getWorkloadContainer retrieves workload container info with error handling -func (d *defaultManager) getWorkloadContainer(ctx context.Context, name string) (*rt.ContainerInfo, error) { +func (d *DefaultManager) getWorkloadContainer(ctx context.Context, name string) (*rt.ContainerInfo, error) { container, err := d.runtime.GetWorkloadInfo(ctx, name) if err != nil { if errors.Is(err, rt.ErrWorkloadNotFound) { @@ -612,7 +690,7 @@ func (d *defaultManager) getWorkloadContainer(ctx context.Context, name string) // - If the supervisor exits cleanly, it cleans up the PID // - If killed unexpectedly, the PID remains but stopProcess will handle it gracefully // - The main issue we're preventing is accumulating zombie supervisors from repeated restarts -func (d *defaultManager) isSupervisorProcessAlive(ctx context.Context, name string) bool { +func (d *DefaultManager) isSupervisorProcessAlive(ctx context.Context, name string) bool { if name == "" { return false } @@ -629,7 +707,7 @@ func (d *defaultManager) isSupervisorProcessAlive(ctx context.Context, name stri } // stopProcess stops the proxy process associated with the container -func (d *defaultManager) stopProcess(ctx context.Context, name string) { +func (d *DefaultManager) stopProcess(ctx context.Context, name string) { if name == "" { logger.Warnf("Warning: Could not find base container name in labels") return @@ -657,7 +735,7 @@ func (d *defaultManager) stopProcess(ctx context.Context, name string) { } // stopProxyIfNeeded stops the proxy process if the workload has a base name -func (d *defaultManager) stopProxyIfNeeded(ctx context.Context, name, baseName string) { +func (d *DefaultManager) stopProxyIfNeeded(ctx context.Context, name, baseName string) { logger.Infof("Removing proxy process for %s...", name) if baseName != "" { d.stopProcess(ctx, baseName) @@ -665,7 +743,7 @@ func (d *defaultManager) stopProxyIfNeeded(ctx context.Context, name, baseName s } // removeContainer removes the container from the runtime -func (d *defaultManager) removeContainer(ctx context.Context, name string) error { +func (d *DefaultManager) removeContainer(ctx context.Context, name string) error { logger.Infof("Removing container %s...", name) if err := d.runtime.RemoveWorkload(ctx, name); err != nil { if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { @@ -677,7 +755,7 @@ func (d *defaultManager) removeContainer(ctx context.Context, name string) error } // cleanupWorkloadResources cleans up all resources associated with a workload -func (d *defaultManager) cleanupWorkloadResources(ctx context.Context, name, baseName string, isAuxiliary bool) { +func (d *DefaultManager) cleanupWorkloadResources(ctx context.Context, name, baseName string, isAuxiliary bool) { if baseName == "" { return } @@ -708,7 +786,8 @@ func (d *defaultManager) cleanupWorkloadResources(ctx context.Context, name, bas logger.Infof("Container %s removed", name) } -func (d *defaultManager) DeleteWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { +// DeleteWorkloads deletes the specified workloads by name. +func (d *DefaultManager) DeleteWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { // Validate all workload names to prevent path traversal attacks for _, name := range names { if err := types.ValidateWorkloadName(name); err != nil { @@ -728,7 +807,7 @@ func (d *defaultManager) DeleteWorkloads(_ context.Context, names []string) (*er } // RestartWorkloads restarts the specified workloads by name. -func (d *defaultManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) { +func (d *DefaultManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) { // Validate all workload names to prevent path traversal attacks for _, name := range names { if err := types.ValidateWorkloadName(name); err != nil { @@ -748,7 +827,7 @@ func (d *defaultManager) RestartWorkloads(_ context.Context, names []string, for } // UpdateWorkload updates a workload by stopping, deleting, and recreating it -func (d *defaultManager) UpdateWorkload(_ context.Context, workloadName string, newConfig *runner.RunConfig) (*errgroup.Group, error) { //nolint:lll +func (d *DefaultManager) UpdateWorkload(_ context.Context, workloadName string, newConfig *runner.RunConfig) (*errgroup.Group, error) { //nolint:lll // Validate workload name if err := types.ValidateWorkloadName(workloadName); err != nil { return nil, fmt.Errorf("invalid workload name '%s': %w", workloadName, err) @@ -762,7 +841,7 @@ func (d *defaultManager) UpdateWorkload(_ context.Context, workloadName string, } // updateSingleWorkload handles the update logic for a single workload -func (d *defaultManager) updateSingleWorkload(workloadName string, newConfig *runner.RunConfig) error { +func (d *DefaultManager) updateSingleWorkload(workloadName string, newConfig *runner.RunConfig) error { // Create a child context with a longer timeout childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -799,7 +878,7 @@ func (d *defaultManager) updateSingleWorkload(workloadName string, newConfig *ru } // restartSingleWorkload handles the restart logic for a single workload -func (d *defaultManager) restartSingleWorkload(name string, foreground bool) error { +func (d *DefaultManager) restartSingleWorkload(name string, foreground bool) error { // Create a child context with a longer timeout childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -822,7 +901,7 @@ func (d *defaultManager) restartSingleWorkload(name string, foreground bool) err } // restartRemoteWorkload handles restarting a remote workload -func (d *defaultManager) restartRemoteWorkload( +func (d *DefaultManager) restartRemoteWorkload( ctx context.Context, name string, runConfig *runner.RunConfig, @@ -889,7 +968,7 @@ func (d *defaultManager) restartRemoteWorkload( // restartContainerWorkload handles restarting a container-based workload // //nolint:gocyclo // Complexity is justified - handles multiple restart scenarios and edge cases -func (d *defaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error { +func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error { // Get container info to resolve partial names and extract proper workload name var containerName string var workloadName string @@ -999,7 +1078,7 @@ func (d *defaultManager) restartContainerWorkload(ctx context.Context, name stri } // startWorkload starts the workload in either foreground or background mode -func (d *defaultManager) startWorkload(ctx context.Context, name string, mcpRunner *runner.Runner, foreground bool) error { +func (d *DefaultManager) startWorkload(ctx context.Context, name string, mcpRunner *runner.Runner, foreground bool) error { logger.Infof("Starting tooling server %s...", name) var err error @@ -1044,7 +1123,7 @@ func removeClientConfigurations(containerName string, isAuxiliary bool) error { } // loadRunnerFromState attempts to load a Runner from the state store -func (d *defaultManager) loadRunnerFromState(ctx context.Context, baseName string) (*runner.Runner, error) { +func (d *DefaultManager) loadRunnerFromState(ctx context.Context, baseName string) (*runner.Runner, error) { // Load the run config from the state store runConfig, err := runner.LoadState(ctx, baseName) if err != nil { @@ -1063,7 +1142,7 @@ func (d *defaultManager) loadRunnerFromState(ctx context.Context, baseName strin return runner.NewRunner(runConfig, d.statuses), nil } -func (d *defaultManager) needSecretsPassword(secretOptions []string) bool { +func (d *DefaultManager) needSecretsPassword(secretOptions []string) bool { // If the user did not ask for any secrets, then don't attempt to instantiate // the secrets manager. if len(secretOptions) == 0 { @@ -1075,7 +1154,7 @@ func (d *defaultManager) needSecretsPassword(secretOptions []string) bool { } // cleanupTempPermissionProfile cleans up temporary permission profile files for a given base name -func (*defaultManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { +func (*DefaultManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { // Try to load the saved configuration to get the permission profile path runConfig, err := runner.LoadState(ctx, baseName) if err != nil { @@ -1095,7 +1174,7 @@ func (*defaultManager) cleanupTempPermissionProfile(ctx context.Context, baseNam } // stopSingleContainerWorkload stops a single container workload -func (d *defaultManager) stopSingleContainerWorkload(ctx context.Context, workload *rt.ContainerInfo) error { +func (d *DefaultManager) stopSingleContainerWorkload(ctx context.Context, workload *rt.ContainerInfo) error { childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -1136,7 +1215,7 @@ func (d *defaultManager) stopSingleContainerWorkload(ctx context.Context, worklo } // MoveToGroup moves the specified workloads from one group to another by updating their runconfig. -func (*defaultManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { +func (*DefaultManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { for _, workloadName := range workloadNames { // Validate workload name if err := types.ValidateWorkloadName(workloadName); err != nil { @@ -1171,7 +1250,7 @@ func (*defaultManager) MoveToGroup(ctx context.Context, workloadNames []string, } // ListWorkloadsInGroup returns all workload names that belong to the specified group -func (d *defaultManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { +func (d *DefaultManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { workloads, err := d.ListWorkloads(ctx, true) // listAll=true to include stopped workloads if err != nil { return nil, fmt.Errorf("failed to list workloads: %w", err) @@ -1189,7 +1268,7 @@ func (d *defaultManager) ListWorkloadsInGroup(ctx context.Context, groupName str } // getRemoteWorkloadsFromState retrieves remote servers from the state store -func (d *defaultManager) getRemoteWorkloadsFromState( +func (d *DefaultManager) getRemoteWorkloadsFromState( ctx context.Context, listAll bool, labelFilters []string, diff --git a/pkg/workloads/manager_test.go b/pkg/workloads/manager_test.go index ea971127e..a49603de5 100644 --- a/pkg/workloads/manager_test.go +++ b/pkg/workloads/manager_test.go @@ -160,7 +160,7 @@ func TestDefaultManager_ListWorkloadsInGroup(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupStatusMgr(mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ runtime: nil, // Not needed for this test statuses: mockStatusMgr, } @@ -196,7 +196,7 @@ func TestNewManagerFromRuntime(t *testing.T) { require.NotNil(t, manager) // Verify it's a defaultManager with the runtime set - defaultMgr, ok := manager.(*defaultManager) + defaultMgr, ok := manager.(*DefaultManager) require.True(t, ok) assert.Equal(t, mockRuntime, defaultMgr.runtime) assert.NotNil(t, defaultMgr.statuses) @@ -217,7 +217,7 @@ func TestNewManagerFromRuntimeWithProvider(t *testing.T) { require.NoError(t, err) require.NotNil(t, manager) - defaultMgr, ok := manager.(*defaultManager) + defaultMgr, ok := manager.(*DefaultManager) require.True(t, ok) assert.Equal(t, mockRuntime, defaultMgr.runtime) assert.Equal(t, mockConfigProvider, defaultMgr.configProvider) @@ -288,7 +288,7 @@ func TestDefaultManager_DoesWorkloadExist(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, } @@ -320,7 +320,7 @@ func TestDefaultManager_GetWorkload(t *testing.T) { mockStatusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(expectedWorkload, nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, } @@ -387,7 +387,7 @@ func TestDefaultManager_GetLogs(t *testing.T) { mockRuntime := runtimeMocks.NewMockRuntime(ctrl) tt.setupMocks(mockRuntime) - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, } @@ -437,7 +437,7 @@ func TestDefaultManager_StopWorkloads(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - manager := &defaultManager{} + manager := &DefaultManager{} ctx := context.Background() group, err := manager.StopWorkloads(ctx, tt.workloadNames) @@ -487,7 +487,7 @@ func TestDefaultManager_DeleteWorkloads(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - manager := &defaultManager{} + manager := &DefaultManager{} ctx := context.Background() group, err := manager.DeleteWorkloads(ctx, tt.workloadNames) @@ -534,7 +534,7 @@ func TestDefaultManager_RestartWorkloads(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - manager := &defaultManager{} + manager := &DefaultManager{} ctx := context.Background() group, err := manager.RestartWorkloads(ctx, tt.workloadNames, tt.foreground) @@ -635,7 +635,7 @@ func TestDefaultManager_restartRemoteWorkload(t *testing.T) { statusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(statusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, } @@ -750,7 +750,7 @@ func TestDefaultManager_restartContainerWorkload(t *testing.T) { tt.setupMocks(statusMgr, runtimeMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, runtime: runtimeMgr, } @@ -788,7 +788,7 @@ func TestDefaultManager_restartLogicConsistency(t *testing.T) { // Check if supervisor is alive - return valid PID (healthy) statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(12345, nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, } @@ -826,7 +826,7 @@ func TestDefaultManager_restartLogicConsistency(t *testing.T) { statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, } @@ -866,7 +866,7 @@ func TestDefaultManager_restartLogicConsistency(t *testing.T) { // Check if supervisor is alive - return valid PID (healthy) statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(12345, nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, runtime: runtimeMgr, } @@ -911,7 +911,7 @@ func TestDefaultManager_restartLogicConsistency(t *testing.T) { statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, runtime: runtimeMgr, } @@ -968,7 +968,7 @@ func TestDefaultManager_RunWorkload(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, } @@ -1029,7 +1029,7 @@ func TestDefaultManager_validateSecretParameters(t *testing.T) { mockConfigProvider := configMocks.NewMockProvider(ctrl) tt.setupMocks(mockConfigProvider) - manager := &defaultManager{ + manager := &DefaultManager{ configProvider: mockConfigProvider, } @@ -1106,7 +1106,7 @@ func TestDefaultManager_getWorkloadContainer(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockRuntime, mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, statuses: mockStatusMgr, } @@ -1170,7 +1170,7 @@ func TestDefaultManager_removeContainer(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockRuntime, mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, statuses: mockStatusMgr, } @@ -1224,7 +1224,7 @@ func TestDefaultManager_needSecretsPassword(t *testing.T) { mockConfigProvider := configMocks.NewMockProvider(ctrl) tt.setupMocks(mockConfigProvider) - manager := &defaultManager{ + manager := &DefaultManager{ configProvider: mockConfigProvider, } @@ -1272,7 +1272,7 @@ func TestDefaultManager_RunWorkloadDetached(t *testing.T) { mockConfigProvider := configMocks.NewMockProvider(ctrl) tt.setupMocks(mockStatusMgr, mockConfigProvider) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, configProvider: mockConfigProvider, } @@ -1307,7 +1307,7 @@ func TestDefaultManager_RunWorkloadDetached_PIDManagement(t *testing.T) { // we verify the PID management integration exists by checking the method signature // and code structure rather than running the full integration. - manager := &defaultManager{} + manager := &DefaultManager{} assert.NotNil(t, manager, "defaultManager should be instantiable") // Verify the method exists with the correct signature @@ -1386,7 +1386,7 @@ func TestDefaultManager_ListWorkloads(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, } @@ -1488,7 +1488,7 @@ func TestDefaultManager_UpdateWorkload(t *testing.T) { tt.setupMocks(mockRuntime, mockStatusManager) } - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, statuses: mockStatusManager, configProvider: mockConfigProvider, @@ -1639,7 +1639,7 @@ func TestDefaultManager_updateSingleWorkload(t *testing.T) { tt.setupMocks(mockRuntime, mockStatusManager) } - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, statuses: mockStatusManager, configProvider: mockConfigProvider,