From fb159b1c6e2d79a22b2cf0fc5936151152f98edd Mon Sep 17 00:00:00 2001 From: amirejaz Date: Thu, 6 Nov 2025 16:00:25 +0000 Subject: [PATCH 01/16] unify workload management across CLI and Kubernetes --- cmd/thv/__debug_bin112695300 | 0 cmd/vmcp/app/commands.go | 16 +- pkg/vmcp/aggregator/cli_discoverer.go | 151 -- pkg/vmcp/aggregator/discoverer.go | 159 +- ..._discoverer_test.go => discoverer_test.go} | 20 +- pkg/vmcp/aggregator/k8s_discoverer.go | 33 - pkg/workloads/cli_manager.go | 1205 ++++++++++++ pkg/workloads/cli_manager_test.go | 1616 ++++++++++++++++ pkg/workloads/k8s_manager.go | 351 ++++ pkg/workloads/k8s_manager_test.go | 777 ++++++++ pkg/workloads/manager.go | 1247 +------------ pkg/workloads/manager_test.go | 1622 +---------------- pkg/workloads/mocks/mock_storage_driver.go | 150 ++ 13 files changed, 4322 insertions(+), 3025 deletions(-) create mode 100644 cmd/thv/__debug_bin112695300 delete mode 100644 pkg/vmcp/aggregator/cli_discoverer.go rename pkg/vmcp/aggregator/{cli_discoverer_test.go => discoverer_test.go} (92%) delete mode 100644 pkg/vmcp/aggregator/k8s_discoverer.go create mode 100644 pkg/workloads/cli_manager.go create mode 100644 pkg/workloads/cli_manager_test.go create mode 100644 pkg/workloads/k8s_manager.go create mode 100644 pkg/workloads/k8s_manager_test.go create mode 100644 pkg/workloads/mocks/mock_storage_driver.go diff --git a/cmd/thv/__debug_bin112695300 b/cmd/thv/__debug_bin112695300 new file mode 100644 index 000000000..e69de29bb diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 42bf7dbde..7d9b00ee2 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -222,29 +222,21 @@ func discoverBackends(ctx context.Context, cfg *config.Config) ([]vmcp.Backend, logger.Info("Initializing workload and group managers") workloadsManager, err := workloads.NewManager(ctx) if err != nil { - logger.Warnf("Failed to create workloads manager (expected in Kubernetes): %v", err) - logger.Warnf("Backend discovery will be skipped - continuing with empty backend list") - return []vmcp.Backend{}, backendClient, nil + return nil, nil, fmt.Errorf("failed to create workloads manager: %w", err) } groupsManager, err := groups.NewManager() if err != nil { - logger.Warnf("Failed to create groups manager (expected in Kubernetes): %v", err) - logger.Warnf("Backend discovery will be skipped - continuing with empty backend list") - return []vmcp.Backend{}, backendClient, nil + return nil, nil, fmt.Errorf("failed to create groups manager: %w", err) } // Create backend discoverer and discover backends - discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth) + discoverer := aggregator.NewBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth) logger.Infof("Discovering backends in group: %s", cfg.Group) backends, err := discoverer.Discover(ctx, cfg.Group) if err != nil { - // Handle discovery errors gracefully - this is expected in Kubernetes - logger.Warnf("CLI backend discovery failed (likely running in Kubernetes): %v", err) - logger.Warnf("Kubernetes backend discovery is not yet implemented - continuing with empty backend list") - logger.Warnf("The vmcp server will start but won't proxy any backends until this feature is implemented") - return []vmcp.Backend{}, backendClient, nil + return nil, nil, fmt.Errorf("failed to discover backends: %w", err) } if len(backends) == 0 { diff --git a/pkg/vmcp/aggregator/cli_discoverer.go b/pkg/vmcp/aggregator/cli_discoverer.go deleted file mode 100644 index f2c5d5c9b..000000000 --- a/pkg/vmcp/aggregator/cli_discoverer.go +++ /dev/null @@ -1,151 +0,0 @@ -package aggregator - -import ( - "context" - "fmt" - - rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/pkg/workloads" -) - -// cliBackendDiscoverer discovers backend MCP servers from Docker/Podman workloads in a group. -// This is the CLI version of BackendDiscoverer that uses the workloads.Manager. -type cliBackendDiscoverer struct { - workloadsManager workloads.Manager - groupsManager groups.Manager - authConfig *config.OutgoingAuthConfig -} - -// NewCLIBackendDiscoverer creates a new CLI-based backend discoverer. -// It discovers workloads from Docker/Podman containers managed by ToolHive. -// -// The authConfig parameter configures authentication for discovered backends. -// If nil, backends will have no authentication configured. -func NewCLIBackendDiscoverer( - workloadsManager workloads.Manager, - groupsManager groups.Manager, - authConfig *config.OutgoingAuthConfig, -) BackendDiscoverer { - return &cliBackendDiscoverer{ - workloadsManager: workloadsManager, - groupsManager: groupsManager, - authConfig: authConfig, - } -} - -// Discover finds all backend workloads in the specified group. -// Returns all accessible backends with their health status marked based on workload status. -// The groupRef is the group name (e.g., "engineering-team"). -func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { - logger.Infof("Discovering backends in group %s", groupRef) - - // Verify that the group exists - exists, err := d.groupsManager.Exists(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to check if group exists: %w", err) - } - if !exists { - return nil, fmt.Errorf("group %s not found", groupRef) - } - - // Get all workload names in the group - workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to list workloads in group: %w", err) - } - - if len(workloadNames) == 0 { - logger.Infof("No workloads found in group %s", groupRef) - return []vmcp.Backend{}, nil - } - - logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) - - // Query each workload and convert to backend - var backends []vmcp.Backend - for _, name := range workloadNames { - workload, err := d.workloadsManager.GetWorkload(ctx, name) - if err != nil { - logger.Warnf("Failed to get workload %s: %v, skipping", name, err) - continue - } - - // Skip workloads without a URL (not accessible) - if workload.URL == "" { - logger.Debugf("Skipping workload %s without URL", name) - continue - } - - // Map workload status to backend health status - healthStatus := mapWorkloadStatusToHealth(workload.Status) - - // Convert core.Workload to vmcp.Backend - // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. - // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. - // ProxyMode tells us which transport the vmcp client should use. - transportType := workload.ProxyMode - if transportType == "" { - // Fallback to TransportType if ProxyMode is not set (for direct transports) - transportType = workload.TransportType.String() - } - - backend := vmcp.Backend{ - ID: name, - Name: name, - BaseURL: workload.URL, - TransportType: transportType, - HealthStatus: healthStatus, - Metadata: make(map[string]string), - } - - // Apply authentication configuration if provided - authStrategy, authMetadata := d.authConfig.ResolveForBackend(name) - backend.AuthStrategy = authStrategy - backend.AuthMetadata = authMetadata - if authStrategy != "" { - logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) - } - - // Copy user labels to metadata first - for k, v := range workload.Labels { - backend.Metadata[k] = v - } - - // Set system metadata (these override user labels to prevent conflicts) - backend.Metadata["group"] = groupRef - backend.Metadata["tool_type"] = workload.ToolType - backend.Metadata["workload_status"] = string(workload.Status) - - backends = append(backends, backend) - logger.Debugf("Discovered backend %s: %s (%s) with health status %s", - backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) - } - - if len(backends) == 0 { - logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) - return []vmcp.Backend{}, nil - } - - logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) - return backends, nil -} - -// mapWorkloadStatusToHealth converts a workload status to a backend health status. -func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { - switch status { - case rt.WorkloadStatusRunning: - return vmcp.BackendHealthy - case rt.WorkloadStatusUnhealthy: - return vmcp.BackendUnhealthy - case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving: - return vmcp.BackendUnhealthy - case rt.WorkloadStatusStarting, rt.WorkloadStatusUnknown: - return vmcp.BackendUnknown - default: - return vmcp.BackendUnknown - } -} diff --git a/pkg/vmcp/aggregator/discoverer.go b/pkg/vmcp/aggregator/discoverer.go index b4cff44d4..ffceaf9d5 100644 --- a/pkg/vmcp/aggregator/discoverer.go +++ b/pkg/vmcp/aggregator/discoverer.go @@ -1,8 +1,157 @@ -// Package aggregator provides platform-specific backend discovery implementations. -// -// This file serves as a navigation reference for backend discovery implementations: -// - CLI (Docker/Podman): see cli_discoverer.go -// - Kubernetes: see k8s_discoverer.go +// Package aggregator provides platform-agnostic backend discovery. // // The BackendDiscoverer interface is defined in aggregator.go. +// The unified implementation (works for both CLI and Kubernetes) is in this file. package aggregator + +import ( + "context" + "fmt" + + rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/groups" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/workloads" +) + +// backendDiscoverer discovers backend MCP servers from workloads in a group. +// It works with both CLI (Docker/Podman) and Kubernetes environments via the unified workloads.Manager interface. +// This is a platform-agnostic implementation that automatically adapts to the runtime environment. +type backendDiscoverer struct { + workloadsManager workloads.Manager + groupsManager groups.Manager + authConfig *config.OutgoingAuthConfig +} + +// NewBackendDiscoverer creates a new backend discoverer. +// It discovers workloads from containers (CLI) or MCPServer CRDs (Kubernetes) managed by ToolHive. +// The workloads.Manager automatically selects the appropriate storage driver based on the runtime environment. +// +// The authConfig parameter configures authentication for discovered backends. +// If nil, backends will have no authentication configured. +func NewBackendDiscoverer( + workloadsManager workloads.Manager, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { + return &backendDiscoverer{ + workloadsManager: workloadsManager, + groupsManager: groupsManager, + authConfig: authConfig, + } +} + +// Discover finds all backend workloads in the specified group. +// Returns all accessible backends with their health status marked based on workload status. +// The groupRef is the group name (e.g., "engineering-team"). +func (d *backendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { + logger.Infof("Discovering backends in group %s", groupRef) + + // Verify that the group exists + exists, err := d.groupsManager.Exists(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to check if group exists: %w", err) + } + if !exists { + return nil, fmt.Errorf("group %s not found", groupRef) + } + + // Get all workload names in the group + workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to list workloads in group: %w", err) + } + + if len(workloadNames) == 0 { + logger.Infof("No workloads found in group %s", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) + + // Query each workload and convert to backend + var backends []vmcp.Backend + for _, name := range workloadNames { + workload, err := d.workloadsManager.GetWorkload(ctx, name) + if err != nil { + logger.Warnf("Failed to get workload %s: %v, skipping", name, err) + continue + } + + // Skip workloads without a URL (not accessible) + if workload.URL == "" { + logger.Debugf("Skipping workload %s without URL", name) + continue + } + + // Map workload status to backend health status + healthStatus := mapWorkloadStatusToHealth(workload.Status) + + // Convert core.Workload to vmcp.Backend + // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. + // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. + // ProxyMode tells us which transport the vmcp client should use. + transportType := workload.ProxyMode + if transportType == "" { + // Fallback to TransportType if ProxyMode is not set (for direct transports) + transportType = workload.TransportType.String() + } + + backend := vmcp.Backend{ + ID: name, + Name: name, + BaseURL: workload.URL, + TransportType: transportType, + HealthStatus: healthStatus, + Metadata: make(map[string]string), + } + + // Apply authentication configuration if provided + authStrategy, authMetadata := d.authConfig.ResolveForBackend(name) + backend.AuthStrategy = authStrategy + backend.AuthMetadata = authMetadata + if authStrategy != "" { + logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) + } + + // Copy user labels to metadata first + for k, v := range workload.Labels { + backend.Metadata[k] = v + } + + // Set system metadata (these override user labels to prevent conflicts) + backend.Metadata["group"] = groupRef + backend.Metadata["tool_type"] = workload.ToolType + backend.Metadata["workload_status"] = string(workload.Status) + + backends = append(backends, backend) + logger.Debugf("Discovered backend %s: %s (%s) with health status %s", + backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) + } + + if len(backends) == 0 { + logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) + return backends, nil +} + +// mapWorkloadStatusToHealth converts a workload status to a backend health status. +func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { + switch status { + case rt.WorkloadStatusRunning: + return vmcp.BackendHealthy + case rt.WorkloadStatusUnhealthy: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStarting, rt.WorkloadStatusUnknown: + return vmcp.BackendUnknown + default: + return vmcp.BackendUnknown + } +} diff --git a/pkg/vmcp/aggregator/cli_discoverer_test.go b/pkg/vmcp/aggregator/discoverer_test.go similarity index 92% rename from pkg/vmcp/aggregator/cli_discoverer_test.go rename to pkg/vmcp/aggregator/discoverer_test.go index 9c3402fad..a8fe4e550 100644 --- a/pkg/vmcp/aggregator/cli_discoverer_test.go +++ b/pkg/vmcp/aggregator/discoverer_test.go @@ -19,7 +19,7 @@ import ( const testGroupName = "test-group" -func TestCLIBackendDiscoverer_Discover(t *testing.T) { +func TestBackendDiscoverer_Discover(t *testing.T) { t.Parallel() t.Run("successful discovery with multiple backends", func(t *testing.T) { @@ -45,7 +45,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -79,7 +79,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -108,7 +108,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -133,7 +133,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -150,7 +150,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "nonexistent-group") require.Error(t, err) @@ -168,7 +168,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.Error(t, err) @@ -187,7 +187,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "empty-group") require.NoError(t, err) @@ -214,7 +214,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -240,7 +240,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). Return(core.Workload{}, errors.New("workload query failed")) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) diff --git a/pkg/vmcp/aggregator/k8s_discoverer.go b/pkg/vmcp/aggregator/k8s_discoverer.go deleted file mode 100644 index b9f61fbc0..000000000 --- a/pkg/vmcp/aggregator/k8s_discoverer.go +++ /dev/null @@ -1,33 +0,0 @@ -package aggregator - -import ( - "context" - "fmt" - - "github.com/stacklok/toolhive/pkg/vmcp" -) - -// k8sBackendDiscoverer discovers backend MCP servers from Kubernetes pods/services in a group. -// This is the Kubernetes version of BackendDiscoverer (not implemented yet). -type k8sBackendDiscoverer struct { - // TODO: Add Kubernetes client and group CRD interfaces -} - -// NewK8sBackendDiscoverer creates a new Kubernetes-based backend discoverer. -// It discovers workloads from Kubernetes MCPServer resources managed by the operator. -func NewK8sBackendDiscoverer() BackendDiscoverer { - return &k8sBackendDiscoverer{} -} - -// Discover finds all backend workloads in the specified Kubernetes group. -// The groupRef is the MCPGroup name. -func (*k8sBackendDiscoverer) Discover(_ context.Context, _ string) ([]vmcp.Backend, error) { - // TODO: Implement Kubernetes backend discovery - // 1. Query MCPGroup CRD by name - // 2. List MCPServer resources with matching group label - // 3. Filter for ready/running MCPServers - // 4. Build service URLs (http://service-name.namespace.svc.cluster.local:port) - // 5. Extract transport type from MCPServer spec - // 6. Return vmcp.Backend list - return nil, fmt.Errorf("kubernetes backend discovery not yet implemented") -} diff --git a/pkg/workloads/cli_manager.go b/pkg/workloads/cli_manager.go new file mode 100644 index 000000000..1e2775e46 --- /dev/null +++ b/pkg/workloads/cli_manager.go @@ -0,0 +1,1205 @@ +// Package workloads provides a CLI-based implementation of the Manager interface. +// This file contains the CLI (Docker/Podman) implementation for local environments. +package workloads + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/adrg/xdg" + "golang.org/x/sync/errgroup" + + "github.com/stacklok/toolhive/pkg/client" + "github.com/stacklok/toolhive/pkg/config" + ct "github.com/stacklok/toolhive/pkg/container" + rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/labels" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/process" + "github.com/stacklok/toolhive/pkg/runner" + "github.com/stacklok/toolhive/pkg/secrets" + "github.com/stacklok/toolhive/pkg/state" + "github.com/stacklok/toolhive/pkg/workloads/statuses" + "github.com/stacklok/toolhive/pkg/workloads/types" +) + +// AsyncOperationTimeout is the timeout for async workload operations +const AsyncOperationTimeout = 5 * time.Minute + +// removeClientConfigurations removes client configuration files for a workload. +// TODO: Move to dedicated config management interface. +func removeClientConfigurations(containerName string, isAuxiliary bool) error { + // Get the workload's group by loading its run config + // Note: This is a standalone function, so we use runner.LoadState directly + // In the future, this should be refactored to use the driver + runConfig, err := runner.LoadState(context.Background(), containerName) + var group string + if err != nil { + // Only warn for non-auxiliary workloads since auxiliary workloads don't have run configs + if !isAuxiliary { + logger.Warnf("Warning: Failed to load run config for %s, will use backward compatible behavior: %v", containerName, err) + } + // Continue with empty group (backward compatibility) + } else { + group = runConfig.Group + } + + clientManager, err := client.NewManager(context.Background()) + if err != nil { + logger.Warnf("Warning: Failed to create client manager for %s, skipping client config removal: %v", containerName, err) + return nil + } + + return clientManager.RemoveServerFromClients(context.Background(), containerName, group) +} + +// cliManager implements the Manager interface for CLI (Docker/Podman) environments. +type cliManager struct { + runtime rt.Runtime + statuses statuses.StatusManager + configProvider config.Provider +} + +// NewCLIManager creates a new CLI-based workload manager. +func NewCLIManager(ctx context.Context) (Manager, error) { + return NewCLIManagerWithProvider(ctx, config.NewDefaultProvider()) +} + +// NewCLIManagerWithProvider creates a new CLI-based workload manager with a custom config provider. +func NewCLIManagerWithProvider(ctx context.Context, configProvider config.Provider) (Manager, error) { + runtime, err := ct.NewFactory().Create(ctx) + if err != nil { + return nil, err + } + + statusManager, err := statuses.NewStatusManager(runtime) + if err != nil { + return nil, fmt.Errorf("failed to create status manager: %w", err) + } + + return &cliManager{ + runtime: runtime, + statuses: statusManager, + configProvider: configProvider, + }, nil +} + +// NewCLIManagerFromRuntime creates a new CLI-based workload manager from an existing runtime. +func NewCLIManagerFromRuntime(runtime rt.Runtime) (Manager, error) { + return NewCLIManagerFromRuntimeWithProvider(runtime, config.NewDefaultProvider()) +} + +// NewCLIManagerFromRuntimeWithProvider creates a new CLI-based workload manager +// from an existing runtime with a custom config provider. +func NewCLIManagerFromRuntimeWithProvider(runtime rt.Runtime, configProvider config.Provider) (Manager, error) { + statusManager, err := statuses.NewStatusManager(runtime) + if err != nil { + return nil, fmt.Errorf("failed to create status manager: %w", err) + } + + return &cliManager{ + runtime: runtime, + statuses: statusManager, + configProvider: configProvider, + }, nil +} + +func (d *cliManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { + return d.statuses.GetWorkload(ctx, workloadName) +} + +func (d *cliManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { + // check if workload exists by trying to get it + workload, err := d.statuses.GetWorkload(ctx, workloadName) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + return false, nil + } + return false, fmt.Errorf("failed to check if workload exists: %w", err) + } + + // now check if the workload is not in error + if workload.Status == rt.WorkloadStatusError { + return false, nil + } + return true, nil +} + +func (d *cliManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { + // Get container workloads from status manager + containerWorkloads, err := d.statuses.ListWorkloads(ctx, listAll, labelFilters) + if err != nil { + return nil, err + } + + // Get remote workloads from the state store + remoteWorkloads, err := d.getRemoteWorkloadsFromState(ctx, listAll, labelFilters) + if err != nil { + logger.Warnf("Failed to get remote workloads from state: %v", err) + // Continue with container workloads only + } else { + // Combine container and remote workloads + containerWorkloads = append(containerWorkloads, remoteWorkloads...) + } + + return containerWorkloads, nil +} + +func (d *cliManager) StopWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { + // Validate all workload names to prevent path traversal attacks + for _, name := range names { + if err := types.ValidateWorkloadName(name); err != nil { + return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) + } + // Ensure workload name does not contain path traversal or separators + if strings.Contains(name, "..") || strings.ContainsAny(name, "/\\") { + return nil, fmt.Errorf("invalid workload name '%s': contains forbidden characters", name) + } + } + + group := &errgroup.Group{} + // Process each workload + for _, name := range names { + group.Go(func() error { + return d.stopSingleWorkload(name) + }) + } + + return group, nil +} + +// stopSingleWorkload stops a single workload (container or remote) +func (d *cliManager) stopSingleWorkload(name string) error { + // Create a child context with a longer timeout + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + // First, try to load the run configuration to check if it's a remote workload + runConfig, err := runner.LoadState(childCtx, name) + if err != nil { + // If we can't load the state, it might be a container workload or the workload doesn't exist + // Try to stop it as a container workload + return d.stopContainerWorkload(childCtx, name) + } + + // Check if this is a remote workload + if runConfig.RemoteURL != "" { + return d.stopRemoteWorkload(childCtx, name, runConfig) + } + + // This is a container-based workload + return d.stopContainerWorkload(childCtx, name) +} + +// stopRemoteWorkload stops a remote workload +func (d *cliManager) stopRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { + logger.Infof("Stopping remote workload %s...", name) + + // Check if the workload is running by checking its status + workload, err := d.statuses.GetWorkload(ctx, name) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + // Log but don't fail the entire operation for not found workload + logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) + return nil + } + return fmt.Errorf("failed to find workload %s: %v", name, err) + } + + if workload.Status != rt.WorkloadStatusRunning { + logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) + return nil + } + + // Set status to stopping + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) + } + + // Stop proxy if running + if runConfig.BaseName != "" { + d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) + } + + // For remote workloads, we only need to clean up client configurations + // The saved state should be preserved for restart capability + if err := removeClientConfigurations(name, false); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } else { + logger.Infof("Client configurations for %s removed", name) + } + + // Set status to stopped + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopped, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopped: %v", name, err) + } + logger.Infof("Remote workload %s stopped successfully", name) + return nil +} + +// stopContainerWorkload stops a container-based workload +func (d *cliManager) stopContainerWorkload(ctx context.Context, name string) error { + container, err := d.runtime.GetWorkloadInfo(ctx, name) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + // Log but don't fail the entire operation for not found containers + logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) + return nil + } + return fmt.Errorf("failed to find workload %s: %v", name, err) + } + + running := container.IsRunning() + if !running { + // Log but don't fail the entire operation for not running containers + logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) + return nil + } + + // Transition workload to `stopping` state. + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) + } + + // Use the existing stopWorkloads method for container workloads + return d.stopSingleContainerWorkload(ctx, &container) +} + +func (d *cliManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error { + // Ensure that the workload has a status entry before starting the process. + if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { + // Failure to create the initial state is a fatal error. + return fmt.Errorf("failed to create workload status: %v", err) + } + + mcpRunner := runner.NewRunner(runConfig, d.statuses) + err := mcpRunner.Run(ctx) + if err != nil { + // If the run failed, we should set the status to error. + if statusErr := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, statusErr) + } + } + return err +} + +func (d *cliManager) validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { + // If there are run secrets, validate them + + hasRegularSecrets := len(runConfig.Secrets) > 0 + hasRemoteAuthSecret := runConfig.RemoteAuthConfig != nil && runConfig.RemoteAuthConfig.ClientSecret != "" + + if hasRegularSecrets || hasRemoteAuthSecret { + cfg := d.configProvider.GetConfig() + + providerType, err := cfg.Secrets.GetProviderType() + if err != nil { + return fmt.Errorf("error determining secrets provider type: %w", err) + } + + secretManager, err := secrets.CreateSecretProvider(providerType) + if err != nil { + return fmt.Errorf("error instantiating secret manager: %w", err) + } + + err = runConfig.ValidateSecrets(ctx, secretManager) + if err != nil { + return fmt.Errorf("error processing secrets: %w", err) + } + } + return nil +} + +func (d *cliManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { + // before running, validate the parameters for the workload + err := d.validateSecretParameters(ctx, runConfig) + if err != nil { + return fmt.Errorf("failed to validate workload parameters: %w", err) + } + + // Get the current executable path + execPath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %v", err) + } + + // Create a log file for the detached process + logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", runConfig.BaseName)) + if err != nil { + return fmt.Errorf("failed to create log file path: %v", err) + } + // #nosec G304 - This is safe as baseName is generated by the application + logFile, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + logger.Warnf("Warning: Failed to create log file: %v", err) + } else { + defer logFile.Close() + logger.Infof("Logging to: %s", logFilePath) + } + + // Use the restart command to start the detached process + // The config has already been saved to disk, so restart can load it + detachedArgs := []string{"restart", runConfig.BaseName, "--foreground"} + + // Create a new command + // #nosec G204 - This is safe as execPath is the path to the current binary + detachedCmd := exec.Command(execPath, detachedArgs...) + + // Set environment variables for the detached process + detachedCmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", process.ToolHiveDetachedEnv, process.ToolHiveDetachedValue)) + + // If we need the decrypt password, set it as an environment variable in the detached process. + // NOTE: This breaks the abstraction slightly since this is only relevant for the CLI, but there + // are checks inside `GetSecretsPassword` to ensure this does not get called in a detached process. + // This will be addressed in a future re-think of the secrets manager interface. + if d.needSecretsPassword(runConfig.Secrets) { + password, err := secrets.GetSecretsPassword("") + if err != nil { + return fmt.Errorf("failed to get secrets password: %v", err) + } + detachedCmd.Env = append(detachedCmd.Env, fmt.Sprintf("%s=%s", secrets.PasswordEnvVar, password)) + } + + // Redirect stdout and stderr to the log file if it was created successfully + if logFile != nil { + detachedCmd.Stdout = logFile + detachedCmd.Stderr = logFile + } else { + // Otherwise, discard the output + detachedCmd.Stdout = nil + detachedCmd.Stderr = nil + } + + // Detach the process from the terminal + detachedCmd.Stdin = nil + detachedCmd.SysProcAttr = getSysProcAttr() + + // Ensure that the workload has a status entry before starting the process. + if err = d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { + // Failure to create the initial state is a fatal error. + return fmt.Errorf("failed to create workload status: %v", err) + } + + // Start the detached process + if err := detachedCmd.Start(); err != nil { + // If the start failed, we need to set the status to error before returning. + if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, ""); err != nil { + logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, err) + } + return fmt.Errorf("failed to start detached process: %v", err) + } + + // Write the PID to a file so the stop command can kill the process + // TODO: Stop writing to PID file once we migrate over to statuses fully. + if err := process.WritePIDFile(runConfig.BaseName, detachedCmd.Process.Pid); err != nil { + logger.Warnf("Warning: Failed to write PID file: %v", err) + } + if err := d.statuses.SetWorkloadPID(ctx, runConfig.BaseName, detachedCmd.Process.Pid); err != nil { + logger.Warnf("Failed to set workload %s PID: %v", runConfig.BaseName, err) + } + + logger.Infof("MCP server is running in the background (PID: %d)", detachedCmd.Process.Pid) + logger.Infof("Use 'thv stop %s' to stop the server", runConfig.ContainerName) + + return nil +} + +func (d *cliManager) GetLogs(ctx context.Context, workloadName string, follow bool) (string, error) { + // Get the logs from the runtime + logs, err := d.runtime.GetWorkloadLogs(ctx, workloadName, follow) + if err != nil { + // Propagate the error if the container is not found + if errors.Is(err, rt.ErrWorkloadNotFound) { + return "", fmt.Errorf("%w: %s", rt.ErrWorkloadNotFound, workloadName) + } + return "", fmt.Errorf("failed to get container logs %s: %v", workloadName, err) + } + + return logs, nil +} + +// GetProxyLogs retrieves proxy logs from the filesystem +func (*cliManager) GetProxyLogs(_ context.Context, workloadName string) (string, error) { + // Get the proxy log file path + logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", workloadName)) + if err != nil { + return "", fmt.Errorf("failed to get proxy log file path for workload %s: %w", workloadName, err) + } + + // Clean the file path to prevent path traversal + cleanLogFilePath := filepath.Clean(logFilePath) + + // Check if the log file exists + if _, err := os.Stat(cleanLogFilePath); os.IsNotExist(err) { + return "", fmt.Errorf("proxy logs not found for workload %s", workloadName) + } + + // Read and return the entire log file + content, err := os.ReadFile(cleanLogFilePath) + if err != nil { + return "", fmt.Errorf("failed to read proxy log for workload %s: %w", workloadName, err) + } + + return string(content), nil +} + +// deleteWorkload handles deletion of a single workload +func (d *cliManager) deleteWorkload(name string) error { + // Create a child context with a longer timeout + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + // First, check if this is a remote workload by trying to load its run configuration + runConfig, err := runner.LoadState(childCtx, name) + if err != nil { + // If we can't load the state, it might be a container workload or the workload doesn't exist + // Continue with the container-based deletion logic + return d.deleteContainerWorkload(childCtx, name) + } + + // If this is a remote workload (has RemoteURL), handle it differently + if runConfig.RemoteURL != "" { + return d.deleteRemoteWorkload(childCtx, name, runConfig) + } + + // This is a container-based workload, use the existing logic + return d.deleteContainerWorkload(childCtx, name) +} + +// deleteRemoteWorkload handles deletion of a remote workload +func (d *cliManager) deleteRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { + logger.Infof("Removing remote workload %s...", name) + + // Set status to removing + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusRemoving, ""); err != nil { + logger.Warnf("Failed to set workload %s status to removing: %v", name, err) + return err + } + + // Stop proxy if running + if runConfig.BaseName != "" { + d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) + } + + // Clean up associated resources (remote workloads are not auxiliary) + d.cleanupWorkloadResources(ctx, name, runConfig.BaseName, false) + + // Remove the workload status from the status store + if err := d.statuses.DeleteWorkloadStatus(ctx, name); err != nil { + logger.Warnf("failed to delete workload status for %s: %v", name, err) + } + + logger.Infof("Remote workload %s removed successfully", name) + return nil +} + +// deleteContainerWorkload handles deletion of a container-based workload (existing logic) +func (d *cliManager) deleteContainerWorkload(ctx context.Context, name string) error { + + // Find and validate the container + container, err := d.getWorkloadContainer(ctx, name) + if err != nil { + return err + } + + // Set status to removing + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusRemoving, ""); err != nil { + logger.Warnf("Failed to set workload %s status to removing: %v", name, err) + } + + if container != nil { + containerLabels := container.Labels + baseName := labels.GetContainerBaseName(containerLabels) + + // Stop proxy if running (skip for auxiliary workloads like inspector) + if container.IsRunning() { + // Skip proxy stopping for auxiliary workloads that don't use proxy processes + if labels.IsAuxiliaryWorkload(containerLabels) { + logger.Debugf("Skipping proxy stop for auxiliary workload %s", name) + } else { + d.stopProxyIfNeeded(ctx, name, baseName) + } + } + + // Remove the container + if err := d.removeContainer(ctx, name); err != nil { + return err + } + + // Clean up associated resources + d.cleanupWorkloadResources(ctx, name, baseName, labels.IsAuxiliaryWorkload(containerLabels)) + } + + // Remove the workload status from the status store + if err := d.statuses.DeleteWorkloadStatus(ctx, name); err != nil { + logger.Warnf("failed to delete workload status for %s: %v", name, err) + } + + return nil +} + +// getWorkloadContainer retrieves workload container info with error handling +func (d *cliManager) getWorkloadContainer(ctx context.Context, name string) (*rt.ContainerInfo, error) { + container, err := d.runtime.GetWorkloadInfo(ctx, name) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + // Log but don't fail the entire operation for not found containers + logger.Warnf("Warning: Failed to get workload %s: %v", name, err) + return nil, nil + } + if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + } + return nil, fmt.Errorf("failed to find workload %s: %v", name, err) + } + return &container, nil +} + +// isSupervisorProcessAlive checks if the supervisor process for a workload is alive +// by checking if a PID exists. If a PID exists, we assume the supervisor is running. +// This is a reasonable assumption because: +// - If the supervisor exits cleanly, it cleans up the PID +// - If killed unexpectedly, the PID remains but stopProcess will handle it gracefully +// - The main issue we're preventing is accumulating zombie supervisors from repeated restarts +func (d *cliManager) isSupervisorProcessAlive(ctx context.Context, name string) bool { + if name == "" { + return false + } + + // Try to read the PID - if it exists, assume supervisor is running + _, err := d.statuses.GetWorkloadPID(ctx, name) + if err != nil { + // No PID found, supervisor is not running + return false + } + + // PID exists, assume supervisor is alive + return true +} + +// stopProcess stops the proxy process associated with the container +func (d *cliManager) stopProcess(ctx context.Context, name string) { + if name == "" { + logger.Warnf("Warning: Could not find base container name in labels") + return + } + + // Try to read the PID and kill the process + pid, err := d.statuses.GetWorkloadPID(ctx, name) + if err != nil { + logger.Errorf("No PID file found for %s, proxy may not be running in detached mode", name) + return + } + + // PID file found, try to kill the process + logger.Infof("Stopping proxy process (PID: %d)...", pid) + if err := process.KillProcess(pid); err != nil { + logger.Warnf("Warning: Failed to kill proxy process: %v", err) + } else { + logger.Info("Proxy process stopped") + } + + // Clean up PID file after successful kill + if err := process.RemovePIDFile(name); err != nil { + logger.Warnf("Warning: Failed to remove PID file: %v", err) + } +} + +// stopProxyIfNeeded stops the proxy process if the workload has a base name +func (d *cliManager) stopProxyIfNeeded(ctx context.Context, name, baseName string) { + logger.Infof("Removing proxy process for %s...", name) + if baseName != "" { + d.stopProcess(ctx, baseName) + } +} + +// removeContainer removes the container from the runtime +func (d *cliManager) removeContainer(ctx context.Context, name string) error { + logger.Infof("Removing container %s...", name) + if err := d.runtime.RemoveWorkload(ctx, name); err != nil { + if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + } + return fmt.Errorf("failed to remove container: %v", err) + } + return nil +} + +// cleanupWorkloadResources cleans up all resources associated with a workload +func (d *cliManager) cleanupWorkloadResources(ctx context.Context, name, baseName string, isAuxiliary bool) { + if baseName == "" { + return + } + + // Clean up temporary permission profile + if err := d.cleanupTempPermissionProfile(ctx, baseName); err != nil { + logger.Warnf("Warning: Failed to cleanup temporary permission profile: %v", err) + } + + // Remove client configurations + if err := removeClientConfigurations(name, isAuxiliary); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } else { + logger.Infof("Client configurations for %s removed", name) + } + + // Delete the saved state last (skip for auxiliary workloads that don't have run configs) + if !isAuxiliary { + if err := state.DeleteSavedRunConfig(ctx, baseName); err != nil { + logger.Warnf("Warning: Failed to delete saved state: %v", err) + } else { + logger.Infof("Saved state for %s removed", baseName) + } + } else { + logger.Debugf("Skipping saved state deletion for auxiliary workload %s", name) + } + + logger.Infof("Container %s removed", name) +} + +func (d *cliManager) DeleteWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { + // Validate all workload names to prevent path traversal attacks + for _, name := range names { + if err := types.ValidateWorkloadName(name); err != nil { + return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) + } + } + + group := &errgroup.Group{} + + for _, name := range names { + group.Go(func() error { + return d.deleteWorkload(name) + }) + } + + return group, nil +} + +// RestartWorkloads restarts the specified workloads by name. +func (d *cliManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) { + // Validate all workload names to prevent path traversal attacks + for _, name := range names { + if err := types.ValidateWorkloadName(name); err != nil { + return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) + } + } + + group := &errgroup.Group{} + + for _, name := range names { + group.Go(func() error { + return d.restartSingleWorkload(name, foreground) + }) + } + + return group, nil +} + +// UpdateWorkload updates a workload by stopping, deleting, and recreating it +func (d *cliManager) UpdateWorkload(_ context.Context, workloadName string, newConfig *runner.RunConfig) (*errgroup.Group, error) { //nolint:lll + // Validate workload name + if err := types.ValidateWorkloadName(workloadName); err != nil { + return nil, fmt.Errorf("invalid workload name '%s': %w", workloadName, err) + } + + group := &errgroup.Group{} + group.Go(func() error { + return d.updateSingleWorkload(workloadName, newConfig) + }) + return group, nil +} + +// updateSingleWorkload handles the update logic for a single workload +func (d *cliManager) updateSingleWorkload(workloadName string, newConfig *runner.RunConfig) error { + // Create a child context with a longer timeout + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + logger.Infof("Starting update for workload %s", workloadName) + + // Stop the existing workload + if err := d.stopSingleWorkload(workloadName); err != nil { + return fmt.Errorf("failed to stop workload: %w", err) + } + logger.Infof("Successfully stopped workload %s", workloadName) + + // Delete the existing workload + if err := d.deleteWorkload(workloadName); err != nil { + return fmt.Errorf("failed to delete workload: %w", err) + } + logger.Infof("Successfully deleted workload %s", workloadName) + + // Save the new workload configuration state + if err := newConfig.SaveState(childCtx); err != nil { + logger.Errorf("Failed to save workload config: %v", err) + return fmt.Errorf("failed to save workload config: %w", err) + } + + // Step 3: Start the new workload + // TODO: This currently just handles detached processes and wouldn't work for + // foreground CLI executions. Should be refactored to support both modes. + if err := d.RunWorkloadDetached(childCtx, newConfig); err != nil { + return fmt.Errorf("failed to start new workload: %w", err) + } + + logger.Infof("Successfully completed update for workload %s", workloadName) + return nil +} + +// restartSingleWorkload handles the restart logic for a single workload +func (d *cliManager) restartSingleWorkload(name string, foreground bool) error { + // Create a child context with a longer timeout + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + // First, try to load the run configuration to check if it's a remote workload + runConfig, err := runner.LoadState(childCtx, name) + if err != nil { + // If we can't load the state, it might be a container workload or the workload doesn't exist + // Try to restart it as a container workload + return d.restartContainerWorkload(childCtx, name, foreground) + } + + // Check if this is a remote workload + if runConfig.RemoteURL != "" { + return d.restartRemoteWorkload(childCtx, name, runConfig, foreground) + } + + // This is a container-based workload + return d.restartContainerWorkload(childCtx, name, foreground) +} + +// restartRemoteWorkload handles restarting a remote workload +func (d *cliManager) restartRemoteWorkload( + ctx context.Context, + name string, + runConfig *runner.RunConfig, + foreground bool, +) error { + // Get workload status using the status manager + workload, err := d.statuses.GetWorkload(ctx, name) + if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) { + return err + } + + // If workload is already running, check if the supervisor process is healthy + if err == nil && workload.Status == rt.WorkloadStatusRunning { + // Check if the supervisor process is actually alive + supervisorAlive := d.isSupervisorProcessAlive(ctx, runConfig.BaseName) + + if supervisorAlive { + // Workload is running and healthy - preserve old behavior (no-op) + logger.Infof("Remote workload %s is already running", name) + return nil + } + + // Supervisor is dead/missing - we need to clean up and restart to fix the damaged state + logger.Infof("Remote workload %s is running but supervisor is dead, cleaning up before restart", name) + + // Set status to stopping + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) + } + + // Stop the supervisor process (proxy) if it exists (may already be dead) + // This ensures we clean up any orphaned supervisor processes + d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) + + // Clean up client configurations + if err := removeClientConfigurations(name, false); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } + + // Set status to stopped after cleanup is complete + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopped, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopped: %v", name, err) + } + } + + // Load runner configuration from state + mcpRunner, err := d.loadRunnerFromState(ctx, runConfig.BaseName) + if err != nil { + return fmt.Errorf("failed to load state for %s: %v", runConfig.BaseName, err) + } + + // Set status to starting + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStarting, ""); err != nil { + logger.Warnf("Failed to set workload %s status to starting: %v", name, err) + } + + logger.Infof("Loaded configuration from state for %s", runConfig.BaseName) + + // Start the remote workload using the loaded runner + // Use background context to avoid timeout cancellation - same reasoning as container workloads + return d.startWorkload(context.Background(), name, mcpRunner, foreground) +} + +// restartContainerWorkload handles restarting a container-based workload +// +//nolint:gocyclo // Complexity is justified - handles multiple restart scenarios and edge cases +func (d *cliManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error { + // Get container info to resolve partial names and extract proper workload name + var containerName string + var workloadName string + + container, err := d.runtime.GetWorkloadInfo(ctx, name) + if err == nil { + // If we found the container, use its actual container name for runtime operations + containerName = container.Name + // Extract the workload name (base name) from container labels for status operations + workloadName = labels.GetContainerBaseName(container.Labels) + if workloadName == "" { + // Fallback to the provided name if base name is not available + workloadName = name + } + } else { + // If container not found, use the provided name as both container and workload name + containerName = name + workloadName = name + } + + // Get workload status using the status manager + workload, err := d.statuses.GetWorkload(ctx, name) + if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) { + return err + } + + // Check if workload is running and healthy (including supervisor process) + if err == nil && workload.Status == rt.WorkloadStatusRunning { + // Check if the supervisor process is actually alive + supervisorAlive := d.isSupervisorProcessAlive(ctx, workloadName) + + if supervisorAlive { + // Workload is running and healthy - preserve old behavior (no-op) + logger.Infof("Container %s is already running", containerName) + return nil + } + + // Supervisor is dead/missing - we need to clean up and restart to fix the damaged state + logger.Infof("Container %s is running but supervisor is dead, cleaning up before restart", containerName) + } + + // Check if we need to stop the workload before restarting + // This happens when: 1) container is running, or 2) inconsistent state + shouldStop := false + if err == nil && workload.Status == rt.WorkloadStatusRunning { + // Workload status shows running (and supervisor is dead, otherwise we would have returned above) + shouldStop = true + } else if container.IsRunning() { + // Container is running but status is not running (inconsistent state) + shouldStop = true + } + + // If we need to stop, do it now (including cleanup of any remaining supervisor process) + if shouldStop { + logger.Infof("Stopping container %s before restart", containerName) + + // Set status to stopping + if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStopping, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopping: %v", workloadName, err) + } + + // Stop the supervisor process (proxy) if it exists (may already be dead) + // This ensures we clean up any orphaned supervisor processes + if !labels.IsAuxiliaryWorkload(container.Labels) { + d.stopProcess(ctx, workloadName) + } + + // Now stop the container if it's running + if container.IsRunning() { + if err := d.runtime.StopWorkload(ctx, containerName); err != nil { + if statusErr := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", workloadName, statusErr) + } + return fmt.Errorf("failed to stop container %s: %v", containerName, err) + } + logger.Infof("Container %s stopped", containerName) + } + + // Clean up client configurations + if err := removeClientConfigurations(workloadName, labels.IsAuxiliaryWorkload(container.Labels)); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } + + // Set status to stopped after cleanup is complete + if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStopped, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopped: %v", workloadName, err) + } + } + + // Load runner configuration from state + mcpRunner, err := d.loadRunnerFromState(ctx, workloadName) + if err != nil { + return fmt.Errorf("failed to load state for %s: %v", workloadName, err) + } + + // Set workload status to starting - use the workload name for status operations + if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStarting, ""); err != nil { + logger.Warnf("Failed to set workload %s status to starting: %v", workloadName, err) + } + logger.Infof("Loaded configuration from state for %s", workloadName) + + // Start the workload with background context to avoid timeout cancellation + // The ctx with AsyncOperationTimeout is only for the restart setup operations, + // but the actual workload should run indefinitely with its own lifecycle management + // Use workload name for user-facing operations + return d.startWorkload(context.Background(), workloadName, mcpRunner, foreground) +} + +// startWorkload starts the workload in either foreground or background mode +func (d *cliManager) startWorkload(ctx context.Context, name string, mcpRunner *runner.Runner, foreground bool) error { + logger.Infof("Starting tooling server %s...", name) + + var err error + if foreground { + err = d.RunWorkload(ctx, mcpRunner.Config) + } else { + err = d.RunWorkloadDetached(ctx, mcpRunner.Config) + } + + if err != nil { + // If we could not start the workload, set the status to error before returning + if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, ""); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + } + } + return err +} + +// loadRunnerFromState attempts to load a Runner from the state store +func (d *cliManager) loadRunnerFromState(ctx context.Context, baseName string) (*runner.Runner, error) { + // Load the run config from state + runConfig, err := runner.LoadState(ctx, baseName) + if err != nil { + return nil, err + } + + if runConfig.RemoteURL != "" { + // For remote workloads, we don't need a deployer + runConfig.Deployer = nil + } else { + // Update the runtime in the loaded configuration + runConfig.Deployer = d.runtime + } + + // Create a new runner with the loaded configuration + return runner.NewRunner(runConfig, d.statuses), nil +} + +func (d *cliManager) needSecretsPassword(secretOptions []string) bool { + // If the user did not ask for any secrets, then don't attempt to instantiate + // the secrets manager. + if len(secretOptions) == 0 { + return false + } + // Ignore err - if the flag is not set, it's not needed. + providerType, _ := d.configProvider.GetConfig().Secrets.GetProviderType() + return providerType == secrets.EncryptedType +} + +// cleanupTempPermissionProfile cleans up temporary permission profile files for a given base name +func (*cliManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { + // Try to load the saved configuration to get the permission profile path + runConfig, err := runner.LoadState(ctx, baseName) + if err != nil { + // If we can't load the state, there's nothing to clean up + logger.Debugf("Could not load state for %s, skipping permission profile cleanup: %v", baseName, err) + return nil + } + + // Clean up the temporary permission profile if it exists + if runConfig.PermissionProfileNameOrPath != "" { + if err := runner.CleanupTempPermissionProfile(runConfig.PermissionProfileNameOrPath); err != nil { + return fmt.Errorf("failed to cleanup temporary permission profile: %v", err) + } + } + + return nil +} + +// stopSingleContainerWorkload stops a single container workload +func (d *cliManager) stopSingleContainerWorkload(ctx context.Context, workload *rt.ContainerInfo) error { + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + name := labels.GetContainerBaseName(workload.Labels) + // Stop the proxy process (skip for auxiliary workloads like inspector) + if labels.IsAuxiliaryWorkload(workload.Labels) { + logger.Debugf("Skipping proxy stop for auxiliary workload %s", name) + } else { + d.stopProcess(ctx, name) + } + + // TODO: refactor the StopProcess function to stop dealing explicitly with PID files. + // Note that this is not a blocker for k8s since this code path is not called there. + if err := d.statuses.ResetWorkloadPID(ctx, name); err != nil { + logger.Warnf("Warning: Failed to reset workload %s PID: %v", name, err) + } + + logger.Infof("Stopping containers for %s...", name) + // Stop the container + if err := d.runtime.StopWorkload(childCtx, workload.Name); err != nil { + if statusErr := d.statuses.SetWorkloadStatus(childCtx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + } + return fmt.Errorf("failed to stop container: %w", err) + } + + if err := removeClientConfigurations(name, labels.IsAuxiliaryWorkload(workload.Labels)); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } else { + logger.Infof("Client configurations for %s removed", name) + } + + if err := d.statuses.SetWorkloadStatus(childCtx, name, rt.WorkloadStatusStopped, ""); err != nil { + logger.Warnf("Failed to set workload %s status to stopped: %v", name, err) + } + logger.Infof("Successfully stopped %s...", name) + return nil +} + +// MoveToGroup moves the specified workloads from one group to another by updating their runconfig. +func (*cliManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { + for _, name := range workloadNames { + // Validate workload name + if err := types.ValidateWorkloadName(name); err != nil { + return fmt.Errorf("invalid workload name %s: %w", name, err) + } + + // Load the runner state to check and update the configuration + runnerConfig, err := runner.LoadState(ctx, name) + if err != nil { + return fmt.Errorf("failed to load runner state for workload %s: %w", name, err) + } + + // Check if the workload is actually in the specified group + if runnerConfig.Group != groupFrom { + logger.Debugf("Workload %s is not in group %s (current group: %s), skipping", + name, groupFrom, runnerConfig.Group) + continue + } + + // Move the workload to the target group + runnerConfig.Group = groupTo + + // Save the updated configuration + if err = runnerConfig.SaveState(ctx); err != nil { + return fmt.Errorf("failed to save updated configuration for workload %s: %w", name, err) + } + + logger.Infof("Moved workload %s from group %s to %s", name, groupFrom, groupTo) + } + + return nil +} + +// ListWorkloadsInGroup returns all workload names that belong to the specified group +func (d *cliManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + workloads, err := d.ListWorkloads(ctx, true) // listAll=true to include stopped workloads + if err != nil { + return nil, fmt.Errorf("failed to list workloads: %w", err) + } + + // Filter workloads that belong to the specified group + var groupWorkloads []string + for _, workload := range workloads { + if workload.Group == groupName { + groupWorkloads = append(groupWorkloads, workload.Name) + } + } + + return groupWorkloads, nil +} + +// getRemoteWorkloadsFromState retrieves remote servers from the state store +func (d *cliManager) getRemoteWorkloadsFromState( + ctx context.Context, + listAll bool, + labelFilters []string, +) ([]core.Workload, error) { + // Create a state store + store, err := state.NewRunConfigStore(state.DefaultAppName) + if err != nil { + return nil, fmt.Errorf("failed to create state store: %w", err) + } + + // List all configurations + configNames, err := store.List(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list configurations: %w", err) + } + + // Parse the filters into a format we can use for matching + parsedFilters, err := types.ParseLabelFilters(labelFilters) + if err != nil { + return nil, fmt.Errorf("failed to parse label filters: %v", err) + } + + var remoteWorkloads []core.Workload + + for _, name := range configNames { + // Load the run configuration + runConfig, err := runner.LoadState(ctx, name) + if err != nil { + logger.Warnf("failed to load state for %s: %v", name, err) + continue + } + + // Only include remote servers (those with RemoteURL set) + if runConfig.RemoteURL == "" { + continue + } + + // Check the status from the status file + workloadStatus, err := d.statuses.GetWorkload(ctx, name) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + // If status not found, assume stopped + workloadStatus = core.Workload{ + Status: rt.WorkloadStatusStopped, + } + } else { + logger.Warnf("failed to get workload status for %s: %v", name, err) + continue + } + } + + // If not listing all, only include running workloads + if !listAll && workloadStatus.Status != rt.WorkloadStatusRunning { + continue + } + + // Map to core.Workload + workload := core.Workload{ + Name: name, + Package: runConfig.RemoteURL, + URL: runConfig.RemoteURL, + ToolType: "mcp", + TransportType: runConfig.Transport, + ProxyMode: runConfig.ProxyMode.String(), + Status: workloadStatus.Status, + StatusContext: workloadStatus.StatusContext, + CreatedAt: workloadStatus.CreatedAt, + Port: int(runConfig.Port), + Labels: runConfig.ContainerLabels, + Group: runConfig.Group, + ToolsFilter: runConfig.ToolsFilter, + Remote: true, + } + + // If label filters are provided, check if the workload matches them. + if types.MatchesLabelFilters(workload.Labels, parsedFilters) { + remoteWorkloads = append(remoteWorkloads, workload) + } + } + + return remoteWorkloads, nil +} diff --git a/pkg/workloads/cli_manager_test.go b/pkg/workloads/cli_manager_test.go new file mode 100644 index 000000000..e12fd2a87 --- /dev/null +++ b/pkg/workloads/cli_manager_test.go @@ -0,0 +1,1616 @@ +package workloads + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/sync/errgroup" + + "github.com/stacklok/toolhive/pkg/config" + configMocks "github.com/stacklok/toolhive/pkg/config/mocks" + "github.com/stacklok/toolhive/pkg/container/runtime" + runtimeMocks "github.com/stacklok/toolhive/pkg/container/runtime/mocks" + "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/runner" + statusMocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" +) + +func TestCLIManager_ListWorkloadsInGroup(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + groupName string + mockWorkloads []core.Workload + expectedNames []string + expectError bool + setupStatusMgr func(*statusMocks.MockStatusManager) + }{ + { + name: "non existent group returns empty list", + groupName: "non-group", + mockWorkloads: []core.Workload{ + {Name: "workload1", Group: "other-group"}, + {Name: "workload2", Group: "another-group"}, + }, + expectedNames: []string{}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ + {Name: "workload1", Group: "other-group"}, + {Name: "workload2", Group: "another-group"}, + }, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + { + name: "multiple workloads in group", + groupName: "test-group", + mockWorkloads: []core.Workload{ + {Name: "workload1", Group: "test-group"}, + {Name: "workload2", Group: "other-group"}, + {Name: "workload3", Group: "test-group"}, + {Name: "workload4", Group: "test-group"}, + }, + expectedNames: []string{"workload1", "workload3", "workload4"}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ + {Name: "workload1", Group: "test-group"}, + {Name: "workload2", Group: "other-group"}, + {Name: "workload3", Group: "test-group"}, + {Name: "workload4", Group: "test-group"}, + }, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + { + name: "workloads with empty group names", + groupName: "", + mockWorkloads: []core.Workload{ + {Name: "workload1", Group: ""}, + {Name: "workload2", Group: "test-group"}, + {Name: "workload3", Group: ""}, + }, + expectedNames: []string{"workload1", "workload3"}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ + {Name: "workload1", Group: ""}, + {Name: "workload2", Group: "test-group"}, + {Name: "workload3", Group: ""}, + }, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + { + name: "includes stopped workloads", + groupName: "test-group", + mockWorkloads: []core.Workload{ + {Name: "running-workload", Group: "test-group", Status: runtime.WorkloadStatusRunning}, + {Name: "stopped-workload", Group: "test-group", Status: runtime.WorkloadStatusStopped}, + {Name: "other-group-workload", Group: "other-group", Status: runtime.WorkloadStatusRunning}, + }, + expectedNames: []string{"running-workload", "stopped-workload"}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ + {Name: "running-workload", Group: "test-group", Status: runtime.WorkloadStatusRunning}, + {Name: "stopped-workload", Group: "test-group", Status: runtime.WorkloadStatusStopped}, + {Name: "other-group-workload", Group: "other-group", Status: runtime.WorkloadStatusRunning}, + }, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + { + name: "error from ListWorkloads propagated", + groupName: "test-group", + expectedNames: nil, + expectError: true, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return(nil, assert.AnError) + }, + }, + { + name: "no workloads", + groupName: "test-group", + mockWorkloads: []core.Workload{}, + expectedNames: []string{}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{}, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupStatusMgr(mockStatusMgr) + + manager := &cliManager{ + runtime: nil, // Not needed for this test + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.ListWorkloadsInGroup(ctx, tt.groupName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to list workloads") + return + } + + require.NoError(t, err) + assert.ElementsMatch(t, tt.expectedNames, result) + }) + } +} + +func TestCLIManager_DoesWorkloadExist(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + setupMocks func(*statusMocks.MockStatusManager) + expected bool + expectError bool + }{ + { + name: "workload exists and running", + workloadName: "test-workload", + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + }, + expected: true, + expectError: false, + }, + { + name: "workload exists but in error state", + workloadName: "error-workload", + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "error-workload").Return(core.Workload{ + Name: "error-workload", + Status: runtime.WorkloadStatusError, + }, nil) + }, + expected: false, + expectError: false, + }, + { + name: "workload not found", + workloadName: "missing-workload", + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "missing-workload").Return(core.Workload{}, runtime.ErrWorkloadNotFound) + }, + expected: false, + expectError: false, + }, + { + name: "error getting workload", + workloadName: "problematic-workload", + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "problematic-workload").Return(core.Workload{}, errors.New("database error")) + }, + expected: false, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockStatusMgr) + + manager := &cliManager{ + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.DoesWorkloadExist(ctx, tt.workloadName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to check if workload exists") + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestCLIManager_GetWorkload(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + expectedWorkload := core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + } + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + mockStatusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(expectedWorkload, nil) + + manager := &cliManager{ + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.GetWorkload(ctx, "test-workload") + + require.NoError(t, err) + assert.Equal(t, expectedWorkload, result) +} + +func TestCLIManager_GetLogs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + follow bool + setupMocks func(*runtimeMocks.MockRuntime) + expectedLogs string + expectError bool + errorMsg string + }{ + { + name: "successful log retrieval", + workloadName: "test-workload", + follow: false, + setupMocks: func(rt *runtimeMocks.MockRuntime) { + rt.EXPECT().GetWorkloadLogs(gomock.Any(), "test-workload", false).Return("test log content", nil) + }, + expectedLogs: "test log content", + expectError: false, + }, + { + name: "workload not found", + workloadName: "missing-workload", + follow: false, + setupMocks: func(rt *runtimeMocks.MockRuntime) { + rt.EXPECT().GetWorkloadLogs(gomock.Any(), "missing-workload", false).Return("", runtime.ErrWorkloadNotFound) + }, + expectedLogs: "", + expectError: true, + errorMsg: "workload not found", + }, + { + name: "runtime error", + workloadName: "error-workload", + follow: true, + setupMocks: func(rt *runtimeMocks.MockRuntime) { + rt.EXPECT().GetWorkloadLogs(gomock.Any(), "error-workload", true).Return("", errors.New("runtime failure")) + }, + expectedLogs: "", + expectError: true, + errorMsg: "failed to get container logs", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + tt.setupMocks(mockRuntime) + + manager := &cliManager{ + runtime: mockRuntime, + } + + ctx := context.Background() + logs, err := manager.GetLogs(ctx, tt.workloadName, tt.follow) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedLogs, logs) + } + }) + } +} + +func TestCLIManager_StopWorkloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadNames []string + expectError bool + errorMsg string + }{ + { + name: "invalid workload name with path traversal", + workloadNames: []string{"../etc/passwd"}, + expectError: true, + errorMsg: "path traversal", + }, + { + name: "invalid workload name with slash", + workloadNames: []string{"workload/name"}, + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "empty workload name list", + workloadNames: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + manager := &cliManager{} + + ctx := context.Background() + group, err := manager.StopWorkloads(ctx, tt.workloadNames) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, group) + } else { + require.NoError(t, err) + assert.NotNil(t, group) + assert.IsType(t, &errgroup.Group{}, group) + } + }) + } +} + +func TestCLIManager_DeleteWorkloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadNames []string + expectError bool + errorMsg string + }{ + { + name: "invalid workload name", + workloadNames: []string{"../../../etc/passwd"}, + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "mixed valid and invalid names", + workloadNames: []string{"valid-name", "invalid../name"}, + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "empty workload name list", + workloadNames: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + manager := &cliManager{} + + ctx := context.Background() + group, err := manager.DeleteWorkloads(ctx, tt.workloadNames) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, group) + } else { + require.NoError(t, err) + assert.NotNil(t, group) + assert.IsType(t, &errgroup.Group{}, group) + } + }) + } +} + +func TestCLIManager_RestartWorkloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadNames []string + foreground bool + expectError bool + errorMsg string + }{ + { + name: "invalid workload name", + workloadNames: []string{"invalid/name"}, + foreground: false, + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "empty workload name list", + workloadNames: []string{}, + foreground: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + manager := &cliManager{} + + ctx := context.Background() + group, err := manager.RestartWorkloads(ctx, tt.workloadNames, tt.foreground) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, group) + } else { + require.NoError(t, err) + assert.NotNil(t, group) + assert.IsType(t, &errgroup.Group{}, group) + } + }) + } +} + +func TestCLIManager_restartRemoteWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + runConfig *runner.RunConfig + foreground bool + setupMocks func(*statusMocks.MockStatusManager) + expectError bool + errorMsg string + }{ + { + name: "remote workload already running with healthy supervisor", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-base", + RemoteURL: "http://example.com", + }, + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + // Check if supervisor is alive - return valid PID (supervisor is healthy) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(12345, nil) + }, + // With healthy supervisor, restart should return early (no-op) + expectError: false, + }, + { + name: "remote workload already running with dead supervisor", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-base", + RemoteURL: "http://example.com", + }, + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + // Check if supervisor is alive - return error (supervisor is dead) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(0, errors.New("no PID found")) + // With dead supervisor, restart proceeds with cleanup and restart + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(nil) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(0, errors.New("no PID found")) + // Allow any subsequent status updates + sm.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + sm.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + }, + // Restart now proceeds to load state which fails in tests (can't mock runner.LoadState easily) + expectError: true, + errorMsg: "failed to load state", + }, + { + name: "status manager error", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-base", + RemoteURL: "http://example.com", + }, + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{}, errors.New("status manager error")) + }, + expectError: true, + errorMsg: "status manager error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(statusMgr) + + manager := &cliManager{ + statuses: statusMgr, + } + + err := manager.restartRemoteWorkload(context.Background(), tt.workloadName, tt.runConfig, tt.foreground) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCLIManager_restartContainerWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + foreground bool + setupMocks func(*statusMocks.MockStatusManager, *runtimeMocks.MockRuntime) + expectError bool + errorMsg string + }{ + { + name: "container workload already running with healthy supervisor", + workloadName: "container-workload", + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { + // Mock container info + rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ + Name: "container-workload", + State: runtime.WorkloadStatusRunning, + Labels: map[string]string{ + "toolhive.base-name": "container-workload", + }, + }, nil) + sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{ + Name: "container-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + // Check if supervisor is alive - return valid PID (supervisor is healthy) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(12345, nil) + }, + // With healthy supervisor, restart should return early (no-op) + expectError: false, + }, + { + name: "container workload already running with dead supervisor", + workloadName: "container-workload", + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { + // Mock container info + rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ + Name: "container-workload", + State: runtime.WorkloadStatusRunning, + Labels: map[string]string{ + "toolhive.base-name": "container-workload", + }, + }, nil) + sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{ + Name: "container-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + // Check if supervisor is alive - return error (supervisor is dead) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(0, errors.New("no PID found")) + // With dead supervisor, restart proceeds with cleanup and restart + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "container-workload", runtime.WorkloadStatusStopping, "").Return(nil) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(0, errors.New("no PID found")) + rm.EXPECT().StopWorkload(gomock.Any(), "container-workload").Return(nil) + // Allow any subsequent status updates + sm.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + sm.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + }, + // Restart now proceeds to load state which fails in tests (can't mock runner.LoadState easily) + expectError: true, + errorMsg: "failed to load state", + }, + { + name: "status manager error", + workloadName: "container-workload", + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { + // Mock container info + rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ + Name: "container-workload", + State: "running", + Labels: map[string]string{ + "toolhive.base-name": "container-workload", + }, + }, nil) + sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{}, errors.New("status manager error")) + }, + expectError: true, + errorMsg: "status manager error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) + + tt.setupMocks(statusMgr, runtimeMgr) + + manager := &cliManager{ + statuses: statusMgr, + runtime: runtimeMgr, + } + + err := manager.restartContainerWorkload(context.Background(), tt.workloadName, tt.foreground) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestCLIManager_restartLogicConsistency tests restart behavior with healthy vs dead supervisor +func TestCLIManager_restartLogicConsistency(t *testing.T) { + t.Parallel() + + t.Run("remote_workload_healthy_supervisor_no_restart", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + + statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + + // Check if supervisor is alive - return valid PID (healthy) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(12345, nil) + + manager := &cliManager{ + statuses: statusMgr, + } + + runConfig := &runner.RunConfig{ + BaseName: "test-base", + RemoteURL: "http://example.com", + } + + err := manager.restartRemoteWorkload(context.Background(), "test-workload", runConfig, false) + + // With healthy supervisor, restart should return successfully without doing anything + require.NoError(t, err) + }) + + t.Run("remote_workload_dead_supervisor_calls_stop", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + + statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + + // Check if supervisor is alive - return error (dead supervisor) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(0, errors.New("no PID found")) + + // When supervisor is dead, expect stop logic to be called + statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(0, errors.New("no PID found")) + + // Allow any subsequent status updates - we don't care about the exact sequence + statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + manager := &cliManager{ + statuses: statusMgr, + } + + runConfig := &runner.RunConfig{ + BaseName: "test-base", + RemoteURL: "http://example.com", + } + + _ = manager.restartRemoteWorkload(context.Background(), "test-workload", runConfig, false) + + // The important part is that the stop methods were called (verified by mock expectations) + // We don't care if the restart ultimately succeeds or fails + }) + + t.Run("container_workload_healthy_supervisor_no_restart", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) + + containerInfo := runtime.ContainerInfo{ + Name: "test-workload", + State: runtime.WorkloadStatusRunning, + Labels: map[string]string{ + "toolhive.base-name": "test-workload", + }, + } + runtimeMgr.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(containerInfo, nil) + + statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + + // Check if supervisor is alive - return valid PID (healthy) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(12345, nil) + + manager := &cliManager{ + statuses: statusMgr, + runtime: runtimeMgr, + } + + err := manager.restartContainerWorkload(context.Background(), "test-workload", false) + + // With healthy supervisor, restart should return successfully without doing anything + require.NoError(t, err) + }) + + t.Run("container_workload_dead_supervisor_calls_stop", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) + + containerInfo := runtime.ContainerInfo{ + Name: "test-workload", + State: runtime.WorkloadStatusRunning, + Labels: map[string]string{ + "toolhive.base-name": "test-workload", + }, + } + runtimeMgr.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(containerInfo, nil) + + statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + + // Check if supervisor is alive - return error (dead supervisor) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(0, errors.New("no PID found")) + + // When supervisor is dead, expect stop logic to be called + statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(0, errors.New("no PID found")) + runtimeMgr.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) + + // Allow any subsequent status updates (starting, error, etc.) - we don't care about the exact sequence + statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + manager := &cliManager{ + statuses: statusMgr, + runtime: runtimeMgr, + } + + _ = manager.restartContainerWorkload(context.Background(), "test-workload", false) + + // The important part is that the stop methods were called (verified by mock expectations) + // We don't care if the restart ultimately succeeds or fails + }) +} + +func TestCLIManager_RunWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + runConfig *runner.RunConfig + setupMocks func(*statusMocks.MockStatusManager) + expectError bool + errorMsg string + }{ + { + name: "successful run - status creation", + runConfig: &runner.RunConfig{ + BaseName: "test-workload", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + // Expect starting status first, then error status when the runner fails + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStarting, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, gomock.Any()).Return(nil) + }, + expectError: true, // The runner will fail without proper setup + }, + { + name: "status creation failure", + runConfig: &runner.RunConfig{ + BaseName: "failing-workload", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", runtime.WorkloadStatusStarting, "").Return(errors.New("status creation failed")) + }, + expectError: true, + errorMsg: "failed to create workload status", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockStatusMgr) + + manager := &cliManager{ + statuses: mockStatusMgr, + } + + ctx := context.Background() + err := manager.RunWorkload(ctx, tt.runConfig) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCLIManager_validateSecretParameters(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + runConfig *runner.RunConfig + setupMocks func(*configMocks.MockProvider) + expectError bool + errorMsg string + }{ + { + name: "no secrets - should pass", + runConfig: &runner.RunConfig{ + Secrets: []string{}, + }, + setupMocks: func(*configMocks.MockProvider) {}, // No expectations + expectError: false, + }, + { + name: "config error", + runConfig: &runner.RunConfig{ + Secrets: []string{"secret1"}, + }, + setupMocks: func(cp *configMocks.MockProvider) { + mockConfig := &config.Config{} + cp.EXPECT().GetConfig().Return(mockConfig) + }, + expectError: true, + errorMsg: "error determining secrets provider type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConfigProvider := configMocks.NewMockProvider(ctrl) + tt.setupMocks(mockConfigProvider) + + manager := &cliManager{ + configProvider: mockConfigProvider, + } + + ctx := context.Background() + err := manager.validateSecretParameters(ctx, tt.runConfig) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCLIManager_getWorkloadContainer(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) + expected *runtime.ContainerInfo + expectError bool + errorMsg string + }{ + { + name: "successful retrieval", + workloadName: "test-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { + expectedContainer := runtime.ContainerInfo{ + Name: "test-workload", + State: runtime.WorkloadStatusRunning, + } + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(expectedContainer, nil) + }, + expected: &runtime.ContainerInfo{ + Name: "test-workload", + State: runtime.WorkloadStatusRunning, + }, + expectError: false, + }, + { + name: "workload not found", + workloadName: "missing-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "missing-workload").Return(runtime.ContainerInfo{}, runtime.ErrWorkloadNotFound) + }, + expected: nil, + expectError: false, // getWorkloadContainer returns nil for not found, not error + }, + { + name: "runtime error", + workloadName: "error-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "error-workload").Return(runtime.ContainerInfo{}, errors.New("runtime failure")) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "error-workload", runtime.WorkloadStatusError, "runtime failure").Return(nil) + }, + expected: nil, + expectError: true, + errorMsg: "failed to find workload", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockRuntime, mockStatusMgr) + + manager := &cliManager{ + runtime: mockRuntime, + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.getWorkloadContainer(ctx, tt.workloadName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + if tt.expected == nil { + assert.Nil(t, result) + } else { + assert.Equal(t, tt.expected, result) + } + } + }) + } +} + +func TestCLIManager_removeContainer(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) + expectError bool + errorMsg string + }{ + { + name: "successful removal", + workloadName: "test-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { + rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(nil) + }, + expectError: false, + }, + { + name: "removal failure", + workloadName: "failing-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + rt.EXPECT().RemoveWorkload(gomock.Any(), "failing-workload").Return(errors.New("removal failed")) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", runtime.WorkloadStatusError, "removal failed").Return(nil) + }, + expectError: true, + errorMsg: "failed to remove container", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockRuntime, mockStatusMgr) + + manager := &cliManager{ + runtime: mockRuntime, + statuses: mockStatusMgr, + } + + ctx := context.Background() + err := manager.removeContainer(ctx, tt.workloadName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCLIManager_needSecretsPassword(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + secretOptions []string + setupMocks func(*configMocks.MockProvider) + expected bool + }{ + { + name: "no secrets", + secretOptions: []string{}, + setupMocks: func(*configMocks.MockProvider) {}, // No expectations + expected: false, + }, + { + name: "has secrets but config access fails", + secretOptions: []string{"secret1"}, + setupMocks: func(cp *configMocks.MockProvider) { + mockConfig := &config.Config{} + cp.EXPECT().GetConfig().Return(mockConfig) + }, + expected: false, // Returns false when provider type detection fails + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConfigProvider := configMocks.NewMockProvider(ctrl) + tt.setupMocks(mockConfigProvider) + + manager := &cliManager{ + configProvider: mockConfigProvider, + } + + result := manager.needSecretsPassword(tt.secretOptions) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCLIManager_RunWorkloadDetached(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + runConfig *runner.RunConfig + setupMocks func(*statusMocks.MockStatusManager, *configMocks.MockProvider) + expectError bool + errorMsg string + }{ + { + name: "validation failure should not reach PID management", + runConfig: &runner.RunConfig{ + BaseName: "test-workload", + Secrets: []string{"invalid-secret"}, + }, + setupMocks: func(_ *statusMocks.MockStatusManager, cp *configMocks.MockProvider) { + // Mock config provider to cause validation failure + mockConfig := &config.Config{} + cp.EXPECT().GetConfig().Return(mockConfig) + // No SetWorkloadPID expectation since validation should fail first + }, + expectError: true, + errorMsg: "failed to validate workload parameters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + mockConfigProvider := configMocks.NewMockProvider(ctrl) + tt.setupMocks(mockStatusMgr, mockConfigProvider) + + manager := &cliManager{ + statuses: mockStatusMgr, + configProvider: mockConfigProvider, + } + + ctx := context.Background() + err := manager.RunWorkloadDetached(ctx, tt.runConfig) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestCLIManager_RunWorkloadDetached_PIDManagement tests that PID management +// happens in the later stages of RunWorkloadDetached when the process actually starts. +// This is tested indirectly by verifying the behavior exists in the code flow. +func TestCLIManager_RunWorkloadDetached_PIDManagement(t *testing.T) { + t.Parallel() + + // This test documents the expected behavior: + // 1. RunWorkloadDetached calls SetWorkloadPID after starting the detached process + // 2. The PID management happens after validation and process creation + // 3. SetWorkloadPID failures are logged as warnings but don't fail the operation + + // Since RunWorkloadDetached involves spawning actual processes and complex setup, + // we verify the PID management integration exists by checking the method signature + // and code structure rather than running the full integration. + + manager := &cliManager{} + assert.NotNil(t, manager, "cliManager should be instantiable") + + // Verify the method exists with the correct signature + var runWorkloadDetachedFunc interface{} = manager.RunWorkloadDetached + assert.NotNil(t, runWorkloadDetachedFunc, "RunWorkloadDetached method should exist") +} + +func TestAsyncOperationTimeout(t *testing.T) { + t.Parallel() + + // Test that the timeout constant is properly defined + assert.Equal(t, 5*time.Minute, AsyncOperationTimeout) +} + +func TestErrWorkloadNotRunning(t *testing.T) { + t.Parallel() + + // Test that the error is properly defined + assert.Error(t, ErrWorkloadNotRunning) + assert.Contains(t, ErrWorkloadNotRunning.Error(), "workload not running") +} + +func TestCLIManager_ListWorkloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + listAll bool + labelFilters []string + setupMocks func(*statusMocks.MockStatusManager) + expected []core.Workload + expectError bool + errorMsg string + }{ + { + name: "successful listing without filters", + listAll: true, + labelFilters: []string{}, + setupMocks: func(sm *statusMocks.MockStatusManager) { + workloads := []core.Workload{ + {Name: "workload1", Status: runtime.WorkloadStatusRunning}, + {Name: "workload2", Status: runtime.WorkloadStatusStopped}, + } + sm.EXPECT().ListWorkloads(gomock.Any(), true, []string{}).Return(workloads, nil) + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + expected: []core.Workload{ + {Name: "workload1", Status: runtime.WorkloadStatusRunning}, + {Name: "workload2", Status: runtime.WorkloadStatusStopped}, + }, + expectError: false, + }, + { + name: "error from status manager", + listAll: false, + labelFilters: []string{"env=prod"}, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), false, []string{"env=prod"}).Return(nil, errors.New("database error")) + }, + expected: nil, + expectError: true, + errorMsg: "database error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockStatusMgr) + + manager := &cliManager{ + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.ListWorkloads(ctx, tt.listAll, tt.labelFilters...) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + // We expect this to succeed but might include remote workloads + // Since getRemoteWorkloadsFromState will likely fail in unit tests, + // we mainly verify the container workloads are returned + require.NoError(t, err) + assert.GreaterOrEqual(t, len(result), len(tt.expected)) + // Verify at least our expected container workloads are present + for _, expectedWorkload := range tt.expected { + found := false + for _, actualWorkload := range result { + if actualWorkload.Name == expectedWorkload.Name { + found = true + break + } + } + assert.True(t, found, fmt.Sprintf("Expected workload %s not found in result", expectedWorkload.Name)) + } + } + }) + } +} + +func TestCLIManager_UpdateWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + expectError bool + errorMsg string + setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) + }{ + { + name: "invalid workload name with slash", + workloadName: "invalid/name", + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "invalid workload name with backslash", + workloadName: "invalid\\name", + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "invalid workload name with path traversal", + workloadName: "../invalid", + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "valid workload name returns errgroup immediately", + workloadName: "valid-workload", + expectError: false, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // Mock calls that will happen in the background goroutine + // We don't care about the success/failure, just that it doesn't panic + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "valid-workload"). + Return(runtime.ContainerInfo{}, errors.New("not found")).AnyTimes() + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "valid-workload", gomock.Any(), gomock.Any()). + Return(nil).AnyTimes() + }, + }, + { + name: "UpdateWorkload returns errgroup even if async operation will fail", + workloadName: "failing-workload", + expectError: false, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // The async operation will fail, but UpdateWorkload itself should succeed + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "failing-workload"). + Return(runtime.ContainerInfo{}, errors.New("container lookup failed")).AnyTimes() + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", gomock.Any(), gomock.Any()). + Return(nil).AnyTimes() + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + mockStatusManager := statusMocks.NewMockStatusManager(ctrl) + mockConfigProvider := configMocks.NewMockProvider(ctrl) + + if tt.setupMocks != nil { + tt.setupMocks(mockRuntime, mockStatusManager) + } + + manager := &cliManager{ + runtime: mockRuntime, + statuses: mockStatusManager, + configProvider: mockConfigProvider, + } + + // Create a dummy RunConfig for testing + runConfig := &runner.RunConfig{ + ContainerName: tt.workloadName, + BaseName: tt.workloadName, + } + + ctx := context.Background() + group, err := manager.UpdateWorkload(ctx, tt.workloadName, runConfig) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + assert.Nil(t, group) + } else { + assert.NoError(t, err) + assert.NotNil(t, group) + // For valid cases, we get an errgroup but don't wait for completion + // The async operations inside are tested separately + } + }) + } +} + +func TestCLIManager_updateSingleWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + runConfig *runner.RunConfig + setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) + expectError bool + errorMsg string + }{ + { + name: "stop operation fails", + workloadName: "test-workload", + runConfig: &runner.RunConfig{ + ContainerName: "test-workload", + BaseName: "test-workload", + Group: "default", + }, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // Mock the stop operation - return error for GetWorkloadInfo + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{}, errors.New("container lookup failed")).AnyTimes() + // Still expect status updates to be attempted + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil).AnyTimes() + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, "").Return(nil).AnyTimes() + }, + expectError: true, + errorMsg: "failed to stop workload", + }, + { + name: "successful stop and delete operations complete correctly", + workloadName: "test-workload", + runConfig: &runner.RunConfig{ + ContainerName: "test-workload", + BaseName: "test-workload", + Group: "default", + }, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // Mock stop operation - workload exists and can be stopped + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{ + Name: "test-workload", + State: "running", + Labels: map[string]string{"toolhive-basename": "test-workload"}, + }, nil) + // Mock GetWorkloadPID call from stopProcess + sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil) + rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) + sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil) + + // Mock delete operation - workload exists and can be deleted + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{Name: "test-workload"}, nil) + rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(nil) + + // Mock status updates for stop and delete phases + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopped, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusRemoving, "").Return(nil) + sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "test-workload").Return(nil) + + // Mock RunWorkloadDetached calls - expect the ones that will be called + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStarting, "").Return(nil) + sm.EXPECT().SetWorkloadPID(gomock.Any(), "test-workload", gomock.Any()).Return(nil) + }, + expectError: false, // Test passes - update process completes successfully + }, + { + name: "delete operation fails after successful stop", + workloadName: "test-workload", + runConfig: &runner.RunConfig{ + ContainerName: "test-workload", + BaseName: "test-workload", + Group: "default", + }, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // Mock successful stop + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{ + Name: "test-workload", + State: "running", + Labels: map[string]string{"toolhive-basename": "test-workload"}, + }, nil) + // Mock GetWorkloadPID call from stopProcess + sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil) + rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) + sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil) + + // Mock failed delete + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{Name: "test-workload"}, nil) + rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(errors.New("delete failed")) + + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopped, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusRemoving, "").Return(nil) + // RemoveWorkload fails, so error status is set + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, "delete failed").Return(nil) + }, + expectError: true, + errorMsg: "failed to delete workload", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + mockStatusManager := statusMocks.NewMockStatusManager(ctrl) + mockConfigProvider := configMocks.NewMockProvider(ctrl) + + if tt.setupMocks != nil { + tt.setupMocks(mockRuntime, mockStatusManager) + } + + manager := &cliManager{ + runtime: mockRuntime, + statuses: mockStatusManager, + configProvider: mockConfigProvider, + } + + err := manager.updateSingleWorkload(tt.workloadName, tt.runConfig) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/workloads/k8s_manager.go b/pkg/workloads/k8s_manager.go new file mode 100644 index 000000000..0c880c755 --- /dev/null +++ b/pkg/workloads/k8s_manager.go @@ -0,0 +1,351 @@ +// Package workloads provides a Kubernetes-based implementation of the Manager interface. +// This file contains the Kubernetes implementation for operator environments. +package workloads + +import ( + "context" + "fmt" + "time" + + "golang.org/x/sync/errgroup" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/selection" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/runner" + "github.com/stacklok/toolhive/pkg/transport" + transporttypes "github.com/stacklok/toolhive/pkg/transport/types" + workloadtypes "github.com/stacklok/toolhive/pkg/workloads/types" +) + +// k8sManager implements the Manager interface for Kubernetes environments. +// In Kubernetes, the operator manages workload lifecycle via MCPServer CRDs. +// This manager provides read-only operations and CRD-based storage. +type k8sManager struct { + k8sClient client.Client + namespace string +} + +// NewK8SManager creates a new Kubernetes-based workload manager. +func NewK8SManager(k8sClient client.Client, namespace string) (Manager, error) { + return &k8sManager{ + k8sClient: k8sClient, + namespace: namespace, + }, nil +} + +func (k *k8sManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { + mcpServer := &mcpv1alpha1.MCPServer{} + key := types.NamespacedName{Name: workloadName, Namespace: k.namespace} + if err := k.k8sClient.Get(ctx, key, mcpServer); err != nil { + if errors.IsNotFound(err) { + return core.Workload{}, fmt.Errorf("%w: %s", rt.ErrWorkloadNotFound, workloadName) + } + return core.Workload{}, fmt.Errorf("failed to get MCPServer: %w", err) + } + + return k.mcpServerToWorkload(mcpServer) +} + +func (k *k8sManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { + mcpServer := &mcpv1alpha1.MCPServer{} + key := types.NamespacedName{Name: workloadName, Namespace: k.namespace} + if err := k.k8sClient.Get(ctx, key, mcpServer); err != nil { + if errors.IsNotFound(err) { + return false, nil + } + return false, fmt.Errorf("failed to check if workload exists: %w", err) + } + return true, nil +} + +func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { + mcpServerList := &mcpv1alpha1.MCPServerList{} + listOpts := []client.ListOption{ + client.InNamespace(k.namespace), + } + + // Parse label filters if provided + if len(labelFilters) > 0 { + parsedFilters, err := workloadtypes.ParseLabelFilters(labelFilters) + if err != nil { + return nil, fmt.Errorf("failed to parse label filters: %w", err) + } + + // Build label selector from filters (equality matching) + labelSelector := labels.NewSelector() + for key, value := range parsedFilters { + requirement, err := labels.NewRequirement(key, selection.Equals, []string{value}) + if err != nil { + return nil, fmt.Errorf("failed to create label requirement: %w", err) + } + labelSelector = labelSelector.Add(*requirement) + } + listOpts = append(listOpts, client.MatchingLabelsSelector{Selector: labelSelector}) + } + + if err := k.k8sClient.List(ctx, mcpServerList, listOpts...); err != nil { + return nil, fmt.Errorf("failed to list MCPServers: %w", err) + } + + var workloads []core.Workload + for i := range mcpServerList.Items { + mcpServer := &mcpServerList.Items[i] + + // Filter by status if listAll is false + if !listAll { + phase := mcpServer.Status.Phase + if phase != mcpv1alpha1.MCPServerPhaseRunning { + continue + } + } + + workload, err := k.mcpServerToWorkload(mcpServer) + if err != nil { + logger.Warnf("Failed to convert MCPServer %s to workload: %v", mcpServer.Name, err) + continue + } + + workloads = append(workloads, workload) + } + + return workloads, nil +} + +// StopWorkloads is a no-op in Kubernetes mode. +// The operator manages workload lifecycle via MCPServer CRDs. +func (*k8sManager) StopWorkloads(_ context.Context, _ []string) (*errgroup.Group, error) { + logger.Warnf("StopWorkloads is not supported in Kubernetes mode. Use kubectl to manage MCPServer CRDs.") + group := &errgroup.Group{} + // Return empty group - no operations to perform + return group, nil +} + +// RunWorkload is a no-op in Kubernetes mode. +// Workloads are created via MCPServer CRDs managed by the operator. +func (*k8sManager) RunWorkload(_ context.Context, _ *runner.RunConfig) error { + return fmt.Errorf("RunWorkload is not supported in Kubernetes mode. Create MCPServer CRD instead") +} + +// RunWorkloadDetached is a no-op in Kubernetes mode. +// Workloads are created via MCPServer CRDs managed by the operator. +func (*k8sManager) RunWorkloadDetached(_ context.Context, _ *runner.RunConfig) error { + return fmt.Errorf("RunWorkloadDetached is not supported in Kubernetes mode. Create MCPServer CRD instead") +} + +// DeleteWorkloads is a no-op in Kubernetes mode. +// The operator manages workload lifecycle via MCPServer CRDs. +func (*k8sManager) DeleteWorkloads(_ context.Context, _ []string) (*errgroup.Group, error) { + logger.Warnf("DeleteWorkloads is not supported in Kubernetes mode. Use kubectl to delete MCPServer CRDs.") + group := &errgroup.Group{} + // Return empty group - no operations to perform + return group, nil +} + +// RestartWorkloads is a no-op in Kubernetes mode. +// The operator manages workload lifecycle via MCPServer CRDs. +func (*k8sManager) RestartWorkloads(_ context.Context, _ []string, _ bool) (*errgroup.Group, error) { + logger.Warnf("RestartWorkloads is not supported in Kubernetes mode. Use kubectl to restart MCPServer CRDs.") + group := &errgroup.Group{} + // Return empty group - no operations to perform + return group, nil +} + +// UpdateWorkload is a no-op in Kubernetes mode. +// The operator manages workload lifecycle via MCPServer CRDs. +func (*k8sManager) UpdateWorkload(_ context.Context, _ string, _ *runner.RunConfig) (*errgroup.Group, error) { + logger.Warnf("UpdateWorkload is not supported in Kubernetes mode. Update MCPServer CRD instead.") + group := &errgroup.Group{} + // Return empty group - no operations to perform + return group, nil +} + +// GetLogs retrieves logs from the pod associated with the MCPServer. +// Note: This requires a Kubernetes clientset for log streaming. +// For now, this returns an error indicating logs should be retrieved via kubectl. +// TODO: Implement proper log retrieval using clientset or REST client. +func (k *k8sManager) GetLogs(_ context.Context, _ string, follow bool) (string, error) { + if follow { + return "", fmt.Errorf("follow mode is not supported. Use 'kubectl logs -f -n %s' to stream logs", k.namespace) + } + return "", fmt.Errorf( + "GetLogs is not fully implemented in Kubernetes mode. Use 'kubectl logs -n %s' to retrieve logs", + k.namespace) +} + +// GetProxyLogs retrieves logs from the proxy container in the pod associated with the MCPServer. +// Note: This requires a Kubernetes clientset for log streaming. +// For now, this returns an error indicating logs should be retrieved via kubectl. +// TODO: Implement proper log retrieval using clientset or REST client. +func (k *k8sManager) GetProxyLogs(_ context.Context, _ string) (string, error) { + return "", fmt.Errorf( + "GetProxyLogs is not fully implemented in Kubernetes mode. Use 'kubectl logs -c proxy -n %s' to retrieve proxy logs", + k.namespace) +} + +// MoveToGroup moves the specified workloads from one group to another. +func (k *k8sManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { + for _, name := range workloadNames { + mcpServer := &mcpv1alpha1.MCPServer{} + key := types.NamespacedName{Name: name, Namespace: k.namespace} + if err := k.k8sClient.Get(ctx, key, mcpServer); err != nil { + if errors.IsNotFound(err) { + return fmt.Errorf("MCPServer %s not found", name) + } + return fmt.Errorf("failed to get MCPServer: %w", err) + } + + // Verify the workload is in the expected group + if mcpServer.Spec.GroupRef != groupFrom { + return fmt.Errorf("workload %s is not in group %s (current group: %s)", name, groupFrom, mcpServer.Spec.GroupRef) + } + + // Update the group + mcpServer.Spec.GroupRef = groupTo + + // Update the MCPServer + if err := k.k8sClient.Update(ctx, mcpServer); err != nil { + return fmt.Errorf("failed to update MCPServer %s: %w", name, err) + } + } + + return nil +} + +// ListWorkloadsInGroup returns all workload names that belong to the specified group. +func (k *k8sManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + mcpServerList := &mcpv1alpha1.MCPServerList{} + listOpts := []client.ListOption{ + client.InNamespace(k.namespace), + } + + if err := k.k8sClient.List(ctx, mcpServerList, listOpts...); err != nil { + return nil, fmt.Errorf("failed to list MCPServers: %w", err) + } + + var groupWorkloads []string + for i := range mcpServerList.Items { + mcpServer := &mcpServerList.Items[i] + if mcpServer.Spec.GroupRef == groupName { + groupWorkloads = append(groupWorkloads, mcpServer.Name) + } + } + + return groupWorkloads, nil +} + +// mcpServerToWorkload converts an MCPServer CRD to a core.Workload. +func (k *k8sManager) mcpServerToWorkload(mcpServer *mcpv1alpha1.MCPServer) (core.Workload, error) { + // Map MCPServerPhase to runtime.WorkloadStatus + status := k.mcpServerPhaseToWorkloadStatus(mcpServer.Status.Phase) + + // Parse transport type + transportType, err := transporttypes.ParseTransportType(mcpServer.Spec.Transport) + if err != nil { + logger.Warnf("Failed to parse transport type %s for MCPServer %s: %v", mcpServer.Spec.Transport, mcpServer.Name, err) + transportType = transporttypes.TransportTypeSSE + } + + // Calculate effective proxy mode + effectiveProxyMode := workloadtypes.GetEffectiveProxyMode(transportType, mcpServer.Spec.ProxyMode) + + // Generate URL from status or reconstruct from spec + url := mcpServer.Status.URL + if url == "" { + port := int(mcpServer.Spec.ProxyPort) + if port == 0 { + port = int(mcpServer.Spec.Port) // Fallback to deprecated Port field + } + if port > 0 { + url = transport.GenerateMCPServerURL(mcpServer.Spec.Transport, transport.LocalhostIPv4, port, mcpServer.Name, "") + } + } + + port := int(mcpServer.Spec.ProxyPort) + if port == 0 { + port = int(mcpServer.Spec.Port) // Fallback to deprecated Port field + } + + // Get tools filter from spec + toolsFilter := mcpServer.Spec.ToolsFilter + if mcpServer.Spec.ToolConfigRef != nil { + // If ToolConfigRef is set, we can't reconstruct the tools filter here + // The tools filter would be resolved by the operator + toolsFilter = []string{} + } + + // Extract user labels from annotations (Kubernetes doesn't have container labels like Docker) + userLabels := make(map[string]string) + if mcpServer.Annotations != nil { + // Filter out standard Kubernetes annotations + for key, value := range mcpServer.Annotations { + if !k.isStandardK8sAnnotation(key) { + userLabels[key] = value + } + } + } + + // Get creation timestamp + createdAt := mcpServer.CreationTimestamp.Time + if createdAt.IsZero() { + createdAt = time.Now() + } + + return core.Workload{ + Name: mcpServer.Name, + Package: mcpServer.Spec.Image, + URL: url, + ToolType: "mcp", + TransportType: transportType, + ProxyMode: effectiveProxyMode, + Status: status, + StatusContext: mcpServer.Status.Message, + CreatedAt: createdAt, + Port: port, + Labels: userLabels, + Group: mcpServer.Spec.GroupRef, + ToolsFilter: toolsFilter, + Remote: false, // MCPServers are always container workloads in Kubernetes + }, nil +} + +// mcpServerPhaseToWorkloadStatus maps MCPServerPhase to runtime.WorkloadStatus. +func (*k8sManager) mcpServerPhaseToWorkloadStatus(phase mcpv1alpha1.MCPServerPhase) rt.WorkloadStatus { + switch phase { + case mcpv1alpha1.MCPServerPhaseRunning: + return rt.WorkloadStatusRunning + case mcpv1alpha1.MCPServerPhasePending: + return rt.WorkloadStatusStarting + case mcpv1alpha1.MCPServerPhaseFailed: + return rt.WorkloadStatusError + case mcpv1alpha1.MCPServerPhaseTerminating: + return rt.WorkloadStatusStopping + default: + return rt.WorkloadStatusUnknown + } +} + +// isStandardK8sAnnotation checks if an annotation key is a standard Kubernetes annotation. +func (*k8sManager) isStandardK8sAnnotation(key string) bool { + // Common Kubernetes annotation prefixes + standardPrefixes := []string{ + "kubectl.kubernetes.io/", + "kubernetes.io/", + "deployment.kubernetes.io/", + "k8s.io/", + } + + for _, prefix := range standardPrefixes { + if len(key) >= len(prefix) && key[:len(prefix)] == prefix { + return true + } + } + + return false +} diff --git a/pkg/workloads/k8s_manager_test.go b/pkg/workloads/k8s_manager_test.go new file mode 100644 index 000000000..516c92697 --- /dev/null +++ b/pkg/workloads/k8s_manager_test.go @@ -0,0 +1,777 @@ +package workloads + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "sigs.k8s.io/controller-runtime/pkg/client" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/runner" +) + +const ( + defaultNamespace = "default" + testWorkload1 = "workload1" +) + +// mockClient is a mock implementation of client.Client for testing +type mockClient struct { + client.Client + getFunc func(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error + listFunc func(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error + updateFunc func(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error +} + +func (m *mockClient) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + if m.getFunc != nil { + return m.getFunc(ctx, key, obj, opts...) + } + return k8serrors.NewNotFound(schema.GroupResource{}, key.Name) +} + +func (m *mockClient) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + if m.listFunc != nil { + return m.listFunc(ctx, list, opts...) + } + return nil +} + +func (m *mockClient) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { + if m.updateFunc != nil { + return m.updateFunc(ctx, obj, opts...) + } + return nil +} + +func TestNewK8SManager(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + k8sClient client.Client + namespace string + wantError bool + }{ + { + name: "successful creation", + k8sClient: &mockClient{}, + namespace: defaultNamespace, + wantError: false, + }, + { + name: "empty namespace", + k8sClient: &mockClient{}, + namespace: "", + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + manager, err := NewK8SManager(tt.k8sClient, tt.namespace) + + if tt.wantError { + require.Error(t, err) + assert.Nil(t, manager) + } else { + require.NoError(t, err) + require.NotNil(t, manager) + + k8sMgr, ok := manager.(*k8sManager) + require.True(t, ok) + assert.Equal(t, tt.k8sClient, k8sMgr.k8sClient) + assert.Equal(t, tt.namespace, k8sMgr.namespace) + } + }) + } +} + +func TestK8SManager_GetWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + setupMock func(*mockClient) + wantError bool + errorMsg string + expected core.Workload + }{ + { + name: "successful get", + workloadName: "test-workload", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { + if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { + mcpServer.Name = "test-workload" + mcpServer.Namespace = defaultNamespace + mcpServer.Status.Phase = mcpv1alpha1.MCPServerPhaseRunning + mcpServer.Spec.Transport = "streamable-http" + mcpServer.Spec.ProxyPort = 8080 + mcpServer.Labels = map[string]string{ + "group": "test-group", + } + } + return nil + } + }, + wantError: false, + expected: core.Workload{ + Name: "test-workload", + Status: rt.WorkloadStatusRunning, + URL: "http://127.0.0.1:8080/mcp", // URL is generated from spec + Labels: map[string]string{ + "group": "test-group", + }, + }, + }, + { + name: "workload not found", + workloadName: "non-existent", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, key client.ObjectKey, _ client.Object, _ ...client.GetOption) error { + return k8serrors.NewNotFound(schema.GroupResource{Resource: "mcpservers"}, key.Name) + } + }, + wantError: true, + errorMsg: "workload not found", + }, + { + name: "get error", + workloadName: "error-workload", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, _ client.ObjectKey, _ client.Object, _ ...client.GetOption) error { + return k8serrors.NewInternalError(errors.New("internal error")) + } + }, + wantError: true, + errorMsg: "failed to get MCPServer", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockClient := &mockClient{} + tt.setupMock(mockClient) + + manager := &k8sManager{ + k8sClient: mockClient, + namespace: defaultNamespace, + } + + ctx := context.Background() + result, err := manager.GetWorkload(ctx, tt.workloadName) + + if tt.wantError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected.Name, result.Name) + assert.Equal(t, tt.expected.Status, result.Status) + assert.Equal(t, tt.expected.URL, result.URL) + } + }) + } +} + +func TestK8SManager_DoesWorkloadExist(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + setupMock func(*mockClient) + expected bool + wantError bool + }{ + { + name: "workload exists", + workloadName: "existing-workload", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { + if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { + mcpServer.Name = "existing-workload" + } + return nil + } + }, + expected: true, + wantError: false, + }, + { + name: "workload does not exist", + workloadName: "non-existent", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, key client.ObjectKey, _ client.Object, _ ...client.GetOption) error { + return k8serrors.NewNotFound(schema.GroupResource{Resource: "mcpservers"}, key.Name) + } + }, + expected: false, + wantError: false, + }, + { + name: "get error", + workloadName: "error-workload", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, _ client.ObjectKey, _ client.Object, _ ...client.GetOption) error { + return k8serrors.NewInternalError(errors.New("internal error")) + } + }, + expected: false, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockClient := &mockClient{} + tt.setupMock(mockClient) + + manager := &k8sManager{ + k8sClient: mockClient, + namespace: defaultNamespace, + } + + ctx := context.Background() + result, err := manager.DoesWorkloadExist(ctx, tt.workloadName) + + if tt.wantError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestK8SManager_ListWorkloadsInGroup(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + groupName string + setupMock func(*mockClient) + expected []string + wantError bool + errorMsg string + }{ + { + name: "successful list", + groupName: "test-group", + setupMock: func(mc *mockClient) { + mc.listFunc = func(_ context.Context, list client.ObjectList, _ ...client.ListOption) error { + if mcpServerList, ok := list.(*mcpv1alpha1.MCPServerList); ok { + mcpServerList.Items = []mcpv1alpha1.MCPServer{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: testWorkload1, + }, + Spec: mcpv1alpha1.MCPServerSpec{ + GroupRef: "test-group", + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "workload2", + }, + Spec: mcpv1alpha1.MCPServerSpec{ + GroupRef: "test-group", + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "workload3", + }, + Spec: mcpv1alpha1.MCPServerSpec{ + GroupRef: "other-group", + }, + }, + } + } + return nil + } + }, + expected: []string{testWorkload1, "workload2"}, + wantError: false, + }, + { + name: "empty group", + groupName: "empty-group", + setupMock: func(mc *mockClient) { + mc.listFunc = func(_ context.Context, list client.ObjectList, _ ...client.ListOption) error { + if mcpServerList, ok := list.(*mcpv1alpha1.MCPServerList); ok { + mcpServerList.Items = []mcpv1alpha1.MCPServer{} + } + return nil + } + }, + expected: []string{}, + wantError: false, + }, + { + name: "list error", + groupName: "test-group", + setupMock: func(mc *mockClient) { + mc.listFunc = func(_ context.Context, _ client.ObjectList, _ ...client.ListOption) error { + return k8serrors.NewInternalError(errors.New("internal error")) + } + }, + expected: nil, + wantError: true, + errorMsg: "failed to list MCPServers", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockClient := &mockClient{} + tt.setupMock(mockClient) + + manager := &k8sManager{ + k8sClient: mockClient, + namespace: defaultNamespace, + } + + ctx := context.Background() + result, err := manager.ListWorkloadsInGroup(ctx, tt.groupName) + + if tt.wantError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + assert.ElementsMatch(t, tt.expected, result) + } + }) + } +} + +func TestK8SManager_ListWorkloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + listAll bool + labelFilters []string + setupMock func(*mockClient) + expected int + wantError bool + errorMsg string + }{ + { + name: "successful list all", + listAll: true, + setupMock: func(mc *mockClient) { + mc.listFunc = func(_ context.Context, list client.ObjectList, _ ...client.ListOption) error { + if mcpServerList, ok := list.(*mcpv1alpha1.MCPServerList); ok { + mcpServerList.Items = []mcpv1alpha1.MCPServer{ + { + ObjectMeta: metav1.ObjectMeta{Name: testWorkload1}, + Status: mcpv1alpha1.MCPServerStatus{Phase: mcpv1alpha1.MCPServerPhaseRunning}, + }, + { + ObjectMeta: metav1.ObjectMeta{Name: "workload2"}, + Status: mcpv1alpha1.MCPServerStatus{Phase: mcpv1alpha1.MCPServerPhaseTerminating}, + }, + } + } + return nil + } + }, + expected: 2, + wantError: false, + }, + { + name: "list with label filters", + listAll: true, + labelFilters: []string{"env=prod"}, + setupMock: func(mc *mockClient) { + mc.listFunc = func(_ context.Context, list client.ObjectList, _ ...client.ListOption) error { + if mcpServerList, ok := list.(*mcpv1alpha1.MCPServerList); ok { + mcpServerList.Items = []mcpv1alpha1.MCPServer{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: testWorkload1, + Labels: map[string]string{"env": "prod"}, + }, + Status: mcpv1alpha1.MCPServerStatus{Phase: mcpv1alpha1.MCPServerPhaseRunning}, + }, + } + } + return nil + } + }, + expected: 1, + wantError: false, + }, + { + name: "invalid label filter", + listAll: true, + labelFilters: []string{"invalid-filter"}, + setupMock: func(*mockClient) { + // No list call expected due to filter parsing error + }, + expected: 0, + wantError: true, + errorMsg: "failed to parse label filters", + }, + { + name: "list error", + listAll: true, + setupMock: func(mc *mockClient) { + mc.listFunc = func(_ context.Context, _ client.ObjectList, _ ...client.ListOption) error { + return k8serrors.NewInternalError(errors.New("internal error")) + } + }, + expected: 0, + wantError: true, + errorMsg: "failed to list MCPServers", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockClient := &mockClient{} + tt.setupMock(mockClient) + + manager := &k8sManager{ + k8sClient: mockClient, + namespace: defaultNamespace, + } + + ctx := context.Background() + result, err := manager.ListWorkloads(ctx, tt.listAll, tt.labelFilters...) + + if tt.wantError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + assert.Len(t, result, tt.expected) + } + }) + } +} + +func TestK8SManager_MoveToGroup(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadNames []string + groupFrom string + groupTo string + setupMock func(*mockClient) + wantError bool + errorMsg string + }{ + { + name: "successful move", + workloadNames: []string{testWorkload1}, + groupFrom: "old-group", + groupTo: "new-group", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { + if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { + mcpServer.Name = testWorkload1 + mcpServer.Namespace = defaultNamespace + mcpServer.Spec.GroupRef = "old-group" + } + return nil + } + mc.updateFunc = func(_ context.Context, _ client.Object, _ ...client.UpdateOption) error { + return nil + } + }, + wantError: false, + }, + { + name: "workload not found", + workloadNames: []string{"non-existent"}, + groupFrom: "old-group", + groupTo: "new-group", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, key client.ObjectKey, _ client.Object, _ ...client.GetOption) error { + return k8serrors.NewNotFound(schema.GroupResource{Resource: "mcpservers"}, key.Name) + } + }, + wantError: true, + errorMsg: "MCPServer", + }, + { + name: "workload in different group", + workloadNames: []string{testWorkload1}, + groupFrom: "old-group", + groupTo: "new-group", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { + if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { + mcpServer.Name = testWorkload1 + mcpServer.Namespace = defaultNamespace + mcpServer.Spec.GroupRef = "different-group" + } + return nil + } + }, + wantError: true, // Returns error when group doesn't match + errorMsg: "is not in group", + }, + { + name: "update error", + workloadNames: []string{testWorkload1}, + groupFrom: "old-group", + groupTo: "new-group", + setupMock: func(mc *mockClient) { + mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { + if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { + mcpServer.Name = testWorkload1 + mcpServer.Namespace = defaultNamespace + mcpServer.Spec.GroupRef = "old-group" + } + return nil + } + mc.updateFunc = func(_ context.Context, _ client.Object, _ ...client.UpdateOption) error { + return k8serrors.NewInternalError(errors.New("update failed")) + } + }, + wantError: true, + errorMsg: "failed to update MCPServer", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockClient := &mockClient{} + tt.setupMock(mockClient) + + manager := &k8sManager{ + k8sClient: mockClient, + namespace: defaultNamespace, + } + + ctx := context.Background() + err := manager.MoveToGroup(ctx, tt.workloadNames, tt.groupFrom, tt.groupTo) + + if tt.wantError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestK8SManager_NoOpMethods(t *testing.T) { + t.Parallel() + + mockClient := &mockClient{} + manager := &k8sManager{ + k8sClient: mockClient, + namespace: defaultNamespace, + } + + ctx := context.Background() + + t.Run("StopWorkloads returns empty group", func(t *testing.T) { + t.Parallel() + group, err := manager.StopWorkloads(ctx, []string{testWorkload1}) + require.NoError(t, err) + require.NotNil(t, group) + }) + + t.Run("RunWorkload returns error", func(t *testing.T) { + t.Parallel() + err := manager.RunWorkload(ctx, &runner.RunConfig{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not supported in Kubernetes mode") + }) + + t.Run("RunWorkloadDetached returns error", func(t *testing.T) { + t.Parallel() + err := manager.RunWorkloadDetached(ctx, &runner.RunConfig{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not supported in Kubernetes mode") + }) + + t.Run("DeleteWorkloads returns empty group", func(t *testing.T) { + t.Parallel() + group, err := manager.DeleteWorkloads(ctx, []string{testWorkload1}) + require.NoError(t, err) + require.NotNil(t, group) + }) + + t.Run("RestartWorkloads returns empty group", func(t *testing.T) { + t.Parallel() + group, err := manager.RestartWorkloads(ctx, []string{testWorkload1}, false) + require.NoError(t, err) + require.NotNil(t, group) + }) + + t.Run("UpdateWorkload returns empty group", func(t *testing.T) { + t.Parallel() + group, err := manager.UpdateWorkload(ctx, testWorkload1, &runner.RunConfig{}) + require.NoError(t, err) + require.NotNil(t, group) + }) + + t.Run("GetLogs returns error", func(t *testing.T) { + t.Parallel() + logs, err := manager.GetLogs(ctx, testWorkload1, false) + require.Error(t, err) + assert.Empty(t, logs) + assert.Contains(t, err.Error(), "not fully implemented") + }) + + t.Run("GetProxyLogs returns error", func(t *testing.T) { + t.Parallel() + logs, err := manager.GetProxyLogs(ctx, testWorkload1) + require.Error(t, err) + assert.Empty(t, logs) + assert.Contains(t, err.Error(), "not fully implemented") + }) +} + +func TestK8SManager_mcpServerToWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mcpServer *mcpv1alpha1.MCPServer + expected core.Workload + }{ + { + name: "running workload with HTTP transport", + mcpServer: &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-workload", + Annotations: map[string]string{ + "group": "test-group", + "env": "prod", + }, + }, + Status: mcpv1alpha1.MCPServerStatus{ + Phase: mcpv1alpha1.MCPServerPhaseRunning, + URL: "http://localhost:8080", + }, + Spec: mcpv1alpha1.MCPServerSpec{ + Transport: "streamable-http", + ProxyPort: 8080, + }, + }, + expected: core.Workload{ + Name: "test-workload", + Status: rt.WorkloadStatusRunning, + URL: "http://localhost:8080", + Labels: map[string]string{"group": "test-group", "env": "prod"}, + }, + }, + { + name: "terminating workload", + mcpServer: &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "terminating-workload", + }, + Status: mcpv1alpha1.MCPServerStatus{ + Phase: mcpv1alpha1.MCPServerPhaseTerminating, + }, + }, + expected: core.Workload{ + Name: "terminating-workload", + Status: rt.WorkloadStatusStopping, + }, + }, + { + name: "failed workload", + mcpServer: &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "failed-workload", + }, + Status: mcpv1alpha1.MCPServerStatus{ + Phase: mcpv1alpha1.MCPServerPhaseFailed, + }, + }, + expected: core.Workload{ + Name: "failed-workload", + Status: rt.WorkloadStatusError, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + manager := &k8sManager{ + namespace: defaultNamespace, + } + + result, err := manager.mcpServerToWorkload(tt.mcpServer) + require.NoError(t, err) + + assert.Equal(t, tt.expected.Name, result.Name) + assert.Equal(t, tt.expected.Status, result.Status) + assert.Equal(t, tt.expected.URL, result.URL) + if tt.expected.Labels != nil { + assert.Equal(t, tt.expected.Labels, result.Labels) + } + }) + } +} + +func TestK8SManager_mcpServerPhaseToWorkloadStatus(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + phase mcpv1alpha1.MCPServerPhase + expected rt.WorkloadStatus + }{ + {"running", mcpv1alpha1.MCPServerPhaseRunning, rt.WorkloadStatusRunning}, + {"pending", mcpv1alpha1.MCPServerPhasePending, rt.WorkloadStatusStarting}, + {"failed", mcpv1alpha1.MCPServerPhaseFailed, rt.WorkloadStatusError}, + {"terminating", mcpv1alpha1.MCPServerPhaseTerminating, rt.WorkloadStatusStopping}, + {"unknown", mcpv1alpha1.MCPServerPhase(""), rt.WorkloadStatusUnknown}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + manager := &k8sManager{} + result := manager.mcpServerPhaseToWorkloadStatus(tt.phase) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index 7aac9e5d6..3490f0594 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -4,31 +4,21 @@ package workloads import ( "context" - "errors" "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "time" - "github.com/adrg/xdg" "golang.org/x/sync/errgroup" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" - "github.com/stacklok/toolhive/pkg/client" + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" "github.com/stacklok/toolhive/pkg/config" - ct "github.com/stacklok/toolhive/pkg/container" + "github.com/stacklok/toolhive/pkg/container/kubernetes" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/process" "github.com/stacklok/toolhive/pkg/runner" - "github.com/stacklok/toolhive/pkg/secrets" - "github.com/stacklok/toolhive/pkg/state" - "github.com/stacklok/toolhive/pkg/transport" - "github.com/stacklok/toolhive/pkg/workloads/statuses" - "github.com/stacklok/toolhive/pkg/workloads/types" ) // Manager is responsible for managing the state of ToolHive-managed containers. @@ -71,1213 +61,68 @@ type Manager interface { DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) } -type defaultManager struct { - runtime rt.Runtime - statuses statuses.StatusManager - configProvider config.Provider -} - // ErrWorkloadNotRunning is returned when a container cannot be found by name. var ErrWorkloadNotRunning = fmt.Errorf("workload not running") -const ( - // AsyncOperationTimeout is the timeout for async workload operations - AsyncOperationTimeout = 5 * time.Minute -) - -// NewManager creates a new container manager instance. +// NewManager creates a new workload manager based on the runtime environment: +// - In Kubernetes mode: returns a CRD-based manager that uses MCPServer CRDs +// - In local mode: returns a CLI/filesystem-based manager func NewManager(ctx context.Context) (Manager, error) { - runtime, err := ct.NewFactory().Create(ctx) - if err != nil { - return nil, err - } - - statusManager, err := statuses.NewStatusManager(runtime) - if err != nil { - return nil, fmt.Errorf("failed to create status manager: %w", err) + if rt.IsKubernetesRuntime() { + return newK8SManager(ctx) } - - return &defaultManager{ - runtime: runtime, - statuses: statusManager, - configProvider: config.NewDefaultProvider(), - }, nil + return NewCLIManager(ctx) } -// NewManagerWithProvider creates a new container manager instance with a custom config provider. +// NewManagerWithProvider creates a new workload manager with a custom config provider. func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) (Manager, error) { - runtime, err := ct.NewFactory().Create(ctx) - if err != nil { - return nil, err - } - - statusManager, err := statuses.NewStatusManager(runtime) - if err != nil { - return nil, fmt.Errorf("failed to create status manager: %w", err) - } - - return &defaultManager{ - runtime: runtime, - statuses: statusManager, - configProvider: configProvider, - }, nil -} - -// NewManagerFromRuntime creates a new container manager instance from an existing runtime. -func NewManagerFromRuntime(runtime rt.Runtime) (Manager, error) { - statusManager, err := statuses.NewStatusManager(runtime) - if err != nil { - return nil, fmt.Errorf("failed to create status manager: %w", err) - } - - return &defaultManager{ - runtime: runtime, - statuses: statusManager, - configProvider: config.NewDefaultProvider(), - }, nil -} - -// NewManagerFromRuntimeWithProvider creates a new container manager instance from an existing runtime with a -// custom config provider. -func NewManagerFromRuntimeWithProvider(runtime rt.Runtime, configProvider config.Provider) (Manager, error) { - statusManager, err := statuses.NewStatusManager(runtime) - if err != nil { - return nil, fmt.Errorf("failed to create status manager: %w", err) - } - - return &defaultManager{ - runtime: runtime, - statuses: statusManager, - configProvider: configProvider, - }, nil -} - -func (d *defaultManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { - // For the sake of minimizing changes, delegate to the status manager. - // Whether this method should still belong to the workload manager is TBD. - return d.statuses.GetWorkload(ctx, workloadName) -} - -func (d *defaultManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { - // check if workload exists by trying to get it - workload, err := d.statuses.GetWorkload(ctx, workloadName) - if err != nil { - if errors.Is(err, rt.ErrWorkloadNotFound) { - return false, nil - } - return false, fmt.Errorf("failed to check if workload exists: %w", err) - } - - // now check if the workload is not in error - if workload.Status == rt.WorkloadStatusError { - return false, nil - } - return true, nil -} - -func (d *defaultManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { - // For the sake of minimizing changes, delegate to the status manager. - // Whether this method should still belong to the workload manager is TBD. - containerWorkloads, err := d.statuses.ListWorkloads(ctx, listAll, labelFilters) - if err != nil { - return nil, err - } - - // Get remote workloads from the state store - remoteWorkloads, err := d.getRemoteWorkloadsFromState(ctx, listAll, labelFilters) - if err != nil { - logger.Warnf("Failed to get remote workloads from state: %v", err) - // Continue with container workloads only - } else { - // Combine container and remote workloads - containerWorkloads = append(containerWorkloads, remoteWorkloads...) - } - - return containerWorkloads, nil -} - -func (d *defaultManager) StopWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { - // Validate all workload names to prevent path traversal attacks - for _, name := range names { - if err := types.ValidateWorkloadName(name); err != nil { - return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) - } - // Ensure workload name does not contain path traversal or separators - if strings.Contains(name, "..") || strings.ContainsAny(name, "/\\") { - return nil, fmt.Errorf("invalid workload name '%s': contains forbidden characters", name) - } - } - - group := &errgroup.Group{} - // Process each workload - for _, name := range names { - group.Go(func() error { - return d.stopSingleWorkload(name) - }) - } - - return group, nil -} - -// stopSingleWorkload stops a single workload (container or remote) -func (d *defaultManager) stopSingleWorkload(name string) error { - // Create a child context with a longer timeout - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - // First, try to load the run configuration to check if it's a remote workload - runConfig, err := runner.LoadState(childCtx, name) - if err != nil { - // If we can't load the state, it might be a container workload or the workload doesn't exist - // Try to stop it as a container workload - return d.stopContainerWorkload(childCtx, name) - } - - // Check if this is a remote workload - if runConfig.RemoteURL != "" { - return d.stopRemoteWorkload(childCtx, name, runConfig) - } - - // This is a container-based workload - return d.stopContainerWorkload(childCtx, name) -} - -// stopRemoteWorkload stops a remote workload -func (d *defaultManager) stopRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { - logger.Infof("Stopping remote workload %s...", name) - - // Check if the workload is running by checking its status - workload, err := d.statuses.GetWorkload(ctx, name) - if err != nil { - if errors.Is(err, rt.ErrWorkloadNotFound) { - // Log but don't fail the entire operation for not found workload - logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) - return nil - } - return fmt.Errorf("failed to find workload %s: %v", name, err) - } - - if workload.Status != rt.WorkloadStatusRunning { - logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) - return nil - } - - // Set status to stopping - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) - } - - // Stop proxy if running - if runConfig.BaseName != "" { - d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) - } - - // For remote workloads, we only need to clean up client configurations - // The saved state should be preserved for restart capability - if err := removeClientConfigurations(name, false); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } else { - logger.Infof("Client configurations for %s removed", name) - } - - // Set status to stopped - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopped, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopped: %v", name, err) - } - logger.Infof("Remote workload %s stopped successfully", name) - return nil -} - -// stopContainerWorkload stops a container-based workload -func (d *defaultManager) stopContainerWorkload(ctx context.Context, name string) error { - container, err := d.runtime.GetWorkloadInfo(ctx, name) - if err != nil { - if errors.Is(err, rt.ErrWorkloadNotFound) { - // Log but don't fail the entire operation for not found containers - logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) - return nil - } - return fmt.Errorf("failed to find workload %s: %v", name, err) - } - - running := container.IsRunning() - if !running { - // Log but don't fail the entire operation for not running containers - logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) - return nil - } - - // Transition workload to `stopping` state. - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) - } - - // Use the existing stopWorkloads method for container workloads - return d.stopSingleContainerWorkload(ctx, &container) -} - -func (d *defaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error { - // Ensure that the workload has a status entry before starting the process. - if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { - // Failure to create the initial state is a fatal error. - return fmt.Errorf("failed to create workload status: %v", err) - } - - mcpRunner := runner.NewRunner(runConfig, d.statuses) - err := mcpRunner.Run(ctx) - if err != nil { - // If the run failed, we should set the status to error. - if statusErr := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, statusErr) - } - } - return err -} - -func (d *defaultManager) validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { - // If there are run secrets, validate them - - hasRegularSecrets := len(runConfig.Secrets) > 0 - hasRemoteAuthSecret := runConfig.RemoteAuthConfig != nil && runConfig.RemoteAuthConfig.ClientSecret != "" - - if hasRegularSecrets || hasRemoteAuthSecret { - cfg := d.configProvider.GetConfig() - - providerType, err := cfg.Secrets.GetProviderType() - if err != nil { - return fmt.Errorf("error determining secrets provider type: %w", err) - } - - secretManager, err := secrets.CreateSecretProvider(providerType) - if err != nil { - return fmt.Errorf("error instantiating secret manager: %w", err) - } - - err = runConfig.ValidateSecrets(ctx, secretManager) - if err != nil { - return fmt.Errorf("error processing secrets: %w", err) - } + if rt.IsKubernetesRuntime() { + return newK8SManager(ctx) } - return nil + return NewCLIManagerWithProvider(ctx, configProvider) } -func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { - // before running, validate the parameters for the workload - err := d.validateSecretParameters(ctx, runConfig) - if err != nil { - return fmt.Errorf("failed to validate workload parameters: %w", err) - } - - // Get the current executable path - execPath, err := os.Executable() - if err != nil { - return fmt.Errorf("failed to get executable path: %v", err) - } - - // Create a log file for the detached process - logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", runConfig.BaseName)) - if err != nil { - return fmt.Errorf("failed to create log file path: %v", err) - } - // #nosec G304 - This is safe as baseName is generated by the application - logFile, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) - if err != nil { - logger.Warnf("Warning: Failed to create log file: %v", err) - } else { - defer logFile.Close() - logger.Infof("Logging to: %s", logFilePath) +// NewManagerFromRuntime creates a new workload manager from an existing runtime. +func NewManagerFromRuntime(rtRuntime rt.Runtime) (Manager, error) { + if rt.IsKubernetesRuntime() { + // In Kubernetes mode, we need a k8s client, not a runtime + ctx := context.Background() + return newK8SManager(ctx) } - - // Use the restart command to start the detached process - // The config has already been saved to disk, so restart can load it - detachedArgs := []string{"restart", runConfig.BaseName, "--foreground"} - - // Create a new command - // #nosec G204 - This is safe as execPath is the path to the current binary - detachedCmd := exec.Command(execPath, detachedArgs...) - - // Set environment variables for the detached process - detachedCmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", process.ToolHiveDetachedEnv, process.ToolHiveDetachedValue)) - - // If we need the decrypt password, set it as an environment variable in the detached process. - // NOTE: This breaks the abstraction slightly since this is only relevant for the CLI, but there - // are checks inside `GetSecretsPassword` to ensure this does not get called in a detached process. - // This will be addressed in a future re-think of the secrets manager interface. - if d.needSecretsPassword(runConfig.Secrets) { - password, err := secrets.GetSecretsPassword("") - if err != nil { - return fmt.Errorf("failed to get secrets password: %v", err) - } - detachedCmd.Env = append(detachedCmd.Env, fmt.Sprintf("%s=%s", secrets.PasswordEnvVar, password)) - } - - // Redirect stdout and stderr to the log file if it was created successfully - if logFile != nil { - detachedCmd.Stdout = logFile - detachedCmd.Stderr = logFile - } else { - // Otherwise, discard the output - detachedCmd.Stdout = nil - detachedCmd.Stderr = nil - } - - // Detach the process from the terminal - detachedCmd.Stdin = nil - detachedCmd.SysProcAttr = getSysProcAttr() - - // Ensure that the workload has a status entry before starting the process. - if err = d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { - // Failure to create the initial state is a fatal error. - return fmt.Errorf("failed to create workload status: %v", err) - } - - // Start the detached process - if err := detachedCmd.Start(); err != nil { - // If the start failed, we need to set the status to error before returning. - if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, ""); err != nil { - logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, err) - } - return fmt.Errorf("failed to start detached process: %v", err) - } - - // Write the PID to a file so the stop command can kill the process - // TODO: Stop writing to PID file once we migrate over to statuses fully. - if err := process.WritePIDFile(runConfig.BaseName, detachedCmd.Process.Pid); err != nil { - logger.Warnf("Warning: Failed to write PID file: %v", err) - } - if err := d.statuses.SetWorkloadPID(ctx, runConfig.BaseName, detachedCmd.Process.Pid); err != nil { - logger.Warnf("Failed to set workload %s PID: %v", runConfig.BaseName, err) - } - - logger.Infof("MCP server is running in the background (PID: %d)", detachedCmd.Process.Pid) - logger.Infof("Use 'thv stop %s' to stop the server", runConfig.ContainerName) - - return nil + return NewCLIManagerFromRuntime(rtRuntime) } -func (d *defaultManager) GetLogs(ctx context.Context, workloadName string, follow bool) (string, error) { - // Get the logs from the runtime - logs, err := d.runtime.GetWorkloadLogs(ctx, workloadName, follow) - if err != nil { - // Propagate the error if the container is not found - if errors.Is(err, rt.ErrWorkloadNotFound) { - return "", fmt.Errorf("%w: %s", rt.ErrWorkloadNotFound, workloadName) - } - return "", fmt.Errorf("failed to get container logs %s: %v", workloadName, err) +// NewManagerFromRuntimeWithProvider creates a new workload manager from an existing runtime with a custom config provider. +func NewManagerFromRuntimeWithProvider(rtRuntime rt.Runtime, configProvider config.Provider) (Manager, error) { + if rt.IsKubernetesRuntime() { + // In Kubernetes mode, we need a k8s client, not a runtime + ctx := context.Background() + return newK8SManager(ctx) } - - return logs, nil + return NewCLIManagerFromRuntimeWithProvider(rtRuntime, configProvider) } -// GetProxyLogs retrieves proxy logs from the filesystem -func (*defaultManager) GetProxyLogs(_ context.Context, workloadName string) (string, error) { - // Get the proxy log file path - logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", workloadName)) - if err != nil { - return "", fmt.Errorf("failed to get proxy log file path for workload %s: %w", workloadName, err) - } - - // Clean the file path to prevent path traversal - cleanLogFilePath := filepath.Clean(logFilePath) +// newK8SManager creates a Kubernetes-based workload manager for Kubernetes environments +func newK8SManager(context.Context) (Manager, error) { + // Create a scheme for controller-runtime client + scheme := runtime.NewScheme() + utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(mcpv1alpha1.AddToScheme(scheme)) - // Check if the log file exists - if _, err := os.Stat(cleanLogFilePath); os.IsNotExist(err) { - return "", fmt.Errorf("proxy logs not found for workload %s", workloadName) - } - - // Read and return the entire log file - content, err := os.ReadFile(cleanLogFilePath) + // Get Kubernetes config + cfg, err := ctrl.GetConfig() if err != nil { - return "", fmt.Errorf("failed to read proxy log for workload %s: %w", workloadName, err) + return nil, fmt.Errorf("failed to get Kubernetes config: %w", err) } - return string(content), nil -} - -// deleteWorkload handles deletion of a single workload -func (d *defaultManager) deleteWorkload(name string) error { - // Create a child context with a longer timeout - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - // First, check if this is a remote workload by trying to load its run configuration - runConfig, err := runner.LoadState(childCtx, name) + // Create controller-runtime client + k8sClient, err := client.New(cfg, client.Options{Scheme: scheme}) if err != nil { - // If we can't load the state, it might be a container workload or the workload doesn't exist - // Continue with the container-based deletion logic - return d.deleteContainerWorkload(childCtx, name) + return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) } - // If this is a remote workload (has RemoteURL), handle it differently - if runConfig.RemoteURL != "" { - return d.deleteRemoteWorkload(childCtx, name, runConfig) - } - - // This is a container-based workload, use the existing logic - return d.deleteContainerWorkload(childCtx, name) -} - -// deleteRemoteWorkload handles deletion of a remote workload -func (d *defaultManager) deleteRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { - logger.Infof("Removing remote workload %s...", name) - - // Set status to removing - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusRemoving, ""); err != nil { - logger.Warnf("Failed to set workload %s status to removing: %v", name, err) - return err - } - - // Stop proxy if running - if runConfig.BaseName != "" { - d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) - } - - // Clean up associated resources (remote workloads are not auxiliary) - d.cleanupWorkloadResources(ctx, name, runConfig.BaseName, false) - - // Remove the workload status from the status store - if err := d.statuses.DeleteWorkloadStatus(ctx, name); err != nil { - logger.Warnf("failed to delete workload status for %s: %v", name, err) - } - - logger.Infof("Remote workload %s removed successfully", name) - return nil -} - -// deleteContainerWorkload handles deletion of a container-based workload (existing logic) -func (d *defaultManager) deleteContainerWorkload(ctx context.Context, name string) error { - - // Find and validate the container - container, err := d.getWorkloadContainer(ctx, name) - if err != nil { - return err - } - - // Set status to removing - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusRemoving, ""); err != nil { - logger.Warnf("Failed to set workload %s status to removing: %v", name, err) - } - - if container != nil { - containerLabels := container.Labels - baseName := labels.GetContainerBaseName(containerLabels) - - // Stop proxy if running (skip for auxiliary workloads like inspector) - if container.IsRunning() { - // Skip proxy stopping for auxiliary workloads that don't use proxy processes - if labels.IsAuxiliaryWorkload(containerLabels) { - logger.Debugf("Skipping proxy stop for auxiliary workload %s", name) - } else { - d.stopProxyIfNeeded(ctx, name, baseName) - } - } - - // Remove the container - if err := d.removeContainer(ctx, name); err != nil { - return err - } - - // Clean up associated resources - d.cleanupWorkloadResources(ctx, name, baseName, labels.IsAuxiliaryWorkload(containerLabels)) - } - - // Remove the workload status from the status store - if err := d.statuses.DeleteWorkloadStatus(ctx, name); err != nil { - logger.Warnf("failed to delete workload status for %s: %v", name, err) - } - - return nil -} - -// getWorkloadContainer retrieves workload container info with error handling -func (d *defaultManager) getWorkloadContainer(ctx context.Context, name string) (*rt.ContainerInfo, error) { - container, err := d.runtime.GetWorkloadInfo(ctx, name) - if err != nil { - if errors.Is(err, rt.ErrWorkloadNotFound) { - // Log but don't fail the entire operation for not found containers - logger.Warnf("Warning: Failed to get workload %s: %v", name, err) - return nil, nil - } - if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) - } - return nil, fmt.Errorf("failed to find workload %s: %v", name, err) - } - return &container, nil -} - -// isSupervisorProcessAlive checks if the supervisor process for a workload is alive -// by checking if a PID exists. If a PID exists, we assume the supervisor is running. -// This is a reasonable assumption because: -// - If the supervisor exits cleanly, it cleans up the PID -// - If killed unexpectedly, the PID remains but stopProcess will handle it gracefully -// - The main issue we're preventing is accumulating zombie supervisors from repeated restarts -func (d *defaultManager) isSupervisorProcessAlive(ctx context.Context, name string) bool { - if name == "" { - return false - } - - // Try to read the PID - if it exists, assume supervisor is running - _, err := d.statuses.GetWorkloadPID(ctx, name) - if err != nil { - // No PID found, supervisor is not running - return false - } - - // PID exists, assume supervisor is alive - return true -} - -// stopProcess stops the proxy process associated with the container -func (d *defaultManager) stopProcess(ctx context.Context, name string) { - if name == "" { - logger.Warnf("Warning: Could not find base container name in labels") - return - } - - // Try to read the PID and kill the process - pid, err := d.statuses.GetWorkloadPID(ctx, name) - if err != nil { - logger.Errorf("No PID file found for %s, proxy may not be running in detached mode", name) - return - } - - // PID file found, try to kill the process - logger.Infof("Stopping proxy process (PID: %d)...", pid) - if err := process.KillProcess(pid); err != nil { - logger.Warnf("Warning: Failed to kill proxy process: %v", err) - } else { - logger.Info("Proxy process stopped") - } - - // Clean up PID file after successful kill - if err := process.RemovePIDFile(name); err != nil { - logger.Warnf("Warning: Failed to remove PID file: %v", err) - } -} - -// stopProxyIfNeeded stops the proxy process if the workload has a base name -func (d *defaultManager) stopProxyIfNeeded(ctx context.Context, name, baseName string) { - logger.Infof("Removing proxy process for %s...", name) - if baseName != "" { - d.stopProcess(ctx, baseName) - } -} - -// removeContainer removes the container from the runtime -func (d *defaultManager) removeContainer(ctx context.Context, name string) error { - logger.Infof("Removing container %s...", name) - if err := d.runtime.RemoveWorkload(ctx, name); err != nil { - if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) - } - return fmt.Errorf("failed to remove container: %v", err) - } - return nil -} - -// cleanupWorkloadResources cleans up all resources associated with a workload -func (d *defaultManager) cleanupWorkloadResources(ctx context.Context, name, baseName string, isAuxiliary bool) { - if baseName == "" { - return - } - - // Clean up temporary permission profile - if err := d.cleanupTempPermissionProfile(ctx, baseName); err != nil { - logger.Warnf("Warning: Failed to cleanup temporary permission profile: %v", err) - } - - // Remove client configurations - if err := removeClientConfigurations(name, isAuxiliary); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } else { - logger.Infof("Client configurations for %s removed", name) - } - - // Delete the saved state last (skip for auxiliary workloads that don't have run configs) - if !isAuxiliary { - if err := state.DeleteSavedRunConfig(ctx, baseName); err != nil { - logger.Warnf("Warning: Failed to delete saved state: %v", err) - } else { - logger.Infof("Saved state for %s removed", baseName) - } - } else { - logger.Debugf("Skipping saved state deletion for auxiliary workload %s", name) - } - - logger.Infof("Container %s removed", name) -} - -func (d *defaultManager) DeleteWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { - // Validate all workload names to prevent path traversal attacks - for _, name := range names { - if err := types.ValidateWorkloadName(name); err != nil { - return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) - } - } - - group := &errgroup.Group{} - - for _, name := range names { - group.Go(func() error { - return d.deleteWorkload(name) - }) - } - - return group, nil -} - -// RestartWorkloads restarts the specified workloads by name. -func (d *defaultManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) { - // Validate all workload names to prevent path traversal attacks - for _, name := range names { - if err := types.ValidateWorkloadName(name); err != nil { - return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) - } - } - - group := &errgroup.Group{} - - for _, name := range names { - group.Go(func() error { - return d.restartSingleWorkload(name, foreground) - }) - } - - return group, nil -} - -// UpdateWorkload updates a workload by stopping, deleting, and recreating it -func (d *defaultManager) UpdateWorkload(_ context.Context, workloadName string, newConfig *runner.RunConfig) (*errgroup.Group, error) { //nolint:lll - // Validate workload name - if err := types.ValidateWorkloadName(workloadName); err != nil { - return nil, fmt.Errorf("invalid workload name '%s': %w", workloadName, err) - } - - group := &errgroup.Group{} - group.Go(func() error { - return d.updateSingleWorkload(workloadName, newConfig) - }) - return group, nil -} - -// updateSingleWorkload handles the update logic for a single workload -func (d *defaultManager) updateSingleWorkload(workloadName string, newConfig *runner.RunConfig) error { - // Create a child context with a longer timeout - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - logger.Infof("Starting update for workload %s", workloadName) - - // Stop the existing workload - if err := d.stopSingleWorkload(workloadName); err != nil { - return fmt.Errorf("failed to stop workload: %w", err) - } - logger.Infof("Successfully stopped workload %s", workloadName) - - // Delete the existing workload - if err := d.deleteWorkload(workloadName); err != nil { - return fmt.Errorf("failed to delete workload: %w", err) - } - logger.Infof("Successfully deleted workload %s", workloadName) - - // Save the new workload configuration state - if err := newConfig.SaveState(childCtx); err != nil { - logger.Errorf("Failed to save workload config: %v", err) - return fmt.Errorf("failed to save workload config: %w", err) - } - - // Step 3: Start the new workload - // TODO: This currently just handles detached processes and wouldn't work for - // foreground CLI executions. Should be refactored to support both modes. - if err := d.RunWorkloadDetached(childCtx, newConfig); err != nil { - return fmt.Errorf("failed to start new workload: %w", err) - } - - logger.Infof("Successfully completed update for workload %s", workloadName) - return nil -} - -// restartSingleWorkload handles the restart logic for a single workload -func (d *defaultManager) restartSingleWorkload(name string, foreground bool) error { - // Create a child context with a longer timeout - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - // First, try to load the run configuration to check if it's a remote workload - runConfig, err := runner.LoadState(childCtx, name) - if err != nil { - // If we can't load the state, it might be a container workload or the workload doesn't exist - // Try to restart it as a container workload - return d.restartContainerWorkload(childCtx, name, foreground) - } - - // Check if this is a remote workload - if runConfig.RemoteURL != "" { - return d.restartRemoteWorkload(childCtx, name, runConfig, foreground) - } - - // This is a container-based workload - return d.restartContainerWorkload(childCtx, name, foreground) -} - -// restartRemoteWorkload handles restarting a remote workload -func (d *defaultManager) restartRemoteWorkload( - ctx context.Context, - name string, - runConfig *runner.RunConfig, - foreground bool, -) error { - // Get workload status using the status manager - workload, err := d.statuses.GetWorkload(ctx, name) - if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) { - return err - } - - // If workload is already running, check if the supervisor process is healthy - if err == nil && workload.Status == rt.WorkloadStatusRunning { - // Check if the supervisor process is actually alive - supervisorAlive := d.isSupervisorProcessAlive(ctx, runConfig.BaseName) - - if supervisorAlive { - // Workload is running and healthy - preserve old behavior (no-op) - logger.Infof("Remote workload %s is already running", name) - return nil - } - - // Supervisor is dead/missing - we need to clean up and restart to fix the damaged state - logger.Infof("Remote workload %s is running but supervisor is dead, cleaning up before restart", name) - - // Set status to stopping - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) - } - - // Stop the supervisor process (proxy) if it exists (may already be dead) - // This ensures we clean up any orphaned supervisor processes - d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) - - // Clean up client configurations - if err := removeClientConfigurations(name, false); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } - - // Set status to stopped after cleanup is complete - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopped, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopped: %v", name, err) - } - } - - // Load runner configuration from state - mcpRunner, err := d.loadRunnerFromState(ctx, runConfig.BaseName) - if err != nil { - return fmt.Errorf("failed to load state for %s: %v", runConfig.BaseName, err) - } - - // Set status to starting - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStarting, ""); err != nil { - logger.Warnf("Failed to set workload %s status to starting: %v", name, err) - } - - logger.Infof("Loaded configuration from state for %s", runConfig.BaseName) - - // Start the remote workload using the loaded runner - // Use background context to avoid timeout cancellation - same reasoning as container workloads - return d.startWorkload(context.Background(), name, mcpRunner, foreground) -} - -// restartContainerWorkload handles restarting a container-based workload -// -//nolint:gocyclo // Complexity is justified - handles multiple restart scenarios and edge cases -func (d *defaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error { - // Get container info to resolve partial names and extract proper workload name - var containerName string - var workloadName string - - container, err := d.runtime.GetWorkloadInfo(ctx, name) - if err == nil { - // If we found the container, use its actual container name for runtime operations - containerName = container.Name - // Extract the workload name (base name) from container labels for status operations - workloadName = labels.GetContainerBaseName(container.Labels) - if workloadName == "" { - // Fallback to the provided name if base name is not available - workloadName = name - } - } else { - // If container not found, use the provided name as both container and workload name - containerName = name - workloadName = name - } - - // Get workload status using the status manager - workload, err := d.statuses.GetWorkload(ctx, name) - if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) { - return err - } - - // Check if workload is running and healthy (including supervisor process) - if err == nil && workload.Status == rt.WorkloadStatusRunning { - // Check if the supervisor process is actually alive - supervisorAlive := d.isSupervisorProcessAlive(ctx, workloadName) - - if supervisorAlive { - // Workload is running and healthy - preserve old behavior (no-op) - logger.Infof("Container %s is already running", containerName) - return nil - } - - // Supervisor is dead/missing - we need to clean up and restart to fix the damaged state - logger.Infof("Container %s is running but supervisor is dead, cleaning up before restart", containerName) - } - - // Check if we need to stop the workload before restarting - // This happens when: 1) container is running, or 2) inconsistent state - shouldStop := false - if err == nil && workload.Status == rt.WorkloadStatusRunning { - // Workload status shows running (and supervisor is dead, otherwise we would have returned above) - shouldStop = true - } else if container.IsRunning() { - // Container is running but status is not running (inconsistent state) - shouldStop = true - } - - // If we need to stop, do it now (including cleanup of any remaining supervisor process) - if shouldStop { - logger.Infof("Stopping container %s before restart", containerName) - - // Set status to stopping - if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStopping, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopping: %v", workloadName, err) - } - - // Stop the supervisor process (proxy) if it exists (may already be dead) - // This ensures we clean up any orphaned supervisor processes - if !labels.IsAuxiliaryWorkload(container.Labels) { - d.stopProcess(ctx, workloadName) - } - - // Now stop the container if it's running - if container.IsRunning() { - if err := d.runtime.StopWorkload(ctx, containerName); err != nil { - if statusErr := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", workloadName, statusErr) - } - return fmt.Errorf("failed to stop container %s: %v", containerName, err) - } - logger.Infof("Container %s stopped", containerName) - } - - // Clean up client configurations - if err := removeClientConfigurations(workloadName, labels.IsAuxiliaryWorkload(container.Labels)); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } - - // Set status to stopped after cleanup is complete - if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStopped, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopped: %v", workloadName, err) - } - } - - // Load runner configuration from state - mcpRunner, err := d.loadRunnerFromState(ctx, workloadName) - if err != nil { - return fmt.Errorf("failed to load state for %s: %v", workloadName, err) - } - - // Set workload status to starting - use the workload name for status operations - if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStarting, ""); err != nil { - logger.Warnf("Failed to set workload %s status to starting: %v", workloadName, err) - } - logger.Infof("Loaded configuration from state for %s", workloadName) - - // Start the workload with background context to avoid timeout cancellation - // The ctx with AsyncOperationTimeout is only for the restart setup operations, - // but the actual workload should run indefinitely with its own lifecycle management - // Use workload name for user-facing operations - return d.startWorkload(context.Background(), workloadName, mcpRunner, foreground) -} - -// startWorkload starts the workload in either foreground or background mode -func (d *defaultManager) startWorkload(ctx context.Context, name string, mcpRunner *runner.Runner, foreground bool) error { - logger.Infof("Starting tooling server %s...", name) - - var err error - if foreground { - err = d.RunWorkload(ctx, mcpRunner.Config) - } else { - err = d.RunWorkloadDetached(ctx, mcpRunner.Config) - } - - if err != nil { - // If we could not start the workload, set the status to error before returning - if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, ""); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) - } - } - return err -} - -// TODO: Move to dedicated config management interface. -// updateClientConfigurations updates client configuration files with the MCP server URL -func removeClientConfigurations(containerName string, isAuxiliary bool) error { - // Get the workload's group by loading its run config - runConfig, err := runner.LoadState(context.Background(), containerName) - var group string - if err != nil { - // Only warn for non-auxiliary workloads since auxiliary workloads don't have run configs - if !isAuxiliary { - logger.Warnf("Warning: Failed to load run config for %s, will use backward compatible behavior: %v", containerName, err) - } - // Continue with empty group (backward compatibility) - } else { - group = runConfig.Group - } - - clientManager, err := client.NewManager(context.Background()) - if err != nil { - logger.Warnf("Warning: Failed to create client manager for %s, skipping client config removal: %v", containerName, err) - return nil - } - - return clientManager.RemoveServerFromClients(context.Background(), containerName, group) -} - -// loadRunnerFromState attempts to load a Runner from the state store -func (d *defaultManager) loadRunnerFromState(ctx context.Context, baseName string) (*runner.Runner, error) { - // Load the run config from the state store - runConfig, err := runner.LoadState(ctx, baseName) - if err != nil { - return nil, err - } - - if runConfig.RemoteURL != "" { - // For remote workloads, we don't need a deployer - runConfig.Deployer = nil - } else { - // Update the runtime in the loaded configuration - runConfig.Deployer = d.runtime - } - - // Create a new runner with the loaded configuration - return runner.NewRunner(runConfig, d.statuses), nil -} - -func (d *defaultManager) needSecretsPassword(secretOptions []string) bool { - // If the user did not ask for any secrets, then don't attempt to instantiate - // the secrets manager. - if len(secretOptions) == 0 { - return false - } - // Ignore err - if the flag is not set, it's not needed. - providerType, _ := d.configProvider.GetConfig().Secrets.GetProviderType() - return providerType == secrets.EncryptedType -} - -// cleanupTempPermissionProfile cleans up temporary permission profile files for a given base name -func (*defaultManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { - // Try to load the saved configuration to get the permission profile path - runConfig, err := runner.LoadState(ctx, baseName) - if err != nil { - // If we can't load the state, there's nothing to clean up - logger.Debugf("Could not load state for %s, skipping permission profile cleanup: %v", baseName, err) - return nil - } - - // Clean up the temporary permission profile if it exists - if runConfig.PermissionProfileNameOrPath != "" { - if err := runner.CleanupTempPermissionProfile(runConfig.PermissionProfileNameOrPath); err != nil { - return fmt.Errorf("failed to cleanup temporary permission profile: %v", err) - } - } - - return nil -} - -// stopSingleContainerWorkload stops a single container workload -func (d *defaultManager) stopSingleContainerWorkload(ctx context.Context, workload *rt.ContainerInfo) error { - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - name := labels.GetContainerBaseName(workload.Labels) - // Stop the proxy process (skip for auxiliary workloads like inspector) - if labels.IsAuxiliaryWorkload(workload.Labels) { - logger.Debugf("Skipping proxy stop for auxiliary workload %s", name) - } else { - d.stopProcess(ctx, name) - } - - // TODO: refactor the StopProcess function to stop dealing explicitly with PID files. - // Note that this is not a blocker for k8s since this code path is not called there. - if err := d.statuses.ResetWorkloadPID(ctx, name); err != nil { - logger.Warnf("Warning: Failed to reset workload %s PID: %v", name, err) - } - - logger.Infof("Stopping containers for %s...", name) - // Stop the container - if err := d.runtime.StopWorkload(childCtx, workload.Name); err != nil { - if statusErr := d.statuses.SetWorkloadStatus(childCtx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) - } - return fmt.Errorf("failed to stop container: %w", err) - } - - if err := removeClientConfigurations(name, labels.IsAuxiliaryWorkload(workload.Labels)); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } else { - logger.Infof("Client configurations for %s removed", name) - } - - if err := d.statuses.SetWorkloadStatus(childCtx, name, rt.WorkloadStatusStopped, ""); err != nil { - logger.Warnf("Failed to set workload %s status to stopped: %v", name, err) - } - logger.Infof("Successfully stopped %s...", name) - return nil -} - -// MoveToGroup moves the specified workloads from one group to another by updating their runconfig. -func (*defaultManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { - for _, workloadName := range workloadNames { - // Validate workload name - if err := types.ValidateWorkloadName(workloadName); err != nil { - return fmt.Errorf("invalid workload name %s: %w", workloadName, err) - } - - // Load the runner state to check and update the configuration - runnerConfig, err := runner.LoadState(ctx, workloadName) - if err != nil { - return fmt.Errorf("failed to load runner state for workload %s: %w", workloadName, err) - } - - // Check if the workload is actually in the specified group - if runnerConfig.Group != groupFrom { - logger.Debugf("Workload %s is not in group %s (current group: %s), skipping", - workloadName, groupFrom, runnerConfig.Group) - continue - } - - // Move the workload to the default group - runnerConfig.Group = groupTo - - // Save the updated configuration - if err = runnerConfig.SaveState(ctx); err != nil { - return fmt.Errorf("failed to save updated configuration for workload %s: %w", workloadName, err) - } - - logger.Infof("Moved workload %s to default group", workloadName) - } - - return nil -} - -// ListWorkloadsInGroup returns all workload names that belong to the specified group -func (d *defaultManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { - workloads, err := d.ListWorkloads(ctx, true) // listAll=true to include stopped workloads - if err != nil { - return nil, fmt.Errorf("failed to list workloads: %w", err) - } - - // Filter workloads that belong to the specified group - var groupWorkloads []string - for _, workload := range workloads { - if workload.Group == groupName { - groupWorkloads = append(groupWorkloads, workload.Name) - } - } - - return groupWorkloads, nil -} - -// getRemoteWorkloadsFromState retrieves remote servers from the state store -func (d *defaultManager) getRemoteWorkloadsFromState( - ctx context.Context, - listAll bool, - labelFilters []string, -) ([]core.Workload, error) { - // Create a state store - store, err := state.NewRunConfigStore(state.DefaultAppName) - if err != nil { - return nil, fmt.Errorf("failed to create state store: %w", err) - } - - // List all configurations - configNames, err := store.List(ctx) - if err != nil { - return nil, fmt.Errorf("failed to list configurations: %w", err) - } - - // Parse the filters into a format we can use for matching - parsedFilters, err := types.ParseLabelFilters(labelFilters) - if err != nil { - return nil, fmt.Errorf("failed to parse label filters: %v", err) - } - - var remoteWorkloads []core.Workload - - for _, name := range configNames { - // Load the run configuration - runConfig, err := runner.LoadState(ctx, name) - if err != nil { - logger.Warnf("failed to load state for %s: %v", name, err) - continue - } - - // Only include remote servers (those with RemoteURL set) - if runConfig.RemoteURL == "" { - continue - } - - // Check the status from the status file - workloadStatus, err := d.statuses.GetWorkload(ctx, name) - if err != nil { - logger.Warnf("failed to get status for remote workload %s: %v", name, err) - continue - } - - // Apply listAll filter - only include running workloads unless listAll is true - if !listAll && workloadStatus.Status != rt.WorkloadStatusRunning { - continue - } - - // Use the transport type directly since it's already parsed - transportType := runConfig.Transport - - // Generate the local proxy URL (not the remote server URL) - proxyURL := "" - if runConfig.Port > 0 { - proxyURL = transport.GenerateMCPServerURL( - transportType.String(), - transport.LocalhostIPv4, - runConfig.Port, - name, - runConfig.RemoteURL, // Pass remote URL to preserve path - ) - } - - // Calculate the effective proxy mode that clients should use - effectiveProxyMode := types.GetEffectiveProxyMode(transportType, string(runConfig.ProxyMode)) - - // Create a workload from the run configuration - workload := core.Workload{ - Name: name, - Package: "remote", - Status: workloadStatus.Status, - URL: proxyURL, - Port: runConfig.Port, - TransportType: transportType, - ProxyMode: effectiveProxyMode, - ToolType: "remote", - Group: runConfig.Group, - CreatedAt: workloadStatus.CreatedAt, - Labels: runConfig.ContainerLabels, - Remote: true, - } - - // Apply label filtering - if types.MatchesLabelFilters(workload.Labels, parsedFilters) { - remoteWorkloads = append(remoteWorkloads, workload) - } - } + // Detect namespace + namespace := kubernetes.GetCurrentNamespace() - return remoteWorkloads, nil + return NewK8SManager(k8sClient, namespace) } diff --git a/pkg/workloads/manager_test.go b/pkg/workloads/manager_test.go index ea971127e..d1421e63f 100644 --- a/pkg/workloads/manager_test.go +++ b/pkg/workloads/manager_test.go @@ -1,185 +1,16 @@ package workloads import ( - "context" - "errors" - "fmt" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "golang.org/x/sync/errgroup" - "github.com/stacklok/toolhive/pkg/config" configMocks "github.com/stacklok/toolhive/pkg/config/mocks" - "github.com/stacklok/toolhive/pkg/container/runtime" runtimeMocks "github.com/stacklok/toolhive/pkg/container/runtime/mocks" - "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/runner" - statusMocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" ) -func TestDefaultManager_ListWorkloadsInGroup(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - groupName string - mockWorkloads []core.Workload - expectedNames []string - expectError bool - setupStatusMgr func(*statusMocks.MockStatusManager) - }{ - { - name: "non existent group returns empty list", - groupName: "non-group", - mockWorkloads: []core.Workload{ - {Name: "workload1", Group: "other-group"}, - {Name: "workload2", Group: "another-group"}, - }, - expectedNames: []string{}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ - {Name: "workload1", Group: "other-group"}, - {Name: "workload2", Group: "another-group"}, - }, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - { - name: "multiple workloads in group", - groupName: "test-group", - mockWorkloads: []core.Workload{ - {Name: "workload1", Group: "test-group"}, - {Name: "workload2", Group: "other-group"}, - {Name: "workload3", Group: "test-group"}, - {Name: "workload4", Group: "test-group"}, - }, - expectedNames: []string{"workload1", "workload3", "workload4"}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ - {Name: "workload1", Group: "test-group"}, - {Name: "workload2", Group: "other-group"}, - {Name: "workload3", Group: "test-group"}, - {Name: "workload4", Group: "test-group"}, - }, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - { - name: "workloads with empty group names", - groupName: "", - mockWorkloads: []core.Workload{ - {Name: "workload1", Group: ""}, - {Name: "workload2", Group: "test-group"}, - {Name: "workload3", Group: ""}, - }, - expectedNames: []string{"workload1", "workload3"}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ - {Name: "workload1", Group: ""}, - {Name: "workload2", Group: "test-group"}, - {Name: "workload3", Group: ""}, - }, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - { - name: "includes stopped workloads", - groupName: "test-group", - mockWorkloads: []core.Workload{ - {Name: "running-workload", Group: "test-group", Status: runtime.WorkloadStatusRunning}, - {Name: "stopped-workload", Group: "test-group", Status: runtime.WorkloadStatusStopped}, - {Name: "other-group-workload", Group: "other-group", Status: runtime.WorkloadStatusRunning}, - }, - expectedNames: []string{"running-workload", "stopped-workload"}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ - {Name: "running-workload", Group: "test-group", Status: runtime.WorkloadStatusRunning}, - {Name: "stopped-workload", Group: "test-group", Status: runtime.WorkloadStatusStopped}, - {Name: "other-group-workload", Group: "other-group", Status: runtime.WorkloadStatusRunning}, - }, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - { - name: "error from ListWorkloads propagated", - groupName: "test-group", - expectedNames: nil, - expectError: true, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return(nil, assert.AnError) - }, - }, - { - name: "no workloads", - groupName: "test-group", - mockWorkloads: []core.Workload{}, - expectedNames: []string{}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{}, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupStatusMgr(mockStatusMgr) - - manager := &defaultManager{ - runtime: nil, // Not needed for this test - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.ListWorkloadsInGroup(ctx, tt.groupName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to list workloads") - return - } - - require.NoError(t, err) - assert.ElementsMatch(t, tt.expectedNames, result) - }) - } -} - func TestNewManagerFromRuntime(t *testing.T) { t.Parallel() @@ -195,12 +26,12 @@ func TestNewManagerFromRuntime(t *testing.T) { require.NoError(t, err) require.NotNil(t, manager) - // Verify it's a defaultManager with the runtime set - defaultMgr, ok := manager.(*defaultManager) + // Verify it's a cliManager with the runtime set + cliMgr, ok := manager.(*cliManager) require.True(t, ok) - assert.Equal(t, mockRuntime, defaultMgr.runtime) - assert.NotNil(t, defaultMgr.statuses) - assert.NotNil(t, defaultMgr.configProvider) + assert.Equal(t, mockRuntime, cliMgr.runtime) + assert.NotNil(t, cliMgr.statuses) + assert.NotNil(t, cliMgr.configProvider) } func TestNewManagerFromRuntimeWithProvider(t *testing.T) { @@ -217,1444 +48,9 @@ func TestNewManagerFromRuntimeWithProvider(t *testing.T) { require.NoError(t, err) require.NotNil(t, manager) - defaultMgr, ok := manager.(*defaultManager) + cliMgr, ok := manager.(*cliManager) require.True(t, ok) - assert.Equal(t, mockRuntime, defaultMgr.runtime) - assert.Equal(t, mockConfigProvider, defaultMgr.configProvider) - assert.NotNil(t, defaultMgr.statuses) -} - -func TestDefaultManager_DoesWorkloadExist(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - setupMocks func(*statusMocks.MockStatusManager) - expected bool - expectError bool - }{ - { - name: "workload exists and running", - workloadName: "test-workload", - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - }, - expected: true, - expectError: false, - }, - { - name: "workload exists but in error state", - workloadName: "error-workload", - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "error-workload").Return(core.Workload{ - Name: "error-workload", - Status: runtime.WorkloadStatusError, - }, nil) - }, - expected: false, - expectError: false, - }, - { - name: "workload not found", - workloadName: "missing-workload", - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "missing-workload").Return(core.Workload{}, runtime.ErrWorkloadNotFound) - }, - expected: false, - expectError: false, - }, - { - name: "error getting workload", - workloadName: "problematic-workload", - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "problematic-workload").Return(core.Workload{}, errors.New("database error")) - }, - expected: false, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockStatusMgr) - - manager := &defaultManager{ - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.DoesWorkloadExist(ctx, tt.workloadName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to check if workload exists") - } else { - require.NoError(t, err) - assert.Equal(t, tt.expected, result) - } - }) - } -} - -func TestDefaultManager_GetWorkload(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - expectedWorkload := core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - } - - mockStatusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(expectedWorkload, nil) - - manager := &defaultManager{ - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.GetWorkload(ctx, "test-workload") - - require.NoError(t, err) - assert.Equal(t, expectedWorkload, result) -} - -func TestDefaultManager_GetLogs(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - follow bool - setupMocks func(*runtimeMocks.MockRuntime) - expectedLogs string - expectError bool - errorMsg string - }{ - { - name: "successful log retrieval", - workloadName: "test-workload", - follow: false, - setupMocks: func(rt *runtimeMocks.MockRuntime) { - rt.EXPECT().GetWorkloadLogs(gomock.Any(), "test-workload", false).Return("test log content", nil) - }, - expectedLogs: "test log content", - expectError: false, - }, - { - name: "workload not found", - workloadName: "missing-workload", - follow: false, - setupMocks: func(rt *runtimeMocks.MockRuntime) { - rt.EXPECT().GetWorkloadLogs(gomock.Any(), "missing-workload", false).Return("", runtime.ErrWorkloadNotFound) - }, - expectedLogs: "", - expectError: true, - errorMsg: "workload not found", - }, - { - name: "runtime error", - workloadName: "error-workload", - follow: true, - setupMocks: func(rt *runtimeMocks.MockRuntime) { - rt.EXPECT().GetWorkloadLogs(gomock.Any(), "error-workload", true).Return("", errors.New("runtime failure")) - }, - expectedLogs: "", - expectError: true, - errorMsg: "failed to get container logs", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - tt.setupMocks(mockRuntime) - - manager := &defaultManager{ - runtime: mockRuntime, - } - - ctx := context.Background() - logs, err := manager.GetLogs(ctx, tt.workloadName, tt.follow) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - } else { - require.NoError(t, err) - assert.Equal(t, tt.expectedLogs, logs) - } - }) - } -} - -func TestDefaultManager_StopWorkloads(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadNames []string - expectError bool - errorMsg string - }{ - { - name: "invalid workload name with path traversal", - workloadNames: []string{"../etc/passwd"}, - expectError: true, - errorMsg: "path traversal", - }, - { - name: "invalid workload name with slash", - workloadNames: []string{"workload/name"}, - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "empty workload name list", - workloadNames: []string{}, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - manager := &defaultManager{} - - ctx := context.Background() - group, err := manager.StopWorkloads(ctx, tt.workloadNames) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - assert.Nil(t, group) - } else { - require.NoError(t, err) - assert.NotNil(t, group) - assert.IsType(t, &errgroup.Group{}, group) - } - }) - } -} - -func TestDefaultManager_DeleteWorkloads(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadNames []string - expectError bool - errorMsg string - }{ - { - name: "invalid workload name", - workloadNames: []string{"../../../etc/passwd"}, - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "mixed valid and invalid names", - workloadNames: []string{"valid-name", "invalid../name"}, - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "empty workload name list", - workloadNames: []string{}, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - manager := &defaultManager{} - - ctx := context.Background() - group, err := manager.DeleteWorkloads(ctx, tt.workloadNames) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - assert.Nil(t, group) - } else { - require.NoError(t, err) - assert.NotNil(t, group) - assert.IsType(t, &errgroup.Group{}, group) - } - }) - } -} - -func TestDefaultManager_RestartWorkloads(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadNames []string - foreground bool - expectError bool - errorMsg string - }{ - { - name: "invalid workload name", - workloadNames: []string{"invalid/name"}, - foreground: false, - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "empty workload name list", - workloadNames: []string{}, - foreground: false, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - manager := &defaultManager{} - - ctx := context.Background() - group, err := manager.RestartWorkloads(ctx, tt.workloadNames, tt.foreground) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - assert.Nil(t, group) - } else { - require.NoError(t, err) - assert.NotNil(t, group) - assert.IsType(t, &errgroup.Group{}, group) - } - }) - } -} - -func TestDefaultManager_restartRemoteWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - runConfig *runner.RunConfig - foreground bool - setupMocks func(*statusMocks.MockStatusManager) - expectError bool - errorMsg string - }{ - { - name: "remote workload already running with healthy supervisor", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-base", - RemoteURL: "http://example.com", - }, - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - // Check if supervisor is alive - return valid PID (supervisor is healthy) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(12345, nil) - }, - // With healthy supervisor, restart should return early (no-op) - expectError: false, - }, - { - name: "remote workload already running with dead supervisor", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-base", - RemoteURL: "http://example.com", - }, - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - // Check if supervisor is alive - return error (supervisor is dead) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(0, errors.New("no PID found")) - // With dead supervisor, restart proceeds with cleanup and restart - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(nil) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(0, errors.New("no PID found")) - // Allow any subsequent status updates - sm.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - sm.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - }, - // Restart now proceeds to load state which fails in tests (can't mock runner.LoadState easily) - expectError: true, - errorMsg: "failed to load state", - }, - { - name: "status manager error", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-base", - RemoteURL: "http://example.com", - }, - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{}, errors.New("status manager error")) - }, - expectError: true, - errorMsg: "status manager error", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(statusMgr) - - manager := &defaultManager{ - statuses: statusMgr, - } - - err := manager.restartRemoteWorkload(context.Background(), tt.workloadName, tt.runConfig, tt.foreground) - - if tt.expectError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -func TestDefaultManager_restartContainerWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - foreground bool - setupMocks func(*statusMocks.MockStatusManager, *runtimeMocks.MockRuntime) - expectError bool - errorMsg string - }{ - { - name: "container workload already running with healthy supervisor", - workloadName: "container-workload", - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { - // Mock container info - rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ - Name: "container-workload", - State: runtime.WorkloadStatusRunning, - Labels: map[string]string{ - "toolhive.base-name": "container-workload", - }, - }, nil) - sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{ - Name: "container-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - // Check if supervisor is alive - return valid PID (supervisor is healthy) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(12345, nil) - }, - // With healthy supervisor, restart should return early (no-op) - expectError: false, - }, - { - name: "container workload already running with dead supervisor", - workloadName: "container-workload", - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { - // Mock container info - rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ - Name: "container-workload", - State: runtime.WorkloadStatusRunning, - Labels: map[string]string{ - "toolhive.base-name": "container-workload", - }, - }, nil) - sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{ - Name: "container-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - // Check if supervisor is alive - return error (supervisor is dead) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(0, errors.New("no PID found")) - // With dead supervisor, restart proceeds with cleanup and restart - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "container-workload", runtime.WorkloadStatusStopping, "").Return(nil) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(0, errors.New("no PID found")) - rm.EXPECT().StopWorkload(gomock.Any(), "container-workload").Return(nil) - // Allow any subsequent status updates - sm.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - sm.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - }, - // Restart now proceeds to load state which fails in tests (can't mock runner.LoadState easily) - expectError: true, - errorMsg: "failed to load state", - }, - { - name: "status manager error", - workloadName: "container-workload", - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { - // Mock container info - rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ - Name: "container-workload", - State: "running", - Labels: map[string]string{ - "toolhive.base-name": "container-workload", - }, - }, nil) - sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{}, errors.New("status manager error")) - }, - expectError: true, - errorMsg: "status manager error", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) - - tt.setupMocks(statusMgr, runtimeMgr) - - manager := &defaultManager{ - statuses: statusMgr, - runtime: runtimeMgr, - } - - err := manager.restartContainerWorkload(context.Background(), tt.workloadName, tt.foreground) - - if tt.expectError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -// TestDefaultManager_restartLogicConsistency tests restart behavior with healthy vs dead supervisor -func TestDefaultManager_restartLogicConsistency(t *testing.T) { - t.Parallel() - - t.Run("remote_workload_healthy_supervisor_no_restart", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - - statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - - // Check if supervisor is alive - return valid PID (healthy) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(12345, nil) - - manager := &defaultManager{ - statuses: statusMgr, - } - - runConfig := &runner.RunConfig{ - BaseName: "test-base", - RemoteURL: "http://example.com", - } - - err := manager.restartRemoteWorkload(context.Background(), "test-workload", runConfig, false) - - // With healthy supervisor, restart should return successfully without doing anything - require.NoError(t, err) - }) - - t.Run("remote_workload_dead_supervisor_calls_stop", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - - statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - - // Check if supervisor is alive - return error (dead supervisor) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(0, errors.New("no PID found")) - - // When supervisor is dead, expect stop logic to be called - statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(0, errors.New("no PID found")) - - // Allow any subsequent status updates - we don't care about the exact sequence - statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - - manager := &defaultManager{ - statuses: statusMgr, - } - - runConfig := &runner.RunConfig{ - BaseName: "test-base", - RemoteURL: "http://example.com", - } - - _ = manager.restartRemoteWorkload(context.Background(), "test-workload", runConfig, false) - - // The important part is that the stop methods were called (verified by mock expectations) - // We don't care if the restart ultimately succeeds or fails - }) - - t.Run("container_workload_healthy_supervisor_no_restart", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) - - containerInfo := runtime.ContainerInfo{ - Name: "test-workload", - State: runtime.WorkloadStatusRunning, - Labels: map[string]string{ - "toolhive.base-name": "test-workload", - }, - } - runtimeMgr.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(containerInfo, nil) - - statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - - // Check if supervisor is alive - return valid PID (healthy) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(12345, nil) - - manager := &defaultManager{ - statuses: statusMgr, - runtime: runtimeMgr, - } - - err := manager.restartContainerWorkload(context.Background(), "test-workload", false) - - // With healthy supervisor, restart should return successfully without doing anything - require.NoError(t, err) - }) - - t.Run("container_workload_dead_supervisor_calls_stop", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) - - containerInfo := runtime.ContainerInfo{ - Name: "test-workload", - State: runtime.WorkloadStatusRunning, - Labels: map[string]string{ - "toolhive.base-name": "test-workload", - }, - } - runtimeMgr.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(containerInfo, nil) - - statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - - // Check if supervisor is alive - return error (dead supervisor) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(0, errors.New("no PID found")) - - // When supervisor is dead, expect stop logic to be called - statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(0, errors.New("no PID found")) - runtimeMgr.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) - - // Allow any subsequent status updates (starting, error, etc.) - we don't care about the exact sequence - statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - - manager := &defaultManager{ - statuses: statusMgr, - runtime: runtimeMgr, - } - - _ = manager.restartContainerWorkload(context.Background(), "test-workload", false) - - // The important part is that the stop methods were called (verified by mock expectations) - // We don't care if the restart ultimately succeeds or fails - }) -} - -func TestDefaultManager_RunWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - runConfig *runner.RunConfig - setupMocks func(*statusMocks.MockStatusManager) - expectError bool - errorMsg string - }{ - { - name: "successful run - status creation", - runConfig: &runner.RunConfig{ - BaseName: "test-workload", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - // Expect starting status first, then error status when the runner fails - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStarting, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, gomock.Any()).Return(nil) - }, - expectError: true, // The runner will fail without proper setup - }, - { - name: "status creation failure", - runConfig: &runner.RunConfig{ - BaseName: "failing-workload", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", runtime.WorkloadStatusStarting, "").Return(errors.New("status creation failed")) - }, - expectError: true, - errorMsg: "failed to create workload status", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockStatusMgr) - - manager := &defaultManager{ - statuses: mockStatusMgr, - } - - ctx := context.Background() - err := manager.RunWorkload(ctx, tt.runConfig) - - if tt.expectError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -func TestDefaultManager_validateSecretParameters(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - runConfig *runner.RunConfig - setupMocks func(*configMocks.MockProvider) - expectError bool - errorMsg string - }{ - { - name: "no secrets - should pass", - runConfig: &runner.RunConfig{ - Secrets: []string{}, - }, - setupMocks: func(*configMocks.MockProvider) {}, // No expectations - expectError: false, - }, - { - name: "config error", - runConfig: &runner.RunConfig{ - Secrets: []string{"secret1"}, - }, - setupMocks: func(cp *configMocks.MockProvider) { - mockConfig := &config.Config{} - cp.EXPECT().GetConfig().Return(mockConfig) - }, - expectError: true, - errorMsg: "error determining secrets provider type", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockConfigProvider := configMocks.NewMockProvider(ctrl) - tt.setupMocks(mockConfigProvider) - - manager := &defaultManager{ - configProvider: mockConfigProvider, - } - - ctx := context.Background() - err := manager.validateSecretParameters(ctx, tt.runConfig) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestDefaultManager_getWorkloadContainer(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) - expected *runtime.ContainerInfo - expectError bool - errorMsg string - }{ - { - name: "successful retrieval", - workloadName: "test-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { - expectedContainer := runtime.ContainerInfo{ - Name: "test-workload", - State: runtime.WorkloadStatusRunning, - } - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(expectedContainer, nil) - }, - expected: &runtime.ContainerInfo{ - Name: "test-workload", - State: runtime.WorkloadStatusRunning, - }, - expectError: false, - }, - { - name: "workload not found", - workloadName: "missing-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "missing-workload").Return(runtime.ContainerInfo{}, runtime.ErrWorkloadNotFound) - }, - expected: nil, - expectError: false, // getWorkloadContainer returns nil for not found, not error - }, - { - name: "runtime error", - workloadName: "error-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "error-workload").Return(runtime.ContainerInfo{}, errors.New("runtime failure")) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "error-workload", runtime.WorkloadStatusError, "runtime failure").Return(nil) - }, - expected: nil, - expectError: true, - errorMsg: "failed to find workload", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockRuntime, mockStatusMgr) - - manager := &defaultManager{ - runtime: mockRuntime, - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.getWorkloadContainer(ctx, tt.workloadName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - } else { - require.NoError(t, err) - if tt.expected == nil { - assert.Nil(t, result) - } else { - assert.Equal(t, tt.expected, result) - } - } - }) - } -} - -func TestDefaultManager_removeContainer(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) - expectError bool - errorMsg string - }{ - { - name: "successful removal", - workloadName: "test-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { - rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(nil) - }, - expectError: false, - }, - { - name: "removal failure", - workloadName: "failing-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - rt.EXPECT().RemoveWorkload(gomock.Any(), "failing-workload").Return(errors.New("removal failed")) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", runtime.WorkloadStatusError, "removal failed").Return(nil) - }, - expectError: true, - errorMsg: "failed to remove container", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockRuntime, mockStatusMgr) - - manager := &defaultManager{ - runtime: mockRuntime, - statuses: mockStatusMgr, - } - - ctx := context.Background() - err := manager.removeContainer(ctx, tt.workloadName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestDefaultManager_needSecretsPassword(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - secretOptions []string - setupMocks func(*configMocks.MockProvider) - expected bool - }{ - { - name: "no secrets", - secretOptions: []string{}, - setupMocks: func(*configMocks.MockProvider) {}, // No expectations - expected: false, - }, - { - name: "has secrets but config access fails", - secretOptions: []string{"secret1"}, - setupMocks: func(cp *configMocks.MockProvider) { - mockConfig := &config.Config{} - cp.EXPECT().GetConfig().Return(mockConfig) - }, - expected: false, // Returns false when provider type detection fails - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockConfigProvider := configMocks.NewMockProvider(ctrl) - tt.setupMocks(mockConfigProvider) - - manager := &defaultManager{ - configProvider: mockConfigProvider, - } - - result := manager.needSecretsPassword(tt.secretOptions) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestDefaultManager_RunWorkloadDetached(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - runConfig *runner.RunConfig - setupMocks func(*statusMocks.MockStatusManager, *configMocks.MockProvider) - expectError bool - errorMsg string - }{ - { - name: "validation failure should not reach PID management", - runConfig: &runner.RunConfig{ - BaseName: "test-workload", - Secrets: []string{"invalid-secret"}, - }, - setupMocks: func(_ *statusMocks.MockStatusManager, cp *configMocks.MockProvider) { - // Mock config provider to cause validation failure - mockConfig := &config.Config{} - cp.EXPECT().GetConfig().Return(mockConfig) - // No SetWorkloadPID expectation since validation should fail first - }, - expectError: true, - errorMsg: "failed to validate workload parameters", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - mockConfigProvider := configMocks.NewMockProvider(ctrl) - tt.setupMocks(mockStatusMgr, mockConfigProvider) - - manager := &defaultManager{ - statuses: mockStatusMgr, - configProvider: mockConfigProvider, - } - - ctx := context.Background() - err := manager.RunWorkloadDetached(ctx, tt.runConfig) - - if tt.expectError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -// TestDefaultManager_RunWorkloadDetached_PIDManagement tests that PID management -// happens in the later stages of RunWorkloadDetached when the process actually starts. -// This is tested indirectly by verifying the behavior exists in the code flow. -func TestDefaultManager_RunWorkloadDetached_PIDManagement(t *testing.T) { - t.Parallel() - - // This test documents the expected behavior: - // 1. RunWorkloadDetached calls SetWorkloadPID after starting the detached process - // 2. The PID management happens after validation and process creation - // 3. SetWorkloadPID failures are logged as warnings but don't fail the operation - - // Since RunWorkloadDetached involves spawning actual processes and complex setup, - // we verify the PID management integration exists by checking the method signature - // and code structure rather than running the full integration. - - manager := &defaultManager{} - assert.NotNil(t, manager, "defaultManager should be instantiable") - - // Verify the method exists with the correct signature - var runWorkloadDetachedFunc interface{} = manager.RunWorkloadDetached - assert.NotNil(t, runWorkloadDetachedFunc, "RunWorkloadDetached method should exist") -} - -func TestAsyncOperationTimeout(t *testing.T) { - t.Parallel() - - // Test that the timeout constant is properly defined - assert.Equal(t, 5*time.Minute, AsyncOperationTimeout) -} - -func TestErrWorkloadNotRunning(t *testing.T) { - t.Parallel() - - // Test that the error is properly defined - assert.Error(t, ErrWorkloadNotRunning) - assert.Contains(t, ErrWorkloadNotRunning.Error(), "workload not running") -} - -func TestDefaultManager_ListWorkloads(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - listAll bool - labelFilters []string - setupMocks func(*statusMocks.MockStatusManager) - expected []core.Workload - expectError bool - errorMsg string - }{ - { - name: "successful listing without filters", - listAll: true, - labelFilters: []string{}, - setupMocks: func(sm *statusMocks.MockStatusManager) { - workloads := []core.Workload{ - {Name: "workload1", Status: runtime.WorkloadStatusRunning}, - {Name: "workload2", Status: runtime.WorkloadStatusStopped}, - } - sm.EXPECT().ListWorkloads(gomock.Any(), true, []string{}).Return(workloads, nil) - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - expected: []core.Workload{ - {Name: "workload1", Status: runtime.WorkloadStatusRunning}, - {Name: "workload2", Status: runtime.WorkloadStatusStopped}, - }, - expectError: false, - }, - { - name: "error from status manager", - listAll: false, - labelFilters: []string{"env=prod"}, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), false, []string{"env=prod"}).Return(nil, errors.New("database error")) - }, - expected: nil, - expectError: true, - errorMsg: "database error", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockStatusMgr) - - manager := &defaultManager{ - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.ListWorkloads(ctx, tt.listAll, tt.labelFilters...) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - } else { - // We expect this to succeed but might include remote workloads - // Since getRemoteWorkloadsFromState will likely fail in unit tests, - // we mainly verify the container workloads are returned - require.NoError(t, err) - assert.GreaterOrEqual(t, len(result), len(tt.expected)) - // Verify at least our expected container workloads are present - for _, expectedWorkload := range tt.expected { - found := false - for _, actualWorkload := range result { - if actualWorkload.Name == expectedWorkload.Name { - found = true - break - } - } - assert.True(t, found, fmt.Sprintf("Expected workload %s not found in result", expectedWorkload.Name)) - } - } - }) - } -} - -func TestDefaultManager_UpdateWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - expectError bool - errorMsg string - setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) - }{ - { - name: "invalid workload name with slash", - workloadName: "invalid/name", - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "invalid workload name with backslash", - workloadName: "invalid\\name", - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "invalid workload name with path traversal", - workloadName: "../invalid", - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "valid workload name returns errgroup immediately", - workloadName: "valid-workload", - expectError: false, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // Mock calls that will happen in the background goroutine - // We don't care about the success/failure, just that it doesn't panic - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "valid-workload"). - Return(runtime.ContainerInfo{}, errors.New("not found")).AnyTimes() - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "valid-workload", gomock.Any(), gomock.Any()). - Return(nil).AnyTimes() - }, - }, - { - name: "UpdateWorkload returns errgroup even if async operation will fail", - workloadName: "failing-workload", - expectError: false, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // The async operation will fail, but UpdateWorkload itself should succeed - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "failing-workload"). - Return(runtime.ContainerInfo{}, errors.New("container lookup failed")).AnyTimes() - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", gomock.Any(), gomock.Any()). - Return(nil).AnyTimes() - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - mockStatusManager := statusMocks.NewMockStatusManager(ctrl) - mockConfigProvider := configMocks.NewMockProvider(ctrl) - - if tt.setupMocks != nil { - tt.setupMocks(mockRuntime, mockStatusManager) - } - - manager := &defaultManager{ - runtime: mockRuntime, - statuses: mockStatusManager, - configProvider: mockConfigProvider, - } - - // Create a dummy RunConfig for testing - runConfig := &runner.RunConfig{ - ContainerName: tt.workloadName, - BaseName: tt.workloadName, - } - - ctx := context.Background() - group, err := manager.UpdateWorkload(ctx, tt.workloadName, runConfig) - - if tt.expectError { - assert.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - assert.Nil(t, group) - } else { - assert.NoError(t, err) - assert.NotNil(t, group) - // For valid cases, we get an errgroup but don't wait for completion - // The async operations inside are tested separately - } - }) - } -} - -func TestDefaultManager_updateSingleWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - runConfig *runner.RunConfig - setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) - expectError bool - errorMsg string - }{ - { - name: "stop operation fails", - workloadName: "test-workload", - runConfig: &runner.RunConfig{ - ContainerName: "test-workload", - BaseName: "test-workload", - Group: "default", - }, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // Mock the stop operation - return error for GetWorkloadInfo - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{}, errors.New("container lookup failed")).AnyTimes() - // Still expect status updates to be attempted - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil).AnyTimes() - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, "").Return(nil).AnyTimes() - }, - expectError: true, - errorMsg: "failed to stop workload", - }, - { - name: "successful stop and delete operations complete correctly", - workloadName: "test-workload", - runConfig: &runner.RunConfig{ - ContainerName: "test-workload", - BaseName: "test-workload", - Group: "default", - }, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // Mock stop operation - workload exists and can be stopped - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{ - Name: "test-workload", - State: "running", - Labels: map[string]string{"toolhive-basename": "test-workload"}, - }, nil) - // Mock GetWorkloadPID call from stopProcess - sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil) - rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) - sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil) - - // Mock delete operation - workload exists and can be deleted - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{Name: "test-workload"}, nil) - rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(nil) - - // Mock status updates for stop and delete phases - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopped, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusRemoving, "").Return(nil) - sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "test-workload").Return(nil) - - // Mock RunWorkloadDetached calls - expect the ones that will be called - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStarting, "").Return(nil) - sm.EXPECT().SetWorkloadPID(gomock.Any(), "test-workload", gomock.Any()).Return(nil) - }, - expectError: false, // Test passes - update process completes successfully - }, - { - name: "delete operation fails after successful stop", - workloadName: "test-workload", - runConfig: &runner.RunConfig{ - ContainerName: "test-workload", - BaseName: "test-workload", - Group: "default", - }, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // Mock successful stop - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{ - Name: "test-workload", - State: "running", - Labels: map[string]string{"toolhive-basename": "test-workload"}, - }, nil) - // Mock GetWorkloadPID call from stopProcess - sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil) - rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) - sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil) - - // Mock failed delete - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{Name: "test-workload"}, nil) - rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(errors.New("delete failed")) - - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopped, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusRemoving, "").Return(nil) - // RemoveWorkload fails, so error status is set - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, "delete failed").Return(nil) - }, - expectError: true, - errorMsg: "failed to delete workload", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - mockStatusManager := statusMocks.NewMockStatusManager(ctrl) - mockConfigProvider := configMocks.NewMockProvider(ctrl) - - if tt.setupMocks != nil { - tt.setupMocks(mockRuntime, mockStatusManager) - } - - manager := &defaultManager{ - runtime: mockRuntime, - statuses: mockStatusManager, - configProvider: mockConfigProvider, - } - - err := manager.updateSingleWorkload(tt.workloadName, tt.runConfig) - - if tt.expectError { - assert.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - assert.NoError(t, err) - } - }) - } + assert.Equal(t, mockRuntime, cliMgr.runtime) + assert.Equal(t, mockConfigProvider, cliMgr.configProvider) + assert.NotNil(t, cliMgr.statuses) } diff --git a/pkg/workloads/mocks/mock_storage_driver.go b/pkg/workloads/mocks/mock_storage_driver.go new file mode 100644 index 000000000..905113660 --- /dev/null +++ b/pkg/workloads/mocks/mock_storage_driver.go @@ -0,0 +1,150 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: driver.go +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_storage_driver.go -package=mocks -source=driver.go StorageDriver +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + core "github.com/stacklok/toolhive/pkg/core" + runner "github.com/stacklok/toolhive/pkg/runner" + gomock "go.uber.org/mock/gomock" +) + +// MockStorageDriver is a mock of StorageDriver interface. +type MockStorageDriver struct { + ctrl *gomock.Controller + recorder *MockStorageDriverMockRecorder + isgomock struct{} +} + +// MockStorageDriverMockRecorder is the mock recorder for MockStorageDriver. +type MockStorageDriverMockRecorder struct { + mock *MockStorageDriver +} + +// NewMockStorageDriver creates a new mock instance. +func NewMockStorageDriver(ctrl *gomock.Controller) *MockStorageDriver { + mock := &MockStorageDriver{ctrl: ctrl} + mock.recorder = &MockStorageDriverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStorageDriver) EXPECT() *MockStorageDriverMockRecorder { + return m.recorder +} + +// DeleteWorkloadState mocks base method. +func (m *MockStorageDriver) DeleteWorkloadState(ctx context.Context, name string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteWorkloadState", ctx, name) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteWorkloadState indicates an expected call of DeleteWorkloadState. +func (mr *MockStorageDriverMockRecorder) DeleteWorkloadState(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkloadState", reflect.TypeOf((*MockStorageDriver)(nil).DeleteWorkloadState), ctx, name) +} + +// GetWorkload mocks base method. +func (m *MockStorageDriver) GetWorkload(ctx context.Context, name string) (core.Workload, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWorkload", ctx, name) + ret0, _ := ret[0].(core.Workload) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetWorkload indicates an expected call of GetWorkload. +func (mr *MockStorageDriverMockRecorder) GetWorkload(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkload", reflect.TypeOf((*MockStorageDriver)(nil).GetWorkload), ctx, name) +} + +// ListWorkloads mocks base method. +func (m *MockStorageDriver) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, listAll} + for _, a := range labelFilters { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListWorkloads", varargs...) + ret0, _ := ret[0].([]core.Workload) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListWorkloads indicates an expected call of ListWorkloads. +func (mr *MockStorageDriverMockRecorder) ListWorkloads(ctx, listAll any, labelFilters ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, listAll}, labelFilters...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloads", reflect.TypeOf((*MockStorageDriver)(nil).ListWorkloads), varargs...) +} + +// ListWorkloadsInGroup mocks base method. +func (m *MockStorageDriver) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListWorkloadsInGroup", ctx, groupName) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListWorkloadsInGroup indicates an expected call of ListWorkloadsInGroup. +func (mr *MockStorageDriverMockRecorder) ListWorkloadsInGroup(ctx, groupName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloadsInGroup", reflect.TypeOf((*MockStorageDriver)(nil).ListWorkloadsInGroup), ctx, groupName) +} + +// LoadWorkloadState mocks base method. +func (m *MockStorageDriver) LoadWorkloadState(ctx context.Context, name string) (*runner.RunConfig, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadWorkloadState", ctx, name) + ret0, _ := ret[0].(*runner.RunConfig) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadWorkloadState indicates an expected call of LoadWorkloadState. +func (mr *MockStorageDriverMockRecorder) LoadWorkloadState(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWorkloadState", reflect.TypeOf((*MockStorageDriver)(nil).LoadWorkloadState), ctx, name) +} + +// MoveToGroup mocks base method. +func (m *MockStorageDriver) MoveToGroup(ctx context.Context, names []string, from, to string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MoveToGroup", ctx, names, from, to) + ret0, _ := ret[0].(error) + return ret0 +} + +// MoveToGroup indicates an expected call of MoveToGroup. +func (mr *MockStorageDriverMockRecorder) MoveToGroup(ctx, names, from, to any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MoveToGroup", reflect.TypeOf((*MockStorageDriver)(nil).MoveToGroup), ctx, names, from, to) +} + +// SaveWorkloadState mocks base method. +func (m *MockStorageDriver) SaveWorkloadState(ctx context.Context, config *runner.RunConfig) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveWorkloadState", ctx, config) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveWorkloadState indicates an expected call of SaveWorkloadState. +func (mr *MockStorageDriverMockRecorder) SaveWorkloadState(ctx, config any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveWorkloadState", reflect.TypeOf((*MockStorageDriver)(nil).SaveWorkloadState), ctx, config) +} From 480db30eea114266c8aa2f319de6c43e3b2167f7 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Thu, 6 Nov 2025 16:45:02 +0000 Subject: [PATCH 02/16] removed unnecessary files --- cmd/thv/__debug_bin112695300 | 0 pkg/workloads/mocks/mock_storage_driver.go | 150 --------------------- 2 files changed, 150 deletions(-) delete mode 100644 cmd/thv/__debug_bin112695300 delete mode 100644 pkg/workloads/mocks/mock_storage_driver.go diff --git a/cmd/thv/__debug_bin112695300 b/cmd/thv/__debug_bin112695300 deleted file mode 100644 index e69de29bb..000000000 diff --git a/pkg/workloads/mocks/mock_storage_driver.go b/pkg/workloads/mocks/mock_storage_driver.go deleted file mode 100644 index 905113660..000000000 --- a/pkg/workloads/mocks/mock_storage_driver.go +++ /dev/null @@ -1,150 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: driver.go -// -// Generated by this command: -// -// mockgen -destination=mocks/mock_storage_driver.go -package=mocks -source=driver.go StorageDriver -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - core "github.com/stacklok/toolhive/pkg/core" - runner "github.com/stacklok/toolhive/pkg/runner" - gomock "go.uber.org/mock/gomock" -) - -// MockStorageDriver is a mock of StorageDriver interface. -type MockStorageDriver struct { - ctrl *gomock.Controller - recorder *MockStorageDriverMockRecorder - isgomock struct{} -} - -// MockStorageDriverMockRecorder is the mock recorder for MockStorageDriver. -type MockStorageDriverMockRecorder struct { - mock *MockStorageDriver -} - -// NewMockStorageDriver creates a new mock instance. -func NewMockStorageDriver(ctrl *gomock.Controller) *MockStorageDriver { - mock := &MockStorageDriver{ctrl: ctrl} - mock.recorder = &MockStorageDriverMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStorageDriver) EXPECT() *MockStorageDriverMockRecorder { - return m.recorder -} - -// DeleteWorkloadState mocks base method. -func (m *MockStorageDriver) DeleteWorkloadState(ctx context.Context, name string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteWorkloadState", ctx, name) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteWorkloadState indicates an expected call of DeleteWorkloadState. -func (mr *MockStorageDriverMockRecorder) DeleteWorkloadState(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkloadState", reflect.TypeOf((*MockStorageDriver)(nil).DeleteWorkloadState), ctx, name) -} - -// GetWorkload mocks base method. -func (m *MockStorageDriver) GetWorkload(ctx context.Context, name string) (core.Workload, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWorkload", ctx, name) - ret0, _ := ret[0].(core.Workload) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetWorkload indicates an expected call of GetWorkload. -func (mr *MockStorageDriverMockRecorder) GetWorkload(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkload", reflect.TypeOf((*MockStorageDriver)(nil).GetWorkload), ctx, name) -} - -// ListWorkloads mocks base method. -func (m *MockStorageDriver) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { - m.ctrl.T.Helper() - varargs := []any{ctx, listAll} - for _, a := range labelFilters { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "ListWorkloads", varargs...) - ret0, _ := ret[0].([]core.Workload) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListWorkloads indicates an expected call of ListWorkloads. -func (mr *MockStorageDriverMockRecorder) ListWorkloads(ctx, listAll any, labelFilters ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, listAll}, labelFilters...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloads", reflect.TypeOf((*MockStorageDriver)(nil).ListWorkloads), varargs...) -} - -// ListWorkloadsInGroup mocks base method. -func (m *MockStorageDriver) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListWorkloadsInGroup", ctx, groupName) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListWorkloadsInGroup indicates an expected call of ListWorkloadsInGroup. -func (mr *MockStorageDriverMockRecorder) ListWorkloadsInGroup(ctx, groupName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloadsInGroup", reflect.TypeOf((*MockStorageDriver)(nil).ListWorkloadsInGroup), ctx, groupName) -} - -// LoadWorkloadState mocks base method. -func (m *MockStorageDriver) LoadWorkloadState(ctx context.Context, name string) (*runner.RunConfig, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadWorkloadState", ctx, name) - ret0, _ := ret[0].(*runner.RunConfig) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// LoadWorkloadState indicates an expected call of LoadWorkloadState. -func (mr *MockStorageDriverMockRecorder) LoadWorkloadState(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWorkloadState", reflect.TypeOf((*MockStorageDriver)(nil).LoadWorkloadState), ctx, name) -} - -// MoveToGroup mocks base method. -func (m *MockStorageDriver) MoveToGroup(ctx context.Context, names []string, from, to string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MoveToGroup", ctx, names, from, to) - ret0, _ := ret[0].(error) - return ret0 -} - -// MoveToGroup indicates an expected call of MoveToGroup. -func (mr *MockStorageDriverMockRecorder) MoveToGroup(ctx, names, from, to any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MoveToGroup", reflect.TypeOf((*MockStorageDriver)(nil).MoveToGroup), ctx, names, from, to) -} - -// SaveWorkloadState mocks base method. -func (m *MockStorageDriver) SaveWorkloadState(ctx context.Context, config *runner.RunConfig) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SaveWorkloadState", ctx, config) - ret0, _ := ret[0].(error) - return ret0 -} - -// SaveWorkloadState indicates an expected call of SaveWorkloadState. -func (mr *MockStorageDriverMockRecorder) SaveWorkloadState(ctx, config any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveWorkloadState", reflect.TypeOf((*MockStorageDriver)(nil).SaveWorkloadState), ctx, config) -} From 0b09d382b674856d26f95b3453238ee590c153ff Mon Sep 17 00:00:00 2001 From: amirejaz Date: Mon, 10 Nov 2025 23:07:12 +0000 Subject: [PATCH 03/16] unified workload with separate workloads for cli and k8s --- cmd/vmcp/app/commands.go | 20 +++- pkg/vmcp/aggregator/cli_discoverer.go | 137 ++++++++++++++++++++++ pkg/vmcp/aggregator/discoverer.go | 132 +++------------------ pkg/vmcp/aggregator/discoverer_test.go | 27 +++-- pkg/vmcp/aggregator/k8s_discoverer.go | 155 +++++++++++++++++++++++++ pkg/workloads/k8s/workload.go | 45 +++++++ pkg/workloads/k8s_manager.go | 109 ++++------------- pkg/workloads/k8s_manager_interface.go | 49 ++++++++ pkg/workloads/k8s_manager_test.go | 130 ++++++--------------- pkg/workloads/manager.go | 32 ++--- 10 files changed, 511 insertions(+), 325 deletions(-) create mode 100644 pkg/vmcp/aggregator/cli_discoverer.go create mode 100644 pkg/vmcp/aggregator/k8s_discoverer.go create mode 100644 pkg/workloads/k8s/workload.go create mode 100644 pkg/workloads/k8s_manager_interface.go diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 7d9b00ee2..f54677717 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -9,6 +9,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" + rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" @@ -220,9 +221,17 @@ func discoverBackends(ctx context.Context, cfg *config.Config) ([]vmcp.Backend, // Initialize managers for backend discovery logger.Info("Initializing workload and group managers") - workloadsManager, err := workloads.NewManager(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to create workloads manager: %w", err) + var workloadsManager interface{} + if rt.IsKubernetesRuntime() { + workloadsManager, err = workloads.NewK8SManagerFromContext(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to create Kubernetes workloads manager: %w", err) + } + } else { + workloadsManager, err = workloads.NewManager(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to create CLI workloads manager: %w", err) + } } groupsManager, err := groups.NewManager() @@ -231,7 +240,10 @@ func discoverBackends(ctx context.Context, cfg *config.Config) ([]vmcp.Backend, } // Create backend discoverer and discover backends - discoverer := aggregator.NewBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth) + discoverer, err := aggregator.NewBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth) + if err != nil { + return nil, nil, fmt.Errorf("failed to create backend discoverer: %w", err) + } logger.Infof("Discovering backends in group: %s", cfg.Group) backends, err := discoverer.Discover(ctx, cfg.Group) diff --git a/pkg/vmcp/aggregator/cli_discoverer.go b/pkg/vmcp/aggregator/cli_discoverer.go new file mode 100644 index 000000000..0f56ca609 --- /dev/null +++ b/pkg/vmcp/aggregator/cli_discoverer.go @@ -0,0 +1,137 @@ +// Package aggregator provides platform-agnostic backend discovery. +// This file contains the CLI-specific discoverer implementation. +package aggregator + +import ( + "context" + "fmt" + + "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/groups" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/workloads" +) + +// cliBackendDiscoverer discovers backend MCP servers from CLI workloads (containers). +// It works with workloads.Manager and core.Workload. +type cliBackendDiscoverer struct { + workloadsManager workloads.Manager + groupsManager groups.Manager + authConfig *config.OutgoingAuthConfig +} + +// NewCLIBackendDiscoverer creates a new CLI backend discoverer. +func NewCLIBackendDiscoverer( + workloadsManager workloads.Manager, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { + return &cliBackendDiscoverer{ + workloadsManager: workloadsManager, + groupsManager: groupsManager, + authConfig: authConfig, + } +} + +// Discover finds all backend workloads in the specified group. +func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { + logger.Infof("Discovering CLI backends in group %s", groupRef) + + // Verify that the group exists + exists, err := d.groupsManager.Exists(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to check if group exists: %w", err) + } + if !exists { + return nil, fmt.Errorf("group %s not found", groupRef) + } + + // Get all workload names in the group + workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to list workloads in group: %w", err) + } + + if len(workloadNames) == 0 { + logger.Infof("No workloads found in group %s", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) + + // Query each workload and convert to backend + var backends []vmcp.Backend + for _, name := range workloadNames { + workload, err := d.workloadsManager.GetWorkload(ctx, name) + if err != nil { + logger.Warnf("Failed to get workload %s: %v, skipping", name, err) + continue + } + + backend := d.convertCoreWorkload(workload, groupRef) + if backend != nil { + backends = append(backends, *backend) + } + } + + if len(backends) == 0 { + logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) + return backends, nil +} + +// convertCoreWorkload converts a core.Workload to a vmcp.Backend. +func (d *cliBackendDiscoverer) convertCoreWorkload(workload core.Workload, groupRef string) *vmcp.Backend { + // Skip workloads without a URL (not accessible) + if workload.URL == "" { + logger.Debugf("Skipping workload %s without URL", workload.Name) + return nil + } + + // Map workload status to backend health status + healthStatus := mapWorkloadStatusToHealth(workload.Status) + + // Convert core.Workload to vmcp.Backend + transportType := workload.ProxyMode + if transportType == "" { + // Fallback to TransportType if ProxyMode is not set (for direct transports) + transportType = workload.TransportType.String() + } + + backend := vmcp.Backend{ + ID: workload.Name, + Name: workload.Name, + BaseURL: workload.URL, + TransportType: transportType, + HealthStatus: healthStatus, + Metadata: make(map[string]string), + } + + // Apply authentication configuration if provided + authStrategy, authMetadata := d.authConfig.ResolveForBackend(workload.Name) + backend.AuthStrategy = authStrategy + backend.AuthMetadata = authMetadata + if authStrategy != "" { + logger.Debugf("Backend %s configured with auth strategy: %s", workload.Name, authStrategy) + } + + // Copy user labels to metadata first + for k, v := range workload.Labels { + backend.Metadata[k] = v + } + + // Set system metadata (these override user labels to prevent conflicts) + backend.Metadata["group"] = groupRef + backend.Metadata["tool_type"] = workload.ToolType + backend.Metadata["workload_status"] = string(workload.Status) + + logger.Debugf("Discovered backend %s: %s (%s) with health status %s", + backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) + + return &backend +} diff --git a/pkg/vmcp/aggregator/discoverer.go b/pkg/vmcp/aggregator/discoverer.go index ffceaf9d5..df82cf4ed 100644 --- a/pkg/vmcp/aggregator/discoverer.go +++ b/pkg/vmcp/aggregator/discoverer.go @@ -1,146 +1,48 @@ // Package aggregator provides platform-agnostic backend discovery. // // The BackendDiscoverer interface is defined in aggregator.go. -// The unified implementation (works for both CLI and Kubernetes) is in this file. +// This file contains the factory function that selects the appropriate discoverer +// based on the runtime environment (CLI or Kubernetes). package aggregator import ( - "context" "fmt" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/workloads" ) -// backendDiscoverer discovers backend MCP servers from workloads in a group. -// It works with both CLI (Docker/Podman) and Kubernetes environments via the unified workloads.Manager interface. -// This is a platform-agnostic implementation that automatically adapts to the runtime environment. -type backendDiscoverer struct { - workloadsManager workloads.Manager - groupsManager groups.Manager - authConfig *config.OutgoingAuthConfig -} - -// NewBackendDiscoverer creates a new backend discoverer. -// It discovers workloads from containers (CLI) or MCPServer CRDs (Kubernetes) managed by ToolHive. -// The workloads.Manager automatically selects the appropriate storage driver based on the runtime environment. +// NewBackendDiscoverer creates a new backend discoverer based on the runtime environment. +// It accepts interface{} for workloadsManager to handle both workloads.Manager (CLI) and workloads.K8SManager (Kubernetes). +// Type assertion happens once in this factory, not in discovery logic. // // The authConfig parameter configures authentication for discovered backends. // If nil, backends will have no authentication configured. func NewBackendDiscoverer( - workloadsManager workloads.Manager, + workloadsManager interface{}, groupsManager groups.Manager, authConfig *config.OutgoingAuthConfig, -) BackendDiscoverer { - return &backendDiscoverer{ - workloadsManager: workloadsManager, - groupsManager: groupsManager, - authConfig: authConfig, - } -} - -// Discover finds all backend workloads in the specified group. -// Returns all accessible backends with their health status marked based on workload status. -// The groupRef is the group name (e.g., "engineering-team"). -func (d *backendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { - logger.Infof("Discovering backends in group %s", groupRef) - - // Verify that the group exists - exists, err := d.groupsManager.Exists(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to check if group exists: %w", err) - } - if !exists { - return nil, fmt.Errorf("group %s not found", groupRef) - } - - // Get all workload names in the group - workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to list workloads in group: %w", err) - } - - if len(workloadNames) == 0 { - logger.Infof("No workloads found in group %s", groupRef) - return []vmcp.Backend{}, nil - } - - logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) - - // Query each workload and convert to backend - var backends []vmcp.Backend - for _, name := range workloadNames { - workload, err := d.workloadsManager.GetWorkload(ctx, name) - if err != nil { - logger.Warnf("Failed to get workload %s: %v, skipping", name, err) - continue +) (BackendDiscoverer, error) { + if rt.IsKubernetesRuntime() { + k8sMgr, ok := workloadsManager.(workloads.K8SManager) + if !ok { + return nil, fmt.Errorf("expected workloads.K8SManager in Kubernetes mode, got %T", workloadsManager) } - - // Skip workloads without a URL (not accessible) - if workload.URL == "" { - logger.Debugf("Skipping workload %s without URL", name) - continue - } - - // Map workload status to backend health status - healthStatus := mapWorkloadStatusToHealth(workload.Status) - - // Convert core.Workload to vmcp.Backend - // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. - // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. - // ProxyMode tells us which transport the vmcp client should use. - transportType := workload.ProxyMode - if transportType == "" { - // Fallback to TransportType if ProxyMode is not set (for direct transports) - transportType = workload.TransportType.String() - } - - backend := vmcp.Backend{ - ID: name, - Name: name, - BaseURL: workload.URL, - TransportType: transportType, - HealthStatus: healthStatus, - Metadata: make(map[string]string), - } - - // Apply authentication configuration if provided - authStrategy, authMetadata := d.authConfig.ResolveForBackend(name) - backend.AuthStrategy = authStrategy - backend.AuthMetadata = authMetadata - if authStrategy != "" { - logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) - } - - // Copy user labels to metadata first - for k, v := range workload.Labels { - backend.Metadata[k] = v - } - - // Set system metadata (these override user labels to prevent conflicts) - backend.Metadata["group"] = groupRef - backend.Metadata["tool_type"] = workload.ToolType - backend.Metadata["workload_status"] = string(workload.Status) - - backends = append(backends, backend) - logger.Debugf("Discovered backend %s: %s (%s) with health status %s", - backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) + return NewK8SBackendDiscoverer(k8sMgr, groupsManager, authConfig), nil } - if len(backends) == 0 { - logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) - return []vmcp.Backend{}, nil + cliMgr, ok := workloadsManager.(workloads.Manager) + if !ok { + return nil, fmt.Errorf("expected workloads.Manager in CLI mode, got %T", workloadsManager) } - - logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) - return backends, nil + return NewCLIBackendDiscoverer(cliMgr, groupsManager, authConfig), nil } // mapWorkloadStatusToHealth converts a workload status to a backend health status. +// This is used by the CLI discoverer. func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { switch status { case rt.WorkloadStatusRunning: diff --git a/pkg/vmcp/aggregator/discoverer_test.go b/pkg/vmcp/aggregator/discoverer_test.go index a8fe4e550..2ef43123e 100644 --- a/pkg/vmcp/aggregator/discoverer_test.go +++ b/pkg/vmcp/aggregator/discoverer_test.go @@ -45,7 +45,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + require.NoError(t, err) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -79,7 +80,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) - discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + require.NoError(t, err) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -108,7 +110,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) - discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + require.NoError(t, err) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -133,7 +136,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + require.NoError(t, err) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -150,7 +154,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) - discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + require.NoError(t, err) backends, err := discoverer.Discover(context.Background(), "nonexistent-group") require.Error(t, err) @@ -168,7 +173,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) - discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + require.NoError(t, err) backends, err := discoverer.Discover(context.Background(), testGroupName) require.Error(t, err) @@ -187,7 +193,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) - discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + require.NoError(t, err) backends, err := discoverer.Discover(context.Background(), "empty-group") require.NoError(t, err) @@ -214,7 +221,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil) - discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + require.NoError(t, err) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -240,7 +248,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). Return(core.Workload{}, errors.New("workload query failed")) - discoverer := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) + require.NoError(t, err) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) diff --git a/pkg/vmcp/aggregator/k8s_discoverer.go b/pkg/vmcp/aggregator/k8s_discoverer.go new file mode 100644 index 000000000..7ad3f6422 --- /dev/null +++ b/pkg/vmcp/aggregator/k8s_discoverer.go @@ -0,0 +1,155 @@ +// Package aggregator provides platform-agnostic backend discovery. +// This file contains the Kubernetes-specific discoverer implementation. +package aggregator + +import ( + "context" + "fmt" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/groups" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/workloads" + "github.com/stacklok/toolhive/pkg/workloads/k8s" +) + +// k8sBackendDiscoverer discovers backend MCP servers from Kubernetes workloads (MCPServer CRDs). +// It works with workloads.K8SManager and k8s.Workload. +type k8sBackendDiscoverer struct { + workloadsManager workloads.K8SManager + groupsManager groups.Manager + authConfig *config.OutgoingAuthConfig +} + +// NewK8SBackendDiscoverer creates a new Kubernetes backend discoverer. +func NewK8SBackendDiscoverer( + workloadsManager workloads.K8SManager, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { + return &k8sBackendDiscoverer{ + workloadsManager: workloadsManager, + groupsManager: groupsManager, + authConfig: authConfig, + } +} + +// Discover finds all backend workloads in the specified group. +func (d *k8sBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { + logger.Infof("Discovering Kubernetes backends in group %s", groupRef) + + // Verify that the group exists + exists, err := d.groupsManager.Exists(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to check if group exists: %w", err) + } + if !exists { + return nil, fmt.Errorf("group %s not found", groupRef) + } + + // Get all workload names in the group + workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to list workloads in group: %w", err) + } + + if len(workloadNames) == 0 { + logger.Infof("No workloads found in group %s", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) + + // Query each workload and convert to backend + var backends []vmcp.Backend + for _, name := range workloadNames { + workload, err := d.workloadsManager.GetWorkload(ctx, name) + if err != nil { + logger.Warnf("Failed to get workload %s: %v, skipping", name, err) + continue + } + + backend := d.convertK8SWorkload(workload, groupRef) + if backend != nil { + backends = append(backends, *backend) + } + } + + if len(backends) == 0 { + logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) + return backends, nil +} + +// convertK8SWorkload converts a k8s.Workload to a vmcp.Backend. +func (d *k8sBackendDiscoverer) convertK8SWorkload(workload k8s.Workload, groupRef string) *vmcp.Backend { + // Skip workloads without a URL (not accessible) + if workload.URL == "" { + logger.Debugf("Skipping workload %s without URL", workload.Name) + return nil + } + + // Map workload phase to backend health status + healthStatus := mapK8SWorkloadPhaseToHealth(workload.Phase) + + // Convert k8s.Workload to vmcp.Backend + transportType := workload.ProxyMode + if transportType == "" { + // Fallback to TransportType if ProxyMode is not set (for direct transports) + transportType = workload.TransportType.String() + } + + backend := vmcp.Backend{ + ID: workload.Name, + Name: workload.Name, + BaseURL: workload.URL, + TransportType: transportType, + HealthStatus: healthStatus, + Metadata: make(map[string]string), + } + + // Apply authentication configuration if provided + authStrategy, authMetadata := d.authConfig.ResolveForBackend(workload.Name) + backend.AuthStrategy = authStrategy + backend.AuthMetadata = authMetadata + if authStrategy != "" { + logger.Debugf("Backend %s configured with auth strategy: %s", workload.Name, authStrategy) + } + + // Copy user labels to metadata first + for k, v := range workload.Labels { + backend.Metadata[k] = v + } + + // Set system metadata (these override user labels to prevent conflicts) + backend.Metadata["group"] = groupRef + backend.Metadata["tool_type"] = workload.ToolType + backend.Metadata["workload_phase"] = string(workload.Phase) + backend.Metadata["namespace"] = workload.Namespace + + logger.Debugf("Discovered backend %s: %s (%s) with health status %s", + backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) + + return &backend +} + +// mapK8SWorkloadPhaseToHealth converts a MCPServerPhase to a backend health status. +func mapK8SWorkloadPhaseToHealth(phase mcpv1alpha1.MCPServerPhase) vmcp.BackendHealthStatus { + switch phase { + case mcpv1alpha1.MCPServerPhaseRunning: + return vmcp.BackendHealthy + case mcpv1alpha1.MCPServerPhaseFailed: + return vmcp.BackendUnhealthy + case mcpv1alpha1.MCPServerPhaseTerminating: + return vmcp.BackendUnhealthy + case mcpv1alpha1.MCPServerPhasePending: + return vmcp.BackendUnknown + default: + return vmcp.BackendUnknown + } +} diff --git a/pkg/workloads/k8s/workload.go b/pkg/workloads/k8s/workload.go new file mode 100644 index 000000000..e39a1014a --- /dev/null +++ b/pkg/workloads/k8s/workload.go @@ -0,0 +1,45 @@ +// Package k8s provides Kubernetes-specific domain models for workloads. +package k8s + +import ( + "time" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/transport/types" +) + +// Workload represents a Kubernetes workload (MCPServer CRD). +// This is the Kubernetes-specific domain model, separate from core.Workload. +type Workload struct { + // Name is the name of the MCPServer CRD + Name string + // Namespace is the Kubernetes namespace where the MCPServer is deployed + Namespace string + // Package specifies the container image used for this workload + Package string + // URL is the URL of the workload exposed by the ToolHive proxy + URL string + // Port is the port on which the workload is exposed + Port int + // ToolType is the type of tool this workload represents + // For now, it will always be "mcp" - representing an MCP server + ToolType string + // TransportType is the type of transport used for this workload + TransportType types.TransportType + // ProxyMode is the proxy mode that clients should use to connect + ProxyMode string + // Phase is the current phase of the MCPServer CRD + Phase mcpv1alpha1.MCPServerPhase + // StatusContext provides additional context about the workload's status + StatusContext string + // CreatedAt is the timestamp when the workload was created + CreatedAt time.Time + // Labels are user-defined labels (from annotations) + Labels map[string]string + // Group is the name of the group this workload belongs to, if any + Group string + // ToolsFilter is the filter on tools applied to the workload + ToolsFilter []string + // GroupRef is the reference to the MCPGroup (same as Group, but using CRD terminology) + GroupRef string +} diff --git a/pkg/workloads/k8s_manager.go b/pkg/workloads/k8s_manager.go index 0c880c755..5b813d4c0 100644 --- a/pkg/workloads/k8s_manager.go +++ b/pkg/workloads/k8s_manager.go @@ -1,4 +1,4 @@ -// Package workloads provides a Kubernetes-based implementation of the Manager interface. +// Package workloads provides a Kubernetes-based implementation of the K8SManager interface. // This file contains the Kubernetes implementation for operator environments. package workloads @@ -7,7 +7,6 @@ import ( "fmt" "time" - "golang.org/x/sync/errgroup" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/selection" @@ -15,16 +14,14 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/runner" "github.com/stacklok/toolhive/pkg/transport" transporttypes "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/workloads/k8s" workloadtypes "github.com/stacklok/toolhive/pkg/workloads/types" ) -// k8sManager implements the Manager interface for Kubernetes environments. +// k8sManager implements the K8SManager interface for Kubernetes environments. // In Kubernetes, the operator manages workload lifecycle via MCPServer CRDs. // This manager provides read-only operations and CRD-based storage. type k8sManager struct { @@ -33,24 +30,24 @@ type k8sManager struct { } // NewK8SManager creates a new Kubernetes-based workload manager. -func NewK8SManager(k8sClient client.Client, namespace string) (Manager, error) { +func NewK8SManager(k8sClient client.Client, namespace string) (K8SManager, error) { return &k8sManager{ k8sClient: k8sClient, namespace: namespace, }, nil } -func (k *k8sManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { +func (k *k8sManager) GetWorkload(ctx context.Context, workloadName string) (k8s.Workload, error) { mcpServer := &mcpv1alpha1.MCPServer{} key := types.NamespacedName{Name: workloadName, Namespace: k.namespace} if err := k.k8sClient.Get(ctx, key, mcpServer); err != nil { if errors.IsNotFound(err) { - return core.Workload{}, fmt.Errorf("%w: %s", rt.ErrWorkloadNotFound, workloadName) + return k8s.Workload{}, fmt.Errorf("MCPServer %s not found", workloadName) } - return core.Workload{}, fmt.Errorf("failed to get MCPServer: %w", err) + return k8s.Workload{}, fmt.Errorf("failed to get MCPServer: %w", err) } - return k.mcpServerToWorkload(mcpServer) + return k.mcpServerToK8SWorkload(mcpServer) } func (k *k8sManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { @@ -65,7 +62,7 @@ func (k *k8sManager) DoesWorkloadExist(ctx context.Context, workloadName string) return true, nil } -func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { +func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]k8s.Workload, error) { mcpServerList := &mcpv1alpha1.MCPServerList{} listOpts := []client.ListOption{ client.InNamespace(k.namespace), @@ -94,7 +91,7 @@ func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilte return nil, fmt.Errorf("failed to list MCPServers: %w", err) } - var workloads []core.Workload + var workloads []k8s.Workload for i := range mcpServerList.Items { mcpServer := &mcpServerList.Items[i] @@ -106,7 +103,7 @@ func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilte } } - workload, err := k.mcpServerToWorkload(mcpServer) + workload, err := k.mcpServerToK8SWorkload(mcpServer) if err != nil { logger.Warnf("Failed to convert MCPServer %s to workload: %v", mcpServer.Name, err) continue @@ -118,53 +115,13 @@ func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilte return workloads, nil } -// StopWorkloads is a no-op in Kubernetes mode. -// The operator manages workload lifecycle via MCPServer CRDs. -func (*k8sManager) StopWorkloads(_ context.Context, _ []string) (*errgroup.Group, error) { - logger.Warnf("StopWorkloads is not supported in Kubernetes mode. Use kubectl to manage MCPServer CRDs.") - group := &errgroup.Group{} - // Return empty group - no operations to perform - return group, nil -} - -// RunWorkload is a no-op in Kubernetes mode. -// Workloads are created via MCPServer CRDs managed by the operator. -func (*k8sManager) RunWorkload(_ context.Context, _ *runner.RunConfig) error { - return fmt.Errorf("RunWorkload is not supported in Kubernetes mode. Create MCPServer CRD instead") -} - -// RunWorkloadDetached is a no-op in Kubernetes mode. -// Workloads are created via MCPServer CRDs managed by the operator. -func (*k8sManager) RunWorkloadDetached(_ context.Context, _ *runner.RunConfig) error { - return fmt.Errorf("RunWorkloadDetached is not supported in Kubernetes mode. Create MCPServer CRD instead") -} - -// DeleteWorkloads is a no-op in Kubernetes mode. -// The operator manages workload lifecycle via MCPServer CRDs. -func (*k8sManager) DeleteWorkloads(_ context.Context, _ []string) (*errgroup.Group, error) { - logger.Warnf("DeleteWorkloads is not supported in Kubernetes mode. Use kubectl to delete MCPServer CRDs.") - group := &errgroup.Group{} - // Return empty group - no operations to perform - return group, nil -} - -// RestartWorkloads is a no-op in Kubernetes mode. -// The operator manages workload lifecycle via MCPServer CRDs. -func (*k8sManager) RestartWorkloads(_ context.Context, _ []string, _ bool) (*errgroup.Group, error) { - logger.Warnf("RestartWorkloads is not supported in Kubernetes mode. Use kubectl to restart MCPServer CRDs.") - group := &errgroup.Group{} - // Return empty group - no operations to perform - return group, nil -} - -// UpdateWorkload is a no-op in Kubernetes mode. -// The operator manages workload lifecycle via MCPServer CRDs. -func (*k8sManager) UpdateWorkload(_ context.Context, _ string, _ *runner.RunConfig) (*errgroup.Group, error) { - logger.Warnf("UpdateWorkload is not supported in Kubernetes mode. Update MCPServer CRD instead.") - group := &errgroup.Group{} - // Return empty group - no operations to perform - return group, nil -} +// Note: The following operations are not part of K8SManager interface: +// - StopWorkloads: Use kubectl to manage MCPServer CRDs +// - RunWorkload: Create MCPServer CRD instead +// - RunWorkloadDetached: Create MCPServer CRD instead +// - DeleteWorkloads: Use kubectl to delete MCPServer CRDs +// - RestartWorkloads: Use kubectl to restart MCPServer CRDs +// - UpdateWorkload: Update MCPServer CRD directly // GetLogs retrieves logs from the pod associated with the MCPServer. // Note: This requires a Kubernetes clientset for log streaming. @@ -240,11 +197,8 @@ func (k *k8sManager) ListWorkloadsInGroup(ctx context.Context, groupName string) return groupWorkloads, nil } -// mcpServerToWorkload converts an MCPServer CRD to a core.Workload. -func (k *k8sManager) mcpServerToWorkload(mcpServer *mcpv1alpha1.MCPServer) (core.Workload, error) { - // Map MCPServerPhase to runtime.WorkloadStatus - status := k.mcpServerPhaseToWorkloadStatus(mcpServer.Status.Phase) - +// mcpServerToK8SWorkload converts an MCPServer CRD to a k8s.Workload. +func (k *k8sManager) mcpServerToK8SWorkload(mcpServer *mcpv1alpha1.MCPServer) (k8s.Workload, error) { // Parse transport type transportType, err := transporttypes.ParseTransportType(mcpServer.Spec.Transport) if err != nil { @@ -297,40 +251,25 @@ func (k *k8sManager) mcpServerToWorkload(mcpServer *mcpv1alpha1.MCPServer) (core createdAt = time.Now() } - return core.Workload{ + return k8s.Workload{ Name: mcpServer.Name, + Namespace: mcpServer.Namespace, Package: mcpServer.Spec.Image, URL: url, ToolType: "mcp", TransportType: transportType, ProxyMode: effectiveProxyMode, - Status: status, + Phase: mcpServer.Status.Phase, StatusContext: mcpServer.Status.Message, CreatedAt: createdAt, Port: port, Labels: userLabels, Group: mcpServer.Spec.GroupRef, + GroupRef: mcpServer.Spec.GroupRef, ToolsFilter: toolsFilter, - Remote: false, // MCPServers are always container workloads in Kubernetes }, nil } -// mcpServerPhaseToWorkloadStatus maps MCPServerPhase to runtime.WorkloadStatus. -func (*k8sManager) mcpServerPhaseToWorkloadStatus(phase mcpv1alpha1.MCPServerPhase) rt.WorkloadStatus { - switch phase { - case mcpv1alpha1.MCPServerPhaseRunning: - return rt.WorkloadStatusRunning - case mcpv1alpha1.MCPServerPhasePending: - return rt.WorkloadStatusStarting - case mcpv1alpha1.MCPServerPhaseFailed: - return rt.WorkloadStatusError - case mcpv1alpha1.MCPServerPhaseTerminating: - return rt.WorkloadStatusStopping - default: - return rt.WorkloadStatusUnknown - } -} - // isStandardK8sAnnotation checks if an annotation key is a standard Kubernetes annotation. func (*k8sManager) isStandardK8sAnnotation(key string) bool { // Common Kubernetes annotation prefixes diff --git a/pkg/workloads/k8s_manager_interface.go b/pkg/workloads/k8s_manager_interface.go new file mode 100644 index 000000000..e6b559cf9 --- /dev/null +++ b/pkg/workloads/k8s_manager_interface.go @@ -0,0 +1,49 @@ +// Package workloads contains high-level logic for managing the lifecycle of +// ToolHive-managed containers. +package workloads + +import ( + "context" + + "github.com/stacklok/toolhive/pkg/workloads/k8s" +) + +// K8SManager manages MCPServer CRD workloads in Kubernetes. +// This interface is separate from Manager to avoid coupling Kubernetes workloads +// to the CLI container runtime interface. +// +//go:generate mockgen -destination=mocks/mock_k8s_manager.go -package=mocks -source=k8s_manager_interface.go K8SManager +type K8SManager interface { + // GetWorkload retrieves an MCPServer CRD by name + GetWorkload(ctx context.Context, workloadName string) (k8s.Workload, error) + + // ListWorkloads lists all MCPServer CRDs, optionally filtered by labels + // The `listAll` parameter determines whether to include workloads that are not running + // The optional `labelFilters` parameter allows filtering workloads by labels (format: key=value) + ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]k8s.Workload, error) + + // ListWorkloadsInGroup returns all workload names that belong to the specified group + ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) + + // DoesWorkloadExist checks if an MCPServer CRD with the given name exists + DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) + + // MoveToGroup moves the specified workloads from one group to another by updating their GroupRef + MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error + + // GetLogs retrieves logs from the pod associated with the MCPServer + // Note: This may not be fully implemented and may return an error + GetLogs(ctx context.Context, containerName string, follow bool) (string, error) + + // GetProxyLogs retrieves logs from the proxy container in the pod associated with the MCPServer + // Note: This may not be fully implemented and may return an error + GetProxyLogs(ctx context.Context, workloadName string) (string, error) + + // The following operations are not supported in Kubernetes mode (operator manages lifecycle): + // - RunWorkload: Workloads are created via MCPServer CRDs + // - RunWorkloadDetached: Workloads are created via MCPServer CRDs + // - StopWorkloads: Use kubectl to manage MCPServer CRDs + // - DeleteWorkloads: Use kubectl to manage MCPServer CRDs + // - RestartWorkloads: Use kubectl to manage MCPServer CRDs + // - UpdateWorkload: Update MCPServer CRD directly +} diff --git a/pkg/workloads/k8s_manager_test.go b/pkg/workloads/k8s_manager_test.go index 516c92697..e77fbf192 100644 --- a/pkg/workloads/k8s_manager_test.go +++ b/pkg/workloads/k8s_manager_test.go @@ -13,9 +13,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/runner" + "github.com/stacklok/toolhive/pkg/workloads/k8s" ) const ( @@ -106,7 +104,7 @@ func TestK8SManager_GetWorkload(t *testing.T) { setupMock func(*mockClient) wantError bool errorMsg string - expected core.Workload + expected k8s.Workload }{ { name: "successful get", @@ -119,7 +117,7 @@ func TestK8SManager_GetWorkload(t *testing.T) { mcpServer.Status.Phase = mcpv1alpha1.MCPServerPhaseRunning mcpServer.Spec.Transport = "streamable-http" mcpServer.Spec.ProxyPort = 8080 - mcpServer.Labels = map[string]string{ + mcpServer.Annotations = map[string]string{ "group": "test-group", } } @@ -127,10 +125,11 @@ func TestK8SManager_GetWorkload(t *testing.T) { } }, wantError: false, - expected: core.Workload{ - Name: "test-workload", - Status: rt.WorkloadStatusRunning, - URL: "http://127.0.0.1:8080/mcp", // URL is generated from spec + expected: k8s.Workload{ + Name: "test-workload", + Namespace: defaultNamespace, + Phase: mcpv1alpha1.MCPServerPhaseRunning, + URL: "http://127.0.0.1:8080/mcp", // URL is generated from spec Labels: map[string]string{ "group": "test-group", }, @@ -183,7 +182,7 @@ func TestK8SManager_GetWorkload(t *testing.T) { } else { require.NoError(t, err) assert.Equal(t, tt.expected.Name, result.Name) - assert.Equal(t, tt.expected.Status, result.Status) + assert.Equal(t, tt.expected.Phase, result.Phase) assert.Equal(t, tt.expected.URL, result.URL) } }) @@ -604,48 +603,6 @@ func TestK8SManager_NoOpMethods(t *testing.T) { ctx := context.Background() - t.Run("StopWorkloads returns empty group", func(t *testing.T) { - t.Parallel() - group, err := manager.StopWorkloads(ctx, []string{testWorkload1}) - require.NoError(t, err) - require.NotNil(t, group) - }) - - t.Run("RunWorkload returns error", func(t *testing.T) { - t.Parallel() - err := manager.RunWorkload(ctx, &runner.RunConfig{}) - require.Error(t, err) - assert.Contains(t, err.Error(), "not supported in Kubernetes mode") - }) - - t.Run("RunWorkloadDetached returns error", func(t *testing.T) { - t.Parallel() - err := manager.RunWorkloadDetached(ctx, &runner.RunConfig{}) - require.Error(t, err) - assert.Contains(t, err.Error(), "not supported in Kubernetes mode") - }) - - t.Run("DeleteWorkloads returns empty group", func(t *testing.T) { - t.Parallel() - group, err := manager.DeleteWorkloads(ctx, []string{testWorkload1}) - require.NoError(t, err) - require.NotNil(t, group) - }) - - t.Run("RestartWorkloads returns empty group", func(t *testing.T) { - t.Parallel() - group, err := manager.RestartWorkloads(ctx, []string{testWorkload1}, false) - require.NoError(t, err) - require.NotNil(t, group) - }) - - t.Run("UpdateWorkload returns empty group", func(t *testing.T) { - t.Parallel() - group, err := manager.UpdateWorkload(ctx, testWorkload1, &runner.RunConfig{}) - require.NoError(t, err) - require.NotNil(t, group) - }) - t.Run("GetLogs returns error", func(t *testing.T) { t.Parallel() logs, err := manager.GetLogs(ctx, testWorkload1, false) @@ -663,19 +620,20 @@ func TestK8SManager_NoOpMethods(t *testing.T) { }) } -func TestK8SManager_mcpServerToWorkload(t *testing.T) { +func TestK8SManager_mcpServerToK8SWorkload(t *testing.T) { t.Parallel() tests := []struct { name string mcpServer *mcpv1alpha1.MCPServer - expected core.Workload + expected k8s.Workload }{ { name: "running workload with HTTP transport", mcpServer: &mcpv1alpha1.MCPServer{ ObjectMeta: metav1.ObjectMeta{ - Name: "test-workload", + Name: "test-workload", + Namespace: defaultNamespace, Annotations: map[string]string{ "group": "test-group", "env": "prod", @@ -690,41 +648,46 @@ func TestK8SManager_mcpServerToWorkload(t *testing.T) { ProxyPort: 8080, }, }, - expected: core.Workload{ - Name: "test-workload", - Status: rt.WorkloadStatusRunning, - URL: "http://localhost:8080", - Labels: map[string]string{"group": "test-group", "env": "prod"}, + expected: k8s.Workload{ + Name: "test-workload", + Namespace: defaultNamespace, + Phase: mcpv1alpha1.MCPServerPhaseRunning, + URL: "http://localhost:8080", + Labels: map[string]string{"group": "test-group", "env": "prod"}, }, }, { name: "terminating workload", mcpServer: &mcpv1alpha1.MCPServer{ ObjectMeta: metav1.ObjectMeta{ - Name: "terminating-workload", + Name: "terminating-workload", + Namespace: defaultNamespace, }, Status: mcpv1alpha1.MCPServerStatus{ Phase: mcpv1alpha1.MCPServerPhaseTerminating, }, }, - expected: core.Workload{ - Name: "terminating-workload", - Status: rt.WorkloadStatusStopping, + expected: k8s.Workload{ + Name: "terminating-workload", + Namespace: defaultNamespace, + Phase: mcpv1alpha1.MCPServerPhaseTerminating, }, }, { name: "failed workload", mcpServer: &mcpv1alpha1.MCPServer{ ObjectMeta: metav1.ObjectMeta{ - Name: "failed-workload", + Name: "failed-workload", + Namespace: defaultNamespace, }, Status: mcpv1alpha1.MCPServerStatus{ Phase: mcpv1alpha1.MCPServerPhaseFailed, }, }, - expected: core.Workload{ - Name: "failed-workload", - Status: rt.WorkloadStatusError, + expected: k8s.Workload{ + Name: "failed-workload", + Namespace: defaultNamespace, + Phase: mcpv1alpha1.MCPServerPhaseFailed, }, }, } @@ -737,11 +700,12 @@ func TestK8SManager_mcpServerToWorkload(t *testing.T) { namespace: defaultNamespace, } - result, err := manager.mcpServerToWorkload(tt.mcpServer) + result, err := manager.mcpServerToK8SWorkload(tt.mcpServer) require.NoError(t, err) assert.Equal(t, tt.expected.Name, result.Name) - assert.Equal(t, tt.expected.Status, result.Status) + assert.Equal(t, tt.expected.Namespace, result.Namespace) + assert.Equal(t, tt.expected.Phase, result.Phase) assert.Equal(t, tt.expected.URL, result.URL) if tt.expected.Labels != nil { assert.Equal(t, tt.expected.Labels, result.Labels) @@ -749,29 +713,3 @@ func TestK8SManager_mcpServerToWorkload(t *testing.T) { }) } } - -func TestK8SManager_mcpServerPhaseToWorkloadStatus(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - phase mcpv1alpha1.MCPServerPhase - expected rt.WorkloadStatus - }{ - {"running", mcpv1alpha1.MCPServerPhaseRunning, rt.WorkloadStatusRunning}, - {"pending", mcpv1alpha1.MCPServerPhasePending, rt.WorkloadStatusStarting}, - {"failed", mcpv1alpha1.MCPServerPhaseFailed, rt.WorkloadStatusError}, - {"terminating", mcpv1alpha1.MCPServerPhaseTerminating, rt.WorkloadStatusStopping}, - {"unknown", mcpv1alpha1.MCPServerPhase(""), rt.WorkloadStatusUnknown}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - manager := &k8sManager{} - result := manager.mcpServerPhaseToWorkloadStatus(tt.phase) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index 3490f0594..98fb744ae 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -64,46 +64,46 @@ type Manager interface { // ErrWorkloadNotRunning is returned when a container cannot be found by name. var ErrWorkloadNotRunning = fmt.Errorf("workload not running") -// NewManager creates a new workload manager based on the runtime environment: -// - In Kubernetes mode: returns a CRD-based manager that uses MCPServer CRDs -// - In local mode: returns a CLI/filesystem-based manager +// NewManager creates a new CLI workload manager. +// Returns Manager interface (existing behavior, unchanged). +// IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. func NewManager(ctx context.Context) (Manager, error) { if rt.IsKubernetesRuntime() { - return newK8SManager(ctx) + return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes environments") } return NewCLIManager(ctx) } -// NewManagerWithProvider creates a new workload manager with a custom config provider. +// NewManagerWithProvider creates a new CLI workload manager with a custom config provider. +// IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) (Manager, error) { if rt.IsKubernetesRuntime() { - return newK8SManager(ctx) + return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes environments") } return NewCLIManagerWithProvider(ctx, configProvider) } -// NewManagerFromRuntime creates a new workload manager from an existing runtime. +// NewManagerFromRuntime creates a new CLI workload manager from an existing runtime. +// IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. func NewManagerFromRuntime(rtRuntime rt.Runtime) (Manager, error) { if rt.IsKubernetesRuntime() { - // In Kubernetes mode, we need a k8s client, not a runtime - ctx := context.Background() - return newK8SManager(ctx) + return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes environments") } return NewCLIManagerFromRuntime(rtRuntime) } -// NewManagerFromRuntimeWithProvider creates a new workload manager from an existing runtime with a custom config provider. +// NewManagerFromRuntimeWithProvider creates a new CLI workload manager from an existing runtime with a custom config provider. +// IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. func NewManagerFromRuntimeWithProvider(rtRuntime rt.Runtime, configProvider config.Provider) (Manager, error) { if rt.IsKubernetesRuntime() { - // In Kubernetes mode, we need a k8s client, not a runtime - ctx := context.Background() - return newK8SManager(ctx) + return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes environments") } return NewCLIManagerFromRuntimeWithProvider(rtRuntime, configProvider) } -// newK8SManager creates a Kubernetes-based workload manager for Kubernetes environments -func newK8SManager(context.Context) (Manager, error) { +// NewK8SManagerFromContext creates a Kubernetes-based workload manager from context. +// It automatically sets up the Kubernetes client and detects the namespace. +func NewK8SManagerFromContext(ctx context.Context) (K8SManager, error) { // Create a scheme for controller-runtime client scheme := runtime.NewScheme() utilruntime.Must(clientgoscheme.AddToScheme(scheme)) From 068e2ce40a40bac3c190c8c44e24265143e7bedc Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 11 Nov 2025 14:23:42 +0000 Subject: [PATCH 04/16] refactor the constructor and fix tests --- cmd/vmcp/app/commands.go | 25 +++++++--------- pkg/vmcp/aggregator/cli_discoverer.go | 2 ++ pkg/vmcp/aggregator/discoverer.go | 41 -------------------------- pkg/vmcp/aggregator/discoverer_test.go | 27 ++++++----------- pkg/vmcp/aggregator/k8s_discoverer.go | 8 ++++- pkg/workloads/k8s_manager_test.go | 2 +- pkg/workloads/manager.go | 2 +- 7 files changed, 31 insertions(+), 76 deletions(-) delete mode 100644 pkg/vmcp/aggregator/discoverer.go diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index f54677717..5490f6209 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -221,28 +221,25 @@ func discoverBackends(ctx context.Context, cfg *config.Config) ([]vmcp.Backend, // Initialize managers for backend discovery logger.Info("Initializing workload and group managers") - var workloadsManager interface{} + groupsManager, err := groups.NewManager() + if err != nil { + return nil, nil, fmt.Errorf("failed to create groups manager: %w", err) + } + + // Create backend discoverer based on runtime environment + var discoverer aggregator.BackendDiscoverer if rt.IsKubernetesRuntime() { - workloadsManager, err = workloads.NewK8SManagerFromContext(ctx) + k8sWorkloadsManager, err := workloads.NewK8SManagerFromContext(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to create Kubernetes workloads manager: %w", err) } + discoverer = aggregator.NewK8SBackendDiscoverer(k8sWorkloadsManager, groupsManager, cfg.OutgoingAuth) } else { - workloadsManager, err = workloads.NewManager(ctx) + cliWorkloadsManager, err := workloads.NewManager(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to create CLI workloads manager: %w", err) } - } - - groupsManager, err := groups.NewManager() - if err != nil { - return nil, nil, fmt.Errorf("failed to create groups manager: %w", err) - } - - // Create backend discoverer and discover backends - discoverer, err := aggregator.NewBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth) - if err != nil { - return nil, nil, fmt.Errorf("failed to create backend discoverer: %w", err) + discoverer = aggregator.NewCLIBackendDiscoverer(cliWorkloadsManager, groupsManager, cfg.OutgoingAuth) } logger.Infof("Discovering backends in group: %s", cfg.Group) diff --git a/pkg/vmcp/aggregator/cli_discoverer.go b/pkg/vmcp/aggregator/cli_discoverer.go index c577a93be..8c4b3cffc 100644 --- a/pkg/vmcp/aggregator/cli_discoverer.go +++ b/pkg/vmcp/aggregator/cli_discoverer.go @@ -25,6 +25,8 @@ type cliBackendDiscoverer struct { // // The authConfig parameter configures authentication for discovered backends. // If nil, backends will have no authentication configured. +// +// This is the CLI-specific constructor. For Kubernetes workloads, use NewK8SBackendDiscoverer. func NewCLIBackendDiscoverer( workloadsManager workloads.Manager, groupsManager groups.Manager, diff --git a/pkg/vmcp/aggregator/discoverer.go b/pkg/vmcp/aggregator/discoverer.go deleted file mode 100644 index de0d55dc4..000000000 --- a/pkg/vmcp/aggregator/discoverer.go +++ /dev/null @@ -1,41 +0,0 @@ -// Package aggregator provides platform-agnostic backend discovery. -// -// The BackendDiscoverer interface is defined in aggregator.go. -// This file contains the factory function that selects the appropriate discoverer -// based on the runtime environment (CLI or Kubernetes). -package aggregator - -import ( - "fmt" - - rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/pkg/workloads" -) - -// NewBackendDiscoverer creates a new backend discoverer based on the runtime environment. -// It accepts interface{} for workloadsManager to handle both workloads.Manager (CLI) and workloads.K8SManager (Kubernetes). -// Type assertion happens once in this factory, not in discovery logic. -// -// The authConfig parameter configures authentication for discovered backends. -// If nil, backends will have no authentication configured. -func NewBackendDiscoverer( - workloadsManager interface{}, - groupsManager groups.Manager, - authConfig *config.OutgoingAuthConfig, -) (BackendDiscoverer, error) { - if rt.IsKubernetesRuntime() { - k8sMgr, ok := workloadsManager.(workloads.K8SManager) - if !ok { - return nil, fmt.Errorf("expected workloads.K8SManager in Kubernetes mode, got %T", workloadsManager) - } - return NewK8SBackendDiscoverer(k8sMgr, groupsManager, authConfig), nil - } - - cliMgr, ok := workloadsManager.(workloads.Manager) - if !ok { - return nil, fmt.Errorf("expected workloads.Manager in CLI mode, got %T", workloadsManager) - } - return NewCLIBackendDiscoverer(cliMgr, groupsManager, authConfig), nil -} diff --git a/pkg/vmcp/aggregator/discoverer_test.go b/pkg/vmcp/aggregator/discoverer_test.go index 2ef43123e..5973fa640 100644 --- a/pkg/vmcp/aggregator/discoverer_test.go +++ b/pkg/vmcp/aggregator/discoverer_test.go @@ -45,8 +45,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) - require.NoError(t, err) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -80,8 +79,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) - discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) - require.NoError(t, err) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -110,8 +108,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) - discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) - require.NoError(t, err) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -136,8 +133,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) - require.NoError(t, err) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -154,8 +150,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) - discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) - require.NoError(t, err) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "nonexistent-group") require.Error(t, err) @@ -173,8 +168,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) - discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) - require.NoError(t, err) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.Error(t, err) @@ -193,8 +187,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) - discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) - require.NoError(t, err) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "empty-group") require.NoError(t, err) @@ -221,8 +214,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil) - discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) - require.NoError(t, err) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -248,8 +240,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). Return(core.Workload{}, errors.New("workload query failed")) - discoverer, err := NewBackendDiscoverer(mockWorkloads, mockGroups, nil) - require.NoError(t, err) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) diff --git a/pkg/vmcp/aggregator/k8s_discoverer.go b/pkg/vmcp/aggregator/k8s_discoverer.go index 7ad3f6422..84be798ea 100644 --- a/pkg/vmcp/aggregator/k8s_discoverer.go +++ b/pkg/vmcp/aggregator/k8s_discoverer.go @@ -23,7 +23,13 @@ type k8sBackendDiscoverer struct { authConfig *config.OutgoingAuthConfig } -// NewK8SBackendDiscoverer creates a new Kubernetes backend discoverer. +// NewK8SBackendDiscoverer creates a new Kubernetes-based backend discoverer. +// It discovers workloads from MCPServer CRDs managed by the ToolHive operator in Kubernetes. +// +// The authConfig parameter configures authentication for discovered backends. +// If nil, backends will have no authentication configured. +// +// This is the Kubernetes-specific constructor. For CLI workloads, use NewCLIBackendDiscoverer. func NewK8SBackendDiscoverer( workloadsManager workloads.K8SManager, groupsManager groups.Manager, diff --git a/pkg/workloads/k8s_manager_test.go b/pkg/workloads/k8s_manager_test.go index e77fbf192..2a76c63a1 100644 --- a/pkg/workloads/k8s_manager_test.go +++ b/pkg/workloads/k8s_manager_test.go @@ -144,7 +144,7 @@ func TestK8SManager_GetWorkload(t *testing.T) { } }, wantError: true, - errorMsg: "workload not found", + errorMsg: "MCPServer non-existent not found", }, { name: "get error", diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index 98fb744ae..b3c8c85cf 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -103,7 +103,7 @@ func NewManagerFromRuntimeWithProvider(rtRuntime rt.Runtime, configProvider conf // NewK8SManagerFromContext creates a Kubernetes-based workload manager from context. // It automatically sets up the Kubernetes client and detects the namespace. -func NewK8SManagerFromContext(ctx context.Context) (K8SManager, error) { +func NewK8SManagerFromContext(_ context.Context) (K8SManager, error) { // Create a scheme for controller-runtime client scheme := runtime.NewScheme() utilruntime.Must(clientgoscheme.AddToScheme(scheme)) From 31e40e72372a1703c94cf5f1e3fa0f535709b476 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 11 Nov 2025 14:59:19 +0000 Subject: [PATCH 05/16] adds more tests --- ...coverer_test.go => cli_discoverer_test.go} | 22 +- pkg/vmcp/aggregator/discoverer.go | 8 + pkg/vmcp/aggregator/k8s_discoverer_test.go | 324 ++++++++++++++++++ pkg/vmcp/aggregator/testhelpers_test.go | 58 ++++ pkg/workloads/mocks/mock_k8s_manager.go | 151 ++++++++ 5 files changed, 562 insertions(+), 1 deletion(-) rename pkg/vmcp/aggregator/{discoverer_test.go => cli_discoverer_test.go} (92%) create mode 100644 pkg/vmcp/aggregator/discoverer.go create mode 100644 pkg/vmcp/aggregator/k8s_discoverer_test.go create mode 100644 pkg/workloads/mocks/mock_k8s_manager.go diff --git a/pkg/vmcp/aggregator/discoverer_test.go b/pkg/vmcp/aggregator/cli_discoverer_test.go similarity index 92% rename from pkg/vmcp/aggregator/discoverer_test.go rename to pkg/vmcp/aggregator/cli_discoverer_test.go index 5973fa640..23d8155b7 100644 --- a/pkg/vmcp/aggregator/discoverer_test.go +++ b/pkg/vmcp/aggregator/cli_discoverer_test.go @@ -19,7 +19,7 @@ import ( const testGroupName = "test-group" -func TestBackendDiscoverer_Discover(t *testing.T) { +func TestCLIBackendDiscoverer_Discover(t *testing.T) { t.Parallel() t.Run("successful discovery with multiple backends", func(t *testing.T) { @@ -247,4 +247,24 @@ func TestBackendDiscoverer_Discover(t *testing.T) { require.Len(t, backends, 1) assert.Equal(t, "good-workload", backends[0].ID) }) + + t.Run("returns error when list workloads fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return(nil, errors.New("failed to list workloads")) + + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "failed to list workloads in group") + }) } diff --git a/pkg/vmcp/aggregator/discoverer.go b/pkg/vmcp/aggregator/discoverer.go new file mode 100644 index 000000000..b4cff44d4 --- /dev/null +++ b/pkg/vmcp/aggregator/discoverer.go @@ -0,0 +1,8 @@ +// Package aggregator provides platform-specific backend discovery implementations. +// +// This file serves as a navigation reference for backend discovery implementations: +// - CLI (Docker/Podman): see cli_discoverer.go +// - Kubernetes: see k8s_discoverer.go +// +// The BackendDiscoverer interface is defined in aggregator.go. +package aggregator diff --git a/pkg/vmcp/aggregator/k8s_discoverer_test.go b/pkg/vmcp/aggregator/k8s_discoverer_test.go new file mode 100644 index 000000000..2dccb05d5 --- /dev/null +++ b/pkg/vmcp/aggregator/k8s_discoverer_test.go @@ -0,0 +1,324 @@ +package aggregator + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/groups/mocks" + "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/workloads/k8s" + workloadmocks "github.com/stacklok/toolhive/pkg/workloads/mocks" +) + +func TestK8SBackendDiscoverer_Discover(t *testing.T) { + t.Parallel() + + t.Run("successful discovery with multiple backends", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + workload1 := newTestK8SWorkload("workload1", + withK8SToolType("github"), + withK8SLabels(map[string]string{"env": "prod"}), + withK8SNamespace("toolhive-system")) + + workload2 := newTestK8SWorkload("workload2", + withK8SURL("http://localhost:8081/mcp"), + withK8STransport(types.TransportTypeSSE), + withK8SToolType("jira"), + withK8SNamespace("toolhive-system")) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, "workload1", backends[0].ID) + assert.Equal(t, "http://localhost:8080/mcp", backends[0].BaseURL) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "github", backends[0].Metadata["tool_type"]) + assert.Equal(t, "prod", backends[0].Metadata["env"]) + assert.Equal(t, "toolhive-system", backends[0].Metadata["namespace"]) + assert.Equal(t, "workload2", backends[1].ID) + assert.Equal(t, "sse", backends[1].TransportType) + }) + + t.Run("discovers workloads with different phases", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + runningWorkload := newTestK8SWorkload("running-workload", + withK8SPhase(mcpv1alpha1.MCPServerPhaseRunning)) + failedWorkload := newTestK8SWorkload("failed-workload", + withK8SPhase(mcpv1alpha1.MCPServerPhaseFailed), + withK8SURL("http://localhost:8081/mcp"), + withK8STransport(types.TransportTypeSSE)) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"running-workload", "failed-workload"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failed-workload").Return(failedWorkload, nil) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, "running-workload", backends[0].ID) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "failed-workload", backends[1].ID) + assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) + assert.Equal(t, string(mcpv1alpha1.MCPServerPhaseFailed), backends[1].Metadata["workload_phase"]) + }) + + t.Run("filters out workloads without URL", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + workloadWithURL := newTestK8SWorkload("workload1") + workloadWithoutURL := newTestK8SWorkload("workload2", withK8SURL("")) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "workload1", backends[0].ID) + }) + + t.Run("returns empty list when all workloads lack URLs", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + workload1 := newTestK8SWorkload("workload1", withK8SURL("")) + workload2 := newTestK8SWorkload("workload2", + withK8SPhase(mcpv1alpha1.MCPServerPhaseTerminating), + withK8SURL("")) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + assert.Empty(t, backends) + }) + + t.Run("returns error when group does not exist", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), "nonexistent-group") + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("returns error when group check fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "failed to check if group exists") + }) + + t.Run("returns empty list when group is empty", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), "empty-group") + + require.NoError(t, err) + assert.Empty(t, backends) + }) + + t.Run("discovers all workloads regardless of phase", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + terminatingWorkload := newTestK8SWorkload("terminating1", + withK8SPhase(mcpv1alpha1.MCPServerPhaseTerminating)) + failedWorkload := newTestK8SWorkload("failed1", + withK8SPhase(mcpv1alpha1.MCPServerPhaseFailed), + withK8SURL("http://localhost:8081/mcp"), + withK8STransport(types.TransportTypeSSE)) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"terminating1", "failed1"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "terminating1").Return(terminatingWorkload, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failed1").Return(failedWorkload, nil) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, vmcp.BackendUnhealthy, backends[0].HealthStatus) + assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) + }) + + t.Run("gracefully handles workload get failures", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + goodWorkload := newTestK8SWorkload("good-workload") + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"good-workload", "failing-workload"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "good-workload").Return(goodWorkload, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). + Return(k8s.Workload{}, errors.New("MCPServer query failed")) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "good-workload", backends[0].ID) + }) + + t.Run("returns error when list workloads fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return(nil, errors.New("failed to list workloads")) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "failed to list workloads in group") + }) + + t.Run("handles pending phase correctly", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + pendingWorkload := newTestK8SWorkload("pending-workload", + withK8SPhase(mcpv1alpha1.MCPServerPhasePending)) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"pending-workload"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "pending-workload").Return(pendingWorkload, nil) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, vmcp.BackendUnknown, backends[0].HealthStatus) + assert.Equal(t, string(mcpv1alpha1.MCPServerPhasePending), backends[0].Metadata["workload_phase"]) + }) + + t.Run("includes namespace in metadata", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + workload := newTestK8SWorkload("workload1", + withK8SNamespace("custom-namespace")) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1"}, nil) + mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload, nil) + + discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "custom-namespace", backends[0].Metadata["namespace"]) + }) +} diff --git a/pkg/vmcp/aggregator/testhelpers_test.go b/pkg/vmcp/aggregator/testhelpers_test.go index 0b766c508..a25d5b767 100644 --- a/pkg/vmcp/aggregator/testhelpers_test.go +++ b/pkg/vmcp/aggregator/testhelpers_test.go @@ -1,10 +1,12 @@ package aggregator import ( + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/workloads/k8s" ) // Test fixture builders to reduce verbosity in tests @@ -53,6 +55,62 @@ func withLabels(labels map[string]string) func(*core.Workload) { } } +// K8s workload test helpers + +func newTestK8SWorkload(name string, opts ...func(*k8s.Workload)) k8s.Workload { + w := k8s.Workload{ + Name: name, + Namespace: "default", + Phase: mcpv1alpha1.MCPServerPhaseRunning, + URL: "http://localhost:8080/mcp", + TransportType: types.TransportTypeStreamableHTTP, + ToolType: "mcp", + Group: testGroupName, + GroupRef: testGroupName, + Labels: make(map[string]string), + } + for _, opt := range opts { + opt(&w) + } + return w +} + +func withK8SPhase(phase mcpv1alpha1.MCPServerPhase) func(*k8s.Workload) { + return func(w *k8s.Workload) { + w.Phase = phase + } +} + +func withK8SURL(url string) func(*k8s.Workload) { + return func(w *k8s.Workload) { + w.URL = url + } +} + +func withK8STransport(transport types.TransportType) func(*k8s.Workload) { + return func(w *k8s.Workload) { + w.TransportType = transport + } +} + +func withK8SToolType(toolType string) func(*k8s.Workload) { + return func(w *k8s.Workload) { + w.ToolType = toolType + } +} + +func withK8SLabels(labels map[string]string) func(*k8s.Workload) { + return func(w *k8s.Workload) { + w.Labels = labels + } +} + +func withK8SNamespace(namespace string) func(*k8s.Workload) { + return func(w *k8s.Workload) { + w.Namespace = namespace + } +} + func newTestBackend(id string, opts ...func(*vmcp.Backend)) vmcp.Backend { b := vmcp.Backend{ ID: id, diff --git a/pkg/workloads/mocks/mock_k8s_manager.go b/pkg/workloads/mocks/mock_k8s_manager.go new file mode 100644 index 000000000..dbdec9361 --- /dev/null +++ b/pkg/workloads/mocks/mock_k8s_manager.go @@ -0,0 +1,151 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s_manager_interface.go +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_k8s_manager.go -package=mocks -source=k8s_manager_interface.go K8SManager +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + k8s "github.com/stacklok/toolhive/pkg/workloads/k8s" + gomock "go.uber.org/mock/gomock" +) + +// MockK8SManager is a mock of K8SManager interface. +type MockK8SManager struct { + ctrl *gomock.Controller + recorder *MockK8SManagerMockRecorder + isgomock struct{} +} + +// MockK8SManagerMockRecorder is the mock recorder for MockK8SManager. +type MockK8SManagerMockRecorder struct { + mock *MockK8SManager +} + +// NewMockK8SManager creates a new mock instance. +func NewMockK8SManager(ctrl *gomock.Controller) *MockK8SManager { + mock := &MockK8SManager{ctrl: ctrl} + mock.recorder = &MockK8SManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockK8SManager) EXPECT() *MockK8SManagerMockRecorder { + return m.recorder +} + +// DoesWorkloadExist mocks base method. +func (m *MockK8SManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoesWorkloadExist", ctx, workloadName) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DoesWorkloadExist indicates an expected call of DoesWorkloadExist. +func (mr *MockK8SManagerMockRecorder) DoesWorkloadExist(ctx, workloadName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesWorkloadExist", reflect.TypeOf((*MockK8SManager)(nil).DoesWorkloadExist), ctx, workloadName) +} + +// GetLogs mocks base method. +func (m *MockK8SManager) GetLogs(ctx context.Context, containerName string, follow bool) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLogs", ctx, containerName, follow) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLogs indicates an expected call of GetLogs. +func (mr *MockK8SManagerMockRecorder) GetLogs(ctx, containerName, follow any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogs", reflect.TypeOf((*MockK8SManager)(nil).GetLogs), ctx, containerName, follow) +} + +// GetProxyLogs mocks base method. +func (m *MockK8SManager) GetProxyLogs(ctx context.Context, workloadName string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxyLogs", ctx, workloadName) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProxyLogs indicates an expected call of GetProxyLogs. +func (mr *MockK8SManagerMockRecorder) GetProxyLogs(ctx, workloadName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyLogs", reflect.TypeOf((*MockK8SManager)(nil).GetProxyLogs), ctx, workloadName) +} + +// GetWorkload mocks base method. +func (m *MockK8SManager) GetWorkload(ctx context.Context, workloadName string) (k8s.Workload, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWorkload", ctx, workloadName) + ret0, _ := ret[0].(k8s.Workload) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetWorkload indicates an expected call of GetWorkload. +func (mr *MockK8SManagerMockRecorder) GetWorkload(ctx, workloadName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkload", reflect.TypeOf((*MockK8SManager)(nil).GetWorkload), ctx, workloadName) +} + +// ListWorkloads mocks base method. +func (m *MockK8SManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]k8s.Workload, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, listAll} + for _, a := range labelFilters { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ListWorkloads", varargs...) + ret0, _ := ret[0].([]k8s.Workload) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListWorkloads indicates an expected call of ListWorkloads. +func (mr *MockK8SManagerMockRecorder) ListWorkloads(ctx, listAll any, labelFilters ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, listAll}, labelFilters...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloads", reflect.TypeOf((*MockK8SManager)(nil).ListWorkloads), varargs...) +} + +// ListWorkloadsInGroup mocks base method. +func (m *MockK8SManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListWorkloadsInGroup", ctx, groupName) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListWorkloadsInGroup indicates an expected call of ListWorkloadsInGroup. +func (mr *MockK8SManagerMockRecorder) ListWorkloadsInGroup(ctx, groupName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloadsInGroup", reflect.TypeOf((*MockK8SManager)(nil).ListWorkloadsInGroup), ctx, groupName) +} + +// MoveToGroup mocks base method. +func (m *MockK8SManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom, groupTo string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MoveToGroup", ctx, workloadNames, groupFrom, groupTo) + ret0, _ := ret[0].(error) + return ret0 +} + +// MoveToGroup indicates an expected call of MoveToGroup. +func (mr *MockK8SManagerMockRecorder) MoveToGroup(ctx, workloadNames, groupFrom, groupTo any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MoveToGroup", reflect.TypeOf((*MockK8SManager)(nil).MoveToGroup), ctx, workloadNames, groupFrom, groupTo) +} From 2c921c7653aebd8940981ca5c7484353bc382287 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 11 Nov 2025 17:17:08 +0000 Subject: [PATCH 06/16] fixed thv listing --- pkg/workloads/cli_manager.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/workloads/cli_manager.go b/pkg/workloads/cli_manager.go index 1e2775e46..20ced9ded 100644 --- a/pkg/workloads/cli_manager.go +++ b/pkg/workloads/cli_manager.go @@ -1180,9 +1180,9 @@ func (d *cliManager) getRemoteWorkloadsFromState( // Map to core.Workload workload := core.Workload{ Name: name, - Package: runConfig.RemoteURL, + Package: "remote", URL: runConfig.RemoteURL, - ToolType: "mcp", + ToolType: "remote", TransportType: runConfig.Transport, ProxyMode: runConfig.ProxyMode.String(), Status: workloadStatus.Status, From 1d6d8077d14c8e23a8f72c79e6f9bf5948e0abbb Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 12 Nov 2025 01:09:38 +0000 Subject: [PATCH 07/16] checks the kubernetes client runtime instead of the environment --- pkg/workloads/manager.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index b3c8c85cf..86eaf329f 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -85,18 +85,29 @@ func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) // NewManagerFromRuntime creates a new CLI workload manager from an existing runtime. // IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. +// This function checks the runtime type directly, not the environment, to support +// cases like proxyrunner which runs in Kubernetes pods but uses Docker runtime. func NewManagerFromRuntime(rtRuntime rt.Runtime) (Manager, error) { - if rt.IsKubernetesRuntime() { - return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes environments") + // Check if the runtime is actually a Kubernetes runtime by type assertion + // The proxyrunner runs in pods but uses Docker runtime, so we check the runtime type, + // not the environment (which would always be Kubernetes in pods) + if _, ok := rtRuntime.(*kubernetes.Client); ok { + return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes runtime") } + return NewCLIManagerFromRuntime(rtRuntime) } // NewManagerFromRuntimeWithProvider creates a new CLI workload manager from an existing runtime with a custom config provider. // IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. +// This function checks the runtime type directly, not the environment, to support +// cases like proxyrunner which runs in Kubernetes pods but uses Docker runtime. func NewManagerFromRuntimeWithProvider(rtRuntime rt.Runtime, configProvider config.Provider) (Manager, error) { - if rt.IsKubernetesRuntime() { - return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes environments") + // Check if the runtime is actually a Kubernetes runtime by type assertion + // The proxyrunner runs in pods but uses Docker runtime, so we check the runtime type, + // not the environment (which would always be Kubernetes in pods) + if _, ok := rtRuntime.(*kubernetes.Client); ok { + return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes runtime") } return NewCLIManagerFromRuntimeWithProvider(rtRuntime, configProvider) } From 26a37ad251698649928c46880fc30af4e0526404 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 12 Nov 2025 11:39:47 +0000 Subject: [PATCH 08/16] fix e2e tests --- .../api/v1alpha1/mcpserver_types.go | 8 +++++- pkg/workloads/manager.go | 25 +++++-------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/cmd/thv-operator/api/v1alpha1/mcpserver_types.go b/cmd/thv-operator/api/v1alpha1/mcpserver_types.go index aa3b9d9ee..021a0a660 100644 --- a/cmd/thv-operator/api/v1alpha1/mcpserver_types.go +++ b/cmd/thv-operator/api/v1alpha1/mcpserver_types.go @@ -763,7 +763,13 @@ func (m *MCPServer) GetMcpPort() int32 { // the below is deprecated and will be removed in a future version // we need to keep it here to avoid breaking changes - return m.Spec.TargetPort + if m.Spec.TargetPort > 0 { + return m.Spec.TargetPort + } + + // Default to 8080 if no port is specified (matches GetProxyPort behavior) + // This is needed for HTTP-based transports (SSE, streamable-http) which require a target port + return 8080 } func init() { diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index 86eaf329f..a1d9c6f63 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -84,31 +84,18 @@ func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) } // NewManagerFromRuntime creates a new CLI workload manager from an existing runtime. -// IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. -// This function checks the runtime type directly, not the environment, to support -// cases like proxyrunner which runs in Kubernetes pods but uses Docker runtime. +// This function works with any runtime type. The status manager will automatically +// detect the environment and use the appropriate implementation. +// Proxyrunner uses this with Kubernetes runtime to create StatefulSets. func NewManagerFromRuntime(rtRuntime rt.Runtime) (Manager, error) { - // Check if the runtime is actually a Kubernetes runtime by type assertion - // The proxyrunner runs in pods but uses Docker runtime, so we check the runtime type, - // not the environment (which would always be Kubernetes in pods) - if _, ok := rtRuntime.(*kubernetes.Client); ok { - return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes runtime") - } - return NewCLIManagerFromRuntime(rtRuntime) } // NewManagerFromRuntimeWithProvider creates a new CLI workload manager from an existing runtime with a custom config provider. -// IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. -// This function checks the runtime type directly, not the environment, to support -// cases like proxyrunner which runs in Kubernetes pods but uses Docker runtime. +// This function works with any runtime type. The status manager will automatically +// detect the environment and use the appropriate implementation. +// Proxyrunner uses this with Kubernetes runtime to create StatefulSets. func NewManagerFromRuntimeWithProvider(rtRuntime rt.Runtime, configProvider config.Provider) (Manager, error) { - // Check if the runtime is actually a Kubernetes runtime by type assertion - // The proxyrunner runs in pods but uses Docker runtime, so we check the runtime type, - // not the environment (which would always be Kubernetes in pods) - if _, ok := rtRuntime.(*kubernetes.Client); ok { - return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes runtime") - } return NewCLIManagerFromRuntimeWithProvider(rtRuntime, configProvider) } From f3fd09fd724f9bed07bf01080e6476a8b36cf895 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 12 Nov 2025 14:14:12 +0000 Subject: [PATCH 09/16] improves test coverage --- pkg/workloads/cli_manager_test.go | 599 +++++++++++++++++++++++++++++- pkg/workloads/k8s_manager.go | 4 +- pkg/workloads/k8s_manager_test.go | 71 ++++ 3 files changed, 671 insertions(+), 3 deletions(-) diff --git a/pkg/workloads/cli_manager_test.go b/pkg/workloads/cli_manager_test.go index e12fd2a87..b6c5ba55c 100644 --- a/pkg/workloads/cli_manager_test.go +++ b/pkg/workloads/cli_manager_test.go @@ -4,14 +4,18 @@ import ( "context" "errors" "fmt" + "os" + "path/filepath" "testing" "time" + "github.com/adrg/xdg" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "golang.org/x/sync/errgroup" + "github.com/stacklok/toolhive/pkg/auth/remote" "github.com/stacklok/toolhive/pkg/config" configMocks "github.com/stacklok/toolhive/pkg/config/mocks" "github.com/stacklok/toolhive/pkg/container/runtime" @@ -961,6 +965,26 @@ func TestCLIManager_validateSecretParameters(t *testing.T) { setupMocks: func(*configMocks.MockProvider) {}, // No expectations expectError: false, }, + { + name: "no secrets and no remote auth - should pass", + runConfig: &runner.RunConfig{ + Secrets: []string{}, + RemoteAuthConfig: nil, + }, + setupMocks: func(*configMocks.MockProvider) {}, // No expectations + expectError: false, + }, + { + name: "remote auth without client secret - should pass", + runConfig: &runner.RunConfig{ + Secrets: []string{}, + RemoteAuthConfig: &remote.Config{ + ClientSecret: "", // Empty client secret + }, + }, + setupMocks: func(*configMocks.MockProvider) {}, // No expectations + expectError: false, + }, { name: "config error", runConfig: &runner.RunConfig{ @@ -973,6 +997,36 @@ func TestCLIManager_validateSecretParameters(t *testing.T) { expectError: true, errorMsg: "error determining secrets provider type", }, + { + name: "remote auth with client secret", + runConfig: &runner.RunConfig{ + Secrets: []string{}, + RemoteAuthConfig: &remote.Config{ + ClientSecret: "secret-value", + }, + }, + setupMocks: func(cp *configMocks.MockProvider) { + mockConfig := &config.Config{} + cp.EXPECT().GetConfig().Return(mockConfig) + }, + expectError: true, + errorMsg: "error determining secrets provider type", + }, + { + name: "both regular secrets and remote auth", + runConfig: &runner.RunConfig{ + Secrets: []string{"secret1"}, + RemoteAuthConfig: &remote.Config{ + ClientSecret: "secret-value", + }, + }, + setupMocks: func(cp *configMocks.MockProvider) { + mockConfig := &config.Config{} + cp.EXPECT().GetConfig().Return(mockConfig) + }, + expectError: true, + errorMsg: "error determining secrets provider type", + }, } for _, tt := range tests { @@ -994,7 +1048,9 @@ func TestCLIManager_validateSecretParameters(t *testing.T) { if tt.expectError { require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } } else { require.NoError(t, err) } @@ -1614,3 +1670,544 @@ func TestCLIManager_updateSingleWorkload(t *testing.T) { }) } } + +func TestNewCLIManager(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + wantError bool + }{ + { + name: "successful creation", + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + manager, err := NewCLIManager(ctx) + + // Note: This test may fail if Docker/Podman is not available + // That's acceptable - the function will return an error + if tt.wantError { + require.Error(t, err) + assert.Nil(t, manager) + } else { + // If runtime is available, manager should be created + if err == nil { + require.NotNil(t, manager) + cliMgr, ok := manager.(*cliManager) + require.True(t, ok) + assert.NotNil(t, cliMgr.runtime) + assert.NotNil(t, cliMgr.statuses) + assert.NotNil(t, cliMgr.configProvider) + } + // If runtime is not available, error is acceptable + } + }) + } +} + +func TestNewCLIManagerWithProvider(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConfigProvider := configMocks.NewMockProvider(ctrl) + + tests := []struct { + name string + configProvider config.Provider + wantError bool + }{ + { + name: "successful creation with provider", + configProvider: mockConfigProvider, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + manager, err := NewCLIManagerWithProvider(ctx, tt.configProvider) + + // Note: This test may fail if Docker/Podman is not available + if tt.wantError { + require.Error(t, err) + assert.Nil(t, manager) + } else { + // If runtime is available, manager should be created + if err == nil { + require.NotNil(t, manager) + cliMgr, ok := manager.(*cliManager) + require.True(t, ok) + assert.NotNil(t, cliMgr.runtime) + assert.NotNil(t, cliMgr.statuses) + assert.Equal(t, tt.configProvider, cliMgr.configProvider) + } + } + }) + } +} + +func TestCLIManager_stopRemoteWorkload(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tests := []struct { + name string + workloadName string + runConfig *runner.RunConfig + setupMocks func(*statusMocks.MockStatusManager) + wantError bool + errorMsg string + }{ + { + name: "successful stop", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-workload", + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(nil) + // stopProxyIfNeeded calls stopProcess which calls GetWorkloadPID + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-workload").Return(0, errors.New("no PID found")).AnyTimes() + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopped, "").Return(nil) + }, + wantError: false, + }, + { + name: "workload not found", + workloadName: "non-existent", + runConfig: &runner.RunConfig{ + BaseName: "non-existent", + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "non-existent").Return(core.Workload{}, runtime.ErrWorkloadNotFound) + }, + wantError: false, // Returns nil when workload not found + }, + { + name: "workload not running", + workloadName: "stopped-workload", + runConfig: &runner.RunConfig{ + BaseName: "stopped-workload", + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(core.Workload{ + Name: "stopped-workload", + Status: runtime.WorkloadStatusStopped, + }, nil) + }, + wantError: false, // Returns nil when workload not running + }, + { + name: "error getting workload", + workloadName: "error-workload", + runConfig: &runner.RunConfig{ + BaseName: "error-workload", + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "error-workload").Return(core.Workload{}, errors.New("database error")) + }, + wantError: true, + errorMsg: "failed to find workload", + }, + { + name: "error setting stopping status", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-workload", + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(errors.New("status error")) + // stopProxyIfNeeded calls stopProcess which calls GetWorkloadPID + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-workload").Return(0, errors.New("no PID found")).AnyTimes() + // Still sets stopped status even if stopping status fails + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopped, "").Return(nil) + }, + wantError: false, // Continues even if stopping status fails + }, + { + name: "empty base name", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "", // Empty base name + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(nil) + // stopProxyIfNeeded is called but does nothing if BaseName is empty + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopped, "").Return(nil) + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockStatusManager := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockStatusManager) + + manager := &cliManager{ + statuses: mockStatusManager, + } + + ctx := context.Background() + err := manager.stopRemoteWorkload(ctx, tt.workloadName, tt.runConfig) + + if tt.wantError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestCLIManager_GetProxyLogs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + setup func() (string, func()) // Returns log file path and cleanup function + wantError bool + errorMsg string + checkOutput bool + }{ + { + name: "successful read", + workloadName: "test-workload", + setup: func() (string, func()) { + // Create a temporary directory for XDG_DATA_HOME + tmpDir := t.TempDir() + + // Set XDG_DATA_HOME BEFORE calling xdg.DataFile + // xdg package reads this at initialization, so we need to set it early + oldXDG := os.Getenv("XDG_DATA_HOME") + os.Setenv("XDG_DATA_HOME", tmpDir) + + // Now get the actual path that xdg.DataFile will use + // This ensures we create the file at the exact location xdg will look + expectedLogPath, err := xdg.DataFile("toolhive/logs/test-workload.log") + require.NoError(t, err, "xdg.DataFile should succeed with XDG_DATA_HOME set") + + // Create the directory structure and file at the exact path xdg will use + require.NoError(t, os.MkdirAll(filepath.Dir(expectedLogPath), 0755)) + require.NoError(t, os.WriteFile(expectedLogPath, []byte("test log content"), 0600)) + + return expectedLogPath, func() { + if oldXDG == "" { + os.Unsetenv("XDG_DATA_HOME") + } else { + os.Setenv("XDG_DATA_HOME", oldXDG) + } + } + }, + wantError: false, + checkOutput: true, + }, + { + name: "log file not found", + workloadName: "non-existent", + setup: func() (string, func()) { + tmpDir := t.TempDir() + logDir := filepath.Join(tmpDir, "toolhive", "logs") + require.NoError(t, os.MkdirAll(logDir, 0755)) + // Don't create the log file + + oldXDG := os.Getenv("XDG_DATA_HOME") + os.Setenv("XDG_DATA_HOME", tmpDir) + + return "", func() { + if oldXDG == "" { + os.Unsetenv("XDG_DATA_HOME") + } else { + os.Setenv("XDG_DATA_HOME", oldXDG) + } + } + }, + wantError: true, + errorMsg: "proxy logs not found", + }, + { + name: "invalid workload name with path traversal", + workloadName: "../../etc/passwd", + setup: func() (string, func()) { + return "", func() {} + }, + wantError: true, + // xdg.DataFile may succeed even with path traversal, but file won't exist + // So we get "proxy logs not found" instead of "failed to get proxy log file path" + errorMsg: "proxy logs not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, cleanup := tt.setup() + defer cleanup() + + manager := &cliManager{} + + ctx := context.Background() + logs, err := manager.GetProxyLogs(ctx, tt.workloadName) + + if tt.wantError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + assert.Empty(t, logs) + } else { + require.NoError(t, err) + if tt.checkOutput { + assert.Contains(t, logs, "test log content") + } + } + }) + } +} + +func TestCLIManager_deleteRemoteWorkload(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tests := []struct { + name string + workloadName string + runConfig *runner.RunConfig + setupMocks func(*statusMocks.MockStatusManager) + wantError bool + errorMsg string + }{ + { + name: "successful delete", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-workload", + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusRemoving, "").Return(nil) + // stopProxyIfNeeded calls stopProcess which calls GetWorkloadPID + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-workload").Return(0, errors.New("no PID found")).AnyTimes() + sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "remote-workload").Return(nil) + }, + wantError: false, + }, + { + name: "error setting removing status", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-workload", + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusRemoving, "").Return(errors.New("status error")) + }, + wantError: true, + errorMsg: "status error", // The function returns the error directly + }, + { + name: "error deleting status", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-workload", + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusRemoving, "").Return(nil) + // stopProxyIfNeeded calls stopProcess which calls GetWorkloadPID + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-workload").Return(0, errors.New("no PID found")).AnyTimes() + sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "remote-workload").Return(errors.New("delete error")) + // Error is logged but not returned + }, + wantError: false, // Error is logged but function continues + }, + { + name: "empty base name", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "", // Empty base name + RemoteURL: "https://example.com/mcp", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusRemoving, "").Return(nil) + // stopProxyIfNeeded does nothing if BaseName is empty + sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "remote-workload").Return(nil) + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + mockStatusManager := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockStatusManager) + + manager := &cliManager{ + statuses: mockStatusManager, + } + + ctx := context.Background() + err := manager.deleteRemoteWorkload(ctx, tt.workloadName, tt.runConfig) + + if tt.wantError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestCLIManager_cleanupTempPermissionProfile(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + baseName string + setupState func() (func(), error) // Returns cleanup function and error + wantError bool + errorMsg string + }{ + { + name: "no state file", + baseName: "non-existent", + setupState: func() (func(), error) { + // No state file exists + return func() {}, nil + }, + wantError: false, // Returns nil when state doesn't exist + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cleanup, err := tt.setupState() + defer cleanup() + require.NoError(t, err) + + manager := &cliManager{} + + ctx := context.Background() + err = manager.cleanupTempPermissionProfile(ctx, tt.baseName) + + if tt.wantError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + // Function returns nil when state doesn't exist or no profile to clean + assert.NoError(t, err) + } + }) + } +} + +func TestCLIManager_MoveToGroup(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadNames []string + groupFrom string + groupTo string + setupState func() (func(), error) // Returns cleanup function and error + wantError bool + errorMsg string + }{ + { + name: "invalid workload name", + workloadNames: []string{"../invalid"}, + groupFrom: "group1", + groupTo: "group2", + setupState: func() (func(), error) { + return func() {}, nil + }, + wantError: true, + errorMsg: "invalid workload name", + }, + { + name: "state file not found", + workloadNames: []string{"non-existent"}, + groupFrom: "group1", + groupTo: "group2", + setupState: func() (func(), error) { + return func() {}, nil + }, + wantError: true, + errorMsg: "failed to load runner state", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cleanup, err := tt.setupState() + defer cleanup() + require.NoError(t, err) + + manager := &cliManager{} + + ctx := context.Background() + err = manager.MoveToGroup(ctx, tt.workloadNames, tt.groupFrom, tt.groupTo) + + if tt.wantError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/workloads/k8s_manager.go b/pkg/workloads/k8s_manager.go index 5b813d4c0..f8943fb6d 100644 --- a/pkg/workloads/k8s_manager.go +++ b/pkg/workloads/k8s_manager.go @@ -5,6 +5,7 @@ package workloads import ( "context" "fmt" + "strings" "time" "k8s.io/apimachinery/pkg/api/errors" @@ -281,10 +282,9 @@ func (*k8sManager) isStandardK8sAnnotation(key string) bool { } for _, prefix := range standardPrefixes { - if len(key) >= len(prefix) && key[:len(prefix)] == prefix { + if strings.HasPrefix(key, prefix) { return true } } - return false } diff --git a/pkg/workloads/k8s_manager_test.go b/pkg/workloads/k8s_manager_test.go index 2a76c63a1..903d9fadd 100644 --- a/pkg/workloads/k8s_manager_test.go +++ b/pkg/workloads/k8s_manager_test.go @@ -713,3 +713,74 @@ func TestK8SManager_mcpServerToK8SWorkload(t *testing.T) { }) } } + +func TestK8SManager_isStandardK8sAnnotation(t *testing.T) { + t.Parallel() + + manager := &k8sManager{} + + tests := []struct { + name string + key string + expected bool + }{ + { + name: "kubectl annotation", + key: "kubectl.kubernetes.io/last-applied-configuration", + expected: true, + }, + { + name: "kubernetes.io annotation", + key: "kubernetes.io/created-by", + expected: true, + }, + { + name: "deployment.kubernetes.io annotation", + key: "deployment.kubernetes.io/revision", + expected: true, + }, + { + name: "k8s.io annotation", + key: "k8s.io/annotation", + expected: true, + }, + { + name: "user-defined annotation", + key: "custom/annotation", + expected: false, + }, + { + name: "empty string", + key: "", + expected: false, + }, + { + name: "short key", + key: "k", + expected: false, + }, + { + name: "partial match - not a prefix", + key: "my-kubectl.kubernetes.io/annotation", + expected: false, + }, + { + name: "exact prefix match", + key: "kubectl.kubernetes.io/", + expected: true, + }, + { + name: "case sensitive - uppercase", + key: "KUBECTL.kubernetes.io/annotation", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := manager.isStandardK8sAnnotation(tt.key) + assert.Equal(t, tt.expected, result, "isStandardK8sAnnotation(%q) = %v, want %v", tt.key, result, tt.expected) + }) + } +} From b70ac65ece1183b1a5fcb7bd3943ab1a1a29e67d Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 12 Nov 2025 14:58:11 +0000 Subject: [PATCH 10/16] refactor the k8s manager into separate package --- cmd/vmcp/app/commands.go | 3 +- pkg/vmcp/aggregator/k8s_discoverer.go | 7 +- pkg/vmcp/aggregator/k8s_discoverer_test.go | 26 +++--- pkg/workloads/cli_manager_test.go | 6 +- pkg/workloads/{k8s_manager.go => k8s/k8s.go} | 80 ++++++++++++------ .../{k8s_manager_test.go => k8s/k8s_test.go} | 83 +++++++++---------- .../manager.go} | 19 ++--- .../mocks/mock_manager.go} | 68 +++++++-------- pkg/workloads/manager.go | 41 +-------- 9 files changed, 163 insertions(+), 170 deletions(-) rename pkg/workloads/{k8s_manager.go => k8s/k8s.go} (75%) rename pkg/workloads/{k8s_manager_test.go => k8s/k8s_test.go} (90%) rename pkg/workloads/{k8s_manager_interface.go => k8s/manager.go} (76%) rename pkg/workloads/{mocks/mock_k8s_manager.go => k8s/mocks/mock_manager.go} (53%) diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 5490f6209..563a8dff3 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -20,6 +20,7 @@ import ( vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server" "github.com/stacklok/toolhive/pkg/workloads" + "github.com/stacklok/toolhive/pkg/workloads/k8s" ) var rootCmd = &cobra.Command{ @@ -229,7 +230,7 @@ func discoverBackends(ctx context.Context, cfg *config.Config) ([]vmcp.Backend, // Create backend discoverer based on runtime environment var discoverer aggregator.BackendDiscoverer if rt.IsKubernetesRuntime() { - k8sWorkloadsManager, err := workloads.NewK8SManagerFromContext(ctx) + k8sWorkloadsManager, err := k8s.NewManagerFromContext(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to create Kubernetes workloads manager: %w", err) } diff --git a/pkg/vmcp/aggregator/k8s_discoverer.go b/pkg/vmcp/aggregator/k8s_discoverer.go index 84be798ea..10cdd114c 100644 --- a/pkg/vmcp/aggregator/k8s_discoverer.go +++ b/pkg/vmcp/aggregator/k8s_discoverer.go @@ -11,14 +11,13 @@ import ( "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/pkg/workloads" "github.com/stacklok/toolhive/pkg/workloads/k8s" ) // k8sBackendDiscoverer discovers backend MCP servers from Kubernetes workloads (MCPServer CRDs). -// It works with workloads.K8SManager and k8s.Workload. +// It works with k8s.Manager and k8s.Workload. type k8sBackendDiscoverer struct { - workloadsManager workloads.K8SManager + workloadsManager k8s.Manager groupsManager groups.Manager authConfig *config.OutgoingAuthConfig } @@ -31,7 +30,7 @@ type k8sBackendDiscoverer struct { // // This is the Kubernetes-specific constructor. For CLI workloads, use NewCLIBackendDiscoverer. func NewK8SBackendDiscoverer( - workloadsManager workloads.K8SManager, + workloadsManager k8s.Manager, groupsManager groups.Manager, authConfig *config.OutgoingAuthConfig, ) BackendDiscoverer { diff --git a/pkg/vmcp/aggregator/k8s_discoverer_test.go b/pkg/vmcp/aggregator/k8s_discoverer_test.go index 2dccb05d5..cc1a3eebe 100644 --- a/pkg/vmcp/aggregator/k8s_discoverer_test.go +++ b/pkg/vmcp/aggregator/k8s_discoverer_test.go @@ -14,7 +14,7 @@ import ( "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/workloads/k8s" - workloadmocks "github.com/stacklok/toolhive/pkg/workloads/mocks" + k8smocks "github.com/stacklok/toolhive/pkg/workloads/k8s/mocks" ) func TestK8SBackendDiscoverer_Discover(t *testing.T) { @@ -25,7 +25,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) workload1 := newTestK8SWorkload("workload1", @@ -65,7 +65,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) runningWorkload := newTestK8SWorkload("running-workload", @@ -98,7 +98,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) workloadWithURL := newTestK8SWorkload("workload1") @@ -123,7 +123,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) workload1 := newTestK8SWorkload("workload1", withK8SURL("")) @@ -149,7 +149,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) @@ -167,7 +167,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) @@ -185,7 +185,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) @@ -203,7 +203,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) terminatingWorkload := newTestK8SWorkload("terminating1", @@ -233,7 +233,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) goodWorkload := newTestK8SWorkload("good-workload") @@ -258,7 +258,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) @@ -278,7 +278,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) pendingWorkload := newTestK8SWorkload("pending-workload", @@ -303,7 +303,7 @@ func TestK8SBackendDiscoverer_Discover(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockWorkloads := workloadmocks.NewMockK8SManager(ctrl) + mockWorkloads := k8smocks.NewMockManager(ctrl) mockGroups := mocks.NewMockManager(ctrl) workload := newTestK8SWorkload("workload1", diff --git a/pkg/workloads/cli_manager_test.go b/pkg/workloads/cli_manager_test.go index b6c5ba55c..e00029d0e 100644 --- a/pkg/workloads/cli_manager_test.go +++ b/pkg/workloads/cli_manager_test.go @@ -1716,7 +1716,7 @@ func TestNewCLIManagerWithProvider(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - defer ctrl.Finish() + t.Cleanup(ctrl.Finish) mockConfigProvider := configMocks.NewMockProvider(ctrl) @@ -1762,7 +1762,7 @@ func TestCLIManager_stopRemoteWorkload(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - defer ctrl.Finish() + t.Cleanup(ctrl.Finish) tests := []struct { name string @@ -2008,7 +2008,7 @@ func TestCLIManager_deleteRemoteWorkload(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) - defer ctrl.Finish() + t.Cleanup(ctrl.Finish) tests := []struct { name string diff --git a/pkg/workloads/k8s_manager.go b/pkg/workloads/k8s/k8s.go similarity index 75% rename from pkg/workloads/k8s_manager.go rename to pkg/workloads/k8s/k8s.go index f8943fb6d..7d17783e0 100644 --- a/pkg/workloads/k8s_manager.go +++ b/pkg/workloads/k8s/k8s.go @@ -1,6 +1,6 @@ -// Package workloads provides a Kubernetes-based implementation of the K8SManager interface. +// Package k8s provides Kubernetes-specific workload management. // This file contains the Kubernetes implementation for operator environments. -package workloads +package k8s import ( "context" @@ -10,48 +10,78 @@ import ( "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/selection" "k8s.io/apimachinery/pkg/types" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/container/kubernetes" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport" transporttypes "github.com/stacklok/toolhive/pkg/transport/types" - "github.com/stacklok/toolhive/pkg/workloads/k8s" workloadtypes "github.com/stacklok/toolhive/pkg/workloads/types" ) -// k8sManager implements the K8SManager interface for Kubernetes environments. +// manager implements the Manager interface for Kubernetes environments. // In Kubernetes, the operator manages workload lifecycle via MCPServer CRDs. // This manager provides read-only operations and CRD-based storage. -type k8sManager struct { +type manager struct { k8sClient client.Client namespace string } -// NewK8SManager creates a new Kubernetes-based workload manager. -func NewK8SManager(k8sClient client.Client, namespace string) (K8SManager, error) { - return &k8sManager{ +// NewManager creates a new Kubernetes-based workload manager. +func NewManager(k8sClient client.Client, namespace string) (Manager, error) { + return &manager{ k8sClient: k8sClient, namespace: namespace, }, nil } -func (k *k8sManager) GetWorkload(ctx context.Context, workloadName string) (k8s.Workload, error) { +// NewManagerFromContext creates a Kubernetes-based workload manager from context. +// It automatically sets up the Kubernetes client and detects the namespace. +func NewManagerFromContext(_ context.Context) (Manager, error) { + // Create a scheme for controller-runtime client + scheme := runtime.NewScheme() + utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(mcpv1alpha1.AddToScheme(scheme)) + + // Get Kubernetes config + cfg, err := ctrl.GetConfig() + if err != nil { + return nil, fmt.Errorf("failed to get Kubernetes config: %w", err) + } + + // Create controller-runtime client + k8sClient, err := client.New(cfg, client.Options{Scheme: scheme}) + if err != nil { + return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) + } + + // Detect namespace + namespace := kubernetes.GetCurrentNamespace() + + return NewManager(k8sClient, namespace) +} + +func (k *manager) GetWorkload(ctx context.Context, workloadName string) (Workload, error) { mcpServer := &mcpv1alpha1.MCPServer{} key := types.NamespacedName{Name: workloadName, Namespace: k.namespace} if err := k.k8sClient.Get(ctx, key, mcpServer); err != nil { if errors.IsNotFound(err) { - return k8s.Workload{}, fmt.Errorf("MCPServer %s not found", workloadName) + return Workload{}, fmt.Errorf("MCPServer %s not found", workloadName) } - return k8s.Workload{}, fmt.Errorf("failed to get MCPServer: %w", err) + return Workload{}, fmt.Errorf("failed to get MCPServer: %w", err) } - return k.mcpServerToK8SWorkload(mcpServer) + return k.mcpServerToWorkload(mcpServer) } -func (k *k8sManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { +func (k *manager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { mcpServer := &mcpv1alpha1.MCPServer{} key := types.NamespacedName{Name: workloadName, Namespace: k.namespace} if err := k.k8sClient.Get(ctx, key, mcpServer); err != nil { @@ -63,7 +93,7 @@ func (k *k8sManager) DoesWorkloadExist(ctx context.Context, workloadName string) return true, nil } -func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]k8s.Workload, error) { +func (k *manager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]Workload, error) { mcpServerList := &mcpv1alpha1.MCPServerList{} listOpts := []client.ListOption{ client.InNamespace(k.namespace), @@ -92,7 +122,7 @@ func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilte return nil, fmt.Errorf("failed to list MCPServers: %w", err) } - var workloads []k8s.Workload + var workloads []Workload for i := range mcpServerList.Items { mcpServer := &mcpServerList.Items[i] @@ -104,7 +134,7 @@ func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilte } } - workload, err := k.mcpServerToK8SWorkload(mcpServer) + workload, err := k.mcpServerToWorkload(mcpServer) if err != nil { logger.Warnf("Failed to convert MCPServer %s to workload: %v", mcpServer.Name, err) continue @@ -116,7 +146,7 @@ func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilte return workloads, nil } -// Note: The following operations are not part of K8SManager interface: +// Note: The following operations are not part of Manager interface: // - StopWorkloads: Use kubectl to manage MCPServer CRDs // - RunWorkload: Create MCPServer CRD instead // - RunWorkloadDetached: Create MCPServer CRD instead @@ -128,7 +158,7 @@ func (k *k8sManager) ListWorkloads(ctx context.Context, listAll bool, labelFilte // Note: This requires a Kubernetes clientset for log streaming. // For now, this returns an error indicating logs should be retrieved via kubectl. // TODO: Implement proper log retrieval using clientset or REST client. -func (k *k8sManager) GetLogs(_ context.Context, _ string, follow bool) (string, error) { +func (k *manager) GetLogs(_ context.Context, _ string, follow bool) (string, error) { if follow { return "", fmt.Errorf("follow mode is not supported. Use 'kubectl logs -f -n %s' to stream logs", k.namespace) } @@ -141,14 +171,14 @@ func (k *k8sManager) GetLogs(_ context.Context, _ string, follow bool) (string, // Note: This requires a Kubernetes clientset for log streaming. // For now, this returns an error indicating logs should be retrieved via kubectl. // TODO: Implement proper log retrieval using clientset or REST client. -func (k *k8sManager) GetProxyLogs(_ context.Context, _ string) (string, error) { +func (k *manager) GetProxyLogs(_ context.Context, _ string) (string, error) { return "", fmt.Errorf( "GetProxyLogs is not fully implemented in Kubernetes mode. Use 'kubectl logs -c proxy -n %s' to retrieve proxy logs", k.namespace) } // MoveToGroup moves the specified workloads from one group to another. -func (k *k8sManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { +func (k *manager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { for _, name := range workloadNames { mcpServer := &mcpv1alpha1.MCPServer{} key := types.NamespacedName{Name: name, Namespace: k.namespace} @@ -177,7 +207,7 @@ func (k *k8sManager) MoveToGroup(ctx context.Context, workloadNames []string, gr } // ListWorkloadsInGroup returns all workload names that belong to the specified group. -func (k *k8sManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { +func (k *manager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { mcpServerList := &mcpv1alpha1.MCPServerList{} listOpts := []client.ListOption{ client.InNamespace(k.namespace), @@ -198,8 +228,8 @@ func (k *k8sManager) ListWorkloadsInGroup(ctx context.Context, groupName string) return groupWorkloads, nil } -// mcpServerToK8SWorkload converts an MCPServer CRD to a k8s.Workload. -func (k *k8sManager) mcpServerToK8SWorkload(mcpServer *mcpv1alpha1.MCPServer) (k8s.Workload, error) { +// mcpServerToWorkload converts an MCPServer CRD to a Workload. +func (k *manager) mcpServerToWorkload(mcpServer *mcpv1alpha1.MCPServer) (Workload, error) { // Parse transport type transportType, err := transporttypes.ParseTransportType(mcpServer.Spec.Transport) if err != nil { @@ -252,7 +282,7 @@ func (k *k8sManager) mcpServerToK8SWorkload(mcpServer *mcpv1alpha1.MCPServer) (k createdAt = time.Now() } - return k8s.Workload{ + return Workload{ Name: mcpServer.Name, Namespace: mcpServer.Namespace, Package: mcpServer.Spec.Image, @@ -272,7 +302,7 @@ func (k *k8sManager) mcpServerToK8SWorkload(mcpServer *mcpv1alpha1.MCPServer) (k } // isStandardK8sAnnotation checks if an annotation key is a standard Kubernetes annotation. -func (*k8sManager) isStandardK8sAnnotation(key string) bool { +func (*manager) isStandardK8sAnnotation(key string) bool { // Common Kubernetes annotation prefixes standardPrefixes := []string{ "kubectl.kubernetes.io/", diff --git a/pkg/workloads/k8s_manager_test.go b/pkg/workloads/k8s/k8s_test.go similarity index 90% rename from pkg/workloads/k8s_manager_test.go rename to pkg/workloads/k8s/k8s_test.go index 903d9fadd..a4af3177e 100644 --- a/pkg/workloads/k8s_manager_test.go +++ b/pkg/workloads/k8s/k8s_test.go @@ -1,4 +1,4 @@ -package workloads +package k8s import ( "context" @@ -13,7 +13,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - "github.com/stacklok/toolhive/pkg/workloads/k8s" ) const ( @@ -50,7 +49,7 @@ func (m *mockClient) Update(ctx context.Context, obj client.Object, opts ...clie return nil } -func TestNewK8SManager(t *testing.T) { +func TestNewManager(t *testing.T) { t.Parallel() tests := []struct { @@ -77,25 +76,25 @@ func TestNewK8SManager(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - manager, err := NewK8SManager(tt.k8sClient, tt.namespace) + k8sManager, err := NewManager(tt.k8sClient, tt.namespace) if tt.wantError { require.Error(t, err) - assert.Nil(t, manager) + assert.Nil(t, k8sManager) } else { require.NoError(t, err) - require.NotNil(t, manager) + require.NotNil(t, k8sManager) - k8sMgr, ok := manager.(*k8sManager) + mgr, ok := k8sManager.(*manager) require.True(t, ok) - assert.Equal(t, tt.k8sClient, k8sMgr.k8sClient) - assert.Equal(t, tt.namespace, k8sMgr.namespace) + assert.Equal(t, tt.k8sClient, mgr.k8sClient) + assert.Equal(t, tt.namespace, mgr.namespace) } }) } } -func TestK8SManager_GetWorkload(t *testing.T) { +func TestManager_GetWorkload(t *testing.T) { t.Parallel() tests := []struct { @@ -104,7 +103,7 @@ func TestK8SManager_GetWorkload(t *testing.T) { setupMock func(*mockClient) wantError bool errorMsg string - expected k8s.Workload + expected Workload }{ { name: "successful get", @@ -125,7 +124,7 @@ func TestK8SManager_GetWorkload(t *testing.T) { } }, wantError: false, - expected: k8s.Workload{ + expected: Workload{ Name: "test-workload", Namespace: defaultNamespace, Phase: mcpv1alpha1.MCPServerPhaseRunning, @@ -166,13 +165,13 @@ func TestK8SManager_GetWorkload(t *testing.T) { mockClient := &mockClient{} tt.setupMock(mockClient) - manager := &k8sManager{ + mgr := &manager{ k8sClient: mockClient, namespace: defaultNamespace, } ctx := context.Background() - result, err := manager.GetWorkload(ctx, tt.workloadName) + result, err := mgr.GetWorkload(ctx, tt.workloadName) if tt.wantError { require.Error(t, err) @@ -189,7 +188,7 @@ func TestK8SManager_GetWorkload(t *testing.T) { } } -func TestK8SManager_DoesWorkloadExist(t *testing.T) { +func TestManager_DoesWorkloadExist(t *testing.T) { t.Parallel() tests := []struct { @@ -244,13 +243,13 @@ func TestK8SManager_DoesWorkloadExist(t *testing.T) { mockClient := &mockClient{} tt.setupMock(mockClient) - manager := &k8sManager{ + mgr := &manager{ k8sClient: mockClient, namespace: defaultNamespace, } ctx := context.Background() - result, err := manager.DoesWorkloadExist(ctx, tt.workloadName) + result, err := mgr.DoesWorkloadExist(ctx, tt.workloadName) if tt.wantError { require.Error(t, err) @@ -262,7 +261,7 @@ func TestK8SManager_DoesWorkloadExist(t *testing.T) { } } -func TestK8SManager_ListWorkloadsInGroup(t *testing.T) { +func TestManager_ListWorkloadsInGroup(t *testing.T) { t.Parallel() tests := []struct { @@ -347,13 +346,13 @@ func TestK8SManager_ListWorkloadsInGroup(t *testing.T) { mockClient := &mockClient{} tt.setupMock(mockClient) - manager := &k8sManager{ + mgr := &manager{ k8sClient: mockClient, namespace: defaultNamespace, } ctx := context.Background() - result, err := manager.ListWorkloadsInGroup(ctx, tt.groupName) + result, err := mgr.ListWorkloadsInGroup(ctx, tt.groupName) if tt.wantError { require.Error(t, err) @@ -368,7 +367,7 @@ func TestK8SManager_ListWorkloadsInGroup(t *testing.T) { } } -func TestK8SManager_ListWorkloads(t *testing.T) { +func TestManager_ListWorkloads(t *testing.T) { t.Parallel() tests := []struct { @@ -458,13 +457,13 @@ func TestK8SManager_ListWorkloads(t *testing.T) { mockClient := &mockClient{} tt.setupMock(mockClient) - manager := &k8sManager{ + mgr := &manager{ k8sClient: mockClient, namespace: defaultNamespace, } ctx := context.Background() - result, err := manager.ListWorkloads(ctx, tt.listAll, tt.labelFilters...) + result, err := mgr.ListWorkloads(ctx, tt.listAll, tt.labelFilters...) if tt.wantError { require.Error(t, err) @@ -479,7 +478,7 @@ func TestK8SManager_ListWorkloads(t *testing.T) { } } -func TestK8SManager_MoveToGroup(t *testing.T) { +func TestManager_MoveToGroup(t *testing.T) { t.Parallel() tests := []struct { @@ -572,13 +571,13 @@ func TestK8SManager_MoveToGroup(t *testing.T) { mockClient := &mockClient{} tt.setupMock(mockClient) - manager := &k8sManager{ + mgr := &manager{ k8sClient: mockClient, namespace: defaultNamespace, } ctx := context.Background() - err := manager.MoveToGroup(ctx, tt.workloadNames, tt.groupFrom, tt.groupTo) + err := mgr.MoveToGroup(ctx, tt.workloadNames, tt.groupFrom, tt.groupTo) if tt.wantError { require.Error(t, err) @@ -592,11 +591,11 @@ func TestK8SManager_MoveToGroup(t *testing.T) { } } -func TestK8SManager_NoOpMethods(t *testing.T) { +func TestManager_NoOpMethods(t *testing.T) { t.Parallel() mockClient := &mockClient{} - manager := &k8sManager{ + mgr := &manager{ k8sClient: mockClient, namespace: defaultNamespace, } @@ -605,7 +604,7 @@ func TestK8SManager_NoOpMethods(t *testing.T) { t.Run("GetLogs returns error", func(t *testing.T) { t.Parallel() - logs, err := manager.GetLogs(ctx, testWorkload1, false) + logs, err := mgr.GetLogs(ctx, testWorkload1, false) require.Error(t, err) assert.Empty(t, logs) assert.Contains(t, err.Error(), "not fully implemented") @@ -613,20 +612,20 @@ func TestK8SManager_NoOpMethods(t *testing.T) { t.Run("GetProxyLogs returns error", func(t *testing.T) { t.Parallel() - logs, err := manager.GetProxyLogs(ctx, testWorkload1) + logs, err := mgr.GetProxyLogs(ctx, testWorkload1) require.Error(t, err) assert.Empty(t, logs) assert.Contains(t, err.Error(), "not fully implemented") }) } -func TestK8SManager_mcpServerToK8SWorkload(t *testing.T) { +func TestManager_mcpServerToWorkload(t *testing.T) { t.Parallel() tests := []struct { name string mcpServer *mcpv1alpha1.MCPServer - expected k8s.Workload + expected Workload }{ { name: "running workload with HTTP transport", @@ -648,7 +647,7 @@ func TestK8SManager_mcpServerToK8SWorkload(t *testing.T) { ProxyPort: 8080, }, }, - expected: k8s.Workload{ + expected: Workload{ Name: "test-workload", Namespace: defaultNamespace, Phase: mcpv1alpha1.MCPServerPhaseRunning, @@ -667,7 +666,7 @@ func TestK8SManager_mcpServerToK8SWorkload(t *testing.T) { Phase: mcpv1alpha1.MCPServerPhaseTerminating, }, }, - expected: k8s.Workload{ + expected: Workload{ Name: "terminating-workload", Namespace: defaultNamespace, Phase: mcpv1alpha1.MCPServerPhaseTerminating, @@ -684,7 +683,7 @@ func TestK8SManager_mcpServerToK8SWorkload(t *testing.T) { Phase: mcpv1alpha1.MCPServerPhaseFailed, }, }, - expected: k8s.Workload{ + expected: Workload{ Name: "failed-workload", Namespace: defaultNamespace, Phase: mcpv1alpha1.MCPServerPhaseFailed, @@ -696,11 +695,11 @@ func TestK8SManager_mcpServerToK8SWorkload(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - manager := &k8sManager{ + mgr := &manager{ namespace: defaultNamespace, } - result, err := manager.mcpServerToK8SWorkload(tt.mcpServer) + result, err := mgr.mcpServerToWorkload(tt.mcpServer) require.NoError(t, err) assert.Equal(t, tt.expected.Name, result.Name) @@ -714,10 +713,10 @@ func TestK8SManager_mcpServerToK8SWorkload(t *testing.T) { } } -func TestK8SManager_isStandardK8sAnnotation(t *testing.T) { +func TestManager_isStandardK8sAnnotation(t *testing.T) { t.Parallel() - manager := &k8sManager{} + mgr := &manager{} tests := []struct { name string @@ -771,7 +770,7 @@ func TestK8SManager_isStandardK8sAnnotation(t *testing.T) { }, { name: "case sensitive - uppercase", - key: "KUBECTL.kubernetes.io/annotation", + key: "KUBECTL.KUBERNETES.IO/annotation", expected: false, }, } @@ -779,8 +778,8 @@ func TestK8SManager_isStandardK8sAnnotation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - result := manager.isStandardK8sAnnotation(tt.key) - assert.Equal(t, tt.expected, result, "isStandardK8sAnnotation(%q) = %v, want %v", tt.key, result, tt.expected) + result := mgr.isStandardK8sAnnotation(tt.key) + assert.Equal(t, tt.expected, result) }) } } diff --git a/pkg/workloads/k8s_manager_interface.go b/pkg/workloads/k8s/manager.go similarity index 76% rename from pkg/workloads/k8s_manager_interface.go rename to pkg/workloads/k8s/manager.go index e6b559cf9..29330a73f 100644 --- a/pkg/workloads/k8s_manager_interface.go +++ b/pkg/workloads/k8s/manager.go @@ -1,26 +1,23 @@ -// Package workloads contains high-level logic for managing the lifecycle of -// ToolHive-managed containers. -package workloads +// Package k8s provides Kubernetes-specific workload management. +package k8s import ( "context" - - "github.com/stacklok/toolhive/pkg/workloads/k8s" ) -// K8SManager manages MCPServer CRD workloads in Kubernetes. -// This interface is separate from Manager to avoid coupling Kubernetes workloads +// Manager manages MCPServer CRD workloads in Kubernetes. +// This interface is separate from workloads.Manager to avoid coupling Kubernetes workloads // to the CLI container runtime interface. // -//go:generate mockgen -destination=mocks/mock_k8s_manager.go -package=mocks -source=k8s_manager_interface.go K8SManager -type K8SManager interface { +//go:generate mockgen -destination=mocks/mock_manager.go -package=mocks github.com/stacklok/toolhive/pkg/workloads/k8s Manager +type Manager interface { // GetWorkload retrieves an MCPServer CRD by name - GetWorkload(ctx context.Context, workloadName string) (k8s.Workload, error) + GetWorkload(ctx context.Context, workloadName string) (Workload, error) // ListWorkloads lists all MCPServer CRDs, optionally filtered by labels // The `listAll` parameter determines whether to include workloads that are not running // The optional `labelFilters` parameter allows filtering workloads by labels (format: key=value) - ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]k8s.Workload, error) + ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]Workload, error) // ListWorkloadsInGroup returns all workload names that belong to the specified group ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) diff --git a/pkg/workloads/mocks/mock_k8s_manager.go b/pkg/workloads/k8s/mocks/mock_manager.go similarity index 53% rename from pkg/workloads/mocks/mock_k8s_manager.go rename to pkg/workloads/k8s/mocks/mock_manager.go index dbdec9361..c91e73afb 100644 --- a/pkg/workloads/mocks/mock_k8s_manager.go +++ b/pkg/workloads/k8s/mocks/mock_manager.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: k8s_manager_interface.go +// Source: github.com/stacklok/toolhive/pkg/workloads/k8s (interfaces: Manager) // // Generated by this command: // -// mockgen -destination=mocks/mock_k8s_manager.go -package=mocks -source=k8s_manager_interface.go K8SManager +// mockgen -destination=mocks/mock_manager.go -package=mocks github.com/stacklok/toolhive/pkg/workloads/k8s Manager // // Package mocks is a generated GoMock package. @@ -17,32 +17,32 @@ import ( gomock "go.uber.org/mock/gomock" ) -// MockK8SManager is a mock of K8SManager interface. -type MockK8SManager struct { +// MockManager is a mock of Manager interface. +type MockManager struct { ctrl *gomock.Controller - recorder *MockK8SManagerMockRecorder + recorder *MockManagerMockRecorder isgomock struct{} } -// MockK8SManagerMockRecorder is the mock recorder for MockK8SManager. -type MockK8SManagerMockRecorder struct { - mock *MockK8SManager +// MockManagerMockRecorder is the mock recorder for MockManager. +type MockManagerMockRecorder struct { + mock *MockManager } -// NewMockK8SManager creates a new mock instance. -func NewMockK8SManager(ctrl *gomock.Controller) *MockK8SManager { - mock := &MockK8SManager{ctrl: ctrl} - mock.recorder = &MockK8SManagerMockRecorder{mock} +// NewMockManager creates a new mock instance. +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockK8SManager) EXPECT() *MockK8SManagerMockRecorder { +func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } // DoesWorkloadExist mocks base method. -func (m *MockK8SManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { +func (m *MockManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DoesWorkloadExist", ctx, workloadName) ret0, _ := ret[0].(bool) @@ -51,13 +51,13 @@ func (m *MockK8SManager) DoesWorkloadExist(ctx context.Context, workloadName str } // DoesWorkloadExist indicates an expected call of DoesWorkloadExist. -func (mr *MockK8SManagerMockRecorder) DoesWorkloadExist(ctx, workloadName any) *gomock.Call { +func (mr *MockManagerMockRecorder) DoesWorkloadExist(ctx, workloadName any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesWorkloadExist", reflect.TypeOf((*MockK8SManager)(nil).DoesWorkloadExist), ctx, workloadName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesWorkloadExist", reflect.TypeOf((*MockManager)(nil).DoesWorkloadExist), ctx, workloadName) } // GetLogs mocks base method. -func (m *MockK8SManager) GetLogs(ctx context.Context, containerName string, follow bool) (string, error) { +func (m *MockManager) GetLogs(ctx context.Context, containerName string, follow bool) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetLogs", ctx, containerName, follow) ret0, _ := ret[0].(string) @@ -66,13 +66,13 @@ func (m *MockK8SManager) GetLogs(ctx context.Context, containerName string, foll } // GetLogs indicates an expected call of GetLogs. -func (mr *MockK8SManagerMockRecorder) GetLogs(ctx, containerName, follow any) *gomock.Call { +func (mr *MockManagerMockRecorder) GetLogs(ctx, containerName, follow any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogs", reflect.TypeOf((*MockK8SManager)(nil).GetLogs), ctx, containerName, follow) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogs", reflect.TypeOf((*MockManager)(nil).GetLogs), ctx, containerName, follow) } // GetProxyLogs mocks base method. -func (m *MockK8SManager) GetProxyLogs(ctx context.Context, workloadName string) (string, error) { +func (m *MockManager) GetProxyLogs(ctx context.Context, workloadName string) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetProxyLogs", ctx, workloadName) ret0, _ := ret[0].(string) @@ -81,13 +81,13 @@ func (m *MockK8SManager) GetProxyLogs(ctx context.Context, workloadName string) } // GetProxyLogs indicates an expected call of GetProxyLogs. -func (mr *MockK8SManagerMockRecorder) GetProxyLogs(ctx, workloadName any) *gomock.Call { +func (mr *MockManagerMockRecorder) GetProxyLogs(ctx, workloadName any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyLogs", reflect.TypeOf((*MockK8SManager)(nil).GetProxyLogs), ctx, workloadName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyLogs", reflect.TypeOf((*MockManager)(nil).GetProxyLogs), ctx, workloadName) } // GetWorkload mocks base method. -func (m *MockK8SManager) GetWorkload(ctx context.Context, workloadName string) (k8s.Workload, error) { +func (m *MockManager) GetWorkload(ctx context.Context, workloadName string) (k8s.Workload, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetWorkload", ctx, workloadName) ret0, _ := ret[0].(k8s.Workload) @@ -96,13 +96,13 @@ func (m *MockK8SManager) GetWorkload(ctx context.Context, workloadName string) ( } // GetWorkload indicates an expected call of GetWorkload. -func (mr *MockK8SManagerMockRecorder) GetWorkload(ctx, workloadName any) *gomock.Call { +func (mr *MockManagerMockRecorder) GetWorkload(ctx, workloadName any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkload", reflect.TypeOf((*MockK8SManager)(nil).GetWorkload), ctx, workloadName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkload", reflect.TypeOf((*MockManager)(nil).GetWorkload), ctx, workloadName) } // ListWorkloads mocks base method. -func (m *MockK8SManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]k8s.Workload, error) { +func (m *MockManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]k8s.Workload, error) { m.ctrl.T.Helper() varargs := []any{ctx, listAll} for _, a := range labelFilters { @@ -115,14 +115,14 @@ func (m *MockK8SManager) ListWorkloads(ctx context.Context, listAll bool, labelF } // ListWorkloads indicates an expected call of ListWorkloads. -func (mr *MockK8SManagerMockRecorder) ListWorkloads(ctx, listAll any, labelFilters ...any) *gomock.Call { +func (mr *MockManagerMockRecorder) ListWorkloads(ctx, listAll any, labelFilters ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{ctx, listAll}, labelFilters...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloads", reflect.TypeOf((*MockK8SManager)(nil).ListWorkloads), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloads", reflect.TypeOf((*MockManager)(nil).ListWorkloads), varargs...) } // ListWorkloadsInGroup mocks base method. -func (m *MockK8SManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { +func (m *MockManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ListWorkloadsInGroup", ctx, groupName) ret0, _ := ret[0].([]string) @@ -131,13 +131,13 @@ func (m *MockK8SManager) ListWorkloadsInGroup(ctx context.Context, groupName str } // ListWorkloadsInGroup indicates an expected call of ListWorkloadsInGroup. -func (mr *MockK8SManagerMockRecorder) ListWorkloadsInGroup(ctx, groupName any) *gomock.Call { +func (mr *MockManagerMockRecorder) ListWorkloadsInGroup(ctx, groupName any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloadsInGroup", reflect.TypeOf((*MockK8SManager)(nil).ListWorkloadsInGroup), ctx, groupName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloadsInGroup", reflect.TypeOf((*MockManager)(nil).ListWorkloadsInGroup), ctx, groupName) } // MoveToGroup mocks base method. -func (m *MockK8SManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom, groupTo string) error { +func (m *MockManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom, groupTo string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "MoveToGroup", ctx, workloadNames, groupFrom, groupTo) ret0, _ := ret[0].(error) @@ -145,7 +145,7 @@ func (m *MockK8SManager) MoveToGroup(ctx context.Context, workloadNames []string } // MoveToGroup indicates an expected call of MoveToGroup. -func (mr *MockK8SManagerMockRecorder) MoveToGroup(ctx, workloadNames, groupFrom, groupTo any) *gomock.Call { +func (mr *MockManagerMockRecorder) MoveToGroup(ctx, workloadNames, groupFrom, groupTo any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MoveToGroup", reflect.TypeOf((*MockK8SManager)(nil).MoveToGroup), ctx, workloadNames, groupFrom, groupTo) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MoveToGroup", reflect.TypeOf((*MockManager)(nil).MoveToGroup), ctx, workloadNames, groupFrom, groupTo) } diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index a1d9c6f63..f34fea4f4 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -7,15 +7,8 @@ import ( "fmt" "golang.org/x/sync/errgroup" - "k8s.io/apimachinery/pkg/runtime" - utilruntime "k8s.io/apimachinery/pkg/util/runtime" - clientgoscheme "k8s.io/client-go/kubernetes/scheme" - ctrl "sigs.k8s.io/controller-runtime" - "sigs.k8s.io/controller-runtime/pkg/client" - mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" "github.com/stacklok/toolhive/pkg/config" - "github.com/stacklok/toolhive/pkg/container/kubernetes" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/runner" @@ -66,19 +59,19 @@ var ErrWorkloadNotRunning = fmt.Errorf("workload not running") // NewManager creates a new CLI workload manager. // Returns Manager interface (existing behavior, unchanged). -// IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. +// IMPORTANT: This function only works in CLI mode. For Kubernetes, use k8s.NewManagerFromContext() directly. func NewManager(ctx context.Context) (Manager, error) { if rt.IsKubernetesRuntime() { - return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes environments") + return nil, fmt.Errorf("use k8s.NewManagerFromContext() for Kubernetes environments") } return NewCLIManager(ctx) } // NewManagerWithProvider creates a new CLI workload manager with a custom config provider. -// IMPORTANT: This function only works in CLI mode. For Kubernetes, use NewK8SManager() directly. +// IMPORTANT: This function only works in CLI mode. For Kubernetes, use k8s.NewManagerFromContext() directly. func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) (Manager, error) { if rt.IsKubernetesRuntime() { - return nil, fmt.Errorf("use workloads.NewK8SManager() for Kubernetes environments") + return nil, fmt.Errorf("use k8s.NewManagerFromContext() for Kubernetes environments") } return NewCLIManagerWithProvider(ctx, configProvider) } @@ -98,29 +91,3 @@ func NewManagerFromRuntime(rtRuntime rt.Runtime) (Manager, error) { func NewManagerFromRuntimeWithProvider(rtRuntime rt.Runtime, configProvider config.Provider) (Manager, error) { return NewCLIManagerFromRuntimeWithProvider(rtRuntime, configProvider) } - -// NewK8SManagerFromContext creates a Kubernetes-based workload manager from context. -// It automatically sets up the Kubernetes client and detects the namespace. -func NewK8SManagerFromContext(_ context.Context) (K8SManager, error) { - // Create a scheme for controller-runtime client - scheme := runtime.NewScheme() - utilruntime.Must(clientgoscheme.AddToScheme(scheme)) - utilruntime.Must(mcpv1alpha1.AddToScheme(scheme)) - - // Get Kubernetes config - cfg, err := ctrl.GetConfig() - if err != nil { - return nil, fmt.Errorf("failed to get Kubernetes config: %w", err) - } - - // Create controller-runtime client - k8sClient, err := client.New(cfg, client.Options{Scheme: scheme}) - if err != nil { - return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) - } - - // Detect namespace - namespace := kubernetes.GetCurrentNamespace() - - return NewK8SManager(k8sClient, namespace) -} From dd7c0edaded4db5a52a66b924487207fb57f374f Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 12 Nov 2025 15:06:38 +0000 Subject: [PATCH 11/16] removed logs fns --- pkg/workloads/k8s/k8s.go | 25 ++------------------- pkg/workloads/k8s/k8s_test.go | 28 ----------------------- pkg/workloads/k8s/manager.go | 10 ++------- pkg/workloads/k8s/mocks/mock_manager.go | 30 ------------------------- 4 files changed, 4 insertions(+), 89 deletions(-) diff --git a/pkg/workloads/k8s/k8s.go b/pkg/workloads/k8s/k8s.go index 7d17783e0..bb2f99904 100644 --- a/pkg/workloads/k8s/k8s.go +++ b/pkg/workloads/k8s/k8s.go @@ -153,29 +153,8 @@ func (k *manager) ListWorkloads(ctx context.Context, listAll bool, labelFilters // - DeleteWorkloads: Use kubectl to delete MCPServer CRDs // - RestartWorkloads: Use kubectl to restart MCPServer CRDs // - UpdateWorkload: Update MCPServer CRD directly - -// GetLogs retrieves logs from the pod associated with the MCPServer. -// Note: This requires a Kubernetes clientset for log streaming. -// For now, this returns an error indicating logs should be retrieved via kubectl. -// TODO: Implement proper log retrieval using clientset or REST client. -func (k *manager) GetLogs(_ context.Context, _ string, follow bool) (string, error) { - if follow { - return "", fmt.Errorf("follow mode is not supported. Use 'kubectl logs -f -n %s' to stream logs", k.namespace) - } - return "", fmt.Errorf( - "GetLogs is not fully implemented in Kubernetes mode. Use 'kubectl logs -n %s' to retrieve logs", - k.namespace) -} - -// GetProxyLogs retrieves logs from the proxy container in the pod associated with the MCPServer. -// Note: This requires a Kubernetes clientset for log streaming. -// For now, this returns an error indicating logs should be retrieved via kubectl. -// TODO: Implement proper log retrieval using clientset or REST client. -func (k *manager) GetProxyLogs(_ context.Context, _ string) (string, error) { - return "", fmt.Errorf( - "GetProxyLogs is not fully implemented in Kubernetes mode. Use 'kubectl logs -c proxy -n %s' to retrieve proxy logs", - k.namespace) -} +// - GetLogs: Use 'kubectl logs -n ' to retrieve logs +// - GetProxyLogs: Use 'kubectl logs -c proxy -n ' to retrieve proxy logs // MoveToGroup moves the specified workloads from one group to another. func (k *manager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { diff --git a/pkg/workloads/k8s/k8s_test.go b/pkg/workloads/k8s/k8s_test.go index a4af3177e..726e4b867 100644 --- a/pkg/workloads/k8s/k8s_test.go +++ b/pkg/workloads/k8s/k8s_test.go @@ -591,34 +591,6 @@ func TestManager_MoveToGroup(t *testing.T) { } } -func TestManager_NoOpMethods(t *testing.T) { - t.Parallel() - - mockClient := &mockClient{} - mgr := &manager{ - k8sClient: mockClient, - namespace: defaultNamespace, - } - - ctx := context.Background() - - t.Run("GetLogs returns error", func(t *testing.T) { - t.Parallel() - logs, err := mgr.GetLogs(ctx, testWorkload1, false) - require.Error(t, err) - assert.Empty(t, logs) - assert.Contains(t, err.Error(), "not fully implemented") - }) - - t.Run("GetProxyLogs returns error", func(t *testing.T) { - t.Parallel() - logs, err := mgr.GetProxyLogs(ctx, testWorkload1) - require.Error(t, err) - assert.Empty(t, logs) - assert.Contains(t, err.Error(), "not fully implemented") - }) -} - func TestManager_mcpServerToWorkload(t *testing.T) { t.Parallel() diff --git a/pkg/workloads/k8s/manager.go b/pkg/workloads/k8s/manager.go index 29330a73f..a9eb86af5 100644 --- a/pkg/workloads/k8s/manager.go +++ b/pkg/workloads/k8s/manager.go @@ -28,14 +28,6 @@ type Manager interface { // MoveToGroup moves the specified workloads from one group to another by updating their GroupRef MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error - // GetLogs retrieves logs from the pod associated with the MCPServer - // Note: This may not be fully implemented and may return an error - GetLogs(ctx context.Context, containerName string, follow bool) (string, error) - - // GetProxyLogs retrieves logs from the proxy container in the pod associated with the MCPServer - // Note: This may not be fully implemented and may return an error - GetProxyLogs(ctx context.Context, workloadName string) (string, error) - // The following operations are not supported in Kubernetes mode (operator manages lifecycle): // - RunWorkload: Workloads are created via MCPServer CRDs // - RunWorkloadDetached: Workloads are created via MCPServer CRDs @@ -43,4 +35,6 @@ type Manager interface { // - DeleteWorkloads: Use kubectl to manage MCPServer CRDs // - RestartWorkloads: Use kubectl to manage MCPServer CRDs // - UpdateWorkload: Update MCPServer CRD directly + // - GetLogs: Use 'kubectl logs -n ' to retrieve logs + // - GetProxyLogs: Use 'kubectl logs -c proxy -n ' to retrieve proxy logs } diff --git a/pkg/workloads/k8s/mocks/mock_manager.go b/pkg/workloads/k8s/mocks/mock_manager.go index c91e73afb..01b8c0dfd 100644 --- a/pkg/workloads/k8s/mocks/mock_manager.go +++ b/pkg/workloads/k8s/mocks/mock_manager.go @@ -56,36 +56,6 @@ func (mr *MockManagerMockRecorder) DoesWorkloadExist(ctx, workloadName any) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesWorkloadExist", reflect.TypeOf((*MockManager)(nil).DoesWorkloadExist), ctx, workloadName) } -// GetLogs mocks base method. -func (m *MockManager) GetLogs(ctx context.Context, containerName string, follow bool) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLogs", ctx, containerName, follow) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetLogs indicates an expected call of GetLogs. -func (mr *MockManagerMockRecorder) GetLogs(ctx, containerName, follow any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogs", reflect.TypeOf((*MockManager)(nil).GetLogs), ctx, containerName, follow) -} - -// GetProxyLogs mocks base method. -func (m *MockManager) GetProxyLogs(ctx context.Context, workloadName string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetProxyLogs", ctx, workloadName) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetProxyLogs indicates an expected call of GetProxyLogs. -func (mr *MockManagerMockRecorder) GetProxyLogs(ctx, workloadName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyLogs", reflect.TypeOf((*MockManager)(nil).GetProxyLogs), ctx, workloadName) -} - // GetWorkload mocks base method. func (m *MockManager) GetWorkload(ctx context.Context, workloadName string) (k8s.Workload, error) { m.ctrl.T.Helper() From 0608b7a4ef19ce2378d4685c3a211824bb5cfa3e Mon Sep 17 00:00:00 2001 From: amirejaz Date: Thu, 13 Nov 2025 09:42:09 +0000 Subject: [PATCH 12/16] use pkg/k8s package for client and namespace --- pkg/workloads/k8s/k8s.go | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/pkg/workloads/k8s/k8s.go b/pkg/workloads/k8s/k8s.go index bb2f99904..2023a01f3 100644 --- a/pkg/workloads/k8s/k8s.go +++ b/pkg/workloads/k8s/k8s.go @@ -15,11 +15,10 @@ import ( "k8s.io/apimachinery/pkg/types" utilruntime "k8s.io/apimachinery/pkg/util/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" - ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - "github.com/stacklok/toolhive/pkg/container/kubernetes" + "github.com/stacklok/toolhive/pkg/k8s" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/transport" transporttypes "github.com/stacklok/toolhive/pkg/transport/types" @@ -50,20 +49,14 @@ func NewManagerFromContext(_ context.Context) (Manager, error) { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) utilruntime.Must(mcpv1alpha1.AddToScheme(scheme)) - // Get Kubernetes config - cfg, err := ctrl.GetConfig() - if err != nil { - return nil, fmt.Errorf("failed to get Kubernetes config: %w", err) - } - // Create controller-runtime client - k8sClient, err := client.New(cfg, client.Options{Scheme: scheme}) + k8sClient, err := k8s.NewControllerRuntimeClient(scheme) if err != nil { return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) } // Detect namespace - namespace := kubernetes.GetCurrentNamespace() + namespace := k8s.GetCurrentNamespace() return NewManager(k8sClient, namespace) } From ed6b6d9a42845e16ed11ea3ea4777125c763fc22 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Fri, 14 Nov 2025 11:28:19 +0000 Subject: [PATCH 13/16] moved discoverer creation to factory inside pkg --- cmd/vmcp/app/commands.go | 19 ++-------- pkg/vmcp/aggregator/discoverer_factory.go | 44 +++++++++++++++++++++++ 2 files changed, 47 insertions(+), 16 deletions(-) create mode 100644 pkg/vmcp/aggregator/discoverer_factory.go diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index dea6be922..ed4abc36d 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -9,7 +9,6 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" - rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" @@ -20,8 +19,6 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/discovery" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server" - "github.com/stacklok/toolhive/pkg/workloads" - "github.com/stacklok/toolhive/pkg/workloads/k8s" ) var rootCmd = &cobra.Command{ @@ -229,19 +226,9 @@ func discoverBackends(ctx context.Context, cfg *config.Config) ([]vmcp.Backend, } // Create backend discoverer based on runtime environment - var discoverer aggregator.BackendDiscoverer - if rt.IsKubernetesRuntime() { - k8sWorkloadsManager, err := k8s.NewManagerFromContext(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to create Kubernetes workloads manager: %w", err) - } - discoverer = aggregator.NewK8SBackendDiscoverer(k8sWorkloadsManager, groupsManager, cfg.OutgoingAuth) - } else { - cliWorkloadsManager, err := workloads.NewManager(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to create CLI workloads manager: %w", err) - } - discoverer = aggregator.NewCLIBackendDiscoverer(cliWorkloadsManager, groupsManager, cfg.OutgoingAuth) + discoverer, err := aggregator.NewBackendDiscoverer(ctx, groupsManager, cfg.OutgoingAuth) + if err != nil { + return nil, nil, fmt.Errorf("failed to create backend discoverer: %w", err) } logger.Infof("Discovering backends in group: %s", cfg.Group) diff --git a/pkg/vmcp/aggregator/discoverer_factory.go b/pkg/vmcp/aggregator/discoverer_factory.go new file mode 100644 index 000000000..2f7089268 --- /dev/null +++ b/pkg/vmcp/aggregator/discoverer_factory.go @@ -0,0 +1,44 @@ +package aggregator + +import ( + "context" + "fmt" + + rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/groups" + "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/workloads" + "github.com/stacklok/toolhive/pkg/workloads/k8s" +) + +// NewBackendDiscoverer creates a BackendDiscoverer based on the runtime environment. +// It automatically detects whether to use CLI (Docker/Podman) or Kubernetes discoverer +// and creates the appropriate workloads manager. +// +// Parameters: +// - ctx: Context for creating managers +// - groupsManager: Manager for group operations (must already be initialized) +// - authConfig: Outgoing authentication configuration for discovered backends +// +// Returns: +// - BackendDiscoverer: The appropriate discoverer for the current runtime +// - error: If manager creation fails +func NewBackendDiscoverer( + ctx context.Context, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) (BackendDiscoverer, error) { + if rt.IsKubernetesRuntime() { + k8sWorkloadsManager, err := k8s.NewManagerFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create Kubernetes workloads manager: %w", err) + } + return NewK8SBackendDiscoverer(k8sWorkloadsManager, groupsManager, authConfig), nil + } + + cliWorkloadsManager, err := workloads.NewManager(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create CLI workloads manager: %w", err) + } + return NewCLIBackendDiscoverer(cliWorkloadsManager, groupsManager, authConfig), nil +} From 5ed3dff49ab659050976c40e37e109a6a02f45e0 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Mon, 17 Nov 2025 10:58:18 +0000 Subject: [PATCH 14/16] refactor the workload manager --- pkg/vmcp/aggregator/cli_discoverer.go | 155 ---- pkg/vmcp/aggregator/cli_discoverer_test.go | 270 ------- pkg/vmcp/aggregator/discoverer.go | 167 ++++- pkg/vmcp/aggregator/discoverer_factory.go | 44 -- pkg/vmcp/aggregator/discoverer_test.go | 389 ++++++++++ pkg/vmcp/aggregator/k8s_discoverer.go | 160 ----- pkg/vmcp/aggregator/k8s_discoverer_test.go | 324 --------- pkg/vmcp/aggregator/testhelpers_test.go | 58 -- pkg/vmcp/workloads/cli.go | 107 +++ pkg/vmcp/workloads/discoverer.go | 25 + pkg/vmcp/workloads/k8s.go | 220 ++++++ pkg/vmcp/workloads/mocks/mock_discoverer.go | 72 ++ pkg/workloads/k8s/k8s.go | 292 -------- pkg/workloads/k8s/k8s_test.go | 757 -------------------- pkg/workloads/k8s/manager.go | 40 -- pkg/workloads/k8s/mocks/mock_manager.go | 121 ---- pkg/workloads/k8s/workload.go | 45 -- pkg/workloads/manager.go | 8 +- 18 files changed, 981 insertions(+), 2273 deletions(-) delete mode 100644 pkg/vmcp/aggregator/cli_discoverer.go delete mode 100644 pkg/vmcp/aggregator/cli_discoverer_test.go delete mode 100644 pkg/vmcp/aggregator/discoverer_factory.go create mode 100644 pkg/vmcp/aggregator/discoverer_test.go delete mode 100644 pkg/vmcp/aggregator/k8s_discoverer.go delete mode 100644 pkg/vmcp/aggregator/k8s_discoverer_test.go create mode 100644 pkg/vmcp/workloads/cli.go create mode 100644 pkg/vmcp/workloads/discoverer.go create mode 100644 pkg/vmcp/workloads/k8s.go create mode 100644 pkg/vmcp/workloads/mocks/mock_discoverer.go delete mode 100644 pkg/workloads/k8s/k8s.go delete mode 100644 pkg/workloads/k8s/k8s_test.go delete mode 100644 pkg/workloads/k8s/manager.go delete mode 100644 pkg/workloads/k8s/mocks/mock_manager.go delete mode 100644 pkg/workloads/k8s/workload.go diff --git a/pkg/vmcp/aggregator/cli_discoverer.go b/pkg/vmcp/aggregator/cli_discoverer.go deleted file mode 100644 index 8c4b3cffc..000000000 --- a/pkg/vmcp/aggregator/cli_discoverer.go +++ /dev/null @@ -1,155 +0,0 @@ -package aggregator - -import ( - "context" - "fmt" - - rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/pkg/workloads" -) - -// cliBackendDiscoverer discovers backend MCP servers from Docker/Podman workloads in a group. -// This is the CLI version of BackendDiscoverer that uses the workloads.Manager. -type cliBackendDiscoverer struct { - workloadsManager workloads.Manager - groupsManager groups.Manager - authConfig *config.OutgoingAuthConfig -} - -// NewCLIBackendDiscoverer creates a new CLI-based backend discoverer. -// It discovers workloads from Docker/Podman containers managed by ToolHive. -// -// The authConfig parameter configures authentication for discovered backends. -// If nil, backends will have no authentication configured. -// -// This is the CLI-specific constructor. For Kubernetes workloads, use NewK8SBackendDiscoverer. -func NewCLIBackendDiscoverer( - workloadsManager workloads.Manager, - groupsManager groups.Manager, - authConfig *config.OutgoingAuthConfig, -) BackendDiscoverer { - return &cliBackendDiscoverer{ - workloadsManager: workloadsManager, - groupsManager: groupsManager, - authConfig: authConfig, - } -} - -// Discover finds all backend workloads in the specified group. -// Returns all accessible backends with their health status marked based on workload status. -// The groupRef is the group name (e.g., "engineering-team"). -func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { - logger.Infof("Discovering backends in group %s", groupRef) - - // Verify that the group exists - exists, err := d.groupsManager.Exists(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to check if group exists: %w", err) - } - if !exists { - return nil, fmt.Errorf("group %s not found", groupRef) - } - - // Get all workload names in the group - workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to list workloads in group: %w", err) - } - - if len(workloadNames) == 0 { - logger.Infof("No workloads found in group %s", groupRef) - return []vmcp.Backend{}, nil - } - - logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) - - // Query each workload and convert to backend - var backends []vmcp.Backend - for _, name := range workloadNames { - workload, err := d.workloadsManager.GetWorkload(ctx, name) - if err != nil { - logger.Warnf("Failed to get workload %s: %v, skipping", name, err) - continue - } - - // Skip workloads without a URL (not accessible) - if workload.URL == "" { - logger.Debugf("Skipping workload %s without URL", name) - continue - } - - // Map workload status to backend health status - healthStatus := mapWorkloadStatusToHealth(workload.Status) - - // Convert core.Workload to vmcp.Backend - // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. - // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. - // ProxyMode tells us which transport the vmcp client should use. - transportType := workload.ProxyMode - if transportType == "" { - // Fallback to TransportType if ProxyMode is not set (for direct transports) - transportType = workload.TransportType.String() - } - - backend := vmcp.Backend{ - ID: name, - Name: name, - BaseURL: workload.URL, - TransportType: transportType, - HealthStatus: healthStatus, - Metadata: make(map[string]string), - } - - // Apply authentication configuration if provided - authStrategy, authMetadata := d.authConfig.ResolveForBackend(name) - backend.AuthStrategy = authStrategy - backend.AuthMetadata = authMetadata - if authStrategy != "" { - logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) - } - - // Copy user labels to metadata first - for k, v := range workload.Labels { - backend.Metadata[k] = v - } - - // Set system metadata (these override user labels to prevent conflicts) - backend.Metadata["group"] = groupRef - backend.Metadata["tool_type"] = workload.ToolType - backend.Metadata["workload_status"] = string(workload.Status) - - backends = append(backends, backend) - logger.Debugf("Discovered backend %s: %s (%s) with health status %s", - backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) - } - - if len(backends) == 0 { - logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) - return []vmcp.Backend{}, nil - } - - logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) - return backends, nil -} - -// mapWorkloadStatusToHealth converts a workload status to a backend health status. -func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { - switch status { - case rt.WorkloadStatusRunning: - return vmcp.BackendHealthy - case rt.WorkloadStatusUnhealthy: - return vmcp.BackendUnhealthy - case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving: - return vmcp.BackendUnhealthy - case rt.WorkloadStatusStarting, rt.WorkloadStatusUnknown: - return vmcp.BackendUnknown - case rt.WorkloadStatusUnauthenticated: - return vmcp.BackendUnauthenticated - default: - return vmcp.BackendUnknown - } -} diff --git a/pkg/vmcp/aggregator/cli_discoverer_test.go b/pkg/vmcp/aggregator/cli_discoverer_test.go deleted file mode 100644 index 23d8155b7..000000000 --- a/pkg/vmcp/aggregator/cli_discoverer_test.go +++ /dev/null @@ -1,270 +0,0 @@ -package aggregator - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/groups/mocks" - "github.com/stacklok/toolhive/pkg/transport/types" - "github.com/stacklok/toolhive/pkg/vmcp" - workloadmocks "github.com/stacklok/toolhive/pkg/workloads/mocks" -) - -const testGroupName = "test-group" - -func TestCLIBackendDiscoverer_Discover(t *testing.T) { - t.Parallel() - - t.Run("successful discovery with multiple backends", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workload1 := newTestWorkload("workload1", - withToolType("github"), - withLabels(map[string]string{"env": "prod"})) - - workload2 := newTestWorkload("workload2", - withURL("http://localhost:8081/mcp"), - withTransport(types.TransportTypeSSE), - withToolType("jira")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1", "workload2"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 2) - assert.Equal(t, "workload1", backends[0].ID) - assert.Equal(t, "http://localhost:8080/mcp", backends[0].BaseURL) - assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) - assert.Equal(t, "github", backends[0].Metadata["tool_type"]) - assert.Equal(t, "prod", backends[0].Metadata["env"]) - assert.Equal(t, "workload2", backends[1].ID) - assert.Equal(t, "sse", backends[1].TransportType) - }) - - t.Run("discovers workloads with different statuses", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - runningWorkload := newTestWorkload("running-workload") - stoppedWorkload := newTestWorkload("stopped-workload", - withStatus(runtime.WorkloadStatusStopped), - withURL("http://localhost:8081/mcp"), - withTransport(types.TransportTypeSSE)) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"running-workload", "stopped-workload"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 2) - assert.Equal(t, "running-workload", backends[0].ID) - assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) - assert.Equal(t, "stopped-workload", backends[1].ID) - assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) - assert.Equal(t, "stopped", backends[1].Metadata["workload_status"]) - }) - - t.Run("filters out workloads without URL", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workloadWithURL := newTestWorkload("workload1") - workloadWithoutURL := newTestWorkload("workload2", withURL("")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1", "workload2"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 1) - assert.Equal(t, "workload1", backends[0].ID) - }) - - t.Run("returns empty list when all workloads lack URLs", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workload1 := newTestWorkload("workload1", withURL("")) - workload2 := newTestWorkload("workload2", withStatus(runtime.WorkloadStatusStopped), withURL("")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1", "workload2"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - assert.Empty(t, backends) - }) - - t.Run("returns error when group does not exist", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), "nonexistent-group") - - require.Error(t, err) - assert.Nil(t, backends) - assert.Contains(t, err.Error(), "not found") - }) - - t.Run("returns error when group check fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.Error(t, err) - assert.Nil(t, backends) - assert.Contains(t, err.Error(), "failed to check if group exists") - }) - - t.Run("returns empty list when group is empty", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), "empty-group") - - require.NoError(t, err) - assert.Empty(t, backends) - }) - - t.Run("discovers all workloads regardless of health status", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - stoppedWorkload := newTestWorkload("stopped1", withStatus(runtime.WorkloadStatusStopped)) - errorWorkload := newTestWorkload("error1", - withStatus(runtime.WorkloadStatusError), - withURL("http://localhost:8081/mcp"), - withTransport(types.TransportTypeSSE)) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"stopped1", "error1"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 2) - assert.Equal(t, vmcp.BackendUnhealthy, backends[0].HealthStatus) - assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) - }) - - t.Run("gracefully handles workload get failures", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - goodWorkload := newTestWorkload("good-workload") - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"good-workload", "failing-workload"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "good-workload").Return(goodWorkload, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). - Return(core.Workload{}, errors.New("workload query failed")) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 1) - assert.Equal(t, "good-workload", backends[0].ID) - }) - - t.Run("returns error when list workloads fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := workloadmocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return(nil, errors.New("failed to list workloads")) - - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.Error(t, err) - assert.Nil(t, backends) - assert.Contains(t, err.Error(), "failed to list workloads in group") - }) -} diff --git a/pkg/vmcp/aggregator/discoverer.go b/pkg/vmcp/aggregator/discoverer.go index b4cff44d4..aace5ef00 100644 --- a/pkg/vmcp/aggregator/discoverer.go +++ b/pkg/vmcp/aggregator/discoverer.go @@ -1,8 +1,169 @@ // Package aggregator provides platform-specific backend discovery implementations. // -// This file serves as a navigation reference for backend discovery implementations: -// - CLI (Docker/Podman): see cli_discoverer.go -// - Kubernetes: see k8s_discoverer.go +// This file contains: +// - Unified backend discoverer implementation (works with both CLI and Kubernetes) +// - Factory function to create BackendDiscoverer based on runtime environment +// - WorkloadDiscoverer interface and implementations are in pkg/vmcp/workloads // // The BackendDiscoverer interface is defined in aggregator.go. package aggregator + +import ( + "context" + "fmt" + + ct "github.com/stacklok/toolhive/pkg/container" + rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/groups" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/vmcp/workloads" + "github.com/stacklok/toolhive/pkg/workloads/statuses" +) + +// backendDiscoverer discovers backend MCP servers using a WorkloadDiscoverer. +// This is a unified discoverer that works with both CLI and Kubernetes workloads. +type backendDiscoverer struct { + workloadsManager workloads.Discoverer + groupsManager groups.Manager + authConfig *config.OutgoingAuthConfig +} + +// NewUnifiedBackendDiscoverer creates a unified backend discoverer that works with both +// CLI and Kubernetes workloads through the WorkloadDiscoverer interface. +// +// The authConfig parameter configures authentication for discovered backends. +// If nil, backends will have no authentication configured. +func NewUnifiedBackendDiscoverer( + workloadsManager workloads.Discoverer, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { + return &backendDiscoverer{ + workloadsManager: workloadsManager, + groupsManager: groupsManager, + authConfig: authConfig, + } +} + +// NewBackendDiscoverer creates a unified BackendDiscoverer based on the runtime environment. +// It automatically detects whether to use CLI (Docker/Podman) or Kubernetes workloads +// and creates the appropriate WorkloadDiscoverer implementation. +// +// Parameters: +// - ctx: Context for creating managers +// - groupsManager: Manager for group operations (must already be initialized) +// - authConfig: Outgoing authentication configuration for discovered backends +// +// Returns: +// - BackendDiscoverer: A unified discoverer that works with both CLI and Kubernetes workloads +// - error: If manager creation fails +func NewBackendDiscoverer( + ctx context.Context, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) (BackendDiscoverer, error) { + var workloadDiscoverer workloads.Discoverer + + if rt.IsKubernetesRuntime() { + k8sDiscoverer, err := workloads.NewK8SDiscoverer() + if err != nil { + return nil, fmt.Errorf("failed to create Kubernetes workload discoverer: %w", err) + } + workloadDiscoverer = k8sDiscoverer + } else { + // Create runtime and status manager for CLI workloads + runtime, err := ct.NewFactory().Create(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create runtime: %w", err) + } + + statusManager, err := statuses.NewStatusManager(runtime) + if err != nil { + return nil, fmt.Errorf("failed to create status manager: %w", err) + } + + workloadDiscoverer = workloads.NewCLIDiscoverer(statusManager) + } + + return NewUnifiedBackendDiscoverer(workloadDiscoverer, groupsManager, authConfig), nil +} + +// NewBackendDiscovererWithManager creates a unified BackendDiscoverer with a pre-configured +// WorkloadDiscoverer. This is useful for testing or when you already have a workload manager. +func NewBackendDiscovererWithManager( + workloadManager workloads.Discoverer, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { + return NewUnifiedBackendDiscoverer(workloadManager, groupsManager, authConfig) +} + +// Discover finds all backend workloads in the specified group. +// Returns all accessible backends with their health status marked based on workload status. +// The groupRef is the group name (e.g., "engineering-team"). +func (d *backendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { + logger.Infof("Discovering backends in group %s", groupRef) + + // Verify that the group exists + exists, err := d.groupsManager.Exists(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to check if group exists: %w", err) + } + if !exists { + return nil, fmt.Errorf("group %s not found", groupRef) + } + + // Get all workload names in the group + workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) + if err != nil { + return nil, fmt.Errorf("failed to list workloads in group: %w", err) + } + + if len(workloadNames) == 0 { + logger.Infof("No workloads found in group %s", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) + + // Query each workload and convert to backend + var backends []vmcp.Backend + for _, name := range workloadNames { + backend, err := d.workloadsManager.GetWorkload(ctx, name) + if err != nil { + logger.Warnf("Failed to get workload %s: %v, skipping", name, err) + continue + } + + // Skip workloads that are not accessible (GetWorkload returns nil) + if backend == nil { + continue + } + + // Apply authentication configuration if provided + authStrategy, authMetadata := d.authConfig.ResolveForBackend(name) + backend.AuthStrategy = authStrategy + backend.AuthMetadata = authMetadata + if authStrategy != "" { + logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) + } + + // Set group metadata (override user labels to prevent conflicts) + backend.Metadata["group"] = groupRef + + logger.Debugf("Discovered backend %s: %s (%s) with health status %s", + backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) + + backends = append(backends, *backend) + } + + if len(backends) == 0 { + logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) + return []vmcp.Backend{}, nil + } + + logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) + return backends, nil +} diff --git a/pkg/vmcp/aggregator/discoverer_factory.go b/pkg/vmcp/aggregator/discoverer_factory.go deleted file mode 100644 index 2f7089268..000000000 --- a/pkg/vmcp/aggregator/discoverer_factory.go +++ /dev/null @@ -1,44 +0,0 @@ -package aggregator - -import ( - "context" - "fmt" - - rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/pkg/workloads" - "github.com/stacklok/toolhive/pkg/workloads/k8s" -) - -// NewBackendDiscoverer creates a BackendDiscoverer based on the runtime environment. -// It automatically detects whether to use CLI (Docker/Podman) or Kubernetes discoverer -// and creates the appropriate workloads manager. -// -// Parameters: -// - ctx: Context for creating managers -// - groupsManager: Manager for group operations (must already be initialized) -// - authConfig: Outgoing authentication configuration for discovered backends -// -// Returns: -// - BackendDiscoverer: The appropriate discoverer for the current runtime -// - error: If manager creation fails -func NewBackendDiscoverer( - ctx context.Context, - groupsManager groups.Manager, - authConfig *config.OutgoingAuthConfig, -) (BackendDiscoverer, error) { - if rt.IsKubernetesRuntime() { - k8sWorkloadsManager, err := k8s.NewManagerFromContext(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create Kubernetes workloads manager: %w", err) - } - return NewK8SBackendDiscoverer(k8sWorkloadsManager, groupsManager, authConfig), nil - } - - cliWorkloadsManager, err := workloads.NewManager(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create CLI workloads manager: %w", err) - } - return NewCLIBackendDiscoverer(cliWorkloadsManager, groupsManager, authConfig), nil -} diff --git a/pkg/vmcp/aggregator/discoverer_test.go b/pkg/vmcp/aggregator/discoverer_test.go new file mode 100644 index 000000000..e43e464eb --- /dev/null +++ b/pkg/vmcp/aggregator/discoverer_test.go @@ -0,0 +1,389 @@ +package aggregator + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/groups/mocks" + "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/vmcp/workloads" + discoverermocks "github.com/stacklok/toolhive/pkg/vmcp/workloads/mocks" + statusmocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" +) + +const testGroupName = "test-group" + +func TestBackendDiscoverer_Discover(t *testing.T) { + t.Parallel() + + t.Run("successful discovery with multiple backends", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + backend1 := &vmcp.Backend{ + ID: "workload1", + Name: "workload1", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{ + "tool_type": "github", + "workload_status": "running", + "env": "prod", + }, + } + backend2 := &vmcp.Backend{ + ID: "workload2", + Name: "workload2", + BaseURL: "http://localhost:8081/mcp", + TransportType: "sse", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{ + "tool_type": "jira", + "workload_status": "running", + }, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(backend1, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(backend2, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, "workload1", backends[0].ID) + assert.Equal(t, "http://localhost:8080/mcp", backends[0].BaseURL) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "github", backends[0].Metadata["tool_type"]) + assert.Equal(t, "prod", backends[0].Metadata["env"]) + assert.Equal(t, "workload2", backends[1].ID) + assert.Equal(t, "sse", backends[1].TransportType) + }) + + t.Run("discovers workloads with different health statuses", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + healthyBackend := &vmcp.Backend{ + ID: "healthy-workload", + Name: "healthy-workload", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{"workload_status": "running"}, + } + unhealthyBackend := &vmcp.Backend{ + ID: "unhealthy-workload", + Name: "unhealthy-workload", + BaseURL: "http://localhost:8081/mcp", + TransportType: "sse", + HealthStatus: vmcp.BackendUnhealthy, + Metadata: map[string]string{"workload_status": "stopped"}, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"healthy-workload", "unhealthy-workload"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "healthy-workload").Return(healthyBackend, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "unhealthy-workload").Return(unhealthyBackend, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, "healthy-workload", backends[0].ID) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "unhealthy-workload", backends[1].ID) + assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) + assert.Equal(t, "stopped", backends[1].Metadata["workload_status"]) + }) + + t.Run("filters out workloads without URL", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + backendWithURL := &vmcp.Backend{ + ID: "workload1", + Name: "workload1", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{}, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(backendWithURL, nil) + // workload2 has no URL, so GetWorkload returns nil + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(nil, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "workload1", backends[0].ID) + }) + + t.Run("returns empty list when all workloads lack URLs", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1", "workload2"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(nil, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(nil, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + assert.Empty(t, backends) + }) + + t.Run("returns error when group does not exist", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), "nonexistent-group") + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("returns error when group check fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "failed to check if group exists") + }) + + t.Run("returns empty list when group is empty", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), "empty-group") + + require.NoError(t, err) + assert.Empty(t, backends) + }) + + t.Run("gracefully handles workload get failures", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + goodBackend := &vmcp.Backend{ + ID: "good-workload", + Name: "good-workload", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{}, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"good-workload", "failing-workload"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "good-workload").Return(goodBackend, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). + Return(nil, errors.New("workload query failed")) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "good-workload", backends[0].ID) + }) + + t.Run("returns error when list workloads fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return(nil, errors.New("failed to list workloads")) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.Error(t, err) + assert.Nil(t, backends) + assert.Contains(t, err.Error(), "failed to list workloads in group") + }) + + t.Run("applies authentication configuration", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockWorkloadDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + backend := &vmcp.Backend{ + ID: "workload1", + Name: "workload1", + BaseURL: "http://localhost:8080/mcp", + TransportType: "streamable-http", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{}, + } + + authConfig := &config.OutgoingAuthConfig{ + Backends: map[string]*config.BackendAuthStrategy{ + "workload1": { + Type: "bearer", + Metadata: map[string]any{ + "token": "test-token", + }, + }, + }, + } + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1"}, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(backend, nil) + + discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, authConfig) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "bearer", backends[0].AuthStrategy) + assert.Equal(t, "test-token", backends[0].AuthMetadata["token"]) + }) +} + +// TestCLIWorkloadDiscoverer tests the CLI workload discoverer implementation +// to ensure it correctly converts CLI workloads to backends. +func TestCLIWorkloadDiscoverer(t *testing.T) { + t.Parallel() + + t.Run("converts CLI workload to backend correctly", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStatusManager := statusmocks.NewMockStatusManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + workload := newTestWorkload("workload1", + withToolType("github"), + withLabels(map[string]string{"env": "prod"})) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockStatusManager.EXPECT().ListWorkloads(gomock.Any(), true, nil). + Return([]core.Workload{workload}, nil) + mockStatusManager.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload, nil) + + cliDiscoverer := workloads.NewCLIDiscoverer(mockStatusManager) + discoverer := NewUnifiedBackendDiscoverer(cliDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 1) + assert.Equal(t, "workload1", backends[0].ID) + assert.Equal(t, "http://localhost:8080/mcp", backends[0].BaseURL) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, "github", backends[0].Metadata["tool_type"]) + assert.Equal(t, "prod", backends[0].Metadata["env"]) + }) + + t.Run("maps CLI workload statuses to health correctly", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStatusManager := statusmocks.NewMockStatusManager(ctrl) + mockGroups := mocks.NewMockManager(ctrl) + + runningWorkload := newTestWorkload("running-workload") + stoppedWorkload := newTestWorkload("stopped-workload", + withStatus(runtime.WorkloadStatusStopped), + withURL("http://localhost:8081/mcp"), + withTransport(types.TransportTypeSSE)) + + mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) + mockStatusManager.EXPECT().ListWorkloads(gomock.Any(), true, nil). + Return([]core.Workload{runningWorkload, stoppedWorkload}, nil) + mockStatusManager.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) + mockStatusManager.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) + + cliDiscoverer := workloads.NewCLIDiscoverer(mockStatusManager) + discoverer := NewUnifiedBackendDiscoverer(cliDiscoverer, mockGroups, nil) + backends, err := discoverer.Discover(context.Background(), testGroupName) + + require.NoError(t, err) + require.Len(t, backends, 2) + assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) + assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) + }) +} diff --git a/pkg/vmcp/aggregator/k8s_discoverer.go b/pkg/vmcp/aggregator/k8s_discoverer.go deleted file mode 100644 index 10cdd114c..000000000 --- a/pkg/vmcp/aggregator/k8s_discoverer.go +++ /dev/null @@ -1,160 +0,0 @@ -// Package aggregator provides platform-agnostic backend discovery. -// This file contains the Kubernetes-specific discoverer implementation. -package aggregator - -import ( - "context" - "fmt" - - mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - "github.com/stacklok/toolhive/pkg/groups" - "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/pkg/workloads/k8s" -) - -// k8sBackendDiscoverer discovers backend MCP servers from Kubernetes workloads (MCPServer CRDs). -// It works with k8s.Manager and k8s.Workload. -type k8sBackendDiscoverer struct { - workloadsManager k8s.Manager - groupsManager groups.Manager - authConfig *config.OutgoingAuthConfig -} - -// NewK8SBackendDiscoverer creates a new Kubernetes-based backend discoverer. -// It discovers workloads from MCPServer CRDs managed by the ToolHive operator in Kubernetes. -// -// The authConfig parameter configures authentication for discovered backends. -// If nil, backends will have no authentication configured. -// -// This is the Kubernetes-specific constructor. For CLI workloads, use NewCLIBackendDiscoverer. -func NewK8SBackendDiscoverer( - workloadsManager k8s.Manager, - groupsManager groups.Manager, - authConfig *config.OutgoingAuthConfig, -) BackendDiscoverer { - return &k8sBackendDiscoverer{ - workloadsManager: workloadsManager, - groupsManager: groupsManager, - authConfig: authConfig, - } -} - -// Discover finds all backend workloads in the specified group. -func (d *k8sBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { - logger.Infof("Discovering Kubernetes backends in group %s", groupRef) - - // Verify that the group exists - exists, err := d.groupsManager.Exists(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to check if group exists: %w", err) - } - if !exists { - return nil, fmt.Errorf("group %s not found", groupRef) - } - - // Get all workload names in the group - workloadNames, err := d.workloadsManager.ListWorkloadsInGroup(ctx, groupRef) - if err != nil { - return nil, fmt.Errorf("failed to list workloads in group: %w", err) - } - - if len(workloadNames) == 0 { - logger.Infof("No workloads found in group %s", groupRef) - return []vmcp.Backend{}, nil - } - - logger.Debugf("Found %d workloads in group %s, discovering backends", len(workloadNames), groupRef) - - // Query each workload and convert to backend - var backends []vmcp.Backend - for _, name := range workloadNames { - workload, err := d.workloadsManager.GetWorkload(ctx, name) - if err != nil { - logger.Warnf("Failed to get workload %s: %v, skipping", name, err) - continue - } - - backend := d.convertK8SWorkload(workload, groupRef) - if backend != nil { - backends = append(backends, *backend) - } - } - - if len(backends) == 0 { - logger.Infof("No accessible backends found in group %s (all workloads lack URLs)", groupRef) - return []vmcp.Backend{}, nil - } - - logger.Infof("Discovered %d backends in group %s", len(backends), groupRef) - return backends, nil -} - -// convertK8SWorkload converts a k8s.Workload to a vmcp.Backend. -func (d *k8sBackendDiscoverer) convertK8SWorkload(workload k8s.Workload, groupRef string) *vmcp.Backend { - // Skip workloads without a URL (not accessible) - if workload.URL == "" { - logger.Debugf("Skipping workload %s without URL", workload.Name) - return nil - } - - // Map workload phase to backend health status - healthStatus := mapK8SWorkloadPhaseToHealth(workload.Phase) - - // Convert k8s.Workload to vmcp.Backend - transportType := workload.ProxyMode - if transportType == "" { - // Fallback to TransportType if ProxyMode is not set (for direct transports) - transportType = workload.TransportType.String() - } - - backend := vmcp.Backend{ - ID: workload.Name, - Name: workload.Name, - BaseURL: workload.URL, - TransportType: transportType, - HealthStatus: healthStatus, - Metadata: make(map[string]string), - } - - // Apply authentication configuration if provided - authStrategy, authMetadata := d.authConfig.ResolveForBackend(workload.Name) - backend.AuthStrategy = authStrategy - backend.AuthMetadata = authMetadata - if authStrategy != "" { - logger.Debugf("Backend %s configured with auth strategy: %s", workload.Name, authStrategy) - } - - // Copy user labels to metadata first - for k, v := range workload.Labels { - backend.Metadata[k] = v - } - - // Set system metadata (these override user labels to prevent conflicts) - backend.Metadata["group"] = groupRef - backend.Metadata["tool_type"] = workload.ToolType - backend.Metadata["workload_phase"] = string(workload.Phase) - backend.Metadata["namespace"] = workload.Namespace - - logger.Debugf("Discovered backend %s: %s (%s) with health status %s", - backend.ID, backend.BaseURL, backend.TransportType, backend.HealthStatus) - - return &backend -} - -// mapK8SWorkloadPhaseToHealth converts a MCPServerPhase to a backend health status. -func mapK8SWorkloadPhaseToHealth(phase mcpv1alpha1.MCPServerPhase) vmcp.BackendHealthStatus { - switch phase { - case mcpv1alpha1.MCPServerPhaseRunning: - return vmcp.BackendHealthy - case mcpv1alpha1.MCPServerPhaseFailed: - return vmcp.BackendUnhealthy - case mcpv1alpha1.MCPServerPhaseTerminating: - return vmcp.BackendUnhealthy - case mcpv1alpha1.MCPServerPhasePending: - return vmcp.BackendUnknown - default: - return vmcp.BackendUnknown - } -} diff --git a/pkg/vmcp/aggregator/k8s_discoverer_test.go b/pkg/vmcp/aggregator/k8s_discoverer_test.go deleted file mode 100644 index cc1a3eebe..000000000 --- a/pkg/vmcp/aggregator/k8s_discoverer_test.go +++ /dev/null @@ -1,324 +0,0 @@ -package aggregator - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - "github.com/stacklok/toolhive/pkg/groups/mocks" - "github.com/stacklok/toolhive/pkg/transport/types" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/workloads/k8s" - k8smocks "github.com/stacklok/toolhive/pkg/workloads/k8s/mocks" -) - -func TestK8SBackendDiscoverer_Discover(t *testing.T) { - t.Parallel() - - t.Run("successful discovery with multiple backends", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workload1 := newTestK8SWorkload("workload1", - withK8SToolType("github"), - withK8SLabels(map[string]string{"env": "prod"}), - withK8SNamespace("toolhive-system")) - - workload2 := newTestK8SWorkload("workload2", - withK8SURL("http://localhost:8081/mcp"), - withK8STransport(types.TransportTypeSSE), - withK8SToolType("jira"), - withK8SNamespace("toolhive-system")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1", "workload2"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 2) - assert.Equal(t, "workload1", backends[0].ID) - assert.Equal(t, "http://localhost:8080/mcp", backends[0].BaseURL) - assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) - assert.Equal(t, "github", backends[0].Metadata["tool_type"]) - assert.Equal(t, "prod", backends[0].Metadata["env"]) - assert.Equal(t, "toolhive-system", backends[0].Metadata["namespace"]) - assert.Equal(t, "workload2", backends[1].ID) - assert.Equal(t, "sse", backends[1].TransportType) - }) - - t.Run("discovers workloads with different phases", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - runningWorkload := newTestK8SWorkload("running-workload", - withK8SPhase(mcpv1alpha1.MCPServerPhaseRunning)) - failedWorkload := newTestK8SWorkload("failed-workload", - withK8SPhase(mcpv1alpha1.MCPServerPhaseFailed), - withK8SURL("http://localhost:8081/mcp"), - withK8STransport(types.TransportTypeSSE)) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"running-workload", "failed-workload"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failed-workload").Return(failedWorkload, nil) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 2) - assert.Equal(t, "running-workload", backends[0].ID) - assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) - assert.Equal(t, "failed-workload", backends[1].ID) - assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) - assert.Equal(t, string(mcpv1alpha1.MCPServerPhaseFailed), backends[1].Metadata["workload_phase"]) - }) - - t.Run("filters out workloads without URL", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workloadWithURL := newTestK8SWorkload("workload1") - workloadWithoutURL := newTestK8SWorkload("workload2", withK8SURL("")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1", "workload2"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 1) - assert.Equal(t, "workload1", backends[0].ID) - }) - - t.Run("returns empty list when all workloads lack URLs", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workload1 := newTestK8SWorkload("workload1", withK8SURL("")) - workload2 := newTestK8SWorkload("workload2", - withK8SPhase(mcpv1alpha1.MCPServerPhaseTerminating), - withK8SURL("")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1", "workload2"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - assert.Empty(t, backends) - }) - - t.Run("returns error when group does not exist", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), "nonexistent-group") - - require.Error(t, err) - assert.Nil(t, backends) - assert.Contains(t, err.Error(), "not found") - }) - - t.Run("returns error when group check fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.Error(t, err) - assert.Nil(t, backends) - assert.Contains(t, err.Error(), "failed to check if group exists") - }) - - t.Run("returns empty list when group is empty", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), "empty-group") - - require.NoError(t, err) - assert.Empty(t, backends) - }) - - t.Run("discovers all workloads regardless of phase", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - terminatingWorkload := newTestK8SWorkload("terminating1", - withK8SPhase(mcpv1alpha1.MCPServerPhaseTerminating)) - failedWorkload := newTestK8SWorkload("failed1", - withK8SPhase(mcpv1alpha1.MCPServerPhaseFailed), - withK8SURL("http://localhost:8081/mcp"), - withK8STransport(types.TransportTypeSSE)) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"terminating1", "failed1"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "terminating1").Return(terminatingWorkload, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failed1").Return(failedWorkload, nil) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 2) - assert.Equal(t, vmcp.BackendUnhealthy, backends[0].HealthStatus) - assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) - }) - - t.Run("gracefully handles workload get failures", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - goodWorkload := newTestK8SWorkload("good-workload") - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"good-workload", "failing-workload"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "good-workload").Return(goodWorkload, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). - Return(k8s.Workload{}, errors.New("MCPServer query failed")) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 1) - assert.Equal(t, "good-workload", backends[0].ID) - }) - - t.Run("returns error when list workloads fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return(nil, errors.New("failed to list workloads")) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.Error(t, err) - assert.Nil(t, backends) - assert.Contains(t, err.Error(), "failed to list workloads in group") - }) - - t.Run("handles pending phase correctly", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - pendingWorkload := newTestK8SWorkload("pending-workload", - withK8SPhase(mcpv1alpha1.MCPServerPhasePending)) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"pending-workload"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "pending-workload").Return(pendingWorkload, nil) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 1) - assert.Equal(t, vmcp.BackendUnknown, backends[0].HealthStatus) - assert.Equal(t, string(mcpv1alpha1.MCPServerPhasePending), backends[0].Metadata["workload_phase"]) - }) - - t.Run("includes namespace in metadata", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockWorkloads := k8smocks.NewMockManager(ctrl) - mockGroups := mocks.NewMockManager(ctrl) - - workload := newTestK8SWorkload("workload1", - withK8SNamespace("custom-namespace")) - - mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). - Return([]string{"workload1"}, nil) - mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload, nil) - - discoverer := NewK8SBackendDiscoverer(mockWorkloads, mockGroups, nil) - backends, err := discoverer.Discover(context.Background(), testGroupName) - - require.NoError(t, err) - require.Len(t, backends, 1) - assert.Equal(t, "custom-namespace", backends[0].Metadata["namespace"]) - }) -} diff --git a/pkg/vmcp/aggregator/testhelpers_test.go b/pkg/vmcp/aggregator/testhelpers_test.go index a25d5b767..0b766c508 100644 --- a/pkg/vmcp/aggregator/testhelpers_test.go +++ b/pkg/vmcp/aggregator/testhelpers_test.go @@ -1,12 +1,10 @@ package aggregator import ( - mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/workloads/k8s" ) // Test fixture builders to reduce verbosity in tests @@ -55,62 +53,6 @@ func withLabels(labels map[string]string) func(*core.Workload) { } } -// K8s workload test helpers - -func newTestK8SWorkload(name string, opts ...func(*k8s.Workload)) k8s.Workload { - w := k8s.Workload{ - Name: name, - Namespace: "default", - Phase: mcpv1alpha1.MCPServerPhaseRunning, - URL: "http://localhost:8080/mcp", - TransportType: types.TransportTypeStreamableHTTP, - ToolType: "mcp", - Group: testGroupName, - GroupRef: testGroupName, - Labels: make(map[string]string), - } - for _, opt := range opts { - opt(&w) - } - return w -} - -func withK8SPhase(phase mcpv1alpha1.MCPServerPhase) func(*k8s.Workload) { - return func(w *k8s.Workload) { - w.Phase = phase - } -} - -func withK8SURL(url string) func(*k8s.Workload) { - return func(w *k8s.Workload) { - w.URL = url - } -} - -func withK8STransport(transport types.TransportType) func(*k8s.Workload) { - return func(w *k8s.Workload) { - w.TransportType = transport - } -} - -func withK8SToolType(toolType string) func(*k8s.Workload) { - return func(w *k8s.Workload) { - w.ToolType = toolType - } -} - -func withK8SLabels(labels map[string]string) func(*k8s.Workload) { - return func(w *k8s.Workload) { - w.Labels = labels - } -} - -func withK8SNamespace(namespace string) func(*k8s.Workload) { - return func(w *k8s.Workload) { - w.Namespace = namespace - } -} - func newTestBackend(id string, opts ...func(*vmcp.Backend)) vmcp.Backend { b := vmcp.Backend{ ID: id, diff --git a/pkg/vmcp/workloads/cli.go b/pkg/vmcp/workloads/cli.go new file mode 100644 index 000000000..d5e2b0ca8 --- /dev/null +++ b/pkg/vmcp/workloads/cli.go @@ -0,0 +1,107 @@ +package workloads + +import ( + "context" + + rt "github.com/stacklok/toolhive/pkg/container/runtime" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/workloads/statuses" +) + +// cliDiscoverer is a direct implementation of Discoverer for CLI workloads. +// It uses the status manager directly instead of going through workloads.Manager. +type cliDiscoverer struct { + statusManager statuses.StatusManager +} + +// NewCLIDiscoverer creates a new CLI workload discoverer that directly uses +// the status manager to discover workloads. +func NewCLIDiscoverer(statusManager statuses.StatusManager) Discoverer { + return &cliDiscoverer{ + statusManager: statusManager, + } +} + +// ListWorkloadsInGroup returns all workload names that belong to the specified group. +func (d *cliDiscoverer) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + // List all workloads (including stopped ones) + workloads, err := d.statusManager.ListWorkloads(ctx, true, nil) + if err != nil { + return nil, err + } + + // Filter workloads that belong to the specified group + var groupWorkloads []string + for _, workload := range workloads { + if workload.Group == groupName { + groupWorkloads = append(groupWorkloads, workload.Name) + } + } + + return groupWorkloads, nil +} + +// GetWorkload retrieves workload details by name and converts it to a vmcp.Backend. +func (d *cliDiscoverer) GetWorkload(ctx context.Context, workloadName string) (*vmcp.Backend, error) { + workload, err := d.statusManager.GetWorkload(ctx, workloadName) + if err != nil { + return nil, err + } + + // Skip workloads without a URL (not accessible) + if workload.URL == "" { + logger.Debugf("Skipping workload %s without URL", workloadName) + return nil, nil + } + + // Map workload status to backend health status + healthStatus := mapCLIWorkloadStatusToHealth(workload.Status) + + // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. + // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. + // ProxyMode tells us which transport the vmcp client should use. + transportType := workload.ProxyMode + if transportType == "" { + // Fallback to TransportType if ProxyMode is not set (for direct transports) + transportType = workload.TransportType.String() + } + + backend := &vmcp.Backend{ + ID: workload.Name, + Name: workload.Name, + BaseURL: workload.URL, + TransportType: transportType, + HealthStatus: healthStatus, + Metadata: make(map[string]string), + } + + // Copy user labels to metadata first + for k, v := range workload.Labels { + backend.Metadata[k] = v + } + + // Set system metadata (these override user labels to prevent conflicts) + backend.Metadata["tool_type"] = workload.ToolType + backend.Metadata["workload_status"] = string(workload.Status) + + return backend, nil +} + +// mapCLIWorkloadStatusToHealth converts a CLI WorkloadStatus to a backend health status. +func mapCLIWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { + switch status { + case rt.WorkloadStatusRunning: + return vmcp.BackendHealthy + case rt.WorkloadStatusUnhealthy: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStarting, rt.WorkloadStatusUnknown: + return vmcp.BackendUnknown + case rt.WorkloadStatusUnauthenticated: + return vmcp.BackendUnauthenticated + default: + return vmcp.BackendUnknown + } +} diff --git a/pkg/vmcp/workloads/discoverer.go b/pkg/vmcp/workloads/discoverer.go new file mode 100644 index 000000000..0be4e89cc --- /dev/null +++ b/pkg/vmcp/workloads/discoverer.go @@ -0,0 +1,25 @@ +// Package workloads provides the WorkloadDiscoverer interface for discovering +// backend workloads in both CLI and Kubernetes environments. +package workloads + +import ( + "context" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// Discoverer is the interface for workload managers used by vmcp. +// This interface contains only the methods needed for backend discovery, +// allowing both CLI and Kubernetes managers to implement it. +// +//go:generate mockgen -destination=mocks/mock_discoverer.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/workloads Discoverer +type Discoverer interface { + // ListWorkloadsInGroup returns all workload names that belong to the specified group + ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) + + // GetWorkload retrieves workload details by name and converts it to a vmcp.Backend. + // The returned Backend should have all fields populated except AuthStrategy and AuthMetadata, + // which will be set by the discoverer based on the auth configuration. + // Returns nil if the workload exists but is not accessible (e.g., no URL). + GetWorkload(ctx context.Context, workloadName string) (*vmcp.Backend, error) +} diff --git a/pkg/vmcp/workloads/k8s.go b/pkg/vmcp/workloads/k8s.go new file mode 100644 index 000000000..834782f9e --- /dev/null +++ b/pkg/vmcp/workloads/k8s.go @@ -0,0 +1,220 @@ +package workloads + +import ( + "context" + "fmt" + "strings" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + "github.com/stacklok/toolhive/pkg/k8s" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/transport" + transporttypes "github.com/stacklok/toolhive/pkg/transport/types" + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// k8sDiscoverer is a direct implementation of Discoverer for Kubernetes workloads. +// It uses the Kubernetes client directly to query MCPServer CRDs instead of going through k8s.Manager. +type k8sDiscoverer struct { + k8sClient client.Client + namespace string +} + +// NewK8SDiscoverer creates a new Kubernetes workload discoverer that directly uses +// the Kubernetes client to discover MCPServer CRDs. +func NewK8SDiscoverer() (Discoverer, error) { + // Create a scheme for controller-runtime client + scheme := runtime.NewScheme() + utilruntime.Must(clientgoscheme.AddToScheme(scheme)) + utilruntime.Must(mcpv1alpha1.AddToScheme(scheme)) + + // Create controller-runtime client + k8sClient, err := k8s.NewControllerRuntimeClient(scheme) + if err != nil { + return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) + } + + // Detect namespace + namespace := k8s.GetCurrentNamespace() + + return &k8sDiscoverer{ + k8sClient: k8sClient, + namespace: namespace, + }, nil +} + +// ListWorkloadsInGroup returns all workload names that belong to the specified group. +func (d *k8sDiscoverer) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + mcpServerList := &mcpv1alpha1.MCPServerList{} + listOpts := []client.ListOption{ + client.InNamespace(d.namespace), + } + + if err := d.k8sClient.List(ctx, mcpServerList, listOpts...); err != nil { + return nil, fmt.Errorf("failed to list MCPServers: %w", err) + } + + var groupWorkloads []string + for i := range mcpServerList.Items { + mcpServer := &mcpServerList.Items[i] + if mcpServer.Spec.GroupRef == groupName { + groupWorkloads = append(groupWorkloads, mcpServer.Name) + } + } + + return groupWorkloads, nil +} + +// GetWorkload retrieves workload details by name and converts it to a vmcp.Backend. +func (d *k8sDiscoverer) GetWorkload(ctx context.Context, workloadName string) (*vmcp.Backend, error) { + mcpServer := &mcpv1alpha1.MCPServer{} + key := client.ObjectKey{Name: workloadName, Namespace: d.namespace} + if err := d.k8sClient.Get(ctx, key, mcpServer); err != nil { + if errors.IsNotFound(err) { + return nil, fmt.Errorf("MCPServer %s not found", workloadName) + } + return nil, fmt.Errorf("failed to get MCPServer: %w", err) + } + + // Convert MCPServer to Backend + backend := d.mcpServerToBackend(mcpServer) + + // Skip workloads without a URL (not accessible) + if backend.BaseURL == "" { + logger.Debugf("Skipping workload %s without URL", workloadName) + return nil, nil + } + + return backend, nil +} + +// mcpServerToBackend converts an MCPServer CRD to a vmcp.Backend. +func (*k8sDiscoverer) mcpServerToBackend(mcpServer *mcpv1alpha1.MCPServer) *vmcp.Backend { + // Parse transport type + transportType, err := transporttypes.ParseTransportType(mcpServer.Spec.Transport) + if err != nil { + logger.Warnf("Failed to parse transport type %s for MCPServer %s: %v", mcpServer.Spec.Transport, mcpServer.Name, err) + transportType = transporttypes.TransportTypeStreamableHTTP + } + + // Calculate effective proxy mode + effectiveProxyMode := getEffectiveProxyMode(transportType, mcpServer.Spec.ProxyMode) + + // Generate URL from status or reconstruct from spec + url := mcpServer.Status.URL + if url == "" { + port := int(mcpServer.Spec.ProxyPort) + if port == 0 { + port = int(mcpServer.Spec.Port) // Fallback to deprecated Port field + } + if port > 0 { + url = transport.GenerateMCPServerURL(mcpServer.Spec.Transport, transport.LocalhostIPv4, port, mcpServer.Name, "") + } + } + + // Map workload phase to backend health status + healthStatus := mapK8SWorkloadPhaseToHealth(mcpServer.Status.Phase) + + // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. + // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. + // ProxyMode tells us which transport the vmcp client should use. + transportTypeStr := effectiveProxyMode + if transportTypeStr == "" { + // Fallback to TransportType if ProxyMode is not set (for direct transports) + transportTypeStr = transportType.String() + if transportTypeStr == "" { + transportTypeStr = "unknown" + } + } + + // Extract user labels from annotations (Kubernetes doesn't have container labels like Docker) + userLabels := make(map[string]string) + if mcpServer.Annotations != nil { + // Filter out standard Kubernetes annotations + for key, value := range mcpServer.Annotations { + if !isStandardK8sAnnotation(key) { + userLabels[key] = value + } + } + } + + backend := &vmcp.Backend{ + ID: mcpServer.Name, + Name: mcpServer.Name, + BaseURL: url, + TransportType: transportTypeStr, + HealthStatus: healthStatus, + Metadata: make(map[string]string), + } + + // Copy user labels to metadata first + for k, v := range userLabels { + backend.Metadata[k] = v + } + + // Set system metadata (these override user labels to prevent conflicts) + backend.Metadata["tool_type"] = "mcp" + backend.Metadata["workload_status"] = string(mcpServer.Status.Phase) + if mcpServer.Namespace != "" { + backend.Metadata["namespace"] = mcpServer.Namespace + } + + return backend +} + +// mapK8SWorkloadPhaseToHealth converts a MCPServerPhase to a backend health status. +func mapK8SWorkloadPhaseToHealth(phase mcpv1alpha1.MCPServerPhase) vmcp.BackendHealthStatus { + switch phase { + case mcpv1alpha1.MCPServerPhaseRunning: + return vmcp.BackendHealthy + case mcpv1alpha1.MCPServerPhaseFailed: + return vmcp.BackendUnhealthy + case mcpv1alpha1.MCPServerPhaseTerminating: + return vmcp.BackendUnhealthy + case mcpv1alpha1.MCPServerPhasePending: + return vmcp.BackendUnknown + default: + return vmcp.BackendUnknown + } +} + +// getEffectiveProxyMode calculates the effective proxy mode based on transport type and configured proxy mode. +// This replicates the logic from pkg/workloads/types/proxy_mode.go +func getEffectiveProxyMode(transportType transporttypes.TransportType, configuredProxyMode string) string { + // If proxy mode is explicitly configured, use it + if configuredProxyMode != "" { + return configuredProxyMode + } + + // For stdio transports, default to streamable-http proxy mode + if transportType == transporttypes.TransportTypeStdio { + return transporttypes.ProxyModeStreamableHTTP.String() + } + + // For direct transports (SSE, streamable-http), use the transport type as proxy mode + return transportType.String() +} + +// isStandardK8sAnnotation checks if an annotation key is a standard Kubernetes annotation. +func isStandardK8sAnnotation(key string) bool { + // Common Kubernetes annotation prefixes + standardPrefixes := []string{ + "kubectl.kubernetes.io/", + "kubernetes.io/", + "deployment.kubernetes.io/", + "k8s.io/", + } + + for _, prefix := range standardPrefixes { + if strings.HasPrefix(key, prefix) { + return true + } + } + return false +} diff --git a/pkg/vmcp/workloads/mocks/mock_discoverer.go b/pkg/vmcp/workloads/mocks/mock_discoverer.go new file mode 100644 index 000000000..7d3719a76 --- /dev/null +++ b/pkg/vmcp/workloads/mocks/mock_discoverer.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/vmcp/workloads (interfaces: Discoverer) +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_discoverer.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/workloads Discoverer +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + vmcp "github.com/stacklok/toolhive/pkg/vmcp" + gomock "go.uber.org/mock/gomock" +) + +// MockDiscoverer is a mock of Discoverer interface. +type MockDiscoverer struct { + ctrl *gomock.Controller + recorder *MockDiscovererMockRecorder + isgomock struct{} +} + +// MockDiscovererMockRecorder is the mock recorder for MockDiscoverer. +type MockDiscovererMockRecorder struct { + mock *MockDiscoverer +} + +// NewMockDiscoverer creates a new mock instance. +func NewMockDiscoverer(ctrl *gomock.Controller) *MockDiscoverer { + mock := &MockDiscoverer{ctrl: ctrl} + mock.recorder = &MockDiscovererMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDiscoverer) EXPECT() *MockDiscovererMockRecorder { + return m.recorder +} + +// GetWorkload mocks base method. +func (m *MockDiscoverer) GetWorkload(ctx context.Context, workloadName string) (*vmcp.Backend, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWorkload", ctx, workloadName) + ret0, _ := ret[0].(*vmcp.Backend) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetWorkload indicates an expected call of GetWorkload. +func (mr *MockDiscovererMockRecorder) GetWorkload(ctx, workloadName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkload", reflect.TypeOf((*MockDiscoverer)(nil).GetWorkload), ctx, workloadName) +} + +// ListWorkloadsInGroup mocks base method. +func (m *MockDiscoverer) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListWorkloadsInGroup", ctx, groupName) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListWorkloadsInGroup indicates an expected call of ListWorkloadsInGroup. +func (mr *MockDiscovererMockRecorder) ListWorkloadsInGroup(ctx, groupName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloadsInGroup", reflect.TypeOf((*MockDiscoverer)(nil).ListWorkloadsInGroup), ctx, groupName) +} diff --git a/pkg/workloads/k8s/k8s.go b/pkg/workloads/k8s/k8s.go deleted file mode 100644 index 2023a01f3..000000000 --- a/pkg/workloads/k8s/k8s.go +++ /dev/null @@ -1,292 +0,0 @@ -// Package k8s provides Kubernetes-specific workload management. -// This file contains the Kubernetes implementation for operator environments. -package k8s - -import ( - "context" - "fmt" - "strings" - "time" - - "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/labels" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/selection" - "k8s.io/apimachinery/pkg/types" - utilruntime "k8s.io/apimachinery/pkg/util/runtime" - clientgoscheme "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client" - - mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - "github.com/stacklok/toolhive/pkg/k8s" - "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/transport" - transporttypes "github.com/stacklok/toolhive/pkg/transport/types" - workloadtypes "github.com/stacklok/toolhive/pkg/workloads/types" -) - -// manager implements the Manager interface for Kubernetes environments. -// In Kubernetes, the operator manages workload lifecycle via MCPServer CRDs. -// This manager provides read-only operations and CRD-based storage. -type manager struct { - k8sClient client.Client - namespace string -} - -// NewManager creates a new Kubernetes-based workload manager. -func NewManager(k8sClient client.Client, namespace string) (Manager, error) { - return &manager{ - k8sClient: k8sClient, - namespace: namespace, - }, nil -} - -// NewManagerFromContext creates a Kubernetes-based workload manager from context. -// It automatically sets up the Kubernetes client and detects the namespace. -func NewManagerFromContext(_ context.Context) (Manager, error) { - // Create a scheme for controller-runtime client - scheme := runtime.NewScheme() - utilruntime.Must(clientgoscheme.AddToScheme(scheme)) - utilruntime.Must(mcpv1alpha1.AddToScheme(scheme)) - - // Create controller-runtime client - k8sClient, err := k8s.NewControllerRuntimeClient(scheme) - if err != nil { - return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) - } - - // Detect namespace - namespace := k8s.GetCurrentNamespace() - - return NewManager(k8sClient, namespace) -} - -func (k *manager) GetWorkload(ctx context.Context, workloadName string) (Workload, error) { - mcpServer := &mcpv1alpha1.MCPServer{} - key := types.NamespacedName{Name: workloadName, Namespace: k.namespace} - if err := k.k8sClient.Get(ctx, key, mcpServer); err != nil { - if errors.IsNotFound(err) { - return Workload{}, fmt.Errorf("MCPServer %s not found", workloadName) - } - return Workload{}, fmt.Errorf("failed to get MCPServer: %w", err) - } - - return k.mcpServerToWorkload(mcpServer) -} - -func (k *manager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { - mcpServer := &mcpv1alpha1.MCPServer{} - key := types.NamespacedName{Name: workloadName, Namespace: k.namespace} - if err := k.k8sClient.Get(ctx, key, mcpServer); err != nil { - if errors.IsNotFound(err) { - return false, nil - } - return false, fmt.Errorf("failed to check if workload exists: %w", err) - } - return true, nil -} - -func (k *manager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]Workload, error) { - mcpServerList := &mcpv1alpha1.MCPServerList{} - listOpts := []client.ListOption{ - client.InNamespace(k.namespace), - } - - // Parse label filters if provided - if len(labelFilters) > 0 { - parsedFilters, err := workloadtypes.ParseLabelFilters(labelFilters) - if err != nil { - return nil, fmt.Errorf("failed to parse label filters: %w", err) - } - - // Build label selector from filters (equality matching) - labelSelector := labels.NewSelector() - for key, value := range parsedFilters { - requirement, err := labels.NewRequirement(key, selection.Equals, []string{value}) - if err != nil { - return nil, fmt.Errorf("failed to create label requirement: %w", err) - } - labelSelector = labelSelector.Add(*requirement) - } - listOpts = append(listOpts, client.MatchingLabelsSelector{Selector: labelSelector}) - } - - if err := k.k8sClient.List(ctx, mcpServerList, listOpts...); err != nil { - return nil, fmt.Errorf("failed to list MCPServers: %w", err) - } - - var workloads []Workload - for i := range mcpServerList.Items { - mcpServer := &mcpServerList.Items[i] - - // Filter by status if listAll is false - if !listAll { - phase := mcpServer.Status.Phase - if phase != mcpv1alpha1.MCPServerPhaseRunning { - continue - } - } - - workload, err := k.mcpServerToWorkload(mcpServer) - if err != nil { - logger.Warnf("Failed to convert MCPServer %s to workload: %v", mcpServer.Name, err) - continue - } - - workloads = append(workloads, workload) - } - - return workloads, nil -} - -// Note: The following operations are not part of Manager interface: -// - StopWorkloads: Use kubectl to manage MCPServer CRDs -// - RunWorkload: Create MCPServer CRD instead -// - RunWorkloadDetached: Create MCPServer CRD instead -// - DeleteWorkloads: Use kubectl to delete MCPServer CRDs -// - RestartWorkloads: Use kubectl to restart MCPServer CRDs -// - UpdateWorkload: Update MCPServer CRD directly -// - GetLogs: Use 'kubectl logs -n ' to retrieve logs -// - GetProxyLogs: Use 'kubectl logs -c proxy -n ' to retrieve proxy logs - -// MoveToGroup moves the specified workloads from one group to another. -func (k *manager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { - for _, name := range workloadNames { - mcpServer := &mcpv1alpha1.MCPServer{} - key := types.NamespacedName{Name: name, Namespace: k.namespace} - if err := k.k8sClient.Get(ctx, key, mcpServer); err != nil { - if errors.IsNotFound(err) { - return fmt.Errorf("MCPServer %s not found", name) - } - return fmt.Errorf("failed to get MCPServer: %w", err) - } - - // Verify the workload is in the expected group - if mcpServer.Spec.GroupRef != groupFrom { - return fmt.Errorf("workload %s is not in group %s (current group: %s)", name, groupFrom, mcpServer.Spec.GroupRef) - } - - // Update the group - mcpServer.Spec.GroupRef = groupTo - - // Update the MCPServer - if err := k.k8sClient.Update(ctx, mcpServer); err != nil { - return fmt.Errorf("failed to update MCPServer %s: %w", name, err) - } - } - - return nil -} - -// ListWorkloadsInGroup returns all workload names that belong to the specified group. -func (k *manager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { - mcpServerList := &mcpv1alpha1.MCPServerList{} - listOpts := []client.ListOption{ - client.InNamespace(k.namespace), - } - - if err := k.k8sClient.List(ctx, mcpServerList, listOpts...); err != nil { - return nil, fmt.Errorf("failed to list MCPServers: %w", err) - } - - var groupWorkloads []string - for i := range mcpServerList.Items { - mcpServer := &mcpServerList.Items[i] - if mcpServer.Spec.GroupRef == groupName { - groupWorkloads = append(groupWorkloads, mcpServer.Name) - } - } - - return groupWorkloads, nil -} - -// mcpServerToWorkload converts an MCPServer CRD to a Workload. -func (k *manager) mcpServerToWorkload(mcpServer *mcpv1alpha1.MCPServer) (Workload, error) { - // Parse transport type - transportType, err := transporttypes.ParseTransportType(mcpServer.Spec.Transport) - if err != nil { - logger.Warnf("Failed to parse transport type %s for MCPServer %s: %v", mcpServer.Spec.Transport, mcpServer.Name, err) - transportType = transporttypes.TransportTypeSSE - } - - // Calculate effective proxy mode - effectiveProxyMode := workloadtypes.GetEffectiveProxyMode(transportType, mcpServer.Spec.ProxyMode) - - // Generate URL from status or reconstruct from spec - url := mcpServer.Status.URL - if url == "" { - port := int(mcpServer.Spec.ProxyPort) - if port == 0 { - port = int(mcpServer.Spec.Port) // Fallback to deprecated Port field - } - if port > 0 { - url = transport.GenerateMCPServerURL(mcpServer.Spec.Transport, transport.LocalhostIPv4, port, mcpServer.Name, "") - } - } - - port := int(mcpServer.Spec.ProxyPort) - if port == 0 { - port = int(mcpServer.Spec.Port) // Fallback to deprecated Port field - } - - // Get tools filter from spec - toolsFilter := mcpServer.Spec.ToolsFilter - if mcpServer.Spec.ToolConfigRef != nil { - // If ToolConfigRef is set, we can't reconstruct the tools filter here - // The tools filter would be resolved by the operator - toolsFilter = []string{} - } - - // Extract user labels from annotations (Kubernetes doesn't have container labels like Docker) - userLabels := make(map[string]string) - if mcpServer.Annotations != nil { - // Filter out standard Kubernetes annotations - for key, value := range mcpServer.Annotations { - if !k.isStandardK8sAnnotation(key) { - userLabels[key] = value - } - } - } - - // Get creation timestamp - createdAt := mcpServer.CreationTimestamp.Time - if createdAt.IsZero() { - createdAt = time.Now() - } - - return Workload{ - Name: mcpServer.Name, - Namespace: mcpServer.Namespace, - Package: mcpServer.Spec.Image, - URL: url, - ToolType: "mcp", - TransportType: transportType, - ProxyMode: effectiveProxyMode, - Phase: mcpServer.Status.Phase, - StatusContext: mcpServer.Status.Message, - CreatedAt: createdAt, - Port: port, - Labels: userLabels, - Group: mcpServer.Spec.GroupRef, - GroupRef: mcpServer.Spec.GroupRef, - ToolsFilter: toolsFilter, - }, nil -} - -// isStandardK8sAnnotation checks if an annotation key is a standard Kubernetes annotation. -func (*manager) isStandardK8sAnnotation(key string) bool { - // Common Kubernetes annotation prefixes - standardPrefixes := []string{ - "kubectl.kubernetes.io/", - "kubernetes.io/", - "deployment.kubernetes.io/", - "k8s.io/", - } - - for _, prefix := range standardPrefixes { - if strings.HasPrefix(key, prefix) { - return true - } - } - return false -} diff --git a/pkg/workloads/k8s/k8s_test.go b/pkg/workloads/k8s/k8s_test.go deleted file mode 100644 index 726e4b867..000000000 --- a/pkg/workloads/k8s/k8s_test.go +++ /dev/null @@ -1,757 +0,0 @@ -package k8s - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - k8serrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime/schema" - "sigs.k8s.io/controller-runtime/pkg/client" - - mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" -) - -const ( - defaultNamespace = "default" - testWorkload1 = "workload1" -) - -// mockClient is a mock implementation of client.Client for testing -type mockClient struct { - client.Client - getFunc func(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error - listFunc func(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error - updateFunc func(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error -} - -func (m *mockClient) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { - if m.getFunc != nil { - return m.getFunc(ctx, key, obj, opts...) - } - return k8serrors.NewNotFound(schema.GroupResource{}, key.Name) -} - -func (m *mockClient) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { - if m.listFunc != nil { - return m.listFunc(ctx, list, opts...) - } - return nil -} - -func (m *mockClient) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { - if m.updateFunc != nil { - return m.updateFunc(ctx, obj, opts...) - } - return nil -} - -func TestNewManager(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - k8sClient client.Client - namespace string - wantError bool - }{ - { - name: "successful creation", - k8sClient: &mockClient{}, - namespace: defaultNamespace, - wantError: false, - }, - { - name: "empty namespace", - k8sClient: &mockClient{}, - namespace: "", - wantError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - k8sManager, err := NewManager(tt.k8sClient, tt.namespace) - - if tt.wantError { - require.Error(t, err) - assert.Nil(t, k8sManager) - } else { - require.NoError(t, err) - require.NotNil(t, k8sManager) - - mgr, ok := k8sManager.(*manager) - require.True(t, ok) - assert.Equal(t, tt.k8sClient, mgr.k8sClient) - assert.Equal(t, tt.namespace, mgr.namespace) - } - }) - } -} - -func TestManager_GetWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - setupMock func(*mockClient) - wantError bool - errorMsg string - expected Workload - }{ - { - name: "successful get", - workloadName: "test-workload", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { - if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { - mcpServer.Name = "test-workload" - mcpServer.Namespace = defaultNamespace - mcpServer.Status.Phase = mcpv1alpha1.MCPServerPhaseRunning - mcpServer.Spec.Transport = "streamable-http" - mcpServer.Spec.ProxyPort = 8080 - mcpServer.Annotations = map[string]string{ - "group": "test-group", - } - } - return nil - } - }, - wantError: false, - expected: Workload{ - Name: "test-workload", - Namespace: defaultNamespace, - Phase: mcpv1alpha1.MCPServerPhaseRunning, - URL: "http://127.0.0.1:8080/mcp", // URL is generated from spec - Labels: map[string]string{ - "group": "test-group", - }, - }, - }, - { - name: "workload not found", - workloadName: "non-existent", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, key client.ObjectKey, _ client.Object, _ ...client.GetOption) error { - return k8serrors.NewNotFound(schema.GroupResource{Resource: "mcpservers"}, key.Name) - } - }, - wantError: true, - errorMsg: "MCPServer non-existent not found", - }, - { - name: "get error", - workloadName: "error-workload", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, _ client.ObjectKey, _ client.Object, _ ...client.GetOption) error { - return k8serrors.NewInternalError(errors.New("internal error")) - } - }, - wantError: true, - errorMsg: "failed to get MCPServer", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - mockClient := &mockClient{} - tt.setupMock(mockClient) - - mgr := &manager{ - k8sClient: mockClient, - namespace: defaultNamespace, - } - - ctx := context.Background() - result, err := mgr.GetWorkload(ctx, tt.workloadName) - - if tt.wantError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - assert.Equal(t, tt.expected.Name, result.Name) - assert.Equal(t, tt.expected.Phase, result.Phase) - assert.Equal(t, tt.expected.URL, result.URL) - } - }) - } -} - -func TestManager_DoesWorkloadExist(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - setupMock func(*mockClient) - expected bool - wantError bool - }{ - { - name: "workload exists", - workloadName: "existing-workload", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { - if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { - mcpServer.Name = "existing-workload" - } - return nil - } - }, - expected: true, - wantError: false, - }, - { - name: "workload does not exist", - workloadName: "non-existent", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, key client.ObjectKey, _ client.Object, _ ...client.GetOption) error { - return k8serrors.NewNotFound(schema.GroupResource{Resource: "mcpservers"}, key.Name) - } - }, - expected: false, - wantError: false, - }, - { - name: "get error", - workloadName: "error-workload", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, _ client.ObjectKey, _ client.Object, _ ...client.GetOption) error { - return k8serrors.NewInternalError(errors.New("internal error")) - } - }, - expected: false, - wantError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - mockClient := &mockClient{} - tt.setupMock(mockClient) - - mgr := &manager{ - k8sClient: mockClient, - namespace: defaultNamespace, - } - - ctx := context.Background() - result, err := mgr.DoesWorkloadExist(ctx, tt.workloadName) - - if tt.wantError { - require.Error(t, err) - } else { - require.NoError(t, err) - assert.Equal(t, tt.expected, result) - } - }) - } -} - -func TestManager_ListWorkloadsInGroup(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - groupName string - setupMock func(*mockClient) - expected []string - wantError bool - errorMsg string - }{ - { - name: "successful list", - groupName: "test-group", - setupMock: func(mc *mockClient) { - mc.listFunc = func(_ context.Context, list client.ObjectList, _ ...client.ListOption) error { - if mcpServerList, ok := list.(*mcpv1alpha1.MCPServerList); ok { - mcpServerList.Items = []mcpv1alpha1.MCPServer{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: testWorkload1, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: "test-group", - }, - }, - { - ObjectMeta: metav1.ObjectMeta{ - Name: "workload2", - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: "test-group", - }, - }, - { - ObjectMeta: metav1.ObjectMeta{ - Name: "workload3", - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: "other-group", - }, - }, - } - } - return nil - } - }, - expected: []string{testWorkload1, "workload2"}, - wantError: false, - }, - { - name: "empty group", - groupName: "empty-group", - setupMock: func(mc *mockClient) { - mc.listFunc = func(_ context.Context, list client.ObjectList, _ ...client.ListOption) error { - if mcpServerList, ok := list.(*mcpv1alpha1.MCPServerList); ok { - mcpServerList.Items = []mcpv1alpha1.MCPServer{} - } - return nil - } - }, - expected: []string{}, - wantError: false, - }, - { - name: "list error", - groupName: "test-group", - setupMock: func(mc *mockClient) { - mc.listFunc = func(_ context.Context, _ client.ObjectList, _ ...client.ListOption) error { - return k8serrors.NewInternalError(errors.New("internal error")) - } - }, - expected: nil, - wantError: true, - errorMsg: "failed to list MCPServers", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - mockClient := &mockClient{} - tt.setupMock(mockClient) - - mgr := &manager{ - k8sClient: mockClient, - namespace: defaultNamespace, - } - - ctx := context.Background() - result, err := mgr.ListWorkloadsInGroup(ctx, tt.groupName) - - if tt.wantError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - assert.ElementsMatch(t, tt.expected, result) - } - }) - } -} - -func TestManager_ListWorkloads(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - listAll bool - labelFilters []string - setupMock func(*mockClient) - expected int - wantError bool - errorMsg string - }{ - { - name: "successful list all", - listAll: true, - setupMock: func(mc *mockClient) { - mc.listFunc = func(_ context.Context, list client.ObjectList, _ ...client.ListOption) error { - if mcpServerList, ok := list.(*mcpv1alpha1.MCPServerList); ok { - mcpServerList.Items = []mcpv1alpha1.MCPServer{ - { - ObjectMeta: metav1.ObjectMeta{Name: testWorkload1}, - Status: mcpv1alpha1.MCPServerStatus{Phase: mcpv1alpha1.MCPServerPhaseRunning}, - }, - { - ObjectMeta: metav1.ObjectMeta{Name: "workload2"}, - Status: mcpv1alpha1.MCPServerStatus{Phase: mcpv1alpha1.MCPServerPhaseTerminating}, - }, - } - } - return nil - } - }, - expected: 2, - wantError: false, - }, - { - name: "list with label filters", - listAll: true, - labelFilters: []string{"env=prod"}, - setupMock: func(mc *mockClient) { - mc.listFunc = func(_ context.Context, list client.ObjectList, _ ...client.ListOption) error { - if mcpServerList, ok := list.(*mcpv1alpha1.MCPServerList); ok { - mcpServerList.Items = []mcpv1alpha1.MCPServer{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: testWorkload1, - Labels: map[string]string{"env": "prod"}, - }, - Status: mcpv1alpha1.MCPServerStatus{Phase: mcpv1alpha1.MCPServerPhaseRunning}, - }, - } - } - return nil - } - }, - expected: 1, - wantError: false, - }, - { - name: "invalid label filter", - listAll: true, - labelFilters: []string{"invalid-filter"}, - setupMock: func(*mockClient) { - // No list call expected due to filter parsing error - }, - expected: 0, - wantError: true, - errorMsg: "failed to parse label filters", - }, - { - name: "list error", - listAll: true, - setupMock: func(mc *mockClient) { - mc.listFunc = func(_ context.Context, _ client.ObjectList, _ ...client.ListOption) error { - return k8serrors.NewInternalError(errors.New("internal error")) - } - }, - expected: 0, - wantError: true, - errorMsg: "failed to list MCPServers", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - mockClient := &mockClient{} - tt.setupMock(mockClient) - - mgr := &manager{ - k8sClient: mockClient, - namespace: defaultNamespace, - } - - ctx := context.Background() - result, err := mgr.ListWorkloads(ctx, tt.listAll, tt.labelFilters...) - - if tt.wantError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - assert.Len(t, result, tt.expected) - } - }) - } -} - -func TestManager_MoveToGroup(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadNames []string - groupFrom string - groupTo string - setupMock func(*mockClient) - wantError bool - errorMsg string - }{ - { - name: "successful move", - workloadNames: []string{testWorkload1}, - groupFrom: "old-group", - groupTo: "new-group", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { - if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { - mcpServer.Name = testWorkload1 - mcpServer.Namespace = defaultNamespace - mcpServer.Spec.GroupRef = "old-group" - } - return nil - } - mc.updateFunc = func(_ context.Context, _ client.Object, _ ...client.UpdateOption) error { - return nil - } - }, - wantError: false, - }, - { - name: "workload not found", - workloadNames: []string{"non-existent"}, - groupFrom: "old-group", - groupTo: "new-group", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, key client.ObjectKey, _ client.Object, _ ...client.GetOption) error { - return k8serrors.NewNotFound(schema.GroupResource{Resource: "mcpservers"}, key.Name) - } - }, - wantError: true, - errorMsg: "MCPServer", - }, - { - name: "workload in different group", - workloadNames: []string{testWorkload1}, - groupFrom: "old-group", - groupTo: "new-group", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { - if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { - mcpServer.Name = testWorkload1 - mcpServer.Namespace = defaultNamespace - mcpServer.Spec.GroupRef = "different-group" - } - return nil - } - }, - wantError: true, // Returns error when group doesn't match - errorMsg: "is not in group", - }, - { - name: "update error", - workloadNames: []string{testWorkload1}, - groupFrom: "old-group", - groupTo: "new-group", - setupMock: func(mc *mockClient) { - mc.getFunc = func(_ context.Context, _ client.ObjectKey, obj client.Object, _ ...client.GetOption) error { - if mcpServer, ok := obj.(*mcpv1alpha1.MCPServer); ok { - mcpServer.Name = testWorkload1 - mcpServer.Namespace = defaultNamespace - mcpServer.Spec.GroupRef = "old-group" - } - return nil - } - mc.updateFunc = func(_ context.Context, _ client.Object, _ ...client.UpdateOption) error { - return k8serrors.NewInternalError(errors.New("update failed")) - } - }, - wantError: true, - errorMsg: "failed to update MCPServer", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - mockClient := &mockClient{} - tt.setupMock(mockClient) - - mgr := &manager{ - k8sClient: mockClient, - namespace: defaultNamespace, - } - - ctx := context.Background() - err := mgr.MoveToGroup(ctx, tt.workloadNames, tt.groupFrom, tt.groupTo) - - if tt.wantError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -func TestManager_mcpServerToWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - mcpServer *mcpv1alpha1.MCPServer - expected Workload - }{ - { - name: "running workload with HTTP transport", - mcpServer: &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-workload", - Namespace: defaultNamespace, - Annotations: map[string]string{ - "group": "test-group", - "env": "prod", - }, - }, - Status: mcpv1alpha1.MCPServerStatus{ - Phase: mcpv1alpha1.MCPServerPhaseRunning, - URL: "http://localhost:8080", - }, - Spec: mcpv1alpha1.MCPServerSpec{ - Transport: "streamable-http", - ProxyPort: 8080, - }, - }, - expected: Workload{ - Name: "test-workload", - Namespace: defaultNamespace, - Phase: mcpv1alpha1.MCPServerPhaseRunning, - URL: "http://localhost:8080", - Labels: map[string]string{"group": "test-group", "env": "prod"}, - }, - }, - { - name: "terminating workload", - mcpServer: &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: "terminating-workload", - Namespace: defaultNamespace, - }, - Status: mcpv1alpha1.MCPServerStatus{ - Phase: mcpv1alpha1.MCPServerPhaseTerminating, - }, - }, - expected: Workload{ - Name: "terminating-workload", - Namespace: defaultNamespace, - Phase: mcpv1alpha1.MCPServerPhaseTerminating, - }, - }, - { - name: "failed workload", - mcpServer: &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: "failed-workload", - Namespace: defaultNamespace, - }, - Status: mcpv1alpha1.MCPServerStatus{ - Phase: mcpv1alpha1.MCPServerPhaseFailed, - }, - }, - expected: Workload{ - Name: "failed-workload", - Namespace: defaultNamespace, - Phase: mcpv1alpha1.MCPServerPhaseFailed, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - mgr := &manager{ - namespace: defaultNamespace, - } - - result, err := mgr.mcpServerToWorkload(tt.mcpServer) - require.NoError(t, err) - - assert.Equal(t, tt.expected.Name, result.Name) - assert.Equal(t, tt.expected.Namespace, result.Namespace) - assert.Equal(t, tt.expected.Phase, result.Phase) - assert.Equal(t, tt.expected.URL, result.URL) - if tt.expected.Labels != nil { - assert.Equal(t, tt.expected.Labels, result.Labels) - } - }) - } -} - -func TestManager_isStandardK8sAnnotation(t *testing.T) { - t.Parallel() - - mgr := &manager{} - - tests := []struct { - name string - key string - expected bool - }{ - { - name: "kubectl annotation", - key: "kubectl.kubernetes.io/last-applied-configuration", - expected: true, - }, - { - name: "kubernetes.io annotation", - key: "kubernetes.io/created-by", - expected: true, - }, - { - name: "deployment.kubernetes.io annotation", - key: "deployment.kubernetes.io/revision", - expected: true, - }, - { - name: "k8s.io annotation", - key: "k8s.io/annotation", - expected: true, - }, - { - name: "user-defined annotation", - key: "custom/annotation", - expected: false, - }, - { - name: "empty string", - key: "", - expected: false, - }, - { - name: "short key", - key: "k", - expected: false, - }, - { - name: "partial match - not a prefix", - key: "my-kubectl.kubernetes.io/annotation", - expected: false, - }, - { - name: "exact prefix match", - key: "kubectl.kubernetes.io/", - expected: true, - }, - { - name: "case sensitive - uppercase", - key: "KUBECTL.KUBERNETES.IO/annotation", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - result := mgr.isStandardK8sAnnotation(tt.key) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/pkg/workloads/k8s/manager.go b/pkg/workloads/k8s/manager.go deleted file mode 100644 index a9eb86af5..000000000 --- a/pkg/workloads/k8s/manager.go +++ /dev/null @@ -1,40 +0,0 @@ -// Package k8s provides Kubernetes-specific workload management. -package k8s - -import ( - "context" -) - -// Manager manages MCPServer CRD workloads in Kubernetes. -// This interface is separate from workloads.Manager to avoid coupling Kubernetes workloads -// to the CLI container runtime interface. -// -//go:generate mockgen -destination=mocks/mock_manager.go -package=mocks github.com/stacklok/toolhive/pkg/workloads/k8s Manager -type Manager interface { - // GetWorkload retrieves an MCPServer CRD by name - GetWorkload(ctx context.Context, workloadName string) (Workload, error) - - // ListWorkloads lists all MCPServer CRDs, optionally filtered by labels - // The `listAll` parameter determines whether to include workloads that are not running - // The optional `labelFilters` parameter allows filtering workloads by labels (format: key=value) - ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]Workload, error) - - // ListWorkloadsInGroup returns all workload names that belong to the specified group - ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) - - // DoesWorkloadExist checks if an MCPServer CRD with the given name exists - DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) - - // MoveToGroup moves the specified workloads from one group to another by updating their GroupRef - MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error - - // The following operations are not supported in Kubernetes mode (operator manages lifecycle): - // - RunWorkload: Workloads are created via MCPServer CRDs - // - RunWorkloadDetached: Workloads are created via MCPServer CRDs - // - StopWorkloads: Use kubectl to manage MCPServer CRDs - // - DeleteWorkloads: Use kubectl to manage MCPServer CRDs - // - RestartWorkloads: Use kubectl to manage MCPServer CRDs - // - UpdateWorkload: Update MCPServer CRD directly - // - GetLogs: Use 'kubectl logs -n ' to retrieve logs - // - GetProxyLogs: Use 'kubectl logs -c proxy -n ' to retrieve proxy logs -} diff --git a/pkg/workloads/k8s/mocks/mock_manager.go b/pkg/workloads/k8s/mocks/mock_manager.go deleted file mode 100644 index 01b8c0dfd..000000000 --- a/pkg/workloads/k8s/mocks/mock_manager.go +++ /dev/null @@ -1,121 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/stacklok/toolhive/pkg/workloads/k8s (interfaces: Manager) -// -// Generated by this command: -// -// mockgen -destination=mocks/mock_manager.go -package=mocks github.com/stacklok/toolhive/pkg/workloads/k8s Manager -// - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - - k8s "github.com/stacklok/toolhive/pkg/workloads/k8s" - gomock "go.uber.org/mock/gomock" -) - -// MockManager is a mock of Manager interface. -type MockManager struct { - ctrl *gomock.Controller - recorder *MockManagerMockRecorder - isgomock struct{} -} - -// MockManagerMockRecorder is the mock recorder for MockManager. -type MockManagerMockRecorder struct { - mock *MockManager -} - -// NewMockManager creates a new mock instance. -func NewMockManager(ctrl *gomock.Controller) *MockManager { - mock := &MockManager{ctrl: ctrl} - mock.recorder = &MockManagerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockManager) EXPECT() *MockManagerMockRecorder { - return m.recorder -} - -// DoesWorkloadExist mocks base method. -func (m *MockManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DoesWorkloadExist", ctx, workloadName) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DoesWorkloadExist indicates an expected call of DoesWorkloadExist. -func (mr *MockManagerMockRecorder) DoesWorkloadExist(ctx, workloadName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesWorkloadExist", reflect.TypeOf((*MockManager)(nil).DoesWorkloadExist), ctx, workloadName) -} - -// GetWorkload mocks base method. -func (m *MockManager) GetWorkload(ctx context.Context, workloadName string) (k8s.Workload, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWorkload", ctx, workloadName) - ret0, _ := ret[0].(k8s.Workload) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetWorkload indicates an expected call of GetWorkload. -func (mr *MockManagerMockRecorder) GetWorkload(ctx, workloadName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkload", reflect.TypeOf((*MockManager)(nil).GetWorkload), ctx, workloadName) -} - -// ListWorkloads mocks base method. -func (m *MockManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]k8s.Workload, error) { - m.ctrl.T.Helper() - varargs := []any{ctx, listAll} - for _, a := range labelFilters { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "ListWorkloads", varargs...) - ret0, _ := ret[0].([]k8s.Workload) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListWorkloads indicates an expected call of ListWorkloads. -func (mr *MockManagerMockRecorder) ListWorkloads(ctx, listAll any, labelFilters ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, listAll}, labelFilters...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloads", reflect.TypeOf((*MockManager)(nil).ListWorkloads), varargs...) -} - -// ListWorkloadsInGroup mocks base method. -func (m *MockManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListWorkloadsInGroup", ctx, groupName) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListWorkloadsInGroup indicates an expected call of ListWorkloadsInGroup. -func (mr *MockManagerMockRecorder) ListWorkloadsInGroup(ctx, groupName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkloadsInGroup", reflect.TypeOf((*MockManager)(nil).ListWorkloadsInGroup), ctx, groupName) -} - -// MoveToGroup mocks base method. -func (m *MockManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom, groupTo string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MoveToGroup", ctx, workloadNames, groupFrom, groupTo) - ret0, _ := ret[0].(error) - return ret0 -} - -// MoveToGroup indicates an expected call of MoveToGroup. -func (mr *MockManagerMockRecorder) MoveToGroup(ctx, workloadNames, groupFrom, groupTo any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MoveToGroup", reflect.TypeOf((*MockManager)(nil).MoveToGroup), ctx, workloadNames, groupFrom, groupTo) -} diff --git a/pkg/workloads/k8s/workload.go b/pkg/workloads/k8s/workload.go deleted file mode 100644 index e39a1014a..000000000 --- a/pkg/workloads/k8s/workload.go +++ /dev/null @@ -1,45 +0,0 @@ -// Package k8s provides Kubernetes-specific domain models for workloads. -package k8s - -import ( - "time" - - mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - "github.com/stacklok/toolhive/pkg/transport/types" -) - -// Workload represents a Kubernetes workload (MCPServer CRD). -// This is the Kubernetes-specific domain model, separate from core.Workload. -type Workload struct { - // Name is the name of the MCPServer CRD - Name string - // Namespace is the Kubernetes namespace where the MCPServer is deployed - Namespace string - // Package specifies the container image used for this workload - Package string - // URL is the URL of the workload exposed by the ToolHive proxy - URL string - // Port is the port on which the workload is exposed - Port int - // ToolType is the type of tool this workload represents - // For now, it will always be "mcp" - representing an MCP server - ToolType string - // TransportType is the type of transport used for this workload - TransportType types.TransportType - // ProxyMode is the proxy mode that clients should use to connect - ProxyMode string - // Phase is the current phase of the MCPServer CRD - Phase mcpv1alpha1.MCPServerPhase - // StatusContext provides additional context about the workload's status - StatusContext string - // CreatedAt is the timestamp when the workload was created - CreatedAt time.Time - // Labels are user-defined labels (from annotations) - Labels map[string]string - // Group is the name of the group this workload belongs to, if any - Group string - // ToolsFilter is the filter on tools applied to the workload - ToolsFilter []string - // GroupRef is the reference to the MCPGroup (same as Group, but using CRD terminology) - GroupRef string -} diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index f34fea4f4..1aa7a3c17 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -59,19 +59,19 @@ var ErrWorkloadNotRunning = fmt.Errorf("workload not running") // NewManager creates a new CLI workload manager. // Returns Manager interface (existing behavior, unchanged). -// IMPORTANT: This function only works in CLI mode. For Kubernetes, use k8s.NewManagerFromContext() directly. +// IMPORTANT: This function only works in CLI mode. func NewManager(ctx context.Context) (Manager, error) { if rt.IsKubernetesRuntime() { - return nil, fmt.Errorf("use k8s.NewManagerFromContext() for Kubernetes environments") + return nil, fmt.Errorf("workload manager is not available in Kubernetes environments") } return NewCLIManager(ctx) } // NewManagerWithProvider creates a new CLI workload manager with a custom config provider. -// IMPORTANT: This function only works in CLI mode. For Kubernetes, use k8s.NewManagerFromContext() directly. +// IMPORTANT: This function only works in CLI mode. func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) (Manager, error) { if rt.IsKubernetesRuntime() { - return nil, fmt.Errorf("use k8s.NewManagerFromContext() for Kubernetes environments") + return nil, fmt.Errorf("workload manager is not available in Kubernetes environments") } return NewCLIManagerWithProvider(ctx, configProvider) } From 8e819a06df9b12a3c24ac137036c5d7ccdcee525 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Mon, 17 Nov 2025 11:19:09 +0000 Subject: [PATCH 15/16] removed changes regarding the workload manager refactor --- cmd/vmcp/app/commands.go | 2 +- pkg/workloads/cli_manager.go | 1205 ---------------- pkg/workloads/cli_manager_test.go | 2213 ----------------------------- pkg/workloads/manager.go | 1236 +++++++++++++++- pkg/workloads/manager_test.go | 1622 ++++++++++++++++++++- 5 files changed, 2827 insertions(+), 3451 deletions(-) delete mode 100644 pkg/workloads/cli_manager.go delete mode 100644 pkg/workloads/cli_manager_test.go diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index e7e82eef7..53d6dd638 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -226,7 +226,7 @@ func discoverBackends(ctx context.Context, cfg *config.Config) ([]vmcp.Backend, } // Initialize managers for backend discovery - logger.Info("Initializing workload and group managers") + logger.Info("Initializing group manager") groupsManager, err := groups.NewManager() if err != nil { return nil, nil, fmt.Errorf("failed to create groups manager: %w", err) diff --git a/pkg/workloads/cli_manager.go b/pkg/workloads/cli_manager.go deleted file mode 100644 index 20ced9ded..000000000 --- a/pkg/workloads/cli_manager.go +++ /dev/null @@ -1,1205 +0,0 @@ -// Package workloads provides a CLI-based implementation of the Manager interface. -// This file contains the CLI (Docker/Podman) implementation for local environments. -package workloads - -import ( - "context" - "errors" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "time" - - "github.com/adrg/xdg" - "golang.org/x/sync/errgroup" - - "github.com/stacklok/toolhive/pkg/client" - "github.com/stacklok/toolhive/pkg/config" - ct "github.com/stacklok/toolhive/pkg/container" - rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/labels" - "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/process" - "github.com/stacklok/toolhive/pkg/runner" - "github.com/stacklok/toolhive/pkg/secrets" - "github.com/stacklok/toolhive/pkg/state" - "github.com/stacklok/toolhive/pkg/workloads/statuses" - "github.com/stacklok/toolhive/pkg/workloads/types" -) - -// AsyncOperationTimeout is the timeout for async workload operations -const AsyncOperationTimeout = 5 * time.Minute - -// removeClientConfigurations removes client configuration files for a workload. -// TODO: Move to dedicated config management interface. -func removeClientConfigurations(containerName string, isAuxiliary bool) error { - // Get the workload's group by loading its run config - // Note: This is a standalone function, so we use runner.LoadState directly - // In the future, this should be refactored to use the driver - runConfig, err := runner.LoadState(context.Background(), containerName) - var group string - if err != nil { - // Only warn for non-auxiliary workloads since auxiliary workloads don't have run configs - if !isAuxiliary { - logger.Warnf("Warning: Failed to load run config for %s, will use backward compatible behavior: %v", containerName, err) - } - // Continue with empty group (backward compatibility) - } else { - group = runConfig.Group - } - - clientManager, err := client.NewManager(context.Background()) - if err != nil { - logger.Warnf("Warning: Failed to create client manager for %s, skipping client config removal: %v", containerName, err) - return nil - } - - return clientManager.RemoveServerFromClients(context.Background(), containerName, group) -} - -// cliManager implements the Manager interface for CLI (Docker/Podman) environments. -type cliManager struct { - runtime rt.Runtime - statuses statuses.StatusManager - configProvider config.Provider -} - -// NewCLIManager creates a new CLI-based workload manager. -func NewCLIManager(ctx context.Context) (Manager, error) { - return NewCLIManagerWithProvider(ctx, config.NewDefaultProvider()) -} - -// NewCLIManagerWithProvider creates a new CLI-based workload manager with a custom config provider. -func NewCLIManagerWithProvider(ctx context.Context, configProvider config.Provider) (Manager, error) { - runtime, err := ct.NewFactory().Create(ctx) - if err != nil { - return nil, err - } - - statusManager, err := statuses.NewStatusManager(runtime) - if err != nil { - return nil, fmt.Errorf("failed to create status manager: %w", err) - } - - return &cliManager{ - runtime: runtime, - statuses: statusManager, - configProvider: configProvider, - }, nil -} - -// NewCLIManagerFromRuntime creates a new CLI-based workload manager from an existing runtime. -func NewCLIManagerFromRuntime(runtime rt.Runtime) (Manager, error) { - return NewCLIManagerFromRuntimeWithProvider(runtime, config.NewDefaultProvider()) -} - -// NewCLIManagerFromRuntimeWithProvider creates a new CLI-based workload manager -// from an existing runtime with a custom config provider. -func NewCLIManagerFromRuntimeWithProvider(runtime rt.Runtime, configProvider config.Provider) (Manager, error) { - statusManager, err := statuses.NewStatusManager(runtime) - if err != nil { - return nil, fmt.Errorf("failed to create status manager: %w", err) - } - - return &cliManager{ - runtime: runtime, - statuses: statusManager, - configProvider: configProvider, - }, nil -} - -func (d *cliManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { - return d.statuses.GetWorkload(ctx, workloadName) -} - -func (d *cliManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { - // check if workload exists by trying to get it - workload, err := d.statuses.GetWorkload(ctx, workloadName) - if err != nil { - if errors.Is(err, rt.ErrWorkloadNotFound) { - return false, nil - } - return false, fmt.Errorf("failed to check if workload exists: %w", err) - } - - // now check if the workload is not in error - if workload.Status == rt.WorkloadStatusError { - return false, nil - } - return true, nil -} - -func (d *cliManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { - // Get container workloads from status manager - containerWorkloads, err := d.statuses.ListWorkloads(ctx, listAll, labelFilters) - if err != nil { - return nil, err - } - - // Get remote workloads from the state store - remoteWorkloads, err := d.getRemoteWorkloadsFromState(ctx, listAll, labelFilters) - if err != nil { - logger.Warnf("Failed to get remote workloads from state: %v", err) - // Continue with container workloads only - } else { - // Combine container and remote workloads - containerWorkloads = append(containerWorkloads, remoteWorkloads...) - } - - return containerWorkloads, nil -} - -func (d *cliManager) StopWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { - // Validate all workload names to prevent path traversal attacks - for _, name := range names { - if err := types.ValidateWorkloadName(name); err != nil { - return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) - } - // Ensure workload name does not contain path traversal or separators - if strings.Contains(name, "..") || strings.ContainsAny(name, "/\\") { - return nil, fmt.Errorf("invalid workload name '%s': contains forbidden characters", name) - } - } - - group := &errgroup.Group{} - // Process each workload - for _, name := range names { - group.Go(func() error { - return d.stopSingleWorkload(name) - }) - } - - return group, nil -} - -// stopSingleWorkload stops a single workload (container or remote) -func (d *cliManager) stopSingleWorkload(name string) error { - // Create a child context with a longer timeout - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - // First, try to load the run configuration to check if it's a remote workload - runConfig, err := runner.LoadState(childCtx, name) - if err != nil { - // If we can't load the state, it might be a container workload or the workload doesn't exist - // Try to stop it as a container workload - return d.stopContainerWorkload(childCtx, name) - } - - // Check if this is a remote workload - if runConfig.RemoteURL != "" { - return d.stopRemoteWorkload(childCtx, name, runConfig) - } - - // This is a container-based workload - return d.stopContainerWorkload(childCtx, name) -} - -// stopRemoteWorkload stops a remote workload -func (d *cliManager) stopRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { - logger.Infof("Stopping remote workload %s...", name) - - // Check if the workload is running by checking its status - workload, err := d.statuses.GetWorkload(ctx, name) - if err != nil { - if errors.Is(err, rt.ErrWorkloadNotFound) { - // Log but don't fail the entire operation for not found workload - logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) - return nil - } - return fmt.Errorf("failed to find workload %s: %v", name, err) - } - - if workload.Status != rt.WorkloadStatusRunning { - logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) - return nil - } - - // Set status to stopping - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) - } - - // Stop proxy if running - if runConfig.BaseName != "" { - d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) - } - - // For remote workloads, we only need to clean up client configurations - // The saved state should be preserved for restart capability - if err := removeClientConfigurations(name, false); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } else { - logger.Infof("Client configurations for %s removed", name) - } - - // Set status to stopped - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopped, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopped: %v", name, err) - } - logger.Infof("Remote workload %s stopped successfully", name) - return nil -} - -// stopContainerWorkload stops a container-based workload -func (d *cliManager) stopContainerWorkload(ctx context.Context, name string) error { - container, err := d.runtime.GetWorkloadInfo(ctx, name) - if err != nil { - if errors.Is(err, rt.ErrWorkloadNotFound) { - // Log but don't fail the entire operation for not found containers - logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) - return nil - } - return fmt.Errorf("failed to find workload %s: %v", name, err) - } - - running := container.IsRunning() - if !running { - // Log but don't fail the entire operation for not running containers - logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) - return nil - } - - // Transition workload to `stopping` state. - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) - } - - // Use the existing stopWorkloads method for container workloads - return d.stopSingleContainerWorkload(ctx, &container) -} - -func (d *cliManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error { - // Ensure that the workload has a status entry before starting the process. - if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { - // Failure to create the initial state is a fatal error. - return fmt.Errorf("failed to create workload status: %v", err) - } - - mcpRunner := runner.NewRunner(runConfig, d.statuses) - err := mcpRunner.Run(ctx) - if err != nil { - // If the run failed, we should set the status to error. - if statusErr := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, statusErr) - } - } - return err -} - -func (d *cliManager) validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { - // If there are run secrets, validate them - - hasRegularSecrets := len(runConfig.Secrets) > 0 - hasRemoteAuthSecret := runConfig.RemoteAuthConfig != nil && runConfig.RemoteAuthConfig.ClientSecret != "" - - if hasRegularSecrets || hasRemoteAuthSecret { - cfg := d.configProvider.GetConfig() - - providerType, err := cfg.Secrets.GetProviderType() - if err != nil { - return fmt.Errorf("error determining secrets provider type: %w", err) - } - - secretManager, err := secrets.CreateSecretProvider(providerType) - if err != nil { - return fmt.Errorf("error instantiating secret manager: %w", err) - } - - err = runConfig.ValidateSecrets(ctx, secretManager) - if err != nil { - return fmt.Errorf("error processing secrets: %w", err) - } - } - return nil -} - -func (d *cliManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { - // before running, validate the parameters for the workload - err := d.validateSecretParameters(ctx, runConfig) - if err != nil { - return fmt.Errorf("failed to validate workload parameters: %w", err) - } - - // Get the current executable path - execPath, err := os.Executable() - if err != nil { - return fmt.Errorf("failed to get executable path: %v", err) - } - - // Create a log file for the detached process - logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", runConfig.BaseName)) - if err != nil { - return fmt.Errorf("failed to create log file path: %v", err) - } - // #nosec G304 - This is safe as baseName is generated by the application - logFile, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) - if err != nil { - logger.Warnf("Warning: Failed to create log file: %v", err) - } else { - defer logFile.Close() - logger.Infof("Logging to: %s", logFilePath) - } - - // Use the restart command to start the detached process - // The config has already been saved to disk, so restart can load it - detachedArgs := []string{"restart", runConfig.BaseName, "--foreground"} - - // Create a new command - // #nosec G204 - This is safe as execPath is the path to the current binary - detachedCmd := exec.Command(execPath, detachedArgs...) - - // Set environment variables for the detached process - detachedCmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", process.ToolHiveDetachedEnv, process.ToolHiveDetachedValue)) - - // If we need the decrypt password, set it as an environment variable in the detached process. - // NOTE: This breaks the abstraction slightly since this is only relevant for the CLI, but there - // are checks inside `GetSecretsPassword` to ensure this does not get called in a detached process. - // This will be addressed in a future re-think of the secrets manager interface. - if d.needSecretsPassword(runConfig.Secrets) { - password, err := secrets.GetSecretsPassword("") - if err != nil { - return fmt.Errorf("failed to get secrets password: %v", err) - } - detachedCmd.Env = append(detachedCmd.Env, fmt.Sprintf("%s=%s", secrets.PasswordEnvVar, password)) - } - - // Redirect stdout and stderr to the log file if it was created successfully - if logFile != nil { - detachedCmd.Stdout = logFile - detachedCmd.Stderr = logFile - } else { - // Otherwise, discard the output - detachedCmd.Stdout = nil - detachedCmd.Stderr = nil - } - - // Detach the process from the terminal - detachedCmd.Stdin = nil - detachedCmd.SysProcAttr = getSysProcAttr() - - // Ensure that the workload has a status entry before starting the process. - if err = d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { - // Failure to create the initial state is a fatal error. - return fmt.Errorf("failed to create workload status: %v", err) - } - - // Start the detached process - if err := detachedCmd.Start(); err != nil { - // If the start failed, we need to set the status to error before returning. - if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, ""); err != nil { - logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, err) - } - return fmt.Errorf("failed to start detached process: %v", err) - } - - // Write the PID to a file so the stop command can kill the process - // TODO: Stop writing to PID file once we migrate over to statuses fully. - if err := process.WritePIDFile(runConfig.BaseName, detachedCmd.Process.Pid); err != nil { - logger.Warnf("Warning: Failed to write PID file: %v", err) - } - if err := d.statuses.SetWorkloadPID(ctx, runConfig.BaseName, detachedCmd.Process.Pid); err != nil { - logger.Warnf("Failed to set workload %s PID: %v", runConfig.BaseName, err) - } - - logger.Infof("MCP server is running in the background (PID: %d)", detachedCmd.Process.Pid) - logger.Infof("Use 'thv stop %s' to stop the server", runConfig.ContainerName) - - return nil -} - -func (d *cliManager) GetLogs(ctx context.Context, workloadName string, follow bool) (string, error) { - // Get the logs from the runtime - logs, err := d.runtime.GetWorkloadLogs(ctx, workloadName, follow) - if err != nil { - // Propagate the error if the container is not found - if errors.Is(err, rt.ErrWorkloadNotFound) { - return "", fmt.Errorf("%w: %s", rt.ErrWorkloadNotFound, workloadName) - } - return "", fmt.Errorf("failed to get container logs %s: %v", workloadName, err) - } - - return logs, nil -} - -// GetProxyLogs retrieves proxy logs from the filesystem -func (*cliManager) GetProxyLogs(_ context.Context, workloadName string) (string, error) { - // Get the proxy log file path - logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", workloadName)) - if err != nil { - return "", fmt.Errorf("failed to get proxy log file path for workload %s: %w", workloadName, err) - } - - // Clean the file path to prevent path traversal - cleanLogFilePath := filepath.Clean(logFilePath) - - // Check if the log file exists - if _, err := os.Stat(cleanLogFilePath); os.IsNotExist(err) { - return "", fmt.Errorf("proxy logs not found for workload %s", workloadName) - } - - // Read and return the entire log file - content, err := os.ReadFile(cleanLogFilePath) - if err != nil { - return "", fmt.Errorf("failed to read proxy log for workload %s: %w", workloadName, err) - } - - return string(content), nil -} - -// deleteWorkload handles deletion of a single workload -func (d *cliManager) deleteWorkload(name string) error { - // Create a child context with a longer timeout - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - // First, check if this is a remote workload by trying to load its run configuration - runConfig, err := runner.LoadState(childCtx, name) - if err != nil { - // If we can't load the state, it might be a container workload or the workload doesn't exist - // Continue with the container-based deletion logic - return d.deleteContainerWorkload(childCtx, name) - } - - // If this is a remote workload (has RemoteURL), handle it differently - if runConfig.RemoteURL != "" { - return d.deleteRemoteWorkload(childCtx, name, runConfig) - } - - // This is a container-based workload, use the existing logic - return d.deleteContainerWorkload(childCtx, name) -} - -// deleteRemoteWorkload handles deletion of a remote workload -func (d *cliManager) deleteRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { - logger.Infof("Removing remote workload %s...", name) - - // Set status to removing - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusRemoving, ""); err != nil { - logger.Warnf("Failed to set workload %s status to removing: %v", name, err) - return err - } - - // Stop proxy if running - if runConfig.BaseName != "" { - d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) - } - - // Clean up associated resources (remote workloads are not auxiliary) - d.cleanupWorkloadResources(ctx, name, runConfig.BaseName, false) - - // Remove the workload status from the status store - if err := d.statuses.DeleteWorkloadStatus(ctx, name); err != nil { - logger.Warnf("failed to delete workload status for %s: %v", name, err) - } - - logger.Infof("Remote workload %s removed successfully", name) - return nil -} - -// deleteContainerWorkload handles deletion of a container-based workload (existing logic) -func (d *cliManager) deleteContainerWorkload(ctx context.Context, name string) error { - - // Find and validate the container - container, err := d.getWorkloadContainer(ctx, name) - if err != nil { - return err - } - - // Set status to removing - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusRemoving, ""); err != nil { - logger.Warnf("Failed to set workload %s status to removing: %v", name, err) - } - - if container != nil { - containerLabels := container.Labels - baseName := labels.GetContainerBaseName(containerLabels) - - // Stop proxy if running (skip for auxiliary workloads like inspector) - if container.IsRunning() { - // Skip proxy stopping for auxiliary workloads that don't use proxy processes - if labels.IsAuxiliaryWorkload(containerLabels) { - logger.Debugf("Skipping proxy stop for auxiliary workload %s", name) - } else { - d.stopProxyIfNeeded(ctx, name, baseName) - } - } - - // Remove the container - if err := d.removeContainer(ctx, name); err != nil { - return err - } - - // Clean up associated resources - d.cleanupWorkloadResources(ctx, name, baseName, labels.IsAuxiliaryWorkload(containerLabels)) - } - - // Remove the workload status from the status store - if err := d.statuses.DeleteWorkloadStatus(ctx, name); err != nil { - logger.Warnf("failed to delete workload status for %s: %v", name, err) - } - - return nil -} - -// getWorkloadContainer retrieves workload container info with error handling -func (d *cliManager) getWorkloadContainer(ctx context.Context, name string) (*rt.ContainerInfo, error) { - container, err := d.runtime.GetWorkloadInfo(ctx, name) - if err != nil { - if errors.Is(err, rt.ErrWorkloadNotFound) { - // Log but don't fail the entire operation for not found containers - logger.Warnf("Warning: Failed to get workload %s: %v", name, err) - return nil, nil - } - if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) - } - return nil, fmt.Errorf("failed to find workload %s: %v", name, err) - } - return &container, nil -} - -// isSupervisorProcessAlive checks if the supervisor process for a workload is alive -// by checking if a PID exists. If a PID exists, we assume the supervisor is running. -// This is a reasonable assumption because: -// - If the supervisor exits cleanly, it cleans up the PID -// - If killed unexpectedly, the PID remains but stopProcess will handle it gracefully -// - The main issue we're preventing is accumulating zombie supervisors from repeated restarts -func (d *cliManager) isSupervisorProcessAlive(ctx context.Context, name string) bool { - if name == "" { - return false - } - - // Try to read the PID - if it exists, assume supervisor is running - _, err := d.statuses.GetWorkloadPID(ctx, name) - if err != nil { - // No PID found, supervisor is not running - return false - } - - // PID exists, assume supervisor is alive - return true -} - -// stopProcess stops the proxy process associated with the container -func (d *cliManager) stopProcess(ctx context.Context, name string) { - if name == "" { - logger.Warnf("Warning: Could not find base container name in labels") - return - } - - // Try to read the PID and kill the process - pid, err := d.statuses.GetWorkloadPID(ctx, name) - if err != nil { - logger.Errorf("No PID file found for %s, proxy may not be running in detached mode", name) - return - } - - // PID file found, try to kill the process - logger.Infof("Stopping proxy process (PID: %d)...", pid) - if err := process.KillProcess(pid); err != nil { - logger.Warnf("Warning: Failed to kill proxy process: %v", err) - } else { - logger.Info("Proxy process stopped") - } - - // Clean up PID file after successful kill - if err := process.RemovePIDFile(name); err != nil { - logger.Warnf("Warning: Failed to remove PID file: %v", err) - } -} - -// stopProxyIfNeeded stops the proxy process if the workload has a base name -func (d *cliManager) stopProxyIfNeeded(ctx context.Context, name, baseName string) { - logger.Infof("Removing proxy process for %s...", name) - if baseName != "" { - d.stopProcess(ctx, baseName) - } -} - -// removeContainer removes the container from the runtime -func (d *cliManager) removeContainer(ctx context.Context, name string) error { - logger.Infof("Removing container %s...", name) - if err := d.runtime.RemoveWorkload(ctx, name); err != nil { - if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) - } - return fmt.Errorf("failed to remove container: %v", err) - } - return nil -} - -// cleanupWorkloadResources cleans up all resources associated with a workload -func (d *cliManager) cleanupWorkloadResources(ctx context.Context, name, baseName string, isAuxiliary bool) { - if baseName == "" { - return - } - - // Clean up temporary permission profile - if err := d.cleanupTempPermissionProfile(ctx, baseName); err != nil { - logger.Warnf("Warning: Failed to cleanup temporary permission profile: %v", err) - } - - // Remove client configurations - if err := removeClientConfigurations(name, isAuxiliary); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } else { - logger.Infof("Client configurations for %s removed", name) - } - - // Delete the saved state last (skip for auxiliary workloads that don't have run configs) - if !isAuxiliary { - if err := state.DeleteSavedRunConfig(ctx, baseName); err != nil { - logger.Warnf("Warning: Failed to delete saved state: %v", err) - } else { - logger.Infof("Saved state for %s removed", baseName) - } - } else { - logger.Debugf("Skipping saved state deletion for auxiliary workload %s", name) - } - - logger.Infof("Container %s removed", name) -} - -func (d *cliManager) DeleteWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { - // Validate all workload names to prevent path traversal attacks - for _, name := range names { - if err := types.ValidateWorkloadName(name); err != nil { - return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) - } - } - - group := &errgroup.Group{} - - for _, name := range names { - group.Go(func() error { - return d.deleteWorkload(name) - }) - } - - return group, nil -} - -// RestartWorkloads restarts the specified workloads by name. -func (d *cliManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) { - // Validate all workload names to prevent path traversal attacks - for _, name := range names { - if err := types.ValidateWorkloadName(name); err != nil { - return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) - } - } - - group := &errgroup.Group{} - - for _, name := range names { - group.Go(func() error { - return d.restartSingleWorkload(name, foreground) - }) - } - - return group, nil -} - -// UpdateWorkload updates a workload by stopping, deleting, and recreating it -func (d *cliManager) UpdateWorkload(_ context.Context, workloadName string, newConfig *runner.RunConfig) (*errgroup.Group, error) { //nolint:lll - // Validate workload name - if err := types.ValidateWorkloadName(workloadName); err != nil { - return nil, fmt.Errorf("invalid workload name '%s': %w", workloadName, err) - } - - group := &errgroup.Group{} - group.Go(func() error { - return d.updateSingleWorkload(workloadName, newConfig) - }) - return group, nil -} - -// updateSingleWorkload handles the update logic for a single workload -func (d *cliManager) updateSingleWorkload(workloadName string, newConfig *runner.RunConfig) error { - // Create a child context with a longer timeout - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - logger.Infof("Starting update for workload %s", workloadName) - - // Stop the existing workload - if err := d.stopSingleWorkload(workloadName); err != nil { - return fmt.Errorf("failed to stop workload: %w", err) - } - logger.Infof("Successfully stopped workload %s", workloadName) - - // Delete the existing workload - if err := d.deleteWorkload(workloadName); err != nil { - return fmt.Errorf("failed to delete workload: %w", err) - } - logger.Infof("Successfully deleted workload %s", workloadName) - - // Save the new workload configuration state - if err := newConfig.SaveState(childCtx); err != nil { - logger.Errorf("Failed to save workload config: %v", err) - return fmt.Errorf("failed to save workload config: %w", err) - } - - // Step 3: Start the new workload - // TODO: This currently just handles detached processes and wouldn't work for - // foreground CLI executions. Should be refactored to support both modes. - if err := d.RunWorkloadDetached(childCtx, newConfig); err != nil { - return fmt.Errorf("failed to start new workload: %w", err) - } - - logger.Infof("Successfully completed update for workload %s", workloadName) - return nil -} - -// restartSingleWorkload handles the restart logic for a single workload -func (d *cliManager) restartSingleWorkload(name string, foreground bool) error { - // Create a child context with a longer timeout - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - // First, try to load the run configuration to check if it's a remote workload - runConfig, err := runner.LoadState(childCtx, name) - if err != nil { - // If we can't load the state, it might be a container workload or the workload doesn't exist - // Try to restart it as a container workload - return d.restartContainerWorkload(childCtx, name, foreground) - } - - // Check if this is a remote workload - if runConfig.RemoteURL != "" { - return d.restartRemoteWorkload(childCtx, name, runConfig, foreground) - } - - // This is a container-based workload - return d.restartContainerWorkload(childCtx, name, foreground) -} - -// restartRemoteWorkload handles restarting a remote workload -func (d *cliManager) restartRemoteWorkload( - ctx context.Context, - name string, - runConfig *runner.RunConfig, - foreground bool, -) error { - // Get workload status using the status manager - workload, err := d.statuses.GetWorkload(ctx, name) - if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) { - return err - } - - // If workload is already running, check if the supervisor process is healthy - if err == nil && workload.Status == rt.WorkloadStatusRunning { - // Check if the supervisor process is actually alive - supervisorAlive := d.isSupervisorProcessAlive(ctx, runConfig.BaseName) - - if supervisorAlive { - // Workload is running and healthy - preserve old behavior (no-op) - logger.Infof("Remote workload %s is already running", name) - return nil - } - - // Supervisor is dead/missing - we need to clean up and restart to fix the damaged state - logger.Infof("Remote workload %s is running but supervisor is dead, cleaning up before restart", name) - - // Set status to stopping - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) - } - - // Stop the supervisor process (proxy) if it exists (may already be dead) - // This ensures we clean up any orphaned supervisor processes - d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) - - // Clean up client configurations - if err := removeClientConfigurations(name, false); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } - - // Set status to stopped after cleanup is complete - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopped, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopped: %v", name, err) - } - } - - // Load runner configuration from state - mcpRunner, err := d.loadRunnerFromState(ctx, runConfig.BaseName) - if err != nil { - return fmt.Errorf("failed to load state for %s: %v", runConfig.BaseName, err) - } - - // Set status to starting - if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStarting, ""); err != nil { - logger.Warnf("Failed to set workload %s status to starting: %v", name, err) - } - - logger.Infof("Loaded configuration from state for %s", runConfig.BaseName) - - // Start the remote workload using the loaded runner - // Use background context to avoid timeout cancellation - same reasoning as container workloads - return d.startWorkload(context.Background(), name, mcpRunner, foreground) -} - -// restartContainerWorkload handles restarting a container-based workload -// -//nolint:gocyclo // Complexity is justified - handles multiple restart scenarios and edge cases -func (d *cliManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error { - // Get container info to resolve partial names and extract proper workload name - var containerName string - var workloadName string - - container, err := d.runtime.GetWorkloadInfo(ctx, name) - if err == nil { - // If we found the container, use its actual container name for runtime operations - containerName = container.Name - // Extract the workload name (base name) from container labels for status operations - workloadName = labels.GetContainerBaseName(container.Labels) - if workloadName == "" { - // Fallback to the provided name if base name is not available - workloadName = name - } - } else { - // If container not found, use the provided name as both container and workload name - containerName = name - workloadName = name - } - - // Get workload status using the status manager - workload, err := d.statuses.GetWorkload(ctx, name) - if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) { - return err - } - - // Check if workload is running and healthy (including supervisor process) - if err == nil && workload.Status == rt.WorkloadStatusRunning { - // Check if the supervisor process is actually alive - supervisorAlive := d.isSupervisorProcessAlive(ctx, workloadName) - - if supervisorAlive { - // Workload is running and healthy - preserve old behavior (no-op) - logger.Infof("Container %s is already running", containerName) - return nil - } - - // Supervisor is dead/missing - we need to clean up and restart to fix the damaged state - logger.Infof("Container %s is running but supervisor is dead, cleaning up before restart", containerName) - } - - // Check if we need to stop the workload before restarting - // This happens when: 1) container is running, or 2) inconsistent state - shouldStop := false - if err == nil && workload.Status == rt.WorkloadStatusRunning { - // Workload status shows running (and supervisor is dead, otherwise we would have returned above) - shouldStop = true - } else if container.IsRunning() { - // Container is running but status is not running (inconsistent state) - shouldStop = true - } - - // If we need to stop, do it now (including cleanup of any remaining supervisor process) - if shouldStop { - logger.Infof("Stopping container %s before restart", containerName) - - // Set status to stopping - if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStopping, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopping: %v", workloadName, err) - } - - // Stop the supervisor process (proxy) if it exists (may already be dead) - // This ensures we clean up any orphaned supervisor processes - if !labels.IsAuxiliaryWorkload(container.Labels) { - d.stopProcess(ctx, workloadName) - } - - // Now stop the container if it's running - if container.IsRunning() { - if err := d.runtime.StopWorkload(ctx, containerName); err != nil { - if statusErr := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", workloadName, statusErr) - } - return fmt.Errorf("failed to stop container %s: %v", containerName, err) - } - logger.Infof("Container %s stopped", containerName) - } - - // Clean up client configurations - if err := removeClientConfigurations(workloadName, labels.IsAuxiliaryWorkload(container.Labels)); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } - - // Set status to stopped after cleanup is complete - if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStopped, ""); err != nil { - logger.Debugf("Failed to set workload %s status to stopped: %v", workloadName, err) - } - } - - // Load runner configuration from state - mcpRunner, err := d.loadRunnerFromState(ctx, workloadName) - if err != nil { - return fmt.Errorf("failed to load state for %s: %v", workloadName, err) - } - - // Set workload status to starting - use the workload name for status operations - if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStarting, ""); err != nil { - logger.Warnf("Failed to set workload %s status to starting: %v", workloadName, err) - } - logger.Infof("Loaded configuration from state for %s", workloadName) - - // Start the workload with background context to avoid timeout cancellation - // The ctx with AsyncOperationTimeout is only for the restart setup operations, - // but the actual workload should run indefinitely with its own lifecycle management - // Use workload name for user-facing operations - return d.startWorkload(context.Background(), workloadName, mcpRunner, foreground) -} - -// startWorkload starts the workload in either foreground or background mode -func (d *cliManager) startWorkload(ctx context.Context, name string, mcpRunner *runner.Runner, foreground bool) error { - logger.Infof("Starting tooling server %s...", name) - - var err error - if foreground { - err = d.RunWorkload(ctx, mcpRunner.Config) - } else { - err = d.RunWorkloadDetached(ctx, mcpRunner.Config) - } - - if err != nil { - // If we could not start the workload, set the status to error before returning - if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, ""); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) - } - } - return err -} - -// loadRunnerFromState attempts to load a Runner from the state store -func (d *cliManager) loadRunnerFromState(ctx context.Context, baseName string) (*runner.Runner, error) { - // Load the run config from state - runConfig, err := runner.LoadState(ctx, baseName) - if err != nil { - return nil, err - } - - if runConfig.RemoteURL != "" { - // For remote workloads, we don't need a deployer - runConfig.Deployer = nil - } else { - // Update the runtime in the loaded configuration - runConfig.Deployer = d.runtime - } - - // Create a new runner with the loaded configuration - return runner.NewRunner(runConfig, d.statuses), nil -} - -func (d *cliManager) needSecretsPassword(secretOptions []string) bool { - // If the user did not ask for any secrets, then don't attempt to instantiate - // the secrets manager. - if len(secretOptions) == 0 { - return false - } - // Ignore err - if the flag is not set, it's not needed. - providerType, _ := d.configProvider.GetConfig().Secrets.GetProviderType() - return providerType == secrets.EncryptedType -} - -// cleanupTempPermissionProfile cleans up temporary permission profile files for a given base name -func (*cliManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { - // Try to load the saved configuration to get the permission profile path - runConfig, err := runner.LoadState(ctx, baseName) - if err != nil { - // If we can't load the state, there's nothing to clean up - logger.Debugf("Could not load state for %s, skipping permission profile cleanup: %v", baseName, err) - return nil - } - - // Clean up the temporary permission profile if it exists - if runConfig.PermissionProfileNameOrPath != "" { - if err := runner.CleanupTempPermissionProfile(runConfig.PermissionProfileNameOrPath); err != nil { - return fmt.Errorf("failed to cleanup temporary permission profile: %v", err) - } - } - - return nil -} - -// stopSingleContainerWorkload stops a single container workload -func (d *cliManager) stopSingleContainerWorkload(ctx context.Context, workload *rt.ContainerInfo) error { - childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) - defer cancel() - - name := labels.GetContainerBaseName(workload.Labels) - // Stop the proxy process (skip for auxiliary workloads like inspector) - if labels.IsAuxiliaryWorkload(workload.Labels) { - logger.Debugf("Skipping proxy stop for auxiliary workload %s", name) - } else { - d.stopProcess(ctx, name) - } - - // TODO: refactor the StopProcess function to stop dealing explicitly with PID files. - // Note that this is not a blocker for k8s since this code path is not called there. - if err := d.statuses.ResetWorkloadPID(ctx, name); err != nil { - logger.Warnf("Warning: Failed to reset workload %s PID: %v", name, err) - } - - logger.Infof("Stopping containers for %s...", name) - // Stop the container - if err := d.runtime.StopWorkload(childCtx, workload.Name); err != nil { - if statusErr := d.statuses.SetWorkloadStatus(childCtx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) - } - return fmt.Errorf("failed to stop container: %w", err) - } - - if err := removeClientConfigurations(name, labels.IsAuxiliaryWorkload(workload.Labels)); err != nil { - logger.Warnf("Warning: Failed to remove client configurations: %v", err) - } else { - logger.Infof("Client configurations for %s removed", name) - } - - if err := d.statuses.SetWorkloadStatus(childCtx, name, rt.WorkloadStatusStopped, ""); err != nil { - logger.Warnf("Failed to set workload %s status to stopped: %v", name, err) - } - logger.Infof("Successfully stopped %s...", name) - return nil -} - -// MoveToGroup moves the specified workloads from one group to another by updating their runconfig. -func (*cliManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { - for _, name := range workloadNames { - // Validate workload name - if err := types.ValidateWorkloadName(name); err != nil { - return fmt.Errorf("invalid workload name %s: %w", name, err) - } - - // Load the runner state to check and update the configuration - runnerConfig, err := runner.LoadState(ctx, name) - if err != nil { - return fmt.Errorf("failed to load runner state for workload %s: %w", name, err) - } - - // Check if the workload is actually in the specified group - if runnerConfig.Group != groupFrom { - logger.Debugf("Workload %s is not in group %s (current group: %s), skipping", - name, groupFrom, runnerConfig.Group) - continue - } - - // Move the workload to the target group - runnerConfig.Group = groupTo - - // Save the updated configuration - if err = runnerConfig.SaveState(ctx); err != nil { - return fmt.Errorf("failed to save updated configuration for workload %s: %w", name, err) - } - - logger.Infof("Moved workload %s from group %s to %s", name, groupFrom, groupTo) - } - - return nil -} - -// ListWorkloadsInGroup returns all workload names that belong to the specified group -func (d *cliManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { - workloads, err := d.ListWorkloads(ctx, true) // listAll=true to include stopped workloads - if err != nil { - return nil, fmt.Errorf("failed to list workloads: %w", err) - } - - // Filter workloads that belong to the specified group - var groupWorkloads []string - for _, workload := range workloads { - if workload.Group == groupName { - groupWorkloads = append(groupWorkloads, workload.Name) - } - } - - return groupWorkloads, nil -} - -// getRemoteWorkloadsFromState retrieves remote servers from the state store -func (d *cliManager) getRemoteWorkloadsFromState( - ctx context.Context, - listAll bool, - labelFilters []string, -) ([]core.Workload, error) { - // Create a state store - store, err := state.NewRunConfigStore(state.DefaultAppName) - if err != nil { - return nil, fmt.Errorf("failed to create state store: %w", err) - } - - // List all configurations - configNames, err := store.List(ctx) - if err != nil { - return nil, fmt.Errorf("failed to list configurations: %w", err) - } - - // Parse the filters into a format we can use for matching - parsedFilters, err := types.ParseLabelFilters(labelFilters) - if err != nil { - return nil, fmt.Errorf("failed to parse label filters: %v", err) - } - - var remoteWorkloads []core.Workload - - for _, name := range configNames { - // Load the run configuration - runConfig, err := runner.LoadState(ctx, name) - if err != nil { - logger.Warnf("failed to load state for %s: %v", name, err) - continue - } - - // Only include remote servers (those with RemoteURL set) - if runConfig.RemoteURL == "" { - continue - } - - // Check the status from the status file - workloadStatus, err := d.statuses.GetWorkload(ctx, name) - if err != nil { - if errors.Is(err, rt.ErrWorkloadNotFound) { - // If status not found, assume stopped - workloadStatus = core.Workload{ - Status: rt.WorkloadStatusStopped, - } - } else { - logger.Warnf("failed to get workload status for %s: %v", name, err) - continue - } - } - - // If not listing all, only include running workloads - if !listAll && workloadStatus.Status != rt.WorkloadStatusRunning { - continue - } - - // Map to core.Workload - workload := core.Workload{ - Name: name, - Package: "remote", - URL: runConfig.RemoteURL, - ToolType: "remote", - TransportType: runConfig.Transport, - ProxyMode: runConfig.ProxyMode.String(), - Status: workloadStatus.Status, - StatusContext: workloadStatus.StatusContext, - CreatedAt: workloadStatus.CreatedAt, - Port: int(runConfig.Port), - Labels: runConfig.ContainerLabels, - Group: runConfig.Group, - ToolsFilter: runConfig.ToolsFilter, - Remote: true, - } - - // If label filters are provided, check if the workload matches them. - if types.MatchesLabelFilters(workload.Labels, parsedFilters) { - remoteWorkloads = append(remoteWorkloads, workload) - } - } - - return remoteWorkloads, nil -} diff --git a/pkg/workloads/cli_manager_test.go b/pkg/workloads/cli_manager_test.go deleted file mode 100644 index e00029d0e..000000000 --- a/pkg/workloads/cli_manager_test.go +++ /dev/null @@ -1,2213 +0,0 @@ -package workloads - -import ( - "context" - "errors" - "fmt" - "os" - "path/filepath" - "testing" - "time" - - "github.com/adrg/xdg" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "golang.org/x/sync/errgroup" - - "github.com/stacklok/toolhive/pkg/auth/remote" - "github.com/stacklok/toolhive/pkg/config" - configMocks "github.com/stacklok/toolhive/pkg/config/mocks" - "github.com/stacklok/toolhive/pkg/container/runtime" - runtimeMocks "github.com/stacklok/toolhive/pkg/container/runtime/mocks" - "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/runner" - statusMocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" -) - -func TestCLIManager_ListWorkloadsInGroup(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - groupName string - mockWorkloads []core.Workload - expectedNames []string - expectError bool - setupStatusMgr func(*statusMocks.MockStatusManager) - }{ - { - name: "non existent group returns empty list", - groupName: "non-group", - mockWorkloads: []core.Workload{ - {Name: "workload1", Group: "other-group"}, - {Name: "workload2", Group: "another-group"}, - }, - expectedNames: []string{}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ - {Name: "workload1", Group: "other-group"}, - {Name: "workload2", Group: "another-group"}, - }, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - { - name: "multiple workloads in group", - groupName: "test-group", - mockWorkloads: []core.Workload{ - {Name: "workload1", Group: "test-group"}, - {Name: "workload2", Group: "other-group"}, - {Name: "workload3", Group: "test-group"}, - {Name: "workload4", Group: "test-group"}, - }, - expectedNames: []string{"workload1", "workload3", "workload4"}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ - {Name: "workload1", Group: "test-group"}, - {Name: "workload2", Group: "other-group"}, - {Name: "workload3", Group: "test-group"}, - {Name: "workload4", Group: "test-group"}, - }, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - { - name: "workloads with empty group names", - groupName: "", - mockWorkloads: []core.Workload{ - {Name: "workload1", Group: ""}, - {Name: "workload2", Group: "test-group"}, - {Name: "workload3", Group: ""}, - }, - expectedNames: []string{"workload1", "workload3"}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ - {Name: "workload1", Group: ""}, - {Name: "workload2", Group: "test-group"}, - {Name: "workload3", Group: ""}, - }, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - { - name: "includes stopped workloads", - groupName: "test-group", - mockWorkloads: []core.Workload{ - {Name: "running-workload", Group: "test-group", Status: runtime.WorkloadStatusRunning}, - {Name: "stopped-workload", Group: "test-group", Status: runtime.WorkloadStatusStopped}, - {Name: "other-group-workload", Group: "other-group", Status: runtime.WorkloadStatusRunning}, - }, - expectedNames: []string{"running-workload", "stopped-workload"}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ - {Name: "running-workload", Group: "test-group", Status: runtime.WorkloadStatusRunning}, - {Name: "stopped-workload", Group: "test-group", Status: runtime.WorkloadStatusStopped}, - {Name: "other-group-workload", Group: "other-group", Status: runtime.WorkloadStatusRunning}, - }, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - { - name: "error from ListWorkloads propagated", - groupName: "test-group", - expectedNames: nil, - expectError: true, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return(nil, assert.AnError) - }, - }, - { - name: "no workloads", - groupName: "test-group", - mockWorkloads: []core.Workload{}, - expectedNames: []string{}, - expectError: false, - setupStatusMgr: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{}, nil) - - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupStatusMgr(mockStatusMgr) - - manager := &cliManager{ - runtime: nil, // Not needed for this test - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.ListWorkloadsInGroup(ctx, tt.groupName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to list workloads") - return - } - - require.NoError(t, err) - assert.ElementsMatch(t, tt.expectedNames, result) - }) - } -} - -func TestCLIManager_DoesWorkloadExist(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - setupMocks func(*statusMocks.MockStatusManager) - expected bool - expectError bool - }{ - { - name: "workload exists and running", - workloadName: "test-workload", - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - }, - expected: true, - expectError: false, - }, - { - name: "workload exists but in error state", - workloadName: "error-workload", - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "error-workload").Return(core.Workload{ - Name: "error-workload", - Status: runtime.WorkloadStatusError, - }, nil) - }, - expected: false, - expectError: false, - }, - { - name: "workload not found", - workloadName: "missing-workload", - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "missing-workload").Return(core.Workload{}, runtime.ErrWorkloadNotFound) - }, - expected: false, - expectError: false, - }, - { - name: "error getting workload", - workloadName: "problematic-workload", - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "problematic-workload").Return(core.Workload{}, errors.New("database error")) - }, - expected: false, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockStatusMgr) - - manager := &cliManager{ - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.DoesWorkloadExist(ctx, tt.workloadName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to check if workload exists") - } else { - require.NoError(t, err) - assert.Equal(t, tt.expected, result) - } - }) - } -} - -func TestCLIManager_GetWorkload(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - expectedWorkload := core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - } - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - mockStatusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(expectedWorkload, nil) - - manager := &cliManager{ - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.GetWorkload(ctx, "test-workload") - - require.NoError(t, err) - assert.Equal(t, expectedWorkload, result) -} - -func TestCLIManager_GetLogs(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - follow bool - setupMocks func(*runtimeMocks.MockRuntime) - expectedLogs string - expectError bool - errorMsg string - }{ - { - name: "successful log retrieval", - workloadName: "test-workload", - follow: false, - setupMocks: func(rt *runtimeMocks.MockRuntime) { - rt.EXPECT().GetWorkloadLogs(gomock.Any(), "test-workload", false).Return("test log content", nil) - }, - expectedLogs: "test log content", - expectError: false, - }, - { - name: "workload not found", - workloadName: "missing-workload", - follow: false, - setupMocks: func(rt *runtimeMocks.MockRuntime) { - rt.EXPECT().GetWorkloadLogs(gomock.Any(), "missing-workload", false).Return("", runtime.ErrWorkloadNotFound) - }, - expectedLogs: "", - expectError: true, - errorMsg: "workload not found", - }, - { - name: "runtime error", - workloadName: "error-workload", - follow: true, - setupMocks: func(rt *runtimeMocks.MockRuntime) { - rt.EXPECT().GetWorkloadLogs(gomock.Any(), "error-workload", true).Return("", errors.New("runtime failure")) - }, - expectedLogs: "", - expectError: true, - errorMsg: "failed to get container logs", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - tt.setupMocks(mockRuntime) - - manager := &cliManager{ - runtime: mockRuntime, - } - - ctx := context.Background() - logs, err := manager.GetLogs(ctx, tt.workloadName, tt.follow) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - } else { - require.NoError(t, err) - assert.Equal(t, tt.expectedLogs, logs) - } - }) - } -} - -func TestCLIManager_StopWorkloads(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadNames []string - expectError bool - errorMsg string - }{ - { - name: "invalid workload name with path traversal", - workloadNames: []string{"../etc/passwd"}, - expectError: true, - errorMsg: "path traversal", - }, - { - name: "invalid workload name with slash", - workloadNames: []string{"workload/name"}, - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "empty workload name list", - workloadNames: []string{}, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - manager := &cliManager{} - - ctx := context.Background() - group, err := manager.StopWorkloads(ctx, tt.workloadNames) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - assert.Nil(t, group) - } else { - require.NoError(t, err) - assert.NotNil(t, group) - assert.IsType(t, &errgroup.Group{}, group) - } - }) - } -} - -func TestCLIManager_DeleteWorkloads(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadNames []string - expectError bool - errorMsg string - }{ - { - name: "invalid workload name", - workloadNames: []string{"../../../etc/passwd"}, - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "mixed valid and invalid names", - workloadNames: []string{"valid-name", "invalid../name"}, - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "empty workload name list", - workloadNames: []string{}, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - manager := &cliManager{} - - ctx := context.Background() - group, err := manager.DeleteWorkloads(ctx, tt.workloadNames) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - assert.Nil(t, group) - } else { - require.NoError(t, err) - assert.NotNil(t, group) - assert.IsType(t, &errgroup.Group{}, group) - } - }) - } -} - -func TestCLIManager_RestartWorkloads(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadNames []string - foreground bool - expectError bool - errorMsg string - }{ - { - name: "invalid workload name", - workloadNames: []string{"invalid/name"}, - foreground: false, - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "empty workload name list", - workloadNames: []string{}, - foreground: false, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - manager := &cliManager{} - - ctx := context.Background() - group, err := manager.RestartWorkloads(ctx, tt.workloadNames, tt.foreground) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - assert.Nil(t, group) - } else { - require.NoError(t, err) - assert.NotNil(t, group) - assert.IsType(t, &errgroup.Group{}, group) - } - }) - } -} - -func TestCLIManager_restartRemoteWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - runConfig *runner.RunConfig - foreground bool - setupMocks func(*statusMocks.MockStatusManager) - expectError bool - errorMsg string - }{ - { - name: "remote workload already running with healthy supervisor", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-base", - RemoteURL: "http://example.com", - }, - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - // Check if supervisor is alive - return valid PID (supervisor is healthy) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(12345, nil) - }, - // With healthy supervisor, restart should return early (no-op) - expectError: false, - }, - { - name: "remote workload already running with dead supervisor", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-base", - RemoteURL: "http://example.com", - }, - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - // Check if supervisor is alive - return error (supervisor is dead) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(0, errors.New("no PID found")) - // With dead supervisor, restart proceeds with cleanup and restart - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(nil) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(0, errors.New("no PID found")) - // Allow any subsequent status updates - sm.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - sm.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - }, - // Restart now proceeds to load state which fails in tests (can't mock runner.LoadState easily) - expectError: true, - errorMsg: "failed to load state", - }, - { - name: "status manager error", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-base", - RemoteURL: "http://example.com", - }, - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{}, errors.New("status manager error")) - }, - expectError: true, - errorMsg: "status manager error", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(statusMgr) - - manager := &cliManager{ - statuses: statusMgr, - } - - err := manager.restartRemoteWorkload(context.Background(), tt.workloadName, tt.runConfig, tt.foreground) - - if tt.expectError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -func TestCLIManager_restartContainerWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - foreground bool - setupMocks func(*statusMocks.MockStatusManager, *runtimeMocks.MockRuntime) - expectError bool - errorMsg string - }{ - { - name: "container workload already running with healthy supervisor", - workloadName: "container-workload", - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { - // Mock container info - rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ - Name: "container-workload", - State: runtime.WorkloadStatusRunning, - Labels: map[string]string{ - "toolhive.base-name": "container-workload", - }, - }, nil) - sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{ - Name: "container-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - // Check if supervisor is alive - return valid PID (supervisor is healthy) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(12345, nil) - }, - // With healthy supervisor, restart should return early (no-op) - expectError: false, - }, - { - name: "container workload already running with dead supervisor", - workloadName: "container-workload", - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { - // Mock container info - rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ - Name: "container-workload", - State: runtime.WorkloadStatusRunning, - Labels: map[string]string{ - "toolhive.base-name": "container-workload", - }, - }, nil) - sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{ - Name: "container-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - // Check if supervisor is alive - return error (supervisor is dead) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(0, errors.New("no PID found")) - // With dead supervisor, restart proceeds with cleanup and restart - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "container-workload", runtime.WorkloadStatusStopping, "").Return(nil) - sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(0, errors.New("no PID found")) - rm.EXPECT().StopWorkload(gomock.Any(), "container-workload").Return(nil) - // Allow any subsequent status updates - sm.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - sm.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - }, - // Restart now proceeds to load state which fails in tests (can't mock runner.LoadState easily) - expectError: true, - errorMsg: "failed to load state", - }, - { - name: "status manager error", - workloadName: "container-workload", - foreground: false, - setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { - // Mock container info - rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ - Name: "container-workload", - State: "running", - Labels: map[string]string{ - "toolhive.base-name": "container-workload", - }, - }, nil) - sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{}, errors.New("status manager error")) - }, - expectError: true, - errorMsg: "status manager error", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) - - tt.setupMocks(statusMgr, runtimeMgr) - - manager := &cliManager{ - statuses: statusMgr, - runtime: runtimeMgr, - } - - err := manager.restartContainerWorkload(context.Background(), tt.workloadName, tt.foreground) - - if tt.expectError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -// TestCLIManager_restartLogicConsistency tests restart behavior with healthy vs dead supervisor -func TestCLIManager_restartLogicConsistency(t *testing.T) { - t.Parallel() - - t.Run("remote_workload_healthy_supervisor_no_restart", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - - statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - - // Check if supervisor is alive - return valid PID (healthy) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(12345, nil) - - manager := &cliManager{ - statuses: statusMgr, - } - - runConfig := &runner.RunConfig{ - BaseName: "test-base", - RemoteURL: "http://example.com", - } - - err := manager.restartRemoteWorkload(context.Background(), "test-workload", runConfig, false) - - // With healthy supervisor, restart should return successfully without doing anything - require.NoError(t, err) - }) - - t.Run("remote_workload_dead_supervisor_calls_stop", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - - statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - - // Check if supervisor is alive - return error (dead supervisor) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(0, errors.New("no PID found")) - - // When supervisor is dead, expect stop logic to be called - statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(0, errors.New("no PID found")) - - // Allow any subsequent status updates - we don't care about the exact sequence - statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - - manager := &cliManager{ - statuses: statusMgr, - } - - runConfig := &runner.RunConfig{ - BaseName: "test-base", - RemoteURL: "http://example.com", - } - - _ = manager.restartRemoteWorkload(context.Background(), "test-workload", runConfig, false) - - // The important part is that the stop methods were called (verified by mock expectations) - // We don't care if the restart ultimately succeeds or fails - }) - - t.Run("container_workload_healthy_supervisor_no_restart", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) - - containerInfo := runtime.ContainerInfo{ - Name: "test-workload", - State: runtime.WorkloadStatusRunning, - Labels: map[string]string{ - "toolhive.base-name": "test-workload", - }, - } - runtimeMgr.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(containerInfo, nil) - - statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - - // Check if supervisor is alive - return valid PID (healthy) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(12345, nil) - - manager := &cliManager{ - statuses: statusMgr, - runtime: runtimeMgr, - } - - err := manager.restartContainerWorkload(context.Background(), "test-workload", false) - - // With healthy supervisor, restart should return successfully without doing anything - require.NoError(t, err) - }) - - t.Run("container_workload_dead_supervisor_calls_stop", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - statusMgr := statusMocks.NewMockStatusManager(ctrl) - runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) - - containerInfo := runtime.ContainerInfo{ - Name: "test-workload", - State: runtime.WorkloadStatusRunning, - Labels: map[string]string{ - "toolhive.base-name": "test-workload", - }, - } - runtimeMgr.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(containerInfo, nil) - - statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ - Name: "test-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - - // Check if supervisor is alive - return error (dead supervisor) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(0, errors.New("no PID found")) - - // When supervisor is dead, expect stop logic to be called - statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) - statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(0, errors.New("no PID found")) - runtimeMgr.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) - - // Allow any subsequent status updates (starting, error, etc.) - we don't care about the exact sequence - statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - - manager := &cliManager{ - statuses: statusMgr, - runtime: runtimeMgr, - } - - _ = manager.restartContainerWorkload(context.Background(), "test-workload", false) - - // The important part is that the stop methods were called (verified by mock expectations) - // We don't care if the restart ultimately succeeds or fails - }) -} - -func TestCLIManager_RunWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - runConfig *runner.RunConfig - setupMocks func(*statusMocks.MockStatusManager) - expectError bool - errorMsg string - }{ - { - name: "successful run - status creation", - runConfig: &runner.RunConfig{ - BaseName: "test-workload", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - // Expect starting status first, then error status when the runner fails - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStarting, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, gomock.Any()).Return(nil) - }, - expectError: true, // The runner will fail without proper setup - }, - { - name: "status creation failure", - runConfig: &runner.RunConfig{ - BaseName: "failing-workload", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", runtime.WorkloadStatusStarting, "").Return(errors.New("status creation failed")) - }, - expectError: true, - errorMsg: "failed to create workload status", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockStatusMgr) - - manager := &cliManager{ - statuses: mockStatusMgr, - } - - ctx := context.Background() - err := manager.RunWorkload(ctx, tt.runConfig) - - if tt.expectError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -func TestCLIManager_validateSecretParameters(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - runConfig *runner.RunConfig - setupMocks func(*configMocks.MockProvider) - expectError bool - errorMsg string - }{ - { - name: "no secrets - should pass", - runConfig: &runner.RunConfig{ - Secrets: []string{}, - }, - setupMocks: func(*configMocks.MockProvider) {}, // No expectations - expectError: false, - }, - { - name: "no secrets and no remote auth - should pass", - runConfig: &runner.RunConfig{ - Secrets: []string{}, - RemoteAuthConfig: nil, - }, - setupMocks: func(*configMocks.MockProvider) {}, // No expectations - expectError: false, - }, - { - name: "remote auth without client secret - should pass", - runConfig: &runner.RunConfig{ - Secrets: []string{}, - RemoteAuthConfig: &remote.Config{ - ClientSecret: "", // Empty client secret - }, - }, - setupMocks: func(*configMocks.MockProvider) {}, // No expectations - expectError: false, - }, - { - name: "config error", - runConfig: &runner.RunConfig{ - Secrets: []string{"secret1"}, - }, - setupMocks: func(cp *configMocks.MockProvider) { - mockConfig := &config.Config{} - cp.EXPECT().GetConfig().Return(mockConfig) - }, - expectError: true, - errorMsg: "error determining secrets provider type", - }, - { - name: "remote auth with client secret", - runConfig: &runner.RunConfig{ - Secrets: []string{}, - RemoteAuthConfig: &remote.Config{ - ClientSecret: "secret-value", - }, - }, - setupMocks: func(cp *configMocks.MockProvider) { - mockConfig := &config.Config{} - cp.EXPECT().GetConfig().Return(mockConfig) - }, - expectError: true, - errorMsg: "error determining secrets provider type", - }, - { - name: "both regular secrets and remote auth", - runConfig: &runner.RunConfig{ - Secrets: []string{"secret1"}, - RemoteAuthConfig: &remote.Config{ - ClientSecret: "secret-value", - }, - }, - setupMocks: func(cp *configMocks.MockProvider) { - mockConfig := &config.Config{} - cp.EXPECT().GetConfig().Return(mockConfig) - }, - expectError: true, - errorMsg: "error determining secrets provider type", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockConfigProvider := configMocks.NewMockProvider(ctrl) - tt.setupMocks(mockConfigProvider) - - manager := &cliManager{ - configProvider: mockConfigProvider, - } - - ctx := context.Background() - err := manager.validateSecretParameters(ctx, tt.runConfig) - - if tt.expectError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -func TestCLIManager_getWorkloadContainer(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) - expected *runtime.ContainerInfo - expectError bool - errorMsg string - }{ - { - name: "successful retrieval", - workloadName: "test-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { - expectedContainer := runtime.ContainerInfo{ - Name: "test-workload", - State: runtime.WorkloadStatusRunning, - } - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(expectedContainer, nil) - }, - expected: &runtime.ContainerInfo{ - Name: "test-workload", - State: runtime.WorkloadStatusRunning, - }, - expectError: false, - }, - { - name: "workload not found", - workloadName: "missing-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "missing-workload").Return(runtime.ContainerInfo{}, runtime.ErrWorkloadNotFound) - }, - expected: nil, - expectError: false, // getWorkloadContainer returns nil for not found, not error - }, - { - name: "runtime error", - workloadName: "error-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "error-workload").Return(runtime.ContainerInfo{}, errors.New("runtime failure")) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "error-workload", runtime.WorkloadStatusError, "runtime failure").Return(nil) - }, - expected: nil, - expectError: true, - errorMsg: "failed to find workload", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockRuntime, mockStatusMgr) - - manager := &cliManager{ - runtime: mockRuntime, - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.getWorkloadContainer(ctx, tt.workloadName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - } else { - require.NoError(t, err) - if tt.expected == nil { - assert.Nil(t, result) - } else { - assert.Equal(t, tt.expected, result) - } - } - }) - } -} - -func TestCLIManager_removeContainer(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) - expectError bool - errorMsg string - }{ - { - name: "successful removal", - workloadName: "test-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { - rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(nil) - }, - expectError: false, - }, - { - name: "removal failure", - workloadName: "failing-workload", - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - rt.EXPECT().RemoveWorkload(gomock.Any(), "failing-workload").Return(errors.New("removal failed")) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", runtime.WorkloadStatusError, "removal failed").Return(nil) - }, - expectError: true, - errorMsg: "failed to remove container", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockRuntime, mockStatusMgr) - - manager := &cliManager{ - runtime: mockRuntime, - statuses: mockStatusMgr, - } - - ctx := context.Background() - err := manager.removeContainer(ctx, tt.workloadName) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestCLIManager_needSecretsPassword(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - secretOptions []string - setupMocks func(*configMocks.MockProvider) - expected bool - }{ - { - name: "no secrets", - secretOptions: []string{}, - setupMocks: func(*configMocks.MockProvider) {}, // No expectations - expected: false, - }, - { - name: "has secrets but config access fails", - secretOptions: []string{"secret1"}, - setupMocks: func(cp *configMocks.MockProvider) { - mockConfig := &config.Config{} - cp.EXPECT().GetConfig().Return(mockConfig) - }, - expected: false, // Returns false when provider type detection fails - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockConfigProvider := configMocks.NewMockProvider(ctrl) - tt.setupMocks(mockConfigProvider) - - manager := &cliManager{ - configProvider: mockConfigProvider, - } - - result := manager.needSecretsPassword(tt.secretOptions) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestCLIManager_RunWorkloadDetached(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - runConfig *runner.RunConfig - setupMocks func(*statusMocks.MockStatusManager, *configMocks.MockProvider) - expectError bool - errorMsg string - }{ - { - name: "validation failure should not reach PID management", - runConfig: &runner.RunConfig{ - BaseName: "test-workload", - Secrets: []string{"invalid-secret"}, - }, - setupMocks: func(_ *statusMocks.MockStatusManager, cp *configMocks.MockProvider) { - // Mock config provider to cause validation failure - mockConfig := &config.Config{} - cp.EXPECT().GetConfig().Return(mockConfig) - // No SetWorkloadPID expectation since validation should fail first - }, - expectError: true, - errorMsg: "failed to validate workload parameters", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - mockConfigProvider := configMocks.NewMockProvider(ctrl) - tt.setupMocks(mockStatusMgr, mockConfigProvider) - - manager := &cliManager{ - statuses: mockStatusMgr, - configProvider: mockConfigProvider, - } - - ctx := context.Background() - err := manager.RunWorkloadDetached(ctx, tt.runConfig) - - if tt.expectError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - require.NoError(t, err) - } - }) - } -} - -// TestCLIManager_RunWorkloadDetached_PIDManagement tests that PID management -// happens in the later stages of RunWorkloadDetached when the process actually starts. -// This is tested indirectly by verifying the behavior exists in the code flow. -func TestCLIManager_RunWorkloadDetached_PIDManagement(t *testing.T) { - t.Parallel() - - // This test documents the expected behavior: - // 1. RunWorkloadDetached calls SetWorkloadPID after starting the detached process - // 2. The PID management happens after validation and process creation - // 3. SetWorkloadPID failures are logged as warnings but don't fail the operation - - // Since RunWorkloadDetached involves spawning actual processes and complex setup, - // we verify the PID management integration exists by checking the method signature - // and code structure rather than running the full integration. - - manager := &cliManager{} - assert.NotNil(t, manager, "cliManager should be instantiable") - - // Verify the method exists with the correct signature - var runWorkloadDetachedFunc interface{} = manager.RunWorkloadDetached - assert.NotNil(t, runWorkloadDetachedFunc, "RunWorkloadDetached method should exist") -} - -func TestAsyncOperationTimeout(t *testing.T) { - t.Parallel() - - // Test that the timeout constant is properly defined - assert.Equal(t, 5*time.Minute, AsyncOperationTimeout) -} - -func TestErrWorkloadNotRunning(t *testing.T) { - t.Parallel() - - // Test that the error is properly defined - assert.Error(t, ErrWorkloadNotRunning) - assert.Contains(t, ErrWorkloadNotRunning.Error(), "workload not running") -} - -func TestCLIManager_ListWorkloads(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - listAll bool - labelFilters []string - setupMocks func(*statusMocks.MockStatusManager) - expected []core.Workload - expectError bool - errorMsg string - }{ - { - name: "successful listing without filters", - listAll: true, - labelFilters: []string{}, - setupMocks: func(sm *statusMocks.MockStatusManager) { - workloads := []core.Workload{ - {Name: "workload1", Status: runtime.WorkloadStatusRunning}, - {Name: "workload2", Status: runtime.WorkloadStatusStopped}, - } - sm.EXPECT().ListWorkloads(gomock.Any(), true, []string{}).Return(workloads, nil) - sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil).AnyTimes() - }, - expected: []core.Workload{ - {Name: "workload1", Status: runtime.WorkloadStatusRunning}, - {Name: "workload2", Status: runtime.WorkloadStatusStopped}, - }, - expectError: false, - }, - { - name: "error from status manager", - listAll: false, - labelFilters: []string{"env=prod"}, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().ListWorkloads(gomock.Any(), false, []string{"env=prod"}).Return(nil, errors.New("database error")) - }, - expected: nil, - expectError: true, - errorMsg: "database error", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockStatusMgr) - - manager := &cliManager{ - statuses: mockStatusMgr, - } - - ctx := context.Background() - result, err := manager.ListWorkloads(ctx, tt.listAll, tt.labelFilters...) - - if tt.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errorMsg) - } else { - // We expect this to succeed but might include remote workloads - // Since getRemoteWorkloadsFromState will likely fail in unit tests, - // we mainly verify the container workloads are returned - require.NoError(t, err) - assert.GreaterOrEqual(t, len(result), len(tt.expected)) - // Verify at least our expected container workloads are present - for _, expectedWorkload := range tt.expected { - found := false - for _, actualWorkload := range result { - if actualWorkload.Name == expectedWorkload.Name { - found = true - break - } - } - assert.True(t, found, fmt.Sprintf("Expected workload %s not found in result", expectedWorkload.Name)) - } - } - }) - } -} - -func TestCLIManager_UpdateWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - expectError bool - errorMsg string - setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) - }{ - { - name: "invalid workload name with slash", - workloadName: "invalid/name", - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "invalid workload name with backslash", - workloadName: "invalid\\name", - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "invalid workload name with path traversal", - workloadName: "../invalid", - expectError: true, - errorMsg: "invalid workload name", - }, - { - name: "valid workload name returns errgroup immediately", - workloadName: "valid-workload", - expectError: false, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // Mock calls that will happen in the background goroutine - // We don't care about the success/failure, just that it doesn't panic - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "valid-workload"). - Return(runtime.ContainerInfo{}, errors.New("not found")).AnyTimes() - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "valid-workload", gomock.Any(), gomock.Any()). - Return(nil).AnyTimes() - }, - }, - { - name: "UpdateWorkload returns errgroup even if async operation will fail", - workloadName: "failing-workload", - expectError: false, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // The async operation will fail, but UpdateWorkload itself should succeed - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "failing-workload"). - Return(runtime.ContainerInfo{}, errors.New("container lookup failed")).AnyTimes() - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", gomock.Any(), gomock.Any()). - Return(nil).AnyTimes() - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - mockStatusManager := statusMocks.NewMockStatusManager(ctrl) - mockConfigProvider := configMocks.NewMockProvider(ctrl) - - if tt.setupMocks != nil { - tt.setupMocks(mockRuntime, mockStatusManager) - } - - manager := &cliManager{ - runtime: mockRuntime, - statuses: mockStatusManager, - configProvider: mockConfigProvider, - } - - // Create a dummy RunConfig for testing - runConfig := &runner.RunConfig{ - ContainerName: tt.workloadName, - BaseName: tt.workloadName, - } - - ctx := context.Background() - group, err := manager.UpdateWorkload(ctx, tt.workloadName, runConfig) - - if tt.expectError { - assert.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - assert.Nil(t, group) - } else { - assert.NoError(t, err) - assert.NotNil(t, group) - // For valid cases, we get an errgroup but don't wait for completion - // The async operations inside are tested separately - } - }) - } -} - -func TestCLIManager_updateSingleWorkload(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - runConfig *runner.RunConfig - setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) - expectError bool - errorMsg string - }{ - { - name: "stop operation fails", - workloadName: "test-workload", - runConfig: &runner.RunConfig{ - ContainerName: "test-workload", - BaseName: "test-workload", - Group: "default", - }, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // Mock the stop operation - return error for GetWorkloadInfo - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{}, errors.New("container lookup failed")).AnyTimes() - // Still expect status updates to be attempted - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil).AnyTimes() - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, "").Return(nil).AnyTimes() - }, - expectError: true, - errorMsg: "failed to stop workload", - }, - { - name: "successful stop and delete operations complete correctly", - workloadName: "test-workload", - runConfig: &runner.RunConfig{ - ContainerName: "test-workload", - BaseName: "test-workload", - Group: "default", - }, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // Mock stop operation - workload exists and can be stopped - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{ - Name: "test-workload", - State: "running", - Labels: map[string]string{"toolhive-basename": "test-workload"}, - }, nil) - // Mock GetWorkloadPID call from stopProcess - sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil) - rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) - sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil) - - // Mock delete operation - workload exists and can be deleted - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{Name: "test-workload"}, nil) - rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(nil) - - // Mock status updates for stop and delete phases - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopped, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusRemoving, "").Return(nil) - sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "test-workload").Return(nil) - - // Mock RunWorkloadDetached calls - expect the ones that will be called - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStarting, "").Return(nil) - sm.EXPECT().SetWorkloadPID(gomock.Any(), "test-workload", gomock.Any()).Return(nil) - }, - expectError: false, // Test passes - update process completes successfully - }, - { - name: "delete operation fails after successful stop", - workloadName: "test-workload", - runConfig: &runner.RunConfig{ - ContainerName: "test-workload", - BaseName: "test-workload", - Group: "default", - }, - setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { - // Mock successful stop - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{ - Name: "test-workload", - State: "running", - Labels: map[string]string{"toolhive-basename": "test-workload"}, - }, nil) - // Mock GetWorkloadPID call from stopProcess - sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil) - rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) - sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil) - - // Mock failed delete - rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). - Return(runtime.ContainerInfo{Name: "test-workload"}, nil) - rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(errors.New("delete failed")) - - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopped, "").Return(nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusRemoving, "").Return(nil) - // RemoveWorkload fails, so error status is set - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, "delete failed").Return(nil) - }, - expectError: true, - errorMsg: "failed to delete workload", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockRuntime := runtimeMocks.NewMockRuntime(ctrl) - mockStatusManager := statusMocks.NewMockStatusManager(ctrl) - mockConfigProvider := configMocks.NewMockProvider(ctrl) - - if tt.setupMocks != nil { - tt.setupMocks(mockRuntime, mockStatusManager) - } - - manager := &cliManager{ - runtime: mockRuntime, - statuses: mockStatusManager, - configProvider: mockConfigProvider, - } - - err := manager.updateSingleWorkload(tt.workloadName, tt.runConfig) - - if tt.expectError { - assert.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestNewCLIManager(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - wantError bool - }{ - { - name: "successful creation", - wantError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - manager, err := NewCLIManager(ctx) - - // Note: This test may fail if Docker/Podman is not available - // That's acceptable - the function will return an error - if tt.wantError { - require.Error(t, err) - assert.Nil(t, manager) - } else { - // If runtime is available, manager should be created - if err == nil { - require.NotNil(t, manager) - cliMgr, ok := manager.(*cliManager) - require.True(t, ok) - assert.NotNil(t, cliMgr.runtime) - assert.NotNil(t, cliMgr.statuses) - assert.NotNil(t, cliMgr.configProvider) - } - // If runtime is not available, error is acceptable - } - }) - } -} - -func TestNewCLIManagerWithProvider(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - mockConfigProvider := configMocks.NewMockProvider(ctrl) - - tests := []struct { - name string - configProvider config.Provider - wantError bool - }{ - { - name: "successful creation with provider", - configProvider: mockConfigProvider, - wantError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - manager, err := NewCLIManagerWithProvider(ctx, tt.configProvider) - - // Note: This test may fail if Docker/Podman is not available - if tt.wantError { - require.Error(t, err) - assert.Nil(t, manager) - } else { - // If runtime is available, manager should be created - if err == nil { - require.NotNil(t, manager) - cliMgr, ok := manager.(*cliManager) - require.True(t, ok) - assert.NotNil(t, cliMgr.runtime) - assert.NotNil(t, cliMgr.statuses) - assert.Equal(t, tt.configProvider, cliMgr.configProvider) - } - } - }) - } -} - -func TestCLIManager_stopRemoteWorkload(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - tests := []struct { - name string - workloadName string - runConfig *runner.RunConfig - setupMocks func(*statusMocks.MockStatusManager) - wantError bool - errorMsg string - }{ - { - name: "successful stop", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-workload", - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(nil) - // stopProxyIfNeeded calls stopProcess which calls GetWorkloadPID - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-workload").Return(0, errors.New("no PID found")).AnyTimes() - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopped, "").Return(nil) - }, - wantError: false, - }, - { - name: "workload not found", - workloadName: "non-existent", - runConfig: &runner.RunConfig{ - BaseName: "non-existent", - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "non-existent").Return(core.Workload{}, runtime.ErrWorkloadNotFound) - }, - wantError: false, // Returns nil when workload not found - }, - { - name: "workload not running", - workloadName: "stopped-workload", - runConfig: &runner.RunConfig{ - BaseName: "stopped-workload", - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(core.Workload{ - Name: "stopped-workload", - Status: runtime.WorkloadStatusStopped, - }, nil) - }, - wantError: false, // Returns nil when workload not running - }, - { - name: "error getting workload", - workloadName: "error-workload", - runConfig: &runner.RunConfig{ - BaseName: "error-workload", - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "error-workload").Return(core.Workload{}, errors.New("database error")) - }, - wantError: true, - errorMsg: "failed to find workload", - }, - { - name: "error setting stopping status", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-workload", - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(errors.New("status error")) - // stopProxyIfNeeded calls stopProcess which calls GetWorkloadPID - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-workload").Return(0, errors.New("no PID found")).AnyTimes() - // Still sets stopped status even if stopping status fails - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopped, "").Return(nil) - }, - wantError: false, // Continues even if stopping status fails - }, - { - name: "empty base name", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "", // Empty base name - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ - Name: "remote-workload", - Status: runtime.WorkloadStatusRunning, - }, nil) - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(nil) - // stopProxyIfNeeded is called but does nothing if BaseName is empty - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopped, "").Return(nil) - }, - wantError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - mockStatusManager := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockStatusManager) - - manager := &cliManager{ - statuses: mockStatusManager, - } - - ctx := context.Background() - err := manager.stopRemoteWorkload(ctx, tt.workloadName, tt.runConfig) - - if tt.wantError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestCLIManager_GetProxyLogs(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadName string - setup func() (string, func()) // Returns log file path and cleanup function - wantError bool - errorMsg string - checkOutput bool - }{ - { - name: "successful read", - workloadName: "test-workload", - setup: func() (string, func()) { - // Create a temporary directory for XDG_DATA_HOME - tmpDir := t.TempDir() - - // Set XDG_DATA_HOME BEFORE calling xdg.DataFile - // xdg package reads this at initialization, so we need to set it early - oldXDG := os.Getenv("XDG_DATA_HOME") - os.Setenv("XDG_DATA_HOME", tmpDir) - - // Now get the actual path that xdg.DataFile will use - // This ensures we create the file at the exact location xdg will look - expectedLogPath, err := xdg.DataFile("toolhive/logs/test-workload.log") - require.NoError(t, err, "xdg.DataFile should succeed with XDG_DATA_HOME set") - - // Create the directory structure and file at the exact path xdg will use - require.NoError(t, os.MkdirAll(filepath.Dir(expectedLogPath), 0755)) - require.NoError(t, os.WriteFile(expectedLogPath, []byte("test log content"), 0600)) - - return expectedLogPath, func() { - if oldXDG == "" { - os.Unsetenv("XDG_DATA_HOME") - } else { - os.Setenv("XDG_DATA_HOME", oldXDG) - } - } - }, - wantError: false, - checkOutput: true, - }, - { - name: "log file not found", - workloadName: "non-existent", - setup: func() (string, func()) { - tmpDir := t.TempDir() - logDir := filepath.Join(tmpDir, "toolhive", "logs") - require.NoError(t, os.MkdirAll(logDir, 0755)) - // Don't create the log file - - oldXDG := os.Getenv("XDG_DATA_HOME") - os.Setenv("XDG_DATA_HOME", tmpDir) - - return "", func() { - if oldXDG == "" { - os.Unsetenv("XDG_DATA_HOME") - } else { - os.Setenv("XDG_DATA_HOME", oldXDG) - } - } - }, - wantError: true, - errorMsg: "proxy logs not found", - }, - { - name: "invalid workload name with path traversal", - workloadName: "../../etc/passwd", - setup: func() (string, func()) { - return "", func() {} - }, - wantError: true, - // xdg.DataFile may succeed even with path traversal, but file won't exist - // So we get "proxy logs not found" instead of "failed to get proxy log file path" - errorMsg: "proxy logs not found", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - _, cleanup := tt.setup() - defer cleanup() - - manager := &cliManager{} - - ctx := context.Background() - logs, err := manager.GetProxyLogs(ctx, tt.workloadName) - - if tt.wantError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - assert.Empty(t, logs) - } else { - require.NoError(t, err) - if tt.checkOutput { - assert.Contains(t, logs, "test log content") - } - } - }) - } -} - -func TestCLIManager_deleteRemoteWorkload(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - tests := []struct { - name string - workloadName string - runConfig *runner.RunConfig - setupMocks func(*statusMocks.MockStatusManager) - wantError bool - errorMsg string - }{ - { - name: "successful delete", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-workload", - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusRemoving, "").Return(nil) - // stopProxyIfNeeded calls stopProcess which calls GetWorkloadPID - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-workload").Return(0, errors.New("no PID found")).AnyTimes() - sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "remote-workload").Return(nil) - }, - wantError: false, - }, - { - name: "error setting removing status", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-workload", - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusRemoving, "").Return(errors.New("status error")) - }, - wantError: true, - errorMsg: "status error", // The function returns the error directly - }, - { - name: "error deleting status", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "remote-workload", - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusRemoving, "").Return(nil) - // stopProxyIfNeeded calls stopProcess which calls GetWorkloadPID - sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-workload").Return(0, errors.New("no PID found")).AnyTimes() - sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "remote-workload").Return(errors.New("delete error")) - // Error is logged but not returned - }, - wantError: false, // Error is logged but function continues - }, - { - name: "empty base name", - workloadName: "remote-workload", - runConfig: &runner.RunConfig{ - BaseName: "", // Empty base name - RemoteURL: "https://example.com/mcp", - }, - setupMocks: func(sm *statusMocks.MockStatusManager) { - sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusRemoving, "").Return(nil) - // stopProxyIfNeeded does nothing if BaseName is empty - sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "remote-workload").Return(nil) - }, - wantError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - mockStatusManager := statusMocks.NewMockStatusManager(ctrl) - tt.setupMocks(mockStatusManager) - - manager := &cliManager{ - statuses: mockStatusManager, - } - - ctx := context.Background() - err := manager.deleteRemoteWorkload(ctx, tt.workloadName, tt.runConfig) - - if tt.wantError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestCLIManager_cleanupTempPermissionProfile(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - baseName string - setupState func() (func(), error) // Returns cleanup function and error - wantError bool - errorMsg string - }{ - { - name: "no state file", - baseName: "non-existent", - setupState: func() (func(), error) { - // No state file exists - return func() {}, nil - }, - wantError: false, // Returns nil when state doesn't exist - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - cleanup, err := tt.setupState() - defer cleanup() - require.NoError(t, err) - - manager := &cliManager{} - - ctx := context.Background() - err = manager.cleanupTempPermissionProfile(ctx, tt.baseName) - - if tt.wantError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - // Function returns nil when state doesn't exist or no profile to clean - assert.NoError(t, err) - } - }) - } -} - -func TestCLIManager_MoveToGroup(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - workloadNames []string - groupFrom string - groupTo string - setupState func() (func(), error) // Returns cleanup function and error - wantError bool - errorMsg string - }{ - { - name: "invalid workload name", - workloadNames: []string{"../invalid"}, - groupFrom: "group1", - groupTo: "group2", - setupState: func() (func(), error) { - return func() {}, nil - }, - wantError: true, - errorMsg: "invalid workload name", - }, - { - name: "state file not found", - workloadNames: []string{"non-existent"}, - groupFrom: "group1", - groupTo: "group2", - setupState: func() (func(), error) { - return func() {}, nil - }, - wantError: true, - errorMsg: "failed to load runner state", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - cleanup, err := tt.setupState() - defer cleanup() - require.NoError(t, err) - - manager := &cliManager{} - - ctx := context.Background() - err = manager.MoveToGroup(ctx, tt.workloadNames, tt.groupFrom, tt.groupTo) - - if tt.wantError { - require.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index 1aa7a3c17..7aac9e5d6 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -4,14 +4,31 @@ package workloads import ( "context" + "errors" "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + "github.com/adrg/xdg" "golang.org/x/sync/errgroup" + "github.com/stacklok/toolhive/pkg/client" "github.com/stacklok/toolhive/pkg/config" + ct "github.com/stacklok/toolhive/pkg/container" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/labels" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/process" "github.com/stacklok/toolhive/pkg/runner" + "github.com/stacklok/toolhive/pkg/secrets" + "github.com/stacklok/toolhive/pkg/state" + "github.com/stacklok/toolhive/pkg/transport" + "github.com/stacklok/toolhive/pkg/workloads/statuses" + "github.com/stacklok/toolhive/pkg/workloads/types" ) // Manager is responsible for managing the state of ToolHive-managed containers. @@ -54,40 +71,1213 @@ type Manager interface { DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) } +type defaultManager struct { + runtime rt.Runtime + statuses statuses.StatusManager + configProvider config.Provider +} + // ErrWorkloadNotRunning is returned when a container cannot be found by name. var ErrWorkloadNotRunning = fmt.Errorf("workload not running") -// NewManager creates a new CLI workload manager. -// Returns Manager interface (existing behavior, unchanged). -// IMPORTANT: This function only works in CLI mode. +const ( + // AsyncOperationTimeout is the timeout for async workload operations + AsyncOperationTimeout = 5 * time.Minute +) + +// NewManager creates a new container manager instance. func NewManager(ctx context.Context) (Manager, error) { - if rt.IsKubernetesRuntime() { - return nil, fmt.Errorf("workload manager is not available in Kubernetes environments") + runtime, err := ct.NewFactory().Create(ctx) + if err != nil { + return nil, err + } + + statusManager, err := statuses.NewStatusManager(runtime) + if err != nil { + return nil, fmt.Errorf("failed to create status manager: %w", err) } - return NewCLIManager(ctx) + + return &defaultManager{ + runtime: runtime, + statuses: statusManager, + configProvider: config.NewDefaultProvider(), + }, nil } -// NewManagerWithProvider creates a new CLI workload manager with a custom config provider. -// IMPORTANT: This function only works in CLI mode. +// NewManagerWithProvider creates a new container manager instance with a custom config provider. func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) (Manager, error) { - if rt.IsKubernetesRuntime() { - return nil, fmt.Errorf("workload manager is not available in Kubernetes environments") + runtime, err := ct.NewFactory().Create(ctx) + if err != nil { + return nil, err + } + + statusManager, err := statuses.NewStatusManager(runtime) + if err != nil { + return nil, fmt.Errorf("failed to create status manager: %w", err) + } + + return &defaultManager{ + runtime: runtime, + statuses: statusManager, + configProvider: configProvider, + }, nil +} + +// NewManagerFromRuntime creates a new container manager instance from an existing runtime. +func NewManagerFromRuntime(runtime rt.Runtime) (Manager, error) { + statusManager, err := statuses.NewStatusManager(runtime) + if err != nil { + return nil, fmt.Errorf("failed to create status manager: %w", err) + } + + return &defaultManager{ + runtime: runtime, + statuses: statusManager, + configProvider: config.NewDefaultProvider(), + }, nil +} + +// NewManagerFromRuntimeWithProvider creates a new container manager instance from an existing runtime with a +// custom config provider. +func NewManagerFromRuntimeWithProvider(runtime rt.Runtime, configProvider config.Provider) (Manager, error) { + statusManager, err := statuses.NewStatusManager(runtime) + if err != nil { + return nil, fmt.Errorf("failed to create status manager: %w", err) + } + + return &defaultManager{ + runtime: runtime, + statuses: statusManager, + configProvider: configProvider, + }, nil +} + +func (d *defaultManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { + // For the sake of minimizing changes, delegate to the status manager. + // Whether this method should still belong to the workload manager is TBD. + return d.statuses.GetWorkload(ctx, workloadName) +} + +func (d *defaultManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { + // check if workload exists by trying to get it + workload, err := d.statuses.GetWorkload(ctx, workloadName) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + return false, nil + } + return false, fmt.Errorf("failed to check if workload exists: %w", err) + } + + // now check if the workload is not in error + if workload.Status == rt.WorkloadStatusError { + return false, nil + } + return true, nil +} + +func (d *defaultManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { + // For the sake of minimizing changes, delegate to the status manager. + // Whether this method should still belong to the workload manager is TBD. + containerWorkloads, err := d.statuses.ListWorkloads(ctx, listAll, labelFilters) + if err != nil { + return nil, err + } + + // Get remote workloads from the state store + remoteWorkloads, err := d.getRemoteWorkloadsFromState(ctx, listAll, labelFilters) + if err != nil { + logger.Warnf("Failed to get remote workloads from state: %v", err) + // Continue with container workloads only + } else { + // Combine container and remote workloads + containerWorkloads = append(containerWorkloads, remoteWorkloads...) + } + + return containerWorkloads, nil +} + +func (d *defaultManager) StopWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { + // Validate all workload names to prevent path traversal attacks + for _, name := range names { + if err := types.ValidateWorkloadName(name); err != nil { + return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) + } + // Ensure workload name does not contain path traversal or separators + if strings.Contains(name, "..") || strings.ContainsAny(name, "/\\") { + return nil, fmt.Errorf("invalid workload name '%s': contains forbidden characters", name) + } + } + + group := &errgroup.Group{} + // Process each workload + for _, name := range names { + group.Go(func() error { + return d.stopSingleWorkload(name) + }) + } + + return group, nil +} + +// stopSingleWorkload stops a single workload (container or remote) +func (d *defaultManager) stopSingleWorkload(name string) error { + // Create a child context with a longer timeout + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + // First, try to load the run configuration to check if it's a remote workload + runConfig, err := runner.LoadState(childCtx, name) + if err != nil { + // If we can't load the state, it might be a container workload or the workload doesn't exist + // Try to stop it as a container workload + return d.stopContainerWorkload(childCtx, name) + } + + // Check if this is a remote workload + if runConfig.RemoteURL != "" { + return d.stopRemoteWorkload(childCtx, name, runConfig) + } + + // This is a container-based workload + return d.stopContainerWorkload(childCtx, name) +} + +// stopRemoteWorkload stops a remote workload +func (d *defaultManager) stopRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { + logger.Infof("Stopping remote workload %s...", name) + + // Check if the workload is running by checking its status + workload, err := d.statuses.GetWorkload(ctx, name) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + // Log but don't fail the entire operation for not found workload + logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) + return nil + } + return fmt.Errorf("failed to find workload %s: %v", name, err) + } + + if workload.Status != rt.WorkloadStatusRunning { + logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) + return nil + } + + // Set status to stopping + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) + } + + // Stop proxy if running + if runConfig.BaseName != "" { + d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) + } + + // For remote workloads, we only need to clean up client configurations + // The saved state should be preserved for restart capability + if err := removeClientConfigurations(name, false); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } else { + logger.Infof("Client configurations for %s removed", name) + } + + // Set status to stopped + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopped, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopped: %v", name, err) + } + logger.Infof("Remote workload %s stopped successfully", name) + return nil +} + +// stopContainerWorkload stops a container-based workload +func (d *defaultManager) stopContainerWorkload(ctx context.Context, name string) error { + container, err := d.runtime.GetWorkloadInfo(ctx, name) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + // Log but don't fail the entire operation for not found containers + logger.Warnf("Warning: Failed to stop workload %s: %v", name, err) + return nil + } + return fmt.Errorf("failed to find workload %s: %v", name, err) + } + + running := container.IsRunning() + if !running { + // Log but don't fail the entire operation for not running containers + logger.Warnf("Warning: Failed to stop workload %s: %v", name, ErrWorkloadNotRunning) + return nil + } + + // Transition workload to `stopping` state. + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) + } + + // Use the existing stopWorkloads method for container workloads + return d.stopSingleContainerWorkload(ctx, &container) +} + +func (d *defaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error { + // Ensure that the workload has a status entry before starting the process. + if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { + // Failure to create the initial state is a fatal error. + return fmt.Errorf("failed to create workload status: %v", err) + } + + mcpRunner := runner.NewRunner(runConfig, d.statuses) + err := mcpRunner.Run(ctx) + if err != nil { + // If the run failed, we should set the status to error. + if statusErr := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, statusErr) + } + } + return err +} + +func (d *defaultManager) validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { + // If there are run secrets, validate them + + hasRegularSecrets := len(runConfig.Secrets) > 0 + hasRemoteAuthSecret := runConfig.RemoteAuthConfig != nil && runConfig.RemoteAuthConfig.ClientSecret != "" + + if hasRegularSecrets || hasRemoteAuthSecret { + cfg := d.configProvider.GetConfig() + + providerType, err := cfg.Secrets.GetProviderType() + if err != nil { + return fmt.Errorf("error determining secrets provider type: %w", err) + } + + secretManager, err := secrets.CreateSecretProvider(providerType) + if err != nil { + return fmt.Errorf("error instantiating secret manager: %w", err) + } + + err = runConfig.ValidateSecrets(ctx, secretManager) + if err != nil { + return fmt.Errorf("error processing secrets: %w", err) + } + } + return nil +} + +func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { + // before running, validate the parameters for the workload + err := d.validateSecretParameters(ctx, runConfig) + if err != nil { + return fmt.Errorf("failed to validate workload parameters: %w", err) + } + + // Get the current executable path + execPath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %v", err) + } + + // Create a log file for the detached process + logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", runConfig.BaseName)) + if err != nil { + return fmt.Errorf("failed to create log file path: %v", err) + } + // #nosec G304 - This is safe as baseName is generated by the application + logFile, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) + if err != nil { + logger.Warnf("Warning: Failed to create log file: %v", err) + } else { + defer logFile.Close() + logger.Infof("Logging to: %s", logFilePath) + } + + // Use the restart command to start the detached process + // The config has already been saved to disk, so restart can load it + detachedArgs := []string{"restart", runConfig.BaseName, "--foreground"} + + // Create a new command + // #nosec G204 - This is safe as execPath is the path to the current binary + detachedCmd := exec.Command(execPath, detachedArgs...) + + // Set environment variables for the detached process + detachedCmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", process.ToolHiveDetachedEnv, process.ToolHiveDetachedValue)) + + // If we need the decrypt password, set it as an environment variable in the detached process. + // NOTE: This breaks the abstraction slightly since this is only relevant for the CLI, but there + // are checks inside `GetSecretsPassword` to ensure this does not get called in a detached process. + // This will be addressed in a future re-think of the secrets manager interface. + if d.needSecretsPassword(runConfig.Secrets) { + password, err := secrets.GetSecretsPassword("") + if err != nil { + return fmt.Errorf("failed to get secrets password: %v", err) + } + detachedCmd.Env = append(detachedCmd.Env, fmt.Sprintf("%s=%s", secrets.PasswordEnvVar, password)) + } + + // Redirect stdout and stderr to the log file if it was created successfully + if logFile != nil { + detachedCmd.Stdout = logFile + detachedCmd.Stderr = logFile + } else { + // Otherwise, discard the output + detachedCmd.Stdout = nil + detachedCmd.Stderr = nil + } + + // Detach the process from the terminal + detachedCmd.Stdin = nil + detachedCmd.SysProcAttr = getSysProcAttr() + + // Ensure that the workload has a status entry before starting the process. + if err = d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { + // Failure to create the initial state is a fatal error. + return fmt.Errorf("failed to create workload status: %v", err) + } + + // Start the detached process + if err := detachedCmd.Start(); err != nil { + // If the start failed, we need to set the status to error before returning. + if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, ""); err != nil { + logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, err) + } + return fmt.Errorf("failed to start detached process: %v", err) + } + + // Write the PID to a file so the stop command can kill the process + // TODO: Stop writing to PID file once we migrate over to statuses fully. + if err := process.WritePIDFile(runConfig.BaseName, detachedCmd.Process.Pid); err != nil { + logger.Warnf("Warning: Failed to write PID file: %v", err) + } + if err := d.statuses.SetWorkloadPID(ctx, runConfig.BaseName, detachedCmd.Process.Pid); err != nil { + logger.Warnf("Failed to set workload %s PID: %v", runConfig.BaseName, err) + } + + logger.Infof("MCP server is running in the background (PID: %d)", detachedCmd.Process.Pid) + logger.Infof("Use 'thv stop %s' to stop the server", runConfig.ContainerName) + + return nil +} + +func (d *defaultManager) GetLogs(ctx context.Context, workloadName string, follow bool) (string, error) { + // Get the logs from the runtime + logs, err := d.runtime.GetWorkloadLogs(ctx, workloadName, follow) + if err != nil { + // Propagate the error if the container is not found + if errors.Is(err, rt.ErrWorkloadNotFound) { + return "", fmt.Errorf("%w: %s", rt.ErrWorkloadNotFound, workloadName) + } + return "", fmt.Errorf("failed to get container logs %s: %v", workloadName, err) + } + + return logs, nil +} + +// GetProxyLogs retrieves proxy logs from the filesystem +func (*defaultManager) GetProxyLogs(_ context.Context, workloadName string) (string, error) { + // Get the proxy log file path + logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", workloadName)) + if err != nil { + return "", fmt.Errorf("failed to get proxy log file path for workload %s: %w", workloadName, err) + } + + // Clean the file path to prevent path traversal + cleanLogFilePath := filepath.Clean(logFilePath) + + // Check if the log file exists + if _, err := os.Stat(cleanLogFilePath); os.IsNotExist(err) { + return "", fmt.Errorf("proxy logs not found for workload %s", workloadName) + } + + // Read and return the entire log file + content, err := os.ReadFile(cleanLogFilePath) + if err != nil { + return "", fmt.Errorf("failed to read proxy log for workload %s: %w", workloadName, err) + } + + return string(content), nil +} + +// deleteWorkload handles deletion of a single workload +func (d *defaultManager) deleteWorkload(name string) error { + // Create a child context with a longer timeout + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + // First, check if this is a remote workload by trying to load its run configuration + runConfig, err := runner.LoadState(childCtx, name) + if err != nil { + // If we can't load the state, it might be a container workload or the workload doesn't exist + // Continue with the container-based deletion logic + return d.deleteContainerWorkload(childCtx, name) + } + + // If this is a remote workload (has RemoteURL), handle it differently + if runConfig.RemoteURL != "" { + return d.deleteRemoteWorkload(childCtx, name, runConfig) + } + + // This is a container-based workload, use the existing logic + return d.deleteContainerWorkload(childCtx, name) +} + +// deleteRemoteWorkload handles deletion of a remote workload +func (d *defaultManager) deleteRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { + logger.Infof("Removing remote workload %s...", name) + + // Set status to removing + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusRemoving, ""); err != nil { + logger.Warnf("Failed to set workload %s status to removing: %v", name, err) + return err + } + + // Stop proxy if running + if runConfig.BaseName != "" { + d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) + } + + // Clean up associated resources (remote workloads are not auxiliary) + d.cleanupWorkloadResources(ctx, name, runConfig.BaseName, false) + + // Remove the workload status from the status store + if err := d.statuses.DeleteWorkloadStatus(ctx, name); err != nil { + logger.Warnf("failed to delete workload status for %s: %v", name, err) + } + + logger.Infof("Remote workload %s removed successfully", name) + return nil +} + +// deleteContainerWorkload handles deletion of a container-based workload (existing logic) +func (d *defaultManager) deleteContainerWorkload(ctx context.Context, name string) error { + + // Find and validate the container + container, err := d.getWorkloadContainer(ctx, name) + if err != nil { + return err + } + + // Set status to removing + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusRemoving, ""); err != nil { + logger.Warnf("Failed to set workload %s status to removing: %v", name, err) + } + + if container != nil { + containerLabels := container.Labels + baseName := labels.GetContainerBaseName(containerLabels) + + // Stop proxy if running (skip for auxiliary workloads like inspector) + if container.IsRunning() { + // Skip proxy stopping for auxiliary workloads that don't use proxy processes + if labels.IsAuxiliaryWorkload(containerLabels) { + logger.Debugf("Skipping proxy stop for auxiliary workload %s", name) + } else { + d.stopProxyIfNeeded(ctx, name, baseName) + } + } + + // Remove the container + if err := d.removeContainer(ctx, name); err != nil { + return err + } + + // Clean up associated resources + d.cleanupWorkloadResources(ctx, name, baseName, labels.IsAuxiliaryWorkload(containerLabels)) + } + + // Remove the workload status from the status store + if err := d.statuses.DeleteWorkloadStatus(ctx, name); err != nil { + logger.Warnf("failed to delete workload status for %s: %v", name, err) + } + + return nil +} + +// getWorkloadContainer retrieves workload container info with error handling +func (d *defaultManager) getWorkloadContainer(ctx context.Context, name string) (*rt.ContainerInfo, error) { + container, err := d.runtime.GetWorkloadInfo(ctx, name) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + // Log but don't fail the entire operation for not found containers + logger.Warnf("Warning: Failed to get workload %s: %v", name, err) + return nil, nil + } + if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + } + return nil, fmt.Errorf("failed to find workload %s: %v", name, err) + } + return &container, nil +} + +// isSupervisorProcessAlive checks if the supervisor process for a workload is alive +// by checking if a PID exists. If a PID exists, we assume the supervisor is running. +// This is a reasonable assumption because: +// - If the supervisor exits cleanly, it cleans up the PID +// - If killed unexpectedly, the PID remains but stopProcess will handle it gracefully +// - The main issue we're preventing is accumulating zombie supervisors from repeated restarts +func (d *defaultManager) isSupervisorProcessAlive(ctx context.Context, name string) bool { + if name == "" { + return false + } + + // Try to read the PID - if it exists, assume supervisor is running + _, err := d.statuses.GetWorkloadPID(ctx, name) + if err != nil { + // No PID found, supervisor is not running + return false + } + + // PID exists, assume supervisor is alive + return true +} + +// stopProcess stops the proxy process associated with the container +func (d *defaultManager) stopProcess(ctx context.Context, name string) { + if name == "" { + logger.Warnf("Warning: Could not find base container name in labels") + return + } + + // Try to read the PID and kill the process + pid, err := d.statuses.GetWorkloadPID(ctx, name) + if err != nil { + logger.Errorf("No PID file found for %s, proxy may not be running in detached mode", name) + return + } + + // PID file found, try to kill the process + logger.Infof("Stopping proxy process (PID: %d)...", pid) + if err := process.KillProcess(pid); err != nil { + logger.Warnf("Warning: Failed to kill proxy process: %v", err) + } else { + logger.Info("Proxy process stopped") + } + + // Clean up PID file after successful kill + if err := process.RemovePIDFile(name); err != nil { + logger.Warnf("Warning: Failed to remove PID file: %v", err) + } +} + +// stopProxyIfNeeded stops the proxy process if the workload has a base name +func (d *defaultManager) stopProxyIfNeeded(ctx context.Context, name, baseName string) { + logger.Infof("Removing proxy process for %s...", name) + if baseName != "" { + d.stopProcess(ctx, baseName) + } +} + +// removeContainer removes the container from the runtime +func (d *defaultManager) removeContainer(ctx context.Context, name string) error { + logger.Infof("Removing container %s...", name) + if err := d.runtime.RemoveWorkload(ctx, name); err != nil { + if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + } + return fmt.Errorf("failed to remove container: %v", err) + } + return nil +} + +// cleanupWorkloadResources cleans up all resources associated with a workload +func (d *defaultManager) cleanupWorkloadResources(ctx context.Context, name, baseName string, isAuxiliary bool) { + if baseName == "" { + return + } + + // Clean up temporary permission profile + if err := d.cleanupTempPermissionProfile(ctx, baseName); err != nil { + logger.Warnf("Warning: Failed to cleanup temporary permission profile: %v", err) + } + + // Remove client configurations + if err := removeClientConfigurations(name, isAuxiliary); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } else { + logger.Infof("Client configurations for %s removed", name) + } + + // Delete the saved state last (skip for auxiliary workloads that don't have run configs) + if !isAuxiliary { + if err := state.DeleteSavedRunConfig(ctx, baseName); err != nil { + logger.Warnf("Warning: Failed to delete saved state: %v", err) + } else { + logger.Infof("Saved state for %s removed", baseName) + } + } else { + logger.Debugf("Skipping saved state deletion for auxiliary workload %s", name) + } + + logger.Infof("Container %s removed", name) +} + +func (d *defaultManager) DeleteWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { + // Validate all workload names to prevent path traversal attacks + for _, name := range names { + if err := types.ValidateWorkloadName(name); err != nil { + return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) + } + } + + group := &errgroup.Group{} + + for _, name := range names { + group.Go(func() error { + return d.deleteWorkload(name) + }) } - return NewCLIManagerWithProvider(ctx, configProvider) + + return group, nil } -// NewManagerFromRuntime creates a new CLI workload manager from an existing runtime. -// This function works with any runtime type. The status manager will automatically -// detect the environment and use the appropriate implementation. -// Proxyrunner uses this with Kubernetes runtime to create StatefulSets. -func NewManagerFromRuntime(rtRuntime rt.Runtime) (Manager, error) { - return NewCLIManagerFromRuntime(rtRuntime) +// RestartWorkloads restarts the specified workloads by name. +func (d *defaultManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) { + // Validate all workload names to prevent path traversal attacks + for _, name := range names { + if err := types.ValidateWorkloadName(name); err != nil { + return nil, fmt.Errorf("invalid workload name '%s': %w", name, err) + } + } + + group := &errgroup.Group{} + + for _, name := range names { + group.Go(func() error { + return d.restartSingleWorkload(name, foreground) + }) + } + + return group, nil } -// NewManagerFromRuntimeWithProvider creates a new CLI workload manager from an existing runtime with a custom config provider. -// This function works with any runtime type. The status manager will automatically -// detect the environment and use the appropriate implementation. -// Proxyrunner uses this with Kubernetes runtime to create StatefulSets. -func NewManagerFromRuntimeWithProvider(rtRuntime rt.Runtime, configProvider config.Provider) (Manager, error) { - return NewCLIManagerFromRuntimeWithProvider(rtRuntime, configProvider) +// UpdateWorkload updates a workload by stopping, deleting, and recreating it +func (d *defaultManager) UpdateWorkload(_ context.Context, workloadName string, newConfig *runner.RunConfig) (*errgroup.Group, error) { //nolint:lll + // Validate workload name + if err := types.ValidateWorkloadName(workloadName); err != nil { + return nil, fmt.Errorf("invalid workload name '%s': %w", workloadName, err) + } + + group := &errgroup.Group{} + group.Go(func() error { + return d.updateSingleWorkload(workloadName, newConfig) + }) + return group, nil +} + +// updateSingleWorkload handles the update logic for a single workload +func (d *defaultManager) updateSingleWorkload(workloadName string, newConfig *runner.RunConfig) error { + // Create a child context with a longer timeout + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + logger.Infof("Starting update for workload %s", workloadName) + + // Stop the existing workload + if err := d.stopSingleWorkload(workloadName); err != nil { + return fmt.Errorf("failed to stop workload: %w", err) + } + logger.Infof("Successfully stopped workload %s", workloadName) + + // Delete the existing workload + if err := d.deleteWorkload(workloadName); err != nil { + return fmt.Errorf("failed to delete workload: %w", err) + } + logger.Infof("Successfully deleted workload %s", workloadName) + + // Save the new workload configuration state + if err := newConfig.SaveState(childCtx); err != nil { + logger.Errorf("Failed to save workload config: %v", err) + return fmt.Errorf("failed to save workload config: %w", err) + } + + // Step 3: Start the new workload + // TODO: This currently just handles detached processes and wouldn't work for + // foreground CLI executions. Should be refactored to support both modes. + if err := d.RunWorkloadDetached(childCtx, newConfig); err != nil { + return fmt.Errorf("failed to start new workload: %w", err) + } + + logger.Infof("Successfully completed update for workload %s", workloadName) + return nil +} + +// restartSingleWorkload handles the restart logic for a single workload +func (d *defaultManager) restartSingleWorkload(name string, foreground bool) error { + // Create a child context with a longer timeout + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + // First, try to load the run configuration to check if it's a remote workload + runConfig, err := runner.LoadState(childCtx, name) + if err != nil { + // If we can't load the state, it might be a container workload or the workload doesn't exist + // Try to restart it as a container workload + return d.restartContainerWorkload(childCtx, name, foreground) + } + + // Check if this is a remote workload + if runConfig.RemoteURL != "" { + return d.restartRemoteWorkload(childCtx, name, runConfig, foreground) + } + + // This is a container-based workload + return d.restartContainerWorkload(childCtx, name, foreground) +} + +// restartRemoteWorkload handles restarting a remote workload +func (d *defaultManager) restartRemoteWorkload( + ctx context.Context, + name string, + runConfig *runner.RunConfig, + foreground bool, +) error { + // Get workload status using the status manager + workload, err := d.statuses.GetWorkload(ctx, name) + if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) { + return err + } + + // If workload is already running, check if the supervisor process is healthy + if err == nil && workload.Status == rt.WorkloadStatusRunning { + // Check if the supervisor process is actually alive + supervisorAlive := d.isSupervisorProcessAlive(ctx, runConfig.BaseName) + + if supervisorAlive { + // Workload is running and healthy - preserve old behavior (no-op) + logger.Infof("Remote workload %s is already running", name) + return nil + } + + // Supervisor is dead/missing - we need to clean up and restart to fix the damaged state + logger.Infof("Remote workload %s is running but supervisor is dead, cleaning up before restart", name) + + // Set status to stopping + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopping, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopping: %v", name, err) + } + + // Stop the supervisor process (proxy) if it exists (may already be dead) + // This ensures we clean up any orphaned supervisor processes + d.stopProxyIfNeeded(ctx, name, runConfig.BaseName) + + // Clean up client configurations + if err := removeClientConfigurations(name, false); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } + + // Set status to stopped after cleanup is complete + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStopped, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopped: %v", name, err) + } + } + + // Load runner configuration from state + mcpRunner, err := d.loadRunnerFromState(ctx, runConfig.BaseName) + if err != nil { + return fmt.Errorf("failed to load state for %s: %v", runConfig.BaseName, err) + } + + // Set status to starting + if err := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusStarting, ""); err != nil { + logger.Warnf("Failed to set workload %s status to starting: %v", name, err) + } + + logger.Infof("Loaded configuration from state for %s", runConfig.BaseName) + + // Start the remote workload using the loaded runner + // Use background context to avoid timeout cancellation - same reasoning as container workloads + return d.startWorkload(context.Background(), name, mcpRunner, foreground) +} + +// restartContainerWorkload handles restarting a container-based workload +// +//nolint:gocyclo // Complexity is justified - handles multiple restart scenarios and edge cases +func (d *defaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error { + // Get container info to resolve partial names and extract proper workload name + var containerName string + var workloadName string + + container, err := d.runtime.GetWorkloadInfo(ctx, name) + if err == nil { + // If we found the container, use its actual container name for runtime operations + containerName = container.Name + // Extract the workload name (base name) from container labels for status operations + workloadName = labels.GetContainerBaseName(container.Labels) + if workloadName == "" { + // Fallback to the provided name if base name is not available + workloadName = name + } + } else { + // If container not found, use the provided name as both container and workload name + containerName = name + workloadName = name + } + + // Get workload status using the status manager + workload, err := d.statuses.GetWorkload(ctx, name) + if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) { + return err + } + + // Check if workload is running and healthy (including supervisor process) + if err == nil && workload.Status == rt.WorkloadStatusRunning { + // Check if the supervisor process is actually alive + supervisorAlive := d.isSupervisorProcessAlive(ctx, workloadName) + + if supervisorAlive { + // Workload is running and healthy - preserve old behavior (no-op) + logger.Infof("Container %s is already running", containerName) + return nil + } + + // Supervisor is dead/missing - we need to clean up and restart to fix the damaged state + logger.Infof("Container %s is running but supervisor is dead, cleaning up before restart", containerName) + } + + // Check if we need to stop the workload before restarting + // This happens when: 1) container is running, or 2) inconsistent state + shouldStop := false + if err == nil && workload.Status == rt.WorkloadStatusRunning { + // Workload status shows running (and supervisor is dead, otherwise we would have returned above) + shouldStop = true + } else if container.IsRunning() { + // Container is running but status is not running (inconsistent state) + shouldStop = true + } + + // If we need to stop, do it now (including cleanup of any remaining supervisor process) + if shouldStop { + logger.Infof("Stopping container %s before restart", containerName) + + // Set status to stopping + if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStopping, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopping: %v", workloadName, err) + } + + // Stop the supervisor process (proxy) if it exists (may already be dead) + // This ensures we clean up any orphaned supervisor processes + if !labels.IsAuxiliaryWorkload(container.Labels) { + d.stopProcess(ctx, workloadName) + } + + // Now stop the container if it's running + if container.IsRunning() { + if err := d.runtime.StopWorkload(ctx, containerName); err != nil { + if statusErr := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", workloadName, statusErr) + } + return fmt.Errorf("failed to stop container %s: %v", containerName, err) + } + logger.Infof("Container %s stopped", containerName) + } + + // Clean up client configurations + if err := removeClientConfigurations(workloadName, labels.IsAuxiliaryWorkload(container.Labels)); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } + + // Set status to stopped after cleanup is complete + if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStopped, ""); err != nil { + logger.Debugf("Failed to set workload %s status to stopped: %v", workloadName, err) + } + } + + // Load runner configuration from state + mcpRunner, err := d.loadRunnerFromState(ctx, workloadName) + if err != nil { + return fmt.Errorf("failed to load state for %s: %v", workloadName, err) + } + + // Set workload status to starting - use the workload name for status operations + if err := d.statuses.SetWorkloadStatus(ctx, workloadName, rt.WorkloadStatusStarting, ""); err != nil { + logger.Warnf("Failed to set workload %s status to starting: %v", workloadName, err) + } + logger.Infof("Loaded configuration from state for %s", workloadName) + + // Start the workload with background context to avoid timeout cancellation + // The ctx with AsyncOperationTimeout is only for the restart setup operations, + // but the actual workload should run indefinitely with its own lifecycle management + // Use workload name for user-facing operations + return d.startWorkload(context.Background(), workloadName, mcpRunner, foreground) +} + +// startWorkload starts the workload in either foreground or background mode +func (d *defaultManager) startWorkload(ctx context.Context, name string, mcpRunner *runner.Runner, foreground bool) error { + logger.Infof("Starting tooling server %s...", name) + + var err error + if foreground { + err = d.RunWorkload(ctx, mcpRunner.Config) + } else { + err = d.RunWorkloadDetached(ctx, mcpRunner.Config) + } + + if err != nil { + // If we could not start the workload, set the status to error before returning + if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, ""); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + } + } + return err +} + +// TODO: Move to dedicated config management interface. +// updateClientConfigurations updates client configuration files with the MCP server URL +func removeClientConfigurations(containerName string, isAuxiliary bool) error { + // Get the workload's group by loading its run config + runConfig, err := runner.LoadState(context.Background(), containerName) + var group string + if err != nil { + // Only warn for non-auxiliary workloads since auxiliary workloads don't have run configs + if !isAuxiliary { + logger.Warnf("Warning: Failed to load run config for %s, will use backward compatible behavior: %v", containerName, err) + } + // Continue with empty group (backward compatibility) + } else { + group = runConfig.Group + } + + clientManager, err := client.NewManager(context.Background()) + if err != nil { + logger.Warnf("Warning: Failed to create client manager for %s, skipping client config removal: %v", containerName, err) + return nil + } + + return clientManager.RemoveServerFromClients(context.Background(), containerName, group) +} + +// loadRunnerFromState attempts to load a Runner from the state store +func (d *defaultManager) loadRunnerFromState(ctx context.Context, baseName string) (*runner.Runner, error) { + // Load the run config from the state store + runConfig, err := runner.LoadState(ctx, baseName) + if err != nil { + return nil, err + } + + if runConfig.RemoteURL != "" { + // For remote workloads, we don't need a deployer + runConfig.Deployer = nil + } else { + // Update the runtime in the loaded configuration + runConfig.Deployer = d.runtime + } + + // Create a new runner with the loaded configuration + return runner.NewRunner(runConfig, d.statuses), nil +} + +func (d *defaultManager) needSecretsPassword(secretOptions []string) bool { + // If the user did not ask for any secrets, then don't attempt to instantiate + // the secrets manager. + if len(secretOptions) == 0 { + return false + } + // Ignore err - if the flag is not set, it's not needed. + providerType, _ := d.configProvider.GetConfig().Secrets.GetProviderType() + return providerType == secrets.EncryptedType +} + +// cleanupTempPermissionProfile cleans up temporary permission profile files for a given base name +func (*defaultManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { + // Try to load the saved configuration to get the permission profile path + runConfig, err := runner.LoadState(ctx, baseName) + if err != nil { + // If we can't load the state, there's nothing to clean up + logger.Debugf("Could not load state for %s, skipping permission profile cleanup: %v", baseName, err) + return nil + } + + // Clean up the temporary permission profile if it exists + if runConfig.PermissionProfileNameOrPath != "" { + if err := runner.CleanupTempPermissionProfile(runConfig.PermissionProfileNameOrPath); err != nil { + return fmt.Errorf("failed to cleanup temporary permission profile: %v", err) + } + } + + return nil +} + +// stopSingleContainerWorkload stops a single container workload +func (d *defaultManager) stopSingleContainerWorkload(ctx context.Context, workload *rt.ContainerInfo) error { + childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) + defer cancel() + + name := labels.GetContainerBaseName(workload.Labels) + // Stop the proxy process (skip for auxiliary workloads like inspector) + if labels.IsAuxiliaryWorkload(workload.Labels) { + logger.Debugf("Skipping proxy stop for auxiliary workload %s", name) + } else { + d.stopProcess(ctx, name) + } + + // TODO: refactor the StopProcess function to stop dealing explicitly with PID files. + // Note that this is not a blocker for k8s since this code path is not called there. + if err := d.statuses.ResetWorkloadPID(ctx, name); err != nil { + logger.Warnf("Warning: Failed to reset workload %s PID: %v", name, err) + } + + logger.Infof("Stopping containers for %s...", name) + // Stop the container + if err := d.runtime.StopWorkload(childCtx, workload.Name); err != nil { + if statusErr := d.statuses.SetWorkloadStatus(childCtx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", name, statusErr) + } + return fmt.Errorf("failed to stop container: %w", err) + } + + if err := removeClientConfigurations(name, labels.IsAuxiliaryWorkload(workload.Labels)); err != nil { + logger.Warnf("Warning: Failed to remove client configurations: %v", err) + } else { + logger.Infof("Client configurations for %s removed", name) + } + + if err := d.statuses.SetWorkloadStatus(childCtx, name, rt.WorkloadStatusStopped, ""); err != nil { + logger.Warnf("Failed to set workload %s status to stopped: %v", name, err) + } + logger.Infof("Successfully stopped %s...", name) + return nil +} + +// MoveToGroup moves the specified workloads from one group to another by updating their runconfig. +func (*defaultManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { + for _, workloadName := range workloadNames { + // Validate workload name + if err := types.ValidateWorkloadName(workloadName); err != nil { + return fmt.Errorf("invalid workload name %s: %w", workloadName, err) + } + + // Load the runner state to check and update the configuration + runnerConfig, err := runner.LoadState(ctx, workloadName) + if err != nil { + return fmt.Errorf("failed to load runner state for workload %s: %w", workloadName, err) + } + + // Check if the workload is actually in the specified group + if runnerConfig.Group != groupFrom { + logger.Debugf("Workload %s is not in group %s (current group: %s), skipping", + workloadName, groupFrom, runnerConfig.Group) + continue + } + + // Move the workload to the default group + runnerConfig.Group = groupTo + + // Save the updated configuration + if err = runnerConfig.SaveState(ctx); err != nil { + return fmt.Errorf("failed to save updated configuration for workload %s: %w", workloadName, err) + } + + logger.Infof("Moved workload %s to default group", workloadName) + } + + return nil +} + +// ListWorkloadsInGroup returns all workload names that belong to the specified group +func (d *defaultManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { + workloads, err := d.ListWorkloads(ctx, true) // listAll=true to include stopped workloads + if err != nil { + return nil, fmt.Errorf("failed to list workloads: %w", err) + } + + // Filter workloads that belong to the specified group + var groupWorkloads []string + for _, workload := range workloads { + if workload.Group == groupName { + groupWorkloads = append(groupWorkloads, workload.Name) + } + } + + return groupWorkloads, nil +} + +// getRemoteWorkloadsFromState retrieves remote servers from the state store +func (d *defaultManager) getRemoteWorkloadsFromState( + ctx context.Context, + listAll bool, + labelFilters []string, +) ([]core.Workload, error) { + // Create a state store + store, err := state.NewRunConfigStore(state.DefaultAppName) + if err != nil { + return nil, fmt.Errorf("failed to create state store: %w", err) + } + + // List all configurations + configNames, err := store.List(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list configurations: %w", err) + } + + // Parse the filters into a format we can use for matching + parsedFilters, err := types.ParseLabelFilters(labelFilters) + if err != nil { + return nil, fmt.Errorf("failed to parse label filters: %v", err) + } + + var remoteWorkloads []core.Workload + + for _, name := range configNames { + // Load the run configuration + runConfig, err := runner.LoadState(ctx, name) + if err != nil { + logger.Warnf("failed to load state for %s: %v", name, err) + continue + } + + // Only include remote servers (those with RemoteURL set) + if runConfig.RemoteURL == "" { + continue + } + + // Check the status from the status file + workloadStatus, err := d.statuses.GetWorkload(ctx, name) + if err != nil { + logger.Warnf("failed to get status for remote workload %s: %v", name, err) + continue + } + + // Apply listAll filter - only include running workloads unless listAll is true + if !listAll && workloadStatus.Status != rt.WorkloadStatusRunning { + continue + } + + // Use the transport type directly since it's already parsed + transportType := runConfig.Transport + + // Generate the local proxy URL (not the remote server URL) + proxyURL := "" + if runConfig.Port > 0 { + proxyURL = transport.GenerateMCPServerURL( + transportType.String(), + transport.LocalhostIPv4, + runConfig.Port, + name, + runConfig.RemoteURL, // Pass remote URL to preserve path + ) + } + + // Calculate the effective proxy mode that clients should use + effectiveProxyMode := types.GetEffectiveProxyMode(transportType, string(runConfig.ProxyMode)) + + // Create a workload from the run configuration + workload := core.Workload{ + Name: name, + Package: "remote", + Status: workloadStatus.Status, + URL: proxyURL, + Port: runConfig.Port, + TransportType: transportType, + ProxyMode: effectiveProxyMode, + ToolType: "remote", + Group: runConfig.Group, + CreatedAt: workloadStatus.CreatedAt, + Labels: runConfig.ContainerLabels, + Remote: true, + } + + // Apply label filtering + if types.MatchesLabelFilters(workload.Labels, parsedFilters) { + remoteWorkloads = append(remoteWorkloads, workload) + } + } + + return remoteWorkloads, nil } diff --git a/pkg/workloads/manager_test.go b/pkg/workloads/manager_test.go index d1421e63f..ea971127e 100644 --- a/pkg/workloads/manager_test.go +++ b/pkg/workloads/manager_test.go @@ -1,16 +1,185 @@ package workloads import ( + "context" + "errors" + "fmt" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "golang.org/x/sync/errgroup" + "github.com/stacklok/toolhive/pkg/config" configMocks "github.com/stacklok/toolhive/pkg/config/mocks" + "github.com/stacklok/toolhive/pkg/container/runtime" runtimeMocks "github.com/stacklok/toolhive/pkg/container/runtime/mocks" + "github.com/stacklok/toolhive/pkg/core" + "github.com/stacklok/toolhive/pkg/runner" + statusMocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" ) +func TestDefaultManager_ListWorkloadsInGroup(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + groupName string + mockWorkloads []core.Workload + expectedNames []string + expectError bool + setupStatusMgr func(*statusMocks.MockStatusManager) + }{ + { + name: "non existent group returns empty list", + groupName: "non-group", + mockWorkloads: []core.Workload{ + {Name: "workload1", Group: "other-group"}, + {Name: "workload2", Group: "another-group"}, + }, + expectedNames: []string{}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ + {Name: "workload1", Group: "other-group"}, + {Name: "workload2", Group: "another-group"}, + }, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + { + name: "multiple workloads in group", + groupName: "test-group", + mockWorkloads: []core.Workload{ + {Name: "workload1", Group: "test-group"}, + {Name: "workload2", Group: "other-group"}, + {Name: "workload3", Group: "test-group"}, + {Name: "workload4", Group: "test-group"}, + }, + expectedNames: []string{"workload1", "workload3", "workload4"}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ + {Name: "workload1", Group: "test-group"}, + {Name: "workload2", Group: "other-group"}, + {Name: "workload3", Group: "test-group"}, + {Name: "workload4", Group: "test-group"}, + }, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + { + name: "workloads with empty group names", + groupName: "", + mockWorkloads: []core.Workload{ + {Name: "workload1", Group: ""}, + {Name: "workload2", Group: "test-group"}, + {Name: "workload3", Group: ""}, + }, + expectedNames: []string{"workload1", "workload3"}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ + {Name: "workload1", Group: ""}, + {Name: "workload2", Group: "test-group"}, + {Name: "workload3", Group: ""}, + }, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + { + name: "includes stopped workloads", + groupName: "test-group", + mockWorkloads: []core.Workload{ + {Name: "running-workload", Group: "test-group", Status: runtime.WorkloadStatusRunning}, + {Name: "stopped-workload", Group: "test-group", Status: runtime.WorkloadStatusStopped}, + {Name: "other-group-workload", Group: "other-group", Status: runtime.WorkloadStatusRunning}, + }, + expectedNames: []string{"running-workload", "stopped-workload"}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{ + {Name: "running-workload", Group: "test-group", Status: runtime.WorkloadStatusRunning}, + {Name: "stopped-workload", Group: "test-group", Status: runtime.WorkloadStatusStopped}, + {Name: "other-group-workload", Group: "other-group", Status: runtime.WorkloadStatusRunning}, + }, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + { + name: "error from ListWorkloads propagated", + groupName: "test-group", + expectedNames: nil, + expectError: true, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return(nil, assert.AnError) + }, + }, + { + name: "no workloads", + groupName: "test-group", + mockWorkloads: []core.Workload{}, + expectedNames: []string{}, + expectError: false, + setupStatusMgr: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), true, gomock.Any()).Return([]core.Workload{}, nil) + + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupStatusMgr(mockStatusMgr) + + manager := &defaultManager{ + runtime: nil, // Not needed for this test + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.ListWorkloadsInGroup(ctx, tt.groupName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to list workloads") + return + } + + require.NoError(t, err) + assert.ElementsMatch(t, tt.expectedNames, result) + }) + } +} + func TestNewManagerFromRuntime(t *testing.T) { t.Parallel() @@ -26,12 +195,12 @@ func TestNewManagerFromRuntime(t *testing.T) { require.NoError(t, err) require.NotNil(t, manager) - // Verify it's a cliManager with the runtime set - cliMgr, ok := manager.(*cliManager) + // Verify it's a defaultManager with the runtime set + defaultMgr, ok := manager.(*defaultManager) require.True(t, ok) - assert.Equal(t, mockRuntime, cliMgr.runtime) - assert.NotNil(t, cliMgr.statuses) - assert.NotNil(t, cliMgr.configProvider) + assert.Equal(t, mockRuntime, defaultMgr.runtime) + assert.NotNil(t, defaultMgr.statuses) + assert.NotNil(t, defaultMgr.configProvider) } func TestNewManagerFromRuntimeWithProvider(t *testing.T) { @@ -48,9 +217,1444 @@ func TestNewManagerFromRuntimeWithProvider(t *testing.T) { require.NoError(t, err) require.NotNil(t, manager) - cliMgr, ok := manager.(*cliManager) + defaultMgr, ok := manager.(*defaultManager) require.True(t, ok) - assert.Equal(t, mockRuntime, cliMgr.runtime) - assert.Equal(t, mockConfigProvider, cliMgr.configProvider) - assert.NotNil(t, cliMgr.statuses) + assert.Equal(t, mockRuntime, defaultMgr.runtime) + assert.Equal(t, mockConfigProvider, defaultMgr.configProvider) + assert.NotNil(t, defaultMgr.statuses) +} + +func TestDefaultManager_DoesWorkloadExist(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + setupMocks func(*statusMocks.MockStatusManager) + expected bool + expectError bool + }{ + { + name: "workload exists and running", + workloadName: "test-workload", + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + }, + expected: true, + expectError: false, + }, + { + name: "workload exists but in error state", + workloadName: "error-workload", + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "error-workload").Return(core.Workload{ + Name: "error-workload", + Status: runtime.WorkloadStatusError, + }, nil) + }, + expected: false, + expectError: false, + }, + { + name: "workload not found", + workloadName: "missing-workload", + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "missing-workload").Return(core.Workload{}, runtime.ErrWorkloadNotFound) + }, + expected: false, + expectError: false, + }, + { + name: "error getting workload", + workloadName: "problematic-workload", + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "problematic-workload").Return(core.Workload{}, errors.New("database error")) + }, + expected: false, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockStatusMgr) + + manager := &defaultManager{ + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.DoesWorkloadExist(ctx, tt.workloadName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to check if workload exists") + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestDefaultManager_GetWorkload(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + expectedWorkload := core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + } + + mockStatusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(expectedWorkload, nil) + + manager := &defaultManager{ + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.GetWorkload(ctx, "test-workload") + + require.NoError(t, err) + assert.Equal(t, expectedWorkload, result) +} + +func TestDefaultManager_GetLogs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + follow bool + setupMocks func(*runtimeMocks.MockRuntime) + expectedLogs string + expectError bool + errorMsg string + }{ + { + name: "successful log retrieval", + workloadName: "test-workload", + follow: false, + setupMocks: func(rt *runtimeMocks.MockRuntime) { + rt.EXPECT().GetWorkloadLogs(gomock.Any(), "test-workload", false).Return("test log content", nil) + }, + expectedLogs: "test log content", + expectError: false, + }, + { + name: "workload not found", + workloadName: "missing-workload", + follow: false, + setupMocks: func(rt *runtimeMocks.MockRuntime) { + rt.EXPECT().GetWorkloadLogs(gomock.Any(), "missing-workload", false).Return("", runtime.ErrWorkloadNotFound) + }, + expectedLogs: "", + expectError: true, + errorMsg: "workload not found", + }, + { + name: "runtime error", + workloadName: "error-workload", + follow: true, + setupMocks: func(rt *runtimeMocks.MockRuntime) { + rt.EXPECT().GetWorkloadLogs(gomock.Any(), "error-workload", true).Return("", errors.New("runtime failure")) + }, + expectedLogs: "", + expectError: true, + errorMsg: "failed to get container logs", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + tt.setupMocks(mockRuntime) + + manager := &defaultManager{ + runtime: mockRuntime, + } + + ctx := context.Background() + logs, err := manager.GetLogs(ctx, tt.workloadName, tt.follow) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedLogs, logs) + } + }) + } +} + +func TestDefaultManager_StopWorkloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadNames []string + expectError bool + errorMsg string + }{ + { + name: "invalid workload name with path traversal", + workloadNames: []string{"../etc/passwd"}, + expectError: true, + errorMsg: "path traversal", + }, + { + name: "invalid workload name with slash", + workloadNames: []string{"workload/name"}, + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "empty workload name list", + workloadNames: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + manager := &defaultManager{} + + ctx := context.Background() + group, err := manager.StopWorkloads(ctx, tt.workloadNames) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, group) + } else { + require.NoError(t, err) + assert.NotNil(t, group) + assert.IsType(t, &errgroup.Group{}, group) + } + }) + } +} + +func TestDefaultManager_DeleteWorkloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadNames []string + expectError bool + errorMsg string + }{ + { + name: "invalid workload name", + workloadNames: []string{"../../../etc/passwd"}, + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "mixed valid and invalid names", + workloadNames: []string{"valid-name", "invalid../name"}, + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "empty workload name list", + workloadNames: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + manager := &defaultManager{} + + ctx := context.Background() + group, err := manager.DeleteWorkloads(ctx, tt.workloadNames) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, group) + } else { + require.NoError(t, err) + assert.NotNil(t, group) + assert.IsType(t, &errgroup.Group{}, group) + } + }) + } +} + +func TestDefaultManager_RestartWorkloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadNames []string + foreground bool + expectError bool + errorMsg string + }{ + { + name: "invalid workload name", + workloadNames: []string{"invalid/name"}, + foreground: false, + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "empty workload name list", + workloadNames: []string{}, + foreground: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + manager := &defaultManager{} + + ctx := context.Background() + group, err := manager.RestartWorkloads(ctx, tt.workloadNames, tt.foreground) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, group) + } else { + require.NoError(t, err) + assert.NotNil(t, group) + assert.IsType(t, &errgroup.Group{}, group) + } + }) + } +} + +func TestDefaultManager_restartRemoteWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + runConfig *runner.RunConfig + foreground bool + setupMocks func(*statusMocks.MockStatusManager) + expectError bool + errorMsg string + }{ + { + name: "remote workload already running with healthy supervisor", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-base", + RemoteURL: "http://example.com", + }, + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + // Check if supervisor is alive - return valid PID (supervisor is healthy) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(12345, nil) + }, + // With healthy supervisor, restart should return early (no-op) + expectError: false, + }, + { + name: "remote workload already running with dead supervisor", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-base", + RemoteURL: "http://example.com", + }, + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + // Check if supervisor is alive - return error (supervisor is dead) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(0, errors.New("no PID found")) + // With dead supervisor, restart proceeds with cleanup and restart + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "remote-workload", runtime.WorkloadStatusStopping, "").Return(nil) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "remote-base").Return(0, errors.New("no PID found")) + // Allow any subsequent status updates + sm.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + sm.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + }, + // Restart now proceeds to load state which fails in tests (can't mock runner.LoadState easily) + expectError: true, + errorMsg: "failed to load state", + }, + { + name: "status manager error", + workloadName: "remote-workload", + runConfig: &runner.RunConfig{ + BaseName: "remote-base", + RemoteURL: "http://example.com", + }, + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().GetWorkload(gomock.Any(), "remote-workload").Return(core.Workload{}, errors.New("status manager error")) + }, + expectError: true, + errorMsg: "status manager error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(statusMgr) + + manager := &defaultManager{ + statuses: statusMgr, + } + + err := manager.restartRemoteWorkload(context.Background(), tt.workloadName, tt.runConfig, tt.foreground) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestDefaultManager_restartContainerWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + foreground bool + setupMocks func(*statusMocks.MockStatusManager, *runtimeMocks.MockRuntime) + expectError bool + errorMsg string + }{ + { + name: "container workload already running with healthy supervisor", + workloadName: "container-workload", + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { + // Mock container info + rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ + Name: "container-workload", + State: runtime.WorkloadStatusRunning, + Labels: map[string]string{ + "toolhive.base-name": "container-workload", + }, + }, nil) + sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{ + Name: "container-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + // Check if supervisor is alive - return valid PID (supervisor is healthy) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(12345, nil) + }, + // With healthy supervisor, restart should return early (no-op) + expectError: false, + }, + { + name: "container workload already running with dead supervisor", + workloadName: "container-workload", + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { + // Mock container info + rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ + Name: "container-workload", + State: runtime.WorkloadStatusRunning, + Labels: map[string]string{ + "toolhive.base-name": "container-workload", + }, + }, nil) + sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{ + Name: "container-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + // Check if supervisor is alive - return error (supervisor is dead) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(0, errors.New("no PID found")) + // With dead supervisor, restart proceeds with cleanup and restart + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "container-workload", runtime.WorkloadStatusStopping, "").Return(nil) + sm.EXPECT().GetWorkloadPID(gomock.Any(), "container-workload").Return(0, errors.New("no PID found")) + rm.EXPECT().StopWorkload(gomock.Any(), "container-workload").Return(nil) + // Allow any subsequent status updates + sm.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + sm.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + }, + // Restart now proceeds to load state which fails in tests (can't mock runner.LoadState easily) + expectError: true, + errorMsg: "failed to load state", + }, + { + name: "status manager error", + workloadName: "container-workload", + foreground: false, + setupMocks: func(sm *statusMocks.MockStatusManager, rm *runtimeMocks.MockRuntime) { + // Mock container info + rm.EXPECT().GetWorkloadInfo(gomock.Any(), "container-workload").Return(runtime.ContainerInfo{ + Name: "container-workload", + State: "running", + Labels: map[string]string{ + "toolhive.base-name": "container-workload", + }, + }, nil) + sm.EXPECT().GetWorkload(gomock.Any(), "container-workload").Return(core.Workload{}, errors.New("status manager error")) + }, + expectError: true, + errorMsg: "status manager error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) + + tt.setupMocks(statusMgr, runtimeMgr) + + manager := &defaultManager{ + statuses: statusMgr, + runtime: runtimeMgr, + } + + err := manager.restartContainerWorkload(context.Background(), tt.workloadName, tt.foreground) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestDefaultManager_restartLogicConsistency tests restart behavior with healthy vs dead supervisor +func TestDefaultManager_restartLogicConsistency(t *testing.T) { + t.Parallel() + + t.Run("remote_workload_healthy_supervisor_no_restart", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + + statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + + // Check if supervisor is alive - return valid PID (healthy) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(12345, nil) + + manager := &defaultManager{ + statuses: statusMgr, + } + + runConfig := &runner.RunConfig{ + BaseName: "test-base", + RemoteURL: "http://example.com", + } + + err := manager.restartRemoteWorkload(context.Background(), "test-workload", runConfig, false) + + // With healthy supervisor, restart should return successfully without doing anything + require.NoError(t, err) + }) + + t.Run("remote_workload_dead_supervisor_calls_stop", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + + statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + + // Check if supervisor is alive - return error (dead supervisor) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(0, errors.New("no PID found")) + + // When supervisor is dead, expect stop logic to be called + statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(0, errors.New("no PID found")) + + // Allow any subsequent status updates - we don't care about the exact sequence + statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + manager := &defaultManager{ + statuses: statusMgr, + } + + runConfig := &runner.RunConfig{ + BaseName: "test-base", + RemoteURL: "http://example.com", + } + + _ = manager.restartRemoteWorkload(context.Background(), "test-workload", runConfig, false) + + // The important part is that the stop methods were called (verified by mock expectations) + // We don't care if the restart ultimately succeeds or fails + }) + + t.Run("container_workload_healthy_supervisor_no_restart", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) + + containerInfo := runtime.ContainerInfo{ + Name: "test-workload", + State: runtime.WorkloadStatusRunning, + Labels: map[string]string{ + "toolhive.base-name": "test-workload", + }, + } + runtimeMgr.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(containerInfo, nil) + + statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + + // Check if supervisor is alive - return valid PID (healthy) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(12345, nil) + + manager := &defaultManager{ + statuses: statusMgr, + runtime: runtimeMgr, + } + + err := manager.restartContainerWorkload(context.Background(), "test-workload", false) + + // With healthy supervisor, restart should return successfully without doing anything + require.NoError(t, err) + }) + + t.Run("container_workload_dead_supervisor_calls_stop", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + statusMgr := statusMocks.NewMockStatusManager(ctrl) + runtimeMgr := runtimeMocks.NewMockRuntime(ctrl) + + containerInfo := runtime.ContainerInfo{ + Name: "test-workload", + State: runtime.WorkloadStatusRunning, + Labels: map[string]string{ + "toolhive.base-name": "test-workload", + }, + } + runtimeMgr.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(containerInfo, nil) + + statusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(core.Workload{ + Name: "test-workload", + Status: runtime.WorkloadStatusRunning, + }, nil) + + // Check if supervisor is alive - return error (dead supervisor) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(0, errors.New("no PID found")) + + // When supervisor is dead, expect stop logic to be called + statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) + statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(0, errors.New("no PID found")) + runtimeMgr.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) + + // Allow any subsequent status updates (starting, error, etc.) - we don't care about the exact sequence + statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + + manager := &defaultManager{ + statuses: statusMgr, + runtime: runtimeMgr, + } + + _ = manager.restartContainerWorkload(context.Background(), "test-workload", false) + + // The important part is that the stop methods were called (verified by mock expectations) + // We don't care if the restart ultimately succeeds or fails + }) +} + +func TestDefaultManager_RunWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + runConfig *runner.RunConfig + setupMocks func(*statusMocks.MockStatusManager) + expectError bool + errorMsg string + }{ + { + name: "successful run - status creation", + runConfig: &runner.RunConfig{ + BaseName: "test-workload", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + // Expect starting status first, then error status when the runner fails + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStarting, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, gomock.Any()).Return(nil) + }, + expectError: true, // The runner will fail without proper setup + }, + { + name: "status creation failure", + runConfig: &runner.RunConfig{ + BaseName: "failing-workload", + }, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", runtime.WorkloadStatusStarting, "").Return(errors.New("status creation failed")) + }, + expectError: true, + errorMsg: "failed to create workload status", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockStatusMgr) + + manager := &defaultManager{ + statuses: mockStatusMgr, + } + + ctx := context.Background() + err := manager.RunWorkload(ctx, tt.runConfig) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestDefaultManager_validateSecretParameters(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + runConfig *runner.RunConfig + setupMocks func(*configMocks.MockProvider) + expectError bool + errorMsg string + }{ + { + name: "no secrets - should pass", + runConfig: &runner.RunConfig{ + Secrets: []string{}, + }, + setupMocks: func(*configMocks.MockProvider) {}, // No expectations + expectError: false, + }, + { + name: "config error", + runConfig: &runner.RunConfig{ + Secrets: []string{"secret1"}, + }, + setupMocks: func(cp *configMocks.MockProvider) { + mockConfig := &config.Config{} + cp.EXPECT().GetConfig().Return(mockConfig) + }, + expectError: true, + errorMsg: "error determining secrets provider type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConfigProvider := configMocks.NewMockProvider(ctrl) + tt.setupMocks(mockConfigProvider) + + manager := &defaultManager{ + configProvider: mockConfigProvider, + } + + ctx := context.Background() + err := manager.validateSecretParameters(ctx, tt.runConfig) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestDefaultManager_getWorkloadContainer(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) + expected *runtime.ContainerInfo + expectError bool + errorMsg string + }{ + { + name: "successful retrieval", + workloadName: "test-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { + expectedContainer := runtime.ContainerInfo{ + Name: "test-workload", + State: runtime.WorkloadStatusRunning, + } + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload").Return(expectedContainer, nil) + }, + expected: &runtime.ContainerInfo{ + Name: "test-workload", + State: runtime.WorkloadStatusRunning, + }, + expectError: false, + }, + { + name: "workload not found", + workloadName: "missing-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "missing-workload").Return(runtime.ContainerInfo{}, runtime.ErrWorkloadNotFound) + }, + expected: nil, + expectError: false, // getWorkloadContainer returns nil for not found, not error + }, + { + name: "runtime error", + workloadName: "error-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "error-workload").Return(runtime.ContainerInfo{}, errors.New("runtime failure")) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "error-workload", runtime.WorkloadStatusError, "runtime failure").Return(nil) + }, + expected: nil, + expectError: true, + errorMsg: "failed to find workload", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockRuntime, mockStatusMgr) + + manager := &defaultManager{ + runtime: mockRuntime, + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.getWorkloadContainer(ctx, tt.workloadName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + if tt.expected == nil { + assert.Nil(t, result) + } else { + assert.Equal(t, tt.expected, result) + } + } + }) + } +} + +func TestDefaultManager_removeContainer(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) + expectError bool + errorMsg string + }{ + { + name: "successful removal", + workloadName: "test-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, _ *statusMocks.MockStatusManager) { + rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(nil) + }, + expectError: false, + }, + { + name: "removal failure", + workloadName: "failing-workload", + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + rt.EXPECT().RemoveWorkload(gomock.Any(), "failing-workload").Return(errors.New("removal failed")) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", runtime.WorkloadStatusError, "removal failed").Return(nil) + }, + expectError: true, + errorMsg: "failed to remove container", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockRuntime, mockStatusMgr) + + manager := &defaultManager{ + runtime: mockRuntime, + statuses: mockStatusMgr, + } + + ctx := context.Background() + err := manager.removeContainer(ctx, tt.workloadName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestDefaultManager_needSecretsPassword(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + secretOptions []string + setupMocks func(*configMocks.MockProvider) + expected bool + }{ + { + name: "no secrets", + secretOptions: []string{}, + setupMocks: func(*configMocks.MockProvider) {}, // No expectations + expected: false, + }, + { + name: "has secrets but config access fails", + secretOptions: []string{"secret1"}, + setupMocks: func(cp *configMocks.MockProvider) { + mockConfig := &config.Config{} + cp.EXPECT().GetConfig().Return(mockConfig) + }, + expected: false, // Returns false when provider type detection fails + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConfigProvider := configMocks.NewMockProvider(ctrl) + tt.setupMocks(mockConfigProvider) + + manager := &defaultManager{ + configProvider: mockConfigProvider, + } + + result := manager.needSecretsPassword(tt.secretOptions) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestDefaultManager_RunWorkloadDetached(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + runConfig *runner.RunConfig + setupMocks func(*statusMocks.MockStatusManager, *configMocks.MockProvider) + expectError bool + errorMsg string + }{ + { + name: "validation failure should not reach PID management", + runConfig: &runner.RunConfig{ + BaseName: "test-workload", + Secrets: []string{"invalid-secret"}, + }, + setupMocks: func(_ *statusMocks.MockStatusManager, cp *configMocks.MockProvider) { + // Mock config provider to cause validation failure + mockConfig := &config.Config{} + cp.EXPECT().GetConfig().Return(mockConfig) + // No SetWorkloadPID expectation since validation should fail first + }, + expectError: true, + errorMsg: "failed to validate workload parameters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + mockConfigProvider := configMocks.NewMockProvider(ctrl) + tt.setupMocks(mockStatusMgr, mockConfigProvider) + + manager := &defaultManager{ + statuses: mockStatusMgr, + configProvider: mockConfigProvider, + } + + ctx := context.Background() + err := manager.RunWorkloadDetached(ctx, tt.runConfig) + + if tt.expectError { + require.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// TestDefaultManager_RunWorkloadDetached_PIDManagement tests that PID management +// happens in the later stages of RunWorkloadDetached when the process actually starts. +// This is tested indirectly by verifying the behavior exists in the code flow. +func TestDefaultManager_RunWorkloadDetached_PIDManagement(t *testing.T) { + t.Parallel() + + // This test documents the expected behavior: + // 1. RunWorkloadDetached calls SetWorkloadPID after starting the detached process + // 2. The PID management happens after validation and process creation + // 3. SetWorkloadPID failures are logged as warnings but don't fail the operation + + // Since RunWorkloadDetached involves spawning actual processes and complex setup, + // we verify the PID management integration exists by checking the method signature + // and code structure rather than running the full integration. + + manager := &defaultManager{} + assert.NotNil(t, manager, "defaultManager should be instantiable") + + // Verify the method exists with the correct signature + var runWorkloadDetachedFunc interface{} = manager.RunWorkloadDetached + assert.NotNil(t, runWorkloadDetachedFunc, "RunWorkloadDetached method should exist") +} + +func TestAsyncOperationTimeout(t *testing.T) { + t.Parallel() + + // Test that the timeout constant is properly defined + assert.Equal(t, 5*time.Minute, AsyncOperationTimeout) +} + +func TestErrWorkloadNotRunning(t *testing.T) { + t.Parallel() + + // Test that the error is properly defined + assert.Error(t, ErrWorkloadNotRunning) + assert.Contains(t, ErrWorkloadNotRunning.Error(), "workload not running") +} + +func TestDefaultManager_ListWorkloads(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + listAll bool + labelFilters []string + setupMocks func(*statusMocks.MockStatusManager) + expected []core.Workload + expectError bool + errorMsg string + }{ + { + name: "successful listing without filters", + listAll: true, + labelFilters: []string{}, + setupMocks: func(sm *statusMocks.MockStatusManager) { + workloads := []core.Workload{ + {Name: "workload1", Status: runtime.WorkloadStatusRunning}, + {Name: "workload2", Status: runtime.WorkloadStatusStopped}, + } + sm.EXPECT().ListWorkloads(gomock.Any(), true, []string{}).Return(workloads, nil) + sm.EXPECT().GetWorkload(gomock.Any(), gomock.Any()).Return(core.Workload{ + Name: "remote-workload", + Status: runtime.WorkloadStatusRunning, + }, nil).AnyTimes() + }, + expected: []core.Workload{ + {Name: "workload1", Status: runtime.WorkloadStatusRunning}, + {Name: "workload2", Status: runtime.WorkloadStatusStopped}, + }, + expectError: false, + }, + { + name: "error from status manager", + listAll: false, + labelFilters: []string{"env=prod"}, + setupMocks: func(sm *statusMocks.MockStatusManager) { + sm.EXPECT().ListWorkloads(gomock.Any(), false, []string{"env=prod"}).Return(nil, errors.New("database error")) + }, + expected: nil, + expectError: true, + errorMsg: "database error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + tt.setupMocks(mockStatusMgr) + + manager := &defaultManager{ + statuses: mockStatusMgr, + } + + ctx := context.Background() + result, err := manager.ListWorkloads(ctx, tt.listAll, tt.labelFilters...) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + // We expect this to succeed but might include remote workloads + // Since getRemoteWorkloadsFromState will likely fail in unit tests, + // we mainly verify the container workloads are returned + require.NoError(t, err) + assert.GreaterOrEqual(t, len(result), len(tt.expected)) + // Verify at least our expected container workloads are present + for _, expectedWorkload := range tt.expected { + found := false + for _, actualWorkload := range result { + if actualWorkload.Name == expectedWorkload.Name { + found = true + break + } + } + assert.True(t, found, fmt.Sprintf("Expected workload %s not found in result", expectedWorkload.Name)) + } + } + }) + } +} + +func TestDefaultManager_UpdateWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + expectError bool + errorMsg string + setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) + }{ + { + name: "invalid workload name with slash", + workloadName: "invalid/name", + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "invalid workload name with backslash", + workloadName: "invalid\\name", + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "invalid workload name with path traversal", + workloadName: "../invalid", + expectError: true, + errorMsg: "invalid workload name", + }, + { + name: "valid workload name returns errgroup immediately", + workloadName: "valid-workload", + expectError: false, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // Mock calls that will happen in the background goroutine + // We don't care about the success/failure, just that it doesn't panic + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "valid-workload"). + Return(runtime.ContainerInfo{}, errors.New("not found")).AnyTimes() + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "valid-workload", gomock.Any(), gomock.Any()). + Return(nil).AnyTimes() + }, + }, + { + name: "UpdateWorkload returns errgroup even if async operation will fail", + workloadName: "failing-workload", + expectError: false, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // The async operation will fail, but UpdateWorkload itself should succeed + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "failing-workload"). + Return(runtime.ContainerInfo{}, errors.New("container lookup failed")).AnyTimes() + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "failing-workload", gomock.Any(), gomock.Any()). + Return(nil).AnyTimes() + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + mockStatusManager := statusMocks.NewMockStatusManager(ctrl) + mockConfigProvider := configMocks.NewMockProvider(ctrl) + + if tt.setupMocks != nil { + tt.setupMocks(mockRuntime, mockStatusManager) + } + + manager := &defaultManager{ + runtime: mockRuntime, + statuses: mockStatusManager, + configProvider: mockConfigProvider, + } + + // Create a dummy RunConfig for testing + runConfig := &runner.RunConfig{ + ContainerName: tt.workloadName, + BaseName: tt.workloadName, + } + + ctx := context.Background() + group, err := manager.UpdateWorkload(ctx, tt.workloadName, runConfig) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + assert.Nil(t, group) + } else { + assert.NoError(t, err) + assert.NotNil(t, group) + // For valid cases, we get an errgroup but don't wait for completion + // The async operations inside are tested separately + } + }) + } +} + +func TestDefaultManager_updateSingleWorkload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + workloadName string + runConfig *runner.RunConfig + setupMocks func(*runtimeMocks.MockRuntime, *statusMocks.MockStatusManager) + expectError bool + errorMsg string + }{ + { + name: "stop operation fails", + workloadName: "test-workload", + runConfig: &runner.RunConfig{ + ContainerName: "test-workload", + BaseName: "test-workload", + Group: "default", + }, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // Mock the stop operation - return error for GetWorkloadInfo + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{}, errors.New("container lookup failed")).AnyTimes() + // Still expect status updates to be attempted + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil).AnyTimes() + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, "").Return(nil).AnyTimes() + }, + expectError: true, + errorMsg: "failed to stop workload", + }, + { + name: "successful stop and delete operations complete correctly", + workloadName: "test-workload", + runConfig: &runner.RunConfig{ + ContainerName: "test-workload", + BaseName: "test-workload", + Group: "default", + }, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // Mock stop operation - workload exists and can be stopped + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{ + Name: "test-workload", + State: "running", + Labels: map[string]string{"toolhive-basename": "test-workload"}, + }, nil) + // Mock GetWorkloadPID call from stopProcess + sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil) + rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) + sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil) + + // Mock delete operation - workload exists and can be deleted + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{Name: "test-workload"}, nil) + rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(nil) + + // Mock status updates for stop and delete phases + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopped, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusRemoving, "").Return(nil) + sm.EXPECT().DeleteWorkloadStatus(gomock.Any(), "test-workload").Return(nil) + + // Mock RunWorkloadDetached calls - expect the ones that will be called + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStarting, "").Return(nil) + sm.EXPECT().SetWorkloadPID(gomock.Any(), "test-workload", gomock.Any()).Return(nil) + }, + expectError: false, // Test passes - update process completes successfully + }, + { + name: "delete operation fails after successful stop", + workloadName: "test-workload", + runConfig: &runner.RunConfig{ + ContainerName: "test-workload", + BaseName: "test-workload", + Group: "default", + }, + setupMocks: func(rt *runtimeMocks.MockRuntime, sm *statusMocks.MockStatusManager) { + // Mock successful stop + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{ + Name: "test-workload", + State: "running", + Labels: map[string]string{"toolhive-basename": "test-workload"}, + }, nil) + // Mock GetWorkloadPID call from stopProcess + sm.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(1234, nil) + rt.EXPECT().StopWorkload(gomock.Any(), "test-workload").Return(nil) + sm.EXPECT().ResetWorkloadPID(gomock.Any(), "test-workload").Return(nil) + + // Mock failed delete + rt.EXPECT().GetWorkloadInfo(gomock.Any(), "test-workload"). + Return(runtime.ContainerInfo{Name: "test-workload"}, nil) + rt.EXPECT().RemoveWorkload(gomock.Any(), "test-workload").Return(errors.New("delete failed")) + + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopping, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStopped, "").Return(nil) + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusRemoving, "").Return(nil) + // RemoveWorkload fails, so error status is set + sm.EXPECT().SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, "delete failed").Return(nil) + }, + expectError: true, + errorMsg: "failed to delete workload", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + mockStatusManager := statusMocks.NewMockStatusManager(ctrl) + mockConfigProvider := configMocks.NewMockProvider(ctrl) + + if tt.setupMocks != nil { + tt.setupMocks(mockRuntime, mockStatusManager) + } + + manager := &defaultManager{ + runtime: mockRuntime, + statuses: mockStatusManager, + configProvider: mockConfigProvider, + } + + err := manager.updateSingleWorkload(tt.workloadName, tt.runConfig) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } } From e14a9ea7d611bf0d957b10256c2196ea9b79df25 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Mon, 17 Nov 2025 16:16:13 +0000 Subject: [PATCH 16/16] use default workload manager as the cli implmentation for workload discoverer --- pkg/vmcp/aggregator/discoverer.go | 36 ++--- pkg/vmcp/aggregator/discoverer_test.go | 99 +++++++----- pkg/vmcp/aggregator/testhelpers_test.go | 49 ------ pkg/vmcp/workloads/cli.go | 107 ------------- pkg/vmcp/workloads/discoverer.go | 4 +- pkg/vmcp/workloads/k8s.go | 4 +- pkg/vmcp/workloads/mocks/mock_discoverer.go | 12 +- pkg/workloads/manager.go | 163 +++++++++++++++----- pkg/workloads/manager_test.go | 50 +++--- 9 files changed, 233 insertions(+), 291 deletions(-) delete mode 100644 pkg/vmcp/workloads/cli.go diff --git a/pkg/vmcp/aggregator/discoverer.go b/pkg/vmcp/aggregator/discoverer.go index aace5ef00..4286eae0b 100644 --- a/pkg/vmcp/aggregator/discoverer.go +++ b/pkg/vmcp/aggregator/discoverer.go @@ -12,14 +12,13 @@ import ( "context" "fmt" - ct "github.com/stacklok/toolhive/pkg/container" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/vmcp/workloads" - "github.com/stacklok/toolhive/pkg/workloads/statuses" + workloadsmgr "github.com/stacklok/toolhive/pkg/workloads" ) // backendDiscoverer discovers backend MCP servers using a WorkloadDiscoverer. @@ -73,20 +72,12 @@ func NewBackendDiscoverer( } workloadDiscoverer = k8sDiscoverer } else { - // Create runtime and status manager for CLI workloads - runtime, err := ct.NewFactory().Create(ctx) + manager, err := workloadsmgr.NewManager(ctx) if err != nil { - return nil, fmt.Errorf("failed to create runtime: %w", err) + return nil, fmt.Errorf("failed to create workload manager: %w", err) } - - statusManager, err := statuses.NewStatusManager(runtime) - if err != nil { - return nil, fmt.Errorf("failed to create status manager: %w", err) - } - - workloadDiscoverer = workloads.NewCLIDiscoverer(statusManager) + workloadDiscoverer = manager } - return NewUnifiedBackendDiscoverer(workloadDiscoverer, groupsManager, authConfig), nil } @@ -131,7 +122,7 @@ func (d *backendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vm // Query each workload and convert to backend var backends []vmcp.Backend for _, name := range workloadNames { - backend, err := d.workloadsManager.GetWorkload(ctx, name) + backend, err := d.workloadsManager.GetWorkloadAsVMCPBackend(ctx, name) if err != nil { logger.Warnf("Failed to get workload %s: %v, skipping", name, err) continue @@ -143,14 +134,21 @@ func (d *backendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vm } // Apply authentication configuration if provided - authStrategy, authMetadata := d.authConfig.ResolveForBackend(name) - backend.AuthStrategy = authStrategy - backend.AuthMetadata = authMetadata - if authStrategy != "" { - logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) + var authStrategy string + var authMetadata map[string]any + if d.authConfig != nil { + authStrategy, authMetadata = d.authConfig.ResolveForBackend(name) + backend.AuthStrategy = authStrategy + backend.AuthMetadata = authMetadata + if authStrategy != "" { + logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) + } } // Set group metadata (override user labels to prevent conflicts) + if backend.Metadata == nil { + backend.Metadata = make(map[string]string) + } backend.Metadata["group"] = groupRef logger.Debugf("Discovered backend %s: %s (%s) with health status %s", diff --git a/pkg/vmcp/aggregator/discoverer_test.go b/pkg/vmcp/aggregator/discoverer_test.go index e43e464eb..70311fa2a 100644 --- a/pkg/vmcp/aggregator/discoverer_test.go +++ b/pkg/vmcp/aggregator/discoverer_test.go @@ -9,15 +9,10 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/core" "github.com/stacklok/toolhive/pkg/groups/mocks" - "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/pkg/vmcp/workloads" discoverermocks "github.com/stacklok/toolhive/pkg/vmcp/workloads/mocks" - statusmocks "github.com/stacklok/toolhive/pkg/workloads/statuses/mocks" ) const testGroupName = "test-group" @@ -60,8 +55,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). Return([]string{"workload1", "workload2"}, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(backend1, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(backend2, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(backend1, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload2").Return(backend2, nil) discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) @@ -105,8 +100,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). Return([]string{"healthy-workload", "unhealthy-workload"}, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "healthy-workload").Return(healthyBackend, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "unhealthy-workload").Return(unhealthyBackend, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "healthy-workload").Return(healthyBackend, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "unhealthy-workload").Return(unhealthyBackend, nil) discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) @@ -140,9 +135,9 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). Return([]string{"workload1", "workload2"}, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(backendWithURL, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(backendWithURL, nil) // workload2 has no URL, so GetWorkload returns nil - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(nil, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload2").Return(nil, nil) discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) @@ -163,8 +158,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). Return([]string{"workload1", "workload2"}, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(nil, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(nil, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(nil, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload2").Return(nil, nil) discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) @@ -247,8 +242,8 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). Return([]string{"good-workload", "failing-workload"}, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "good-workload").Return(goodBackend, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "good-workload").Return(goodBackend, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "failing-workload"). Return(nil, errors.New("workload query failed")) discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, nil) @@ -310,7 +305,7 @@ func TestBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) mockWorkloadDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). Return([]string{"workload1"}, nil) - mockWorkloadDiscoverer.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(backend, nil) + mockWorkloadDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(backend, nil) discoverer := NewUnifiedBackendDiscoverer(mockWorkloadDiscoverer, mockGroups, authConfig) backends, err := discoverer.Discover(context.Background(), testGroupName) @@ -332,20 +327,23 @@ func TestCLIWorkloadDiscoverer(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) - mockStatusManager := statusmocks.NewMockStatusManager(ctrl) + mockManager := discoverermocks.NewMockDiscoverer(ctrl) mockGroups := mocks.NewMockManager(ctrl) - workload := newTestWorkload("workload1", - withToolType("github"), - withLabels(map[string]string{"env": "prod"})) + backend := &vmcp.Backend{ + ID: "workload1", + Name: "workload1", + BaseURL: "http://localhost:8080/mcp", + HealthStatus: vmcp.BackendHealthy, + Metadata: map[string]string{"tool_type": "github", "env": "prod"}, + } mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockStatusManager.EXPECT().ListWorkloads(gomock.Any(), true, nil). - Return([]core.Workload{workload}, nil) - mockStatusManager.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload, nil) + mockManager.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"workload1"}, nil) + mockManager.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "workload1").Return(backend, nil) - cliDiscoverer := workloads.NewCLIDiscoverer(mockStatusManager) - discoverer := NewUnifiedBackendDiscoverer(cliDiscoverer, mockGroups, nil) + discoverer := NewUnifiedBackendDiscoverer(mockManager, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -362,28 +360,51 @@ func TestCLIWorkloadDiscoverer(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) - mockStatusManager := statusmocks.NewMockStatusManager(ctrl) + mockDiscoverer := discoverermocks.NewMockDiscoverer(ctrl) mockGroups := mocks.NewMockManager(ctrl) - runningWorkload := newTestWorkload("running-workload") - stoppedWorkload := newTestWorkload("stopped-workload", - withStatus(runtime.WorkloadStatusStopped), - withURL("http://localhost:8081/mcp"), - withTransport(types.TransportTypeSSE)) + runningBackend := &vmcp.Backend{ + ID: "running-workload", + Name: "running-workload", + BaseURL: "http://localhost:8080/mcp", + HealthStatus: vmcp.BackendHealthy, + } + stoppedBackend := &vmcp.Backend{ + ID: "stopped-workload", + Name: "stopped-workload", + BaseURL: "http://localhost:8081/mcp", + HealthStatus: vmcp.BackendUnhealthy, + } mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(true, nil) - mockStatusManager.EXPECT().ListWorkloads(gomock.Any(), true, nil). - Return([]core.Workload{runningWorkload, stoppedWorkload}, nil) - mockStatusManager.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) - mockStatusManager.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) + mockDiscoverer.EXPECT().ListWorkloadsInGroup(gomock.Any(), testGroupName). + Return([]string{"running-workload", "stopped-workload"}, nil) + // The discoverer iterates through all workloads in order + mockDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "running-workload").Return(runningBackend, nil) + mockDiscoverer.EXPECT().GetWorkloadAsVMCPBackend(gomock.Any(), "stopped-workload").Return(stoppedBackend, nil) - cliDiscoverer := workloads.NewCLIDiscoverer(mockStatusManager) - discoverer := NewUnifiedBackendDiscoverer(cliDiscoverer, mockGroups, nil) + discoverer := NewUnifiedBackendDiscoverer(mockDiscoverer, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) require.Len(t, backends, 2) - assert.Equal(t, vmcp.BackendHealthy, backends[0].HealthStatus) - assert.Equal(t, vmcp.BackendUnhealthy, backends[1].HealthStatus) + // Sort backends by name to ensure consistent ordering for assertions + if backends[0].Name > backends[1].Name { + backends[0], backends[1] = backends[1], backends[0] + } + // Find the correct backend by name + var running, stopped *vmcp.Backend + for i := range backends { + if backends[i].Name == "running-workload" { + running = &backends[i] + } + if backends[i].Name == "stopped-workload" { + stopped = &backends[i] + } + } + require.NotNil(t, running, "running-workload should be found") + require.NotNil(t, stopped, "stopped-workload should be found") + assert.Equal(t, vmcp.BackendHealthy, running.HealthStatus) + assert.Equal(t, vmcp.BackendUnhealthy, stopped.HealthStatus) }) } diff --git a/pkg/vmcp/aggregator/testhelpers_test.go b/pkg/vmcp/aggregator/testhelpers_test.go index 0b766c508..6aeafd327 100644 --- a/pkg/vmcp/aggregator/testhelpers_test.go +++ b/pkg/vmcp/aggregator/testhelpers_test.go @@ -1,58 +1,9 @@ package aggregator import ( - "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/core" - "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/vmcp" ) -// Test fixture builders to reduce verbosity in tests - -func newTestWorkload(name string, opts ...func(*core.Workload)) core.Workload { - w := core.Workload{ - Name: name, - Status: runtime.WorkloadStatusRunning, - URL: "http://localhost:8080/mcp", - TransportType: types.TransportTypeStreamableHTTP, - Group: testGroupName, - } - for _, opt := range opts { - opt(&w) - } - return w -} - -func withStatus(status runtime.WorkloadStatus) func(*core.Workload) { - return func(w *core.Workload) { - w.Status = status - } -} - -func withURL(url string) func(*core.Workload) { - return func(w *core.Workload) { - w.URL = url - } -} - -func withTransport(transport types.TransportType) func(*core.Workload) { - return func(w *core.Workload) { - w.TransportType = transport - } -} - -func withToolType(toolType string) func(*core.Workload) { - return func(w *core.Workload) { - w.ToolType = toolType - } -} - -func withLabels(labels map[string]string) func(*core.Workload) { - return func(w *core.Workload) { - w.Labels = labels - } -} - func newTestBackend(id string, opts ...func(*vmcp.Backend)) vmcp.Backend { b := vmcp.Backend{ ID: id, diff --git a/pkg/vmcp/workloads/cli.go b/pkg/vmcp/workloads/cli.go deleted file mode 100644 index d5e2b0ca8..000000000 --- a/pkg/vmcp/workloads/cli.go +++ /dev/null @@ -1,107 +0,0 @@ -package workloads - -import ( - "context" - - rt "github.com/stacklok/toolhive/pkg/container/runtime" - "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/workloads/statuses" -) - -// cliDiscoverer is a direct implementation of Discoverer for CLI workloads. -// It uses the status manager directly instead of going through workloads.Manager. -type cliDiscoverer struct { - statusManager statuses.StatusManager -} - -// NewCLIDiscoverer creates a new CLI workload discoverer that directly uses -// the status manager to discover workloads. -func NewCLIDiscoverer(statusManager statuses.StatusManager) Discoverer { - return &cliDiscoverer{ - statusManager: statusManager, - } -} - -// ListWorkloadsInGroup returns all workload names that belong to the specified group. -func (d *cliDiscoverer) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { - // List all workloads (including stopped ones) - workloads, err := d.statusManager.ListWorkloads(ctx, true, nil) - if err != nil { - return nil, err - } - - // Filter workloads that belong to the specified group - var groupWorkloads []string - for _, workload := range workloads { - if workload.Group == groupName { - groupWorkloads = append(groupWorkloads, workload.Name) - } - } - - return groupWorkloads, nil -} - -// GetWorkload retrieves workload details by name and converts it to a vmcp.Backend. -func (d *cliDiscoverer) GetWorkload(ctx context.Context, workloadName string) (*vmcp.Backend, error) { - workload, err := d.statusManager.GetWorkload(ctx, workloadName) - if err != nil { - return nil, err - } - - // Skip workloads without a URL (not accessible) - if workload.URL == "" { - logger.Debugf("Skipping workload %s without URL", workloadName) - return nil, nil - } - - // Map workload status to backend health status - healthStatus := mapCLIWorkloadStatusToHealth(workload.Status) - - // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. - // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. - // ProxyMode tells us which transport the vmcp client should use. - transportType := workload.ProxyMode - if transportType == "" { - // Fallback to TransportType if ProxyMode is not set (for direct transports) - transportType = workload.TransportType.String() - } - - backend := &vmcp.Backend{ - ID: workload.Name, - Name: workload.Name, - BaseURL: workload.URL, - TransportType: transportType, - HealthStatus: healthStatus, - Metadata: make(map[string]string), - } - - // Copy user labels to metadata first - for k, v := range workload.Labels { - backend.Metadata[k] = v - } - - // Set system metadata (these override user labels to prevent conflicts) - backend.Metadata["tool_type"] = workload.ToolType - backend.Metadata["workload_status"] = string(workload.Status) - - return backend, nil -} - -// mapCLIWorkloadStatusToHealth converts a CLI WorkloadStatus to a backend health status. -func mapCLIWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { - switch status { - case rt.WorkloadStatusRunning: - return vmcp.BackendHealthy - case rt.WorkloadStatusUnhealthy: - return vmcp.BackendUnhealthy - case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving: - return vmcp.BackendUnhealthy - case rt.WorkloadStatusStarting, rt.WorkloadStatusUnknown: - return vmcp.BackendUnknown - case rt.WorkloadStatusUnauthenticated: - return vmcp.BackendUnauthenticated - default: - return vmcp.BackendUnknown - } -} diff --git a/pkg/vmcp/workloads/discoverer.go b/pkg/vmcp/workloads/discoverer.go index 0be4e89cc..fb11f947e 100644 --- a/pkg/vmcp/workloads/discoverer.go +++ b/pkg/vmcp/workloads/discoverer.go @@ -17,9 +17,9 @@ type Discoverer interface { // ListWorkloadsInGroup returns all workload names that belong to the specified group ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) - // GetWorkload retrieves workload details by name and converts it to a vmcp.Backend. + // GetWorkloadAsVMCPBackend retrieves workload details by name and converts it to a vmcp.Backend. // The returned Backend should have all fields populated except AuthStrategy and AuthMetadata, // which will be set by the discoverer based on the auth configuration. // Returns nil if the workload exists but is not accessible (e.g., no URL). - GetWorkload(ctx context.Context, workloadName string) (*vmcp.Backend, error) + GetWorkloadAsVMCPBackend(ctx context.Context, workloadName string) (*vmcp.Backend, error) } diff --git a/pkg/vmcp/workloads/k8s.go b/pkg/vmcp/workloads/k8s.go index 834782f9e..90dd2b5c0 100644 --- a/pkg/vmcp/workloads/k8s.go +++ b/pkg/vmcp/workloads/k8s.go @@ -71,8 +71,8 @@ func (d *k8sDiscoverer) ListWorkloadsInGroup(ctx context.Context, groupName stri return groupWorkloads, nil } -// GetWorkload retrieves workload details by name and converts it to a vmcp.Backend. -func (d *k8sDiscoverer) GetWorkload(ctx context.Context, workloadName string) (*vmcp.Backend, error) { +// GetWorkloadAsVMCPBackend retrieves workload details by name and converts it to a vmcp.Backend. +func (d *k8sDiscoverer) GetWorkloadAsVMCPBackend(ctx context.Context, workloadName string) (*vmcp.Backend, error) { mcpServer := &mcpv1alpha1.MCPServer{} key := client.ObjectKey{Name: workloadName, Namespace: d.namespace} if err := d.k8sClient.Get(ctx, key, mcpServer); err != nil { diff --git a/pkg/vmcp/workloads/mocks/mock_discoverer.go b/pkg/vmcp/workloads/mocks/mock_discoverer.go index 7d3719a76..daed3e8d9 100644 --- a/pkg/vmcp/workloads/mocks/mock_discoverer.go +++ b/pkg/vmcp/workloads/mocks/mock_discoverer.go @@ -41,19 +41,19 @@ func (m *MockDiscoverer) EXPECT() *MockDiscovererMockRecorder { return m.recorder } -// GetWorkload mocks base method. -func (m *MockDiscoverer) GetWorkload(ctx context.Context, workloadName string) (*vmcp.Backend, error) { +// GetWorkloadAsVMCPBackend mocks base method. +func (m *MockDiscoverer) GetWorkloadAsVMCPBackend(ctx context.Context, workloadName string) (*vmcp.Backend, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWorkload", ctx, workloadName) + ret := m.ctrl.Call(m, "GetWorkloadAsVMCPBackend", ctx, workloadName) ret0, _ := ret[0].(*vmcp.Backend) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetWorkload indicates an expected call of GetWorkload. -func (mr *MockDiscovererMockRecorder) GetWorkload(ctx, workloadName any) *gomock.Call { +// GetWorkloadAsVMCPBackend indicates an expected call of GetWorkloadAsVMCPBackend. +func (mr *MockDiscovererMockRecorder) GetWorkloadAsVMCPBackend(ctx, workloadName any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkload", reflect.TypeOf((*MockDiscoverer)(nil).GetWorkload), ctx, workloadName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkloadAsVMCPBackend", reflect.TypeOf((*MockDiscoverer)(nil).GetWorkloadAsVMCPBackend), ctx, workloadName) } // ListWorkloadsInGroup mocks base method. diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index 7aac9e5d6..6389a037f 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -27,6 +27,7 @@ import ( "github.com/stacklok/toolhive/pkg/secrets" "github.com/stacklok/toolhive/pkg/state" "github.com/stacklok/toolhive/pkg/transport" + "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/workloads/statuses" "github.com/stacklok/toolhive/pkg/workloads/types" ) @@ -71,7 +72,8 @@ type Manager interface { DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) } -type defaultManager struct { +// DefaultManager is the default implementation of the Manager interface. +type DefaultManager struct { runtime rt.Runtime statuses statuses.StatusManager configProvider config.Provider @@ -86,7 +88,7 @@ const ( ) // NewManager creates a new container manager instance. -func NewManager(ctx context.Context) (Manager, error) { +func NewManager(ctx context.Context) (*DefaultManager, error) { runtime, err := ct.NewFactory().Create(ctx) if err != nil { return nil, err @@ -97,7 +99,7 @@ func NewManager(ctx context.Context) (Manager, error) { return nil, fmt.Errorf("failed to create status manager: %w", err) } - return &defaultManager{ + return &DefaultManager{ runtime: runtime, statuses: statusManager, configProvider: config.NewDefaultProvider(), @@ -116,7 +118,7 @@ func NewManagerWithProvider(ctx context.Context, configProvider config.Provider) return nil, fmt.Errorf("failed to create status manager: %w", err) } - return &defaultManager{ + return &DefaultManager{ runtime: runtime, statuses: statusManager, configProvider: configProvider, @@ -130,7 +132,7 @@ func NewManagerFromRuntime(runtime rt.Runtime) (Manager, error) { return nil, fmt.Errorf("failed to create status manager: %w", err) } - return &defaultManager{ + return &DefaultManager{ runtime: runtime, statuses: statusManager, configProvider: config.NewDefaultProvider(), @@ -145,20 +147,90 @@ func NewManagerFromRuntimeWithProvider(runtime rt.Runtime, configProvider config return nil, fmt.Errorf("failed to create status manager: %w", err) } - return &defaultManager{ + return &DefaultManager{ runtime: runtime, statuses: statusManager, configProvider: configProvider, }, nil } -func (d *defaultManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { +// GetWorkload retrieves details of the named workload including its status. +func (d *DefaultManager) GetWorkload(ctx context.Context, workloadName string) (core.Workload, error) { // For the sake of minimizing changes, delegate to the status manager. // Whether this method should still belong to the workload manager is TBD. return d.statuses.GetWorkload(ctx, workloadName) } -func (d *defaultManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { +// GetWorkloadAsVMCPBackend retrieves a workload and converts it to a vmcp.Backend. +// This method eliminates indirection by directly returning the vmcp.Backend type +// needed by vmcp workload discovery, avoiding the need for callers to convert +// from core.Workload to vmcp.Backend. +// Returns nil if the workload exists but is not accessible (e.g., no URL). +func (d *DefaultManager) GetWorkloadAsVMCPBackend(ctx context.Context, workloadName string) (*vmcp.Backend, error) { + workload, err := d.statuses.GetWorkload(ctx, workloadName) + if err != nil { + return nil, err + } + + // Skip workloads without a URL (not accessible) + if workload.URL == "" { + logger.Debugf("Skipping workload %s without URL", workloadName) + return nil, nil + } + + // Map workload status to backend health status + healthStatus := mapWorkloadStatusToVMCPHealth(workload.Status) + + // Use ProxyMode instead of TransportType to reflect how ToolHive is exposing the workload. + // For stdio MCP servers, ToolHive proxies them via SSE or streamable-http. + // ProxyMode tells us which transport the vmcp client should use. + transportType := workload.ProxyMode + if transportType == "" { + // Fallback to TransportType if ProxyMode is not set (for direct transports) + transportType = workload.TransportType.String() + } + + backend := &vmcp.Backend{ + ID: workload.Name, + Name: workload.Name, + BaseURL: workload.URL, + TransportType: transportType, + HealthStatus: healthStatus, + Metadata: make(map[string]string), + } + + // Copy user labels to metadata first + for k, v := range workload.Labels { + backend.Metadata[k] = v + } + + // Set system metadata (these override user labels to prevent conflicts) + backend.Metadata["tool_type"] = workload.ToolType + backend.Metadata["workload_status"] = string(workload.Status) + + return backend, nil +} + +// mapWorkloadStatusToVMCPHealth converts a WorkloadStatus to a vmcp BackendHealthStatus. +func mapWorkloadStatusToVMCPHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { + switch status { + case rt.WorkloadStatusRunning: + return vmcp.BackendHealthy + case rt.WorkloadStatusUnhealthy: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStopped, rt.WorkloadStatusError, rt.WorkloadStatusStopping, rt.WorkloadStatusRemoving: + return vmcp.BackendUnhealthy + case rt.WorkloadStatusStarting, rt.WorkloadStatusUnknown: + return vmcp.BackendUnknown + case rt.WorkloadStatusUnauthenticated: + return vmcp.BackendUnauthenticated + default: + return vmcp.BackendUnknown + } +} + +// DoesWorkloadExist checks if a workload with the given name exists. +func (d *DefaultManager) DoesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { // check if workload exists by trying to get it workload, err := d.statuses.GetWorkload(ctx, workloadName) if err != nil { @@ -175,7 +247,8 @@ func (d *defaultManager) DoesWorkloadExist(ctx context.Context, workloadName str return true, nil } -func (d *defaultManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { +// ListWorkloads retrieves the states of all workloads. +func (d *DefaultManager) ListWorkloads(ctx context.Context, listAll bool, labelFilters ...string) ([]core.Workload, error) { // For the sake of minimizing changes, delegate to the status manager. // Whether this method should still belong to the workload manager is TBD. containerWorkloads, err := d.statuses.ListWorkloads(ctx, listAll, labelFilters) @@ -196,7 +269,8 @@ func (d *defaultManager) ListWorkloads(ctx context.Context, listAll bool, labelF return containerWorkloads, nil } -func (d *defaultManager) StopWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { +// StopWorkloads stops the specified workloads by name. +func (d *DefaultManager) StopWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { // Validate all workload names to prevent path traversal attacks for _, name := range names { if err := types.ValidateWorkloadName(name); err != nil { @@ -220,7 +294,7 @@ func (d *defaultManager) StopWorkloads(_ context.Context, names []string) (*errg } // stopSingleWorkload stops a single workload (container or remote) -func (d *defaultManager) stopSingleWorkload(name string) error { +func (d *DefaultManager) stopSingleWorkload(name string) error { // Create a child context with a longer timeout childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -243,7 +317,7 @@ func (d *defaultManager) stopSingleWorkload(name string) error { } // stopRemoteWorkload stops a remote workload -func (d *defaultManager) stopRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { +func (d *DefaultManager) stopRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { logger.Infof("Stopping remote workload %s...", name) // Check if the workload is running by checking its status @@ -289,7 +363,7 @@ func (d *defaultManager) stopRemoteWorkload(ctx context.Context, name string, ru } // stopContainerWorkload stops a container-based workload -func (d *defaultManager) stopContainerWorkload(ctx context.Context, name string) error { +func (d *DefaultManager) stopContainerWorkload(ctx context.Context, name string) error { container, err := d.runtime.GetWorkloadInfo(ctx, name) if err != nil { if errors.Is(err, rt.ErrWorkloadNotFound) { @@ -316,7 +390,8 @@ func (d *defaultManager) stopContainerWorkload(ctx context.Context, name string) return d.stopSingleContainerWorkload(ctx, &container) } -func (d *defaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error { +// RunWorkload runs a workload in the foreground. +func (d *DefaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error { // Ensure that the workload has a status entry before starting the process. if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { // Failure to create the initial state is a fatal error. @@ -334,7 +409,8 @@ func (d *defaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunC return err } -func (d *defaultManager) validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { +// validateSecretParameters validates the secret parameters for a workload. +func (d *DefaultManager) validateSecretParameters(ctx context.Context, runConfig *runner.RunConfig) error { // If there are run secrets, validate them hasRegularSecrets := len(runConfig.Secrets) > 0 @@ -361,7 +437,8 @@ func (d *defaultManager) validateSecretParameters(ctx context.Context, runConfig return nil } -func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { +// RunWorkloadDetached runs a workload in the background. +func (d *DefaultManager) RunWorkloadDetached(ctx context.Context, runConfig *runner.RunConfig) error { // before running, validate the parameters for the workload err := d.validateSecretParameters(ctx, runConfig) if err != nil { @@ -455,7 +532,8 @@ func (d *defaultManager) RunWorkloadDetached(ctx context.Context, runConfig *run return nil } -func (d *defaultManager) GetLogs(ctx context.Context, workloadName string, follow bool) (string, error) { +// GetLogs retrieves the logs of a container. +func (d *DefaultManager) GetLogs(ctx context.Context, workloadName string, follow bool) (string, error) { // Get the logs from the runtime logs, err := d.runtime.GetWorkloadLogs(ctx, workloadName, follow) if err != nil { @@ -470,7 +548,7 @@ func (d *defaultManager) GetLogs(ctx context.Context, workloadName string, follo } // GetProxyLogs retrieves proxy logs from the filesystem -func (*defaultManager) GetProxyLogs(_ context.Context, workloadName string) (string, error) { +func (*DefaultManager) GetProxyLogs(_ context.Context, workloadName string) (string, error) { // Get the proxy log file path logFilePath, err := xdg.DataFile(fmt.Sprintf("toolhive/logs/%s.log", workloadName)) if err != nil { @@ -495,7 +573,7 @@ func (*defaultManager) GetProxyLogs(_ context.Context, workloadName string) (str } // deleteWorkload handles deletion of a single workload -func (d *defaultManager) deleteWorkload(name string) error { +func (d *DefaultManager) deleteWorkload(name string) error { // Create a child context with a longer timeout childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -518,7 +596,7 @@ func (d *defaultManager) deleteWorkload(name string) error { } // deleteRemoteWorkload handles deletion of a remote workload -func (d *defaultManager) deleteRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { +func (d *DefaultManager) deleteRemoteWorkload(ctx context.Context, name string, runConfig *runner.RunConfig) error { logger.Infof("Removing remote workload %s...", name) // Set status to removing @@ -545,7 +623,7 @@ func (d *defaultManager) deleteRemoteWorkload(ctx context.Context, name string, } // deleteContainerWorkload handles deletion of a container-based workload (existing logic) -func (d *defaultManager) deleteContainerWorkload(ctx context.Context, name string) error { +func (d *DefaultManager) deleteContainerWorkload(ctx context.Context, name string) error { // Find and validate the container container, err := d.getWorkloadContainer(ctx, name) @@ -590,7 +668,7 @@ func (d *defaultManager) deleteContainerWorkload(ctx context.Context, name strin } // getWorkloadContainer retrieves workload container info with error handling -func (d *defaultManager) getWorkloadContainer(ctx context.Context, name string) (*rt.ContainerInfo, error) { +func (d *DefaultManager) getWorkloadContainer(ctx context.Context, name string) (*rt.ContainerInfo, error) { container, err := d.runtime.GetWorkloadInfo(ctx, name) if err != nil { if errors.Is(err, rt.ErrWorkloadNotFound) { @@ -612,7 +690,7 @@ func (d *defaultManager) getWorkloadContainer(ctx context.Context, name string) // - If the supervisor exits cleanly, it cleans up the PID // - If killed unexpectedly, the PID remains but stopProcess will handle it gracefully // - The main issue we're preventing is accumulating zombie supervisors from repeated restarts -func (d *defaultManager) isSupervisorProcessAlive(ctx context.Context, name string) bool { +func (d *DefaultManager) isSupervisorProcessAlive(ctx context.Context, name string) bool { if name == "" { return false } @@ -629,7 +707,7 @@ func (d *defaultManager) isSupervisorProcessAlive(ctx context.Context, name stri } // stopProcess stops the proxy process associated with the container -func (d *defaultManager) stopProcess(ctx context.Context, name string) { +func (d *DefaultManager) stopProcess(ctx context.Context, name string) { if name == "" { logger.Warnf("Warning: Could not find base container name in labels") return @@ -657,7 +735,7 @@ func (d *defaultManager) stopProcess(ctx context.Context, name string) { } // stopProxyIfNeeded stops the proxy process if the workload has a base name -func (d *defaultManager) stopProxyIfNeeded(ctx context.Context, name, baseName string) { +func (d *DefaultManager) stopProxyIfNeeded(ctx context.Context, name, baseName string) { logger.Infof("Removing proxy process for %s...", name) if baseName != "" { d.stopProcess(ctx, baseName) @@ -665,7 +743,7 @@ func (d *defaultManager) stopProxyIfNeeded(ctx context.Context, name, baseName s } // removeContainer removes the container from the runtime -func (d *defaultManager) removeContainer(ctx context.Context, name string) error { +func (d *DefaultManager) removeContainer(ctx context.Context, name string) error { logger.Infof("Removing container %s...", name) if err := d.runtime.RemoveWorkload(ctx, name); err != nil { if statusErr := d.statuses.SetWorkloadStatus(ctx, name, rt.WorkloadStatusError, err.Error()); statusErr != nil { @@ -677,7 +755,7 @@ func (d *defaultManager) removeContainer(ctx context.Context, name string) error } // cleanupWorkloadResources cleans up all resources associated with a workload -func (d *defaultManager) cleanupWorkloadResources(ctx context.Context, name, baseName string, isAuxiliary bool) { +func (d *DefaultManager) cleanupWorkloadResources(ctx context.Context, name, baseName string, isAuxiliary bool) { if baseName == "" { return } @@ -708,7 +786,8 @@ func (d *defaultManager) cleanupWorkloadResources(ctx context.Context, name, bas logger.Infof("Container %s removed", name) } -func (d *defaultManager) DeleteWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { +// DeleteWorkloads deletes the specified workloads by name. +func (d *DefaultManager) DeleteWorkloads(_ context.Context, names []string) (*errgroup.Group, error) { // Validate all workload names to prevent path traversal attacks for _, name := range names { if err := types.ValidateWorkloadName(name); err != nil { @@ -728,7 +807,7 @@ func (d *defaultManager) DeleteWorkloads(_ context.Context, names []string) (*er } // RestartWorkloads restarts the specified workloads by name. -func (d *defaultManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) { +func (d *DefaultManager) RestartWorkloads(_ context.Context, names []string, foreground bool) (*errgroup.Group, error) { // Validate all workload names to prevent path traversal attacks for _, name := range names { if err := types.ValidateWorkloadName(name); err != nil { @@ -748,7 +827,7 @@ func (d *defaultManager) RestartWorkloads(_ context.Context, names []string, for } // UpdateWorkload updates a workload by stopping, deleting, and recreating it -func (d *defaultManager) UpdateWorkload(_ context.Context, workloadName string, newConfig *runner.RunConfig) (*errgroup.Group, error) { //nolint:lll +func (d *DefaultManager) UpdateWorkload(_ context.Context, workloadName string, newConfig *runner.RunConfig) (*errgroup.Group, error) { //nolint:lll // Validate workload name if err := types.ValidateWorkloadName(workloadName); err != nil { return nil, fmt.Errorf("invalid workload name '%s': %w", workloadName, err) @@ -762,7 +841,7 @@ func (d *defaultManager) UpdateWorkload(_ context.Context, workloadName string, } // updateSingleWorkload handles the update logic for a single workload -func (d *defaultManager) updateSingleWorkload(workloadName string, newConfig *runner.RunConfig) error { +func (d *DefaultManager) updateSingleWorkload(workloadName string, newConfig *runner.RunConfig) error { // Create a child context with a longer timeout childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -799,7 +878,7 @@ func (d *defaultManager) updateSingleWorkload(workloadName string, newConfig *ru } // restartSingleWorkload handles the restart logic for a single workload -func (d *defaultManager) restartSingleWorkload(name string, foreground bool) error { +func (d *DefaultManager) restartSingleWorkload(name string, foreground bool) error { // Create a child context with a longer timeout childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -822,7 +901,7 @@ func (d *defaultManager) restartSingleWorkload(name string, foreground bool) err } // restartRemoteWorkload handles restarting a remote workload -func (d *defaultManager) restartRemoteWorkload( +func (d *DefaultManager) restartRemoteWorkload( ctx context.Context, name string, runConfig *runner.RunConfig, @@ -889,7 +968,7 @@ func (d *defaultManager) restartRemoteWorkload( // restartContainerWorkload handles restarting a container-based workload // //nolint:gocyclo // Complexity is justified - handles multiple restart scenarios and edge cases -func (d *defaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error { +func (d *DefaultManager) restartContainerWorkload(ctx context.Context, name string, foreground bool) error { // Get container info to resolve partial names and extract proper workload name var containerName string var workloadName string @@ -999,7 +1078,7 @@ func (d *defaultManager) restartContainerWorkload(ctx context.Context, name stri } // startWorkload starts the workload in either foreground or background mode -func (d *defaultManager) startWorkload(ctx context.Context, name string, mcpRunner *runner.Runner, foreground bool) error { +func (d *DefaultManager) startWorkload(ctx context.Context, name string, mcpRunner *runner.Runner, foreground bool) error { logger.Infof("Starting tooling server %s...", name) var err error @@ -1044,7 +1123,7 @@ func removeClientConfigurations(containerName string, isAuxiliary bool) error { } // loadRunnerFromState attempts to load a Runner from the state store -func (d *defaultManager) loadRunnerFromState(ctx context.Context, baseName string) (*runner.Runner, error) { +func (d *DefaultManager) loadRunnerFromState(ctx context.Context, baseName string) (*runner.Runner, error) { // Load the run config from the state store runConfig, err := runner.LoadState(ctx, baseName) if err != nil { @@ -1063,7 +1142,7 @@ func (d *defaultManager) loadRunnerFromState(ctx context.Context, baseName strin return runner.NewRunner(runConfig, d.statuses), nil } -func (d *defaultManager) needSecretsPassword(secretOptions []string) bool { +func (d *DefaultManager) needSecretsPassword(secretOptions []string) bool { // If the user did not ask for any secrets, then don't attempt to instantiate // the secrets manager. if len(secretOptions) == 0 { @@ -1075,7 +1154,7 @@ func (d *defaultManager) needSecretsPassword(secretOptions []string) bool { } // cleanupTempPermissionProfile cleans up temporary permission profile files for a given base name -func (*defaultManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { +func (*DefaultManager) cleanupTempPermissionProfile(ctx context.Context, baseName string) error { // Try to load the saved configuration to get the permission profile path runConfig, err := runner.LoadState(ctx, baseName) if err != nil { @@ -1095,7 +1174,7 @@ func (*defaultManager) cleanupTempPermissionProfile(ctx context.Context, baseNam } // stopSingleContainerWorkload stops a single container workload -func (d *defaultManager) stopSingleContainerWorkload(ctx context.Context, workload *rt.ContainerInfo) error { +func (d *DefaultManager) stopSingleContainerWorkload(ctx context.Context, workload *rt.ContainerInfo) error { childCtx, cancel := context.WithTimeout(context.Background(), AsyncOperationTimeout) defer cancel() @@ -1136,7 +1215,7 @@ func (d *defaultManager) stopSingleContainerWorkload(ctx context.Context, worklo } // MoveToGroup moves the specified workloads from one group to another by updating their runconfig. -func (*defaultManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { +func (*DefaultManager) MoveToGroup(ctx context.Context, workloadNames []string, groupFrom string, groupTo string) error { for _, workloadName := range workloadNames { // Validate workload name if err := types.ValidateWorkloadName(workloadName); err != nil { @@ -1171,7 +1250,7 @@ func (*defaultManager) MoveToGroup(ctx context.Context, workloadNames []string, } // ListWorkloadsInGroup returns all workload names that belong to the specified group -func (d *defaultManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { +func (d *DefaultManager) ListWorkloadsInGroup(ctx context.Context, groupName string) ([]string, error) { workloads, err := d.ListWorkloads(ctx, true) // listAll=true to include stopped workloads if err != nil { return nil, fmt.Errorf("failed to list workloads: %w", err) @@ -1189,7 +1268,7 @@ func (d *defaultManager) ListWorkloadsInGroup(ctx context.Context, groupName str } // getRemoteWorkloadsFromState retrieves remote servers from the state store -func (d *defaultManager) getRemoteWorkloadsFromState( +func (d *DefaultManager) getRemoteWorkloadsFromState( ctx context.Context, listAll bool, labelFilters []string, diff --git a/pkg/workloads/manager_test.go b/pkg/workloads/manager_test.go index ea971127e..a49603de5 100644 --- a/pkg/workloads/manager_test.go +++ b/pkg/workloads/manager_test.go @@ -160,7 +160,7 @@ func TestDefaultManager_ListWorkloadsInGroup(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupStatusMgr(mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ runtime: nil, // Not needed for this test statuses: mockStatusMgr, } @@ -196,7 +196,7 @@ func TestNewManagerFromRuntime(t *testing.T) { require.NotNil(t, manager) // Verify it's a defaultManager with the runtime set - defaultMgr, ok := manager.(*defaultManager) + defaultMgr, ok := manager.(*DefaultManager) require.True(t, ok) assert.Equal(t, mockRuntime, defaultMgr.runtime) assert.NotNil(t, defaultMgr.statuses) @@ -217,7 +217,7 @@ func TestNewManagerFromRuntimeWithProvider(t *testing.T) { require.NoError(t, err) require.NotNil(t, manager) - defaultMgr, ok := manager.(*defaultManager) + defaultMgr, ok := manager.(*DefaultManager) require.True(t, ok) assert.Equal(t, mockRuntime, defaultMgr.runtime) assert.Equal(t, mockConfigProvider, defaultMgr.configProvider) @@ -288,7 +288,7 @@ func TestDefaultManager_DoesWorkloadExist(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, } @@ -320,7 +320,7 @@ func TestDefaultManager_GetWorkload(t *testing.T) { mockStatusMgr.EXPECT().GetWorkload(gomock.Any(), "test-workload").Return(expectedWorkload, nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, } @@ -387,7 +387,7 @@ func TestDefaultManager_GetLogs(t *testing.T) { mockRuntime := runtimeMocks.NewMockRuntime(ctrl) tt.setupMocks(mockRuntime) - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, } @@ -437,7 +437,7 @@ func TestDefaultManager_StopWorkloads(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - manager := &defaultManager{} + manager := &DefaultManager{} ctx := context.Background() group, err := manager.StopWorkloads(ctx, tt.workloadNames) @@ -487,7 +487,7 @@ func TestDefaultManager_DeleteWorkloads(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - manager := &defaultManager{} + manager := &DefaultManager{} ctx := context.Background() group, err := manager.DeleteWorkloads(ctx, tt.workloadNames) @@ -534,7 +534,7 @@ func TestDefaultManager_RestartWorkloads(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - manager := &defaultManager{} + manager := &DefaultManager{} ctx := context.Background() group, err := manager.RestartWorkloads(ctx, tt.workloadNames, tt.foreground) @@ -635,7 +635,7 @@ func TestDefaultManager_restartRemoteWorkload(t *testing.T) { statusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(statusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, } @@ -750,7 +750,7 @@ func TestDefaultManager_restartContainerWorkload(t *testing.T) { tt.setupMocks(statusMgr, runtimeMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, runtime: runtimeMgr, } @@ -788,7 +788,7 @@ func TestDefaultManager_restartLogicConsistency(t *testing.T) { // Check if supervisor is alive - return valid PID (healthy) statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-base").Return(12345, nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, } @@ -826,7 +826,7 @@ func TestDefaultManager_restartLogicConsistency(t *testing.T) { statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, } @@ -866,7 +866,7 @@ func TestDefaultManager_restartLogicConsistency(t *testing.T) { // Check if supervisor is alive - return valid PID (healthy) statusMgr.EXPECT().GetWorkloadPID(gomock.Any(), "test-workload").Return(12345, nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, runtime: runtimeMgr, } @@ -911,7 +911,7 @@ func TestDefaultManager_restartLogicConsistency(t *testing.T) { statusMgr.EXPECT().SetWorkloadStatus(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) statusMgr.EXPECT().SetWorkloadPID(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: statusMgr, runtime: runtimeMgr, } @@ -968,7 +968,7 @@ func TestDefaultManager_RunWorkload(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, } @@ -1029,7 +1029,7 @@ func TestDefaultManager_validateSecretParameters(t *testing.T) { mockConfigProvider := configMocks.NewMockProvider(ctrl) tt.setupMocks(mockConfigProvider) - manager := &defaultManager{ + manager := &DefaultManager{ configProvider: mockConfigProvider, } @@ -1106,7 +1106,7 @@ func TestDefaultManager_getWorkloadContainer(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockRuntime, mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, statuses: mockStatusMgr, } @@ -1170,7 +1170,7 @@ func TestDefaultManager_removeContainer(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockRuntime, mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, statuses: mockStatusMgr, } @@ -1224,7 +1224,7 @@ func TestDefaultManager_needSecretsPassword(t *testing.T) { mockConfigProvider := configMocks.NewMockProvider(ctrl) tt.setupMocks(mockConfigProvider) - manager := &defaultManager{ + manager := &DefaultManager{ configProvider: mockConfigProvider, } @@ -1272,7 +1272,7 @@ func TestDefaultManager_RunWorkloadDetached(t *testing.T) { mockConfigProvider := configMocks.NewMockProvider(ctrl) tt.setupMocks(mockStatusMgr, mockConfigProvider) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, configProvider: mockConfigProvider, } @@ -1307,7 +1307,7 @@ func TestDefaultManager_RunWorkloadDetached_PIDManagement(t *testing.T) { // we verify the PID management integration exists by checking the method signature // and code structure rather than running the full integration. - manager := &defaultManager{} + manager := &DefaultManager{} assert.NotNil(t, manager, "defaultManager should be instantiable") // Verify the method exists with the correct signature @@ -1386,7 +1386,7 @@ func TestDefaultManager_ListWorkloads(t *testing.T) { mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) tt.setupMocks(mockStatusMgr) - manager := &defaultManager{ + manager := &DefaultManager{ statuses: mockStatusMgr, } @@ -1488,7 +1488,7 @@ func TestDefaultManager_UpdateWorkload(t *testing.T) { tt.setupMocks(mockRuntime, mockStatusManager) } - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, statuses: mockStatusManager, configProvider: mockConfigProvider, @@ -1639,7 +1639,7 @@ func TestDefaultManager_updateSingleWorkload(t *testing.T) { tt.setupMocks(mockRuntime, mockStatusManager) } - manager := &defaultManager{ + manager := &DefaultManager{ runtime: mockRuntime, statuses: mockStatusManager, configProvider: mockConfigProvider,