diff --git a/pkg/agent/clientset.go b/pkg/agent/clientset.go index db44b5e4a..d961ae5e1 100644 --- a/pkg/agent/clientset.go +++ b/pkg/agent/clientset.go @@ -255,6 +255,14 @@ func (cs *ClientSet) sync() { } func (cs *ClientSet) connectOnce() error { + // Skip establishing new connections if draining + select { + case <-cs.drainCh: + klog.V(2).InfoS("Skipping connectOnce - agent is draining") + return nil + default: + } + serverCount := cs.determineServerCount() // If not in syncForever mode, we only connect if we have fewer connections than the server count. diff --git a/pkg/server/backend_manager.go b/pkg/server/backend_manager.go index 97b5d8c7d..0ab55c285 100644 --- a/pkg/server/backend_manager.go +++ b/pkg/server/backend_manager.go @@ -49,6 +49,25 @@ type Backend struct { // cached from conn.Context() id string idents header.Identifiers + + // draining indicates if this backend is draining and should not accept new connections + draining bool + // mu protects draining field + mu sync.RWMutex +} + +// IsDraining returns true if the backend is draining +func (b *Backend) IsDraining() bool { + b.mu.RLock() + defer b.mu.RUnlock() + return b.draining +} + +// SetDraining marks the backend as draining +func (b *Backend) SetDraining() { + b.mu.Lock() + defer b.mu.Unlock() + b.draining = true } func (b *Backend) Send(p *client.Packet) error { @@ -346,9 +365,36 @@ func (s *DefaultBackendStorage) GetRandomBackend() (*Backend, error) { if len(s.backends) == 0 { return nil, &ErrNotFound{} } - agentID := s.agentIDs[s.random.Intn(len(s.agentIDs))] - klog.V(3).InfoS("Pick agent as backend", "agentID", agentID) - // always return the first connection to an agent, because the agent - // will close later connections if there are multiple. - return s.backends[agentID][0], nil + + var firstDrainingBackend *Backend + + // Start at a random agent and check each agent in sequence + startIdx := s.random.Intn(len(s.agentIDs)) + for i := 0; i < len(s.agentIDs); i++ { + // Wrap around using modulo + currentIdx := (startIdx + i) % len(s.agentIDs) + agentID := s.agentIDs[currentIdx] + // always return the first connection to an agent, because the agent + // will close later connections if there are multiple. + backend := s.backends[agentID][0] + + if !backend.IsDraining() { + klog.V(3).InfoS("Pick agent as backend", "agentID", agentID) + return backend, nil + } + + // Keep track of first draining backend as fallback + if firstDrainingBackend == nil { + firstDrainingBackend = backend + } + } + + // All agents are draining, use one as fallback + if firstDrainingBackend != nil { + agentID := firstDrainingBackend.id + klog.V(2).InfoS("No non-draining backends available, using draining backend as fallback", "agentID", agentID) + return firstDrainingBackend, nil + } + + return nil, &ErrNotFound{} } diff --git a/pkg/server/desthost_backend_manager.go b/pkg/server/desthost_backend_manager.go index 280065775..d2a3e0f14 100644 --- a/pkg/server/desthost_backend_manager.go +++ b/pkg/server/desthost_backend_manager.go @@ -79,8 +79,25 @@ func (dibm *DestHostBackendManager) Backend(ctx context.Context) (*Backend, erro if destHost != "" { bes, exist := dibm.backends[destHost] if exist && len(bes) > 0 { - klog.V(5).InfoS("Get the backend through the DestHostBackendManager", "destHost", destHost) - return dibm.backends[destHost][0], nil + var firstDrainingBackend *Backend + + // Find a non-draining backend for this destination host + for _, backend := range bes { + if !backend.IsDraining() { + klog.V(5).InfoS("Get the backend through the DestHostBackendManager", "destHost", destHost) + return backend, nil + } + // Keep track of first draining backend as fallback + if firstDrainingBackend == nil { + firstDrainingBackend = backend + } + } + + // All backends for this destination are draining, use one as fallback + if firstDrainingBackend != nil { + klog.V(4).InfoS("All backends for destination host are draining, using one as fallback", "destHost", destHost) + return firstDrainingBackend, nil + } } } return nil, &ErrNotFound{} diff --git a/pkg/server/server.go b/pkg/server/server.go index ffec433c3..b3ef519df 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -999,6 +999,8 @@ func (s *ProxyServer) serveRecvBackend(backend *Backend, agentID string, recvCh case client.PacketType_DRAIN: klog.V(2).InfoS("agent is draining", "agentID", agentID) + backend.SetDraining() + klog.V(2).InfoS("marked backend as draining, will not route new requests to this agent", "agentID", agentID) default: klog.V(5).InfoS("Ignoring unrecognized packet from backend", "packet", pkt, "agentID", agentID) }