diff --git a/apis/aga/v1beta1/globalaccelerator_types.go b/apis/aga/v1beta1/globalaccelerator_types.go index 55bb619a5d..aa7e98209f 100644 --- a/apis/aga/v1beta1/globalaccelerator_types.go +++ b/apis/aga/v1beta1/globalaccelerator_types.go @@ -48,6 +48,7 @@ const ( ) // PortRange defines the port range for Global Accelerator listeners. +// +kubebuilder:validation:XValidation:rule="self.fromPort <= self.toPort",message="FromPort must be less than or equal to ToPort" type PortRange struct { // FromPort is the first port in the range of ports, inclusive. // +kubebuilder:validation:Minimum=1 diff --git a/config/crd/aga/aga-crds.yaml b/config/crd/aga/aga-crds.yaml index 04076af7d2..032fe9a2a8 100644 --- a/config/crd/aga/aga-crds.yaml +++ b/config/crd/aga/aga-crds.yaml @@ -264,6 +264,9 @@ spec: - fromPort - toPort type: object + x-kubernetes-validations: + - message: FromPort must be less than or equal to ToPort + rule: self.fromPort <= self.toPort maxItems: 10 minItems: 1 type: array diff --git a/config/crd/aga/aga.k8s.aws_globalaccelerators.yaml b/config/crd/aga/aga.k8s.aws_globalaccelerators.yaml index 04076af7d2..032fe9a2a8 100644 --- a/config/crd/aga/aga.k8s.aws_globalaccelerators.yaml +++ b/config/crd/aga/aga.k8s.aws_globalaccelerators.yaml @@ -264,6 +264,9 @@ spec: - fromPort - toPort type: object + x-kubernetes-validations: + - message: FromPort must be less than or equal to ToPort + rule: self.fromPort <= self.toPort maxItems: 10 minItems: 1 type: array diff --git a/config/webhook/globalaccelerator_validator_patch.yaml b/config/webhook/globalaccelerator_validator_patch.yaml new file mode 100644 index 0000000000..e6313245d9 --- /dev/null +++ b/config/webhook/globalaccelerator_validator_patch.yaml @@ -0,0 +1,18 @@ +# This patch adds the GlobalAccelerator validator webhook configuration to the webhook configurations +apiVersion: admissionregistration.k8s.io/v1 +kind: ValidatingWebhookConfiguration +metadata: + name: webhook-configuration +webhooks: + - name: vglobalaccelerator.aga.k8s.aws + rules: + - apiGroups: + - "aga.k8s.aws" + apiVersions: + - v1beta1 + operations: + - CREATE + - UPDATE + resources: + - globalaccelerators + scope: "Namespaced" diff --git a/config/webhook/kustomization.yaml b/config/webhook/kustomization.yaml index 20d98aca4c..7147059ebd 100644 --- a/config/webhook/kustomization.yaml +++ b/config/webhook/kustomization.yaml @@ -9,3 +9,4 @@ patchesStrategicMerge: - pod_mutator_patch.yaml - service_mutator_patch.yaml - ingressclassparams_validator_patch.yaml + - globalaccelerator_validator_patch.yaml diff --git a/config/webhook/manifests.yaml b/config/webhook/manifests.yaml index 00793b4707..7deb75f1f8 100644 --- a/config/webhook/manifests.yaml +++ b/config/webhook/manifests.yaml @@ -68,6 +68,27 @@ kind: ValidatingWebhookConfiguration metadata: name: webhook webhooks: + - admissionReviewVersions: + - v1beta1 + clientConfig: + service: + name: webhook-service + namespace: system + path: /validate-aga-k8s-aws-v1beta1-globalaccelerator + failurePolicy: Fail + matchPolicy: Equivalent + name: vglobalaccelerator.aga.k8s.aws + rules: + - apiGroups: + - aga.k8s.aws + apiVersions: + - v1beta1 + operations: + - CREATE + - UPDATE + resources: + - globalaccelerators + sideEffects: None - admissionReviewVersions: - v1beta1 clientConfig: diff --git a/controllers/aga/globalaccelerator_controller.go b/controllers/aga/globalaccelerator_controller.go index 426277f991..354b9b2bf3 100644 --- a/controllers/aga/globalaccelerator_controller.go +++ b/controllers/aga/globalaccelerator_controller.go @@ -275,7 +275,9 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co func (r *globalAcceleratorReconciler) cleanupGlobalAcceleratorResources(ctx context.Context, ga *agaapi.GlobalAccelerator) error { r.logger.Info("Cleaning up GlobalAccelerator resources", "globalAccelerator", k8s.NamespacedName(ga)) - // TODO we will handle cleaning up dependent resources when we implement those + // Our enhanced AcceleratorManager now handles deletion of listeners before accelerator. + // TODO: This will be enhanced to delete endpoint groups and endpoints + // before deleting listeners and accelerator (when those features are implemented) // 1. Find the accelerator ARN from the CRD status if ga.Status.AcceleratorARN == nil { r.logger.Info("No accelerator ARN found in status, nothing to clean up", "globalAccelerator", k8s.NamespacedName(ga)) diff --git a/helm/aws-load-balancer-controller/crds/aga-crds.yaml b/helm/aws-load-balancer-controller/crds/aga-crds.yaml index 04076af7d2..032fe9a2a8 100644 --- a/helm/aws-load-balancer-controller/crds/aga-crds.yaml +++ b/helm/aws-load-balancer-controller/crds/aga-crds.yaml @@ -264,6 +264,9 @@ spec: - fromPort - toPort type: object + x-kubernetes-validations: + - message: FromPort must be less than or equal to ToPort + rule: self.fromPort <= self.toPort maxItems: 10 minItems: 1 type: array diff --git a/main.go b/main.go index 17aa949421..7aae17fd9c 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,7 @@ import ( "fmt" "k8s.io/apimachinery/pkg/util/sets" "os" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/shared_utils" elbv2gw "sigs.k8s.io/aws-load-balancer-controller/apis/gateway/v1beta1" @@ -65,6 +66,7 @@ import ( "sigs.k8s.io/aws-load-balancer-controller/pkg/runtime" "sigs.k8s.io/aws-load-balancer-controller/pkg/targetgroupbinding" "sigs.k8s.io/aws-load-balancer-controller/pkg/version" + agawebhook "sigs.k8s.io/aws-load-balancer-controller/webhooks/aga" corewebhook "sigs.k8s.io/aws-load-balancer-controller/webhooks/core" elbv2webhook "sigs.k8s.io/aws-load-balancer-controller/webhooks/elbv2" networkingwebhook "sigs.k8s.io/aws-load-balancer-controller/webhooks/networking" @@ -236,7 +238,7 @@ func main() { } // Setup GlobalAccelerator controller only if enabled - if shared_utils.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) { + if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) { agaReconciler := agacontroller.NewGlobalAcceleratorReconciler(mgr.GetClient(), mgr.GetEventRecorderFor("globalAccelerator"), finalizerManager, controllerCFG, cloud, ctrl.Log.WithName("controllers").WithName("globalAccelerator"), lbcMetricsCollector, reconcileCounters) if err := agaReconciler.SetupWithManager(ctx, mgr, clientSet); err != nil { @@ -415,6 +417,11 @@ func main() { elbv2webhook.NewTargetGroupBindingMutator(cloud.ELBV2(), ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) elbv2webhook.NewTargetGroupBindingValidator(mgr.GetClient(), cloud.ELBV2(), cloud.VpcID(), ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) networkingwebhook.NewIngressValidator(mgr.GetClient(), controllerCFG.IngressConfig, ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) + + // Setup GlobalAccelerator validator only if enabled + if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) { + agawebhook.NewGlobalAcceleratorValidator(ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) + } //+kubebuilder:scaffold:builder go func() { diff --git a/pkg/aga/model_build_listener.go b/pkg/aga/model_build_listener.go new file mode 100644 index 0000000000..551e0f1ab2 --- /dev/null +++ b/pkg/aga/model_build_listener.go @@ -0,0 +1,128 @@ +package aga + +import ( + "context" + "fmt" + "github.com/pkg/errors" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +// listenerBuilder builds Listener model resources +type listenerBuilder interface { + Build(ctx context.Context, stack core.Stack, accelerator *agamodel.Accelerator, listeners []agaapi.GlobalAcceleratorListener) ([]*agamodel.Listener, error) +} + +// NewListenerBuilder constructs new listenerBuilder +func NewListenerBuilder() listenerBuilder { + return &defaultListenerBuilder{} +} + +var _ listenerBuilder = &defaultListenerBuilder{} + +type defaultListenerBuilder struct{} + +// Build builds Listener model resources +func (b *defaultListenerBuilder) Build(ctx context.Context, stack core.Stack, accelerator *agamodel.Accelerator, listeners []agaapi.GlobalAcceleratorListener) ([]*agamodel.Listener, error) { + if listeners == nil || len(listeners) == 0 { + return nil, nil + } + + var result []*agamodel.Listener + for i, listener := range listeners { + listenerModel, err := buildListener(ctx, stack, accelerator, listener, i) + if err != nil { + return nil, err + } + result = append(result, listenerModel) + } + return result, nil +} + +// buildListener builds a single Listener model resource +func buildListener(ctx context.Context, stack core.Stack, accelerator *agamodel.Accelerator, listener agaapi.GlobalAcceleratorListener, index int) (*agamodel.Listener, error) { + spec, err := buildListenerSpec(ctx, accelerator, listener) + if err != nil { + return nil, err + } + + resourceID := fmt.Sprintf("Listener-%d", index) + listenerModel := agamodel.NewListener(stack, resourceID, spec, accelerator) + return listenerModel, nil +} + +// buildListenerSpec builds the ListenerSpec for a single Listener model resource +func buildListenerSpec(ctx context.Context, accelerator *agamodel.Accelerator, listener agaapi.GlobalAcceleratorListener) (agamodel.ListenerSpec, error) { + protocol, err := buildListenerProtocol(ctx, listener) + if err != nil { + return agamodel.ListenerSpec{}, err + } + + portRanges, err := buildListenerPortRanges(ctx, listener) + if err != nil { + return agamodel.ListenerSpec{}, err + } + + clientAffinity := buildListenerClientAffinity(ctx, listener) + + return agamodel.ListenerSpec{ + AcceleratorARN: accelerator.AcceleratorARN(), + Protocol: protocol, + PortRanges: portRanges, + ClientAffinity: clientAffinity, + }, nil +} + +// buildListenerProtocol determines the protocol for the listener +func buildListenerProtocol(_ context.Context, listener agaapi.GlobalAcceleratorListener) (agamodel.Protocol, error) { + if listener.Protocol == nil { + // TODO: Auto-discovery feature - Auto-determine protocol from endpoints if nil + // For now, default to TCP + return agamodel.ProtocolTCP, nil + } + + switch *listener.Protocol { + case agaapi.GlobalAcceleratorProtocolTCP: + return agamodel.ProtocolTCP, nil + case agaapi.GlobalAcceleratorProtocolUDP: + return agamodel.ProtocolUDP, nil + default: + return "", errors.Errorf("unsupported protocol: %s", *listener.Protocol) + } +} + +// buildListenerPortRanges determines the port ranges for the listener +func buildListenerPortRanges(_ context.Context, listener agaapi.GlobalAcceleratorListener) ([]agamodel.PortRange, error) { + if listener.PortRanges == nil { + // TODO: Auto-discovery feature - Auto-determine port ranges from endpoints if nil + // For now, default to port 80 + return []agamodel.PortRange{{ + FromPort: 80, + ToPort: 80, + }}, nil + } + + var portRanges []agamodel.PortRange + for _, pr := range *listener.PortRanges { + // Required validations are already done webhooks and CEL + portRanges = append(portRanges, agamodel.PortRange{ + FromPort: pr.FromPort, + ToPort: pr.ToPort, + }) + } + return portRanges, nil +} + +// buildListenerClientAffinity determines the client affinity for the listener +func buildListenerClientAffinity(_ context.Context, listener agaapi.GlobalAcceleratorListener) agamodel.ClientAffinity { + switch listener.ClientAffinity { + case agaapi.ClientAffinitySourceIP: + return agamodel.ClientAffinitySourceIP + case agaapi.ClientAffinityNone: + return agamodel.ClientAffinityNone + default: + // Default to NONE as per AWS Global Accelerator behavior + return agamodel.ClientAffinityNone + } +} diff --git a/pkg/aga/model_build_listener_test.go b/pkg/aga/model_build_listener_test.go new file mode 100644 index 0000000000..e74ba00360 --- /dev/null +++ b/pkg/aga/model_build_listener_test.go @@ -0,0 +1,487 @@ +package aga + +import ( + "context" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "testing" + + "github.com/stretchr/testify/assert" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +func TestDefaultListenerBuilder_Build(t *testing.T) { + // Protocol references for direct pointer usage + protocolTCP := agaapi.GlobalAcceleratorProtocolTCP + protocolUDP := agaapi.GlobalAcceleratorProtocolUDP + + tests := []struct { + name string + listeners []agaapi.GlobalAcceleratorListener + wantListeners int + wantErr bool + }{ + { + name: "with nil listeners", + listeners: nil, + wantListeners: 0, + wantErr: false, + }, + { + name: "with empty listeners", + listeners: []agaapi.GlobalAcceleratorListener{}, + wantListeners: 0, + wantErr: false, + }, + { + name: "with single TCP listener", + listeners: []agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + wantListeners: 1, + wantErr: false, + }, + { + name: "with single UDP listener", + listeners: []agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolUDP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 53, + ToPort: 53, + }, + }, + ClientAffinity: agaapi.ClientAffinitySourceIP, + }, + }, + wantListeners: 1, + wantErr: false, + }, + { + name: "with multiple listeners", + listeners: []agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + { + Protocol: &protocolUDP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 53, + ToPort: 53, + }, + }, + ClientAffinity: agaapi.ClientAffinitySourceIP, + }, + }, + wantListeners: 2, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup test context + ctx := context.Background() + stack := core.NewDefaultStack(core.StackID{Namespace: "test-ns", Name: "test-name"}) + accelerator := createTestAccelerator(stack) + + // Create listener builder and build listeners + builder := NewListenerBuilder() + listeners, err := builder.Build(ctx, stack, accelerator, tt.listeners) + + // Check results + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.wantListeners == 0 { + assert.Nil(t, listeners) + } else { + assert.Equal(t, tt.wantListeners, len(listeners)) + } + } + }) + } +} + +func TestDefaultListenerBuilder_buildListenerSpec(t *testing.T) { + // Protocol references for direct pointer usage + protocolTCP := agaapi.GlobalAcceleratorProtocolTCP + protocolUDP := agaapi.GlobalAcceleratorProtocolUDP + + // Setup test context + ctx := context.Background() + stack := core.NewDefaultStack(core.StackID{Namespace: "test-ns", Name: "test-name"}) + accelerator := createTestAccelerator(stack) + + tests := []struct { + name string + listener agaapi.GlobalAcceleratorListener + wantProtocol agamodel.Protocol + wantAffinity agamodel.ClientAffinity + wantPorts []agamodel.PortRange + wantErr bool + }{ + { + name: "with TCP protocol", + listener: agaapi.GlobalAcceleratorListener{ + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + wantProtocol: agamodel.ProtocolTCP, + wantAffinity: agamodel.ClientAffinityNone, + wantPorts: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + wantErr: false, + }, + { + name: "with UDP protocol", + listener: agaapi.GlobalAcceleratorListener{ + Protocol: &protocolUDP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 53, + ToPort: 53, + }, + }, + ClientAffinity: agaapi.ClientAffinitySourceIP, + }, + wantProtocol: agamodel.ProtocolUDP, + wantAffinity: agamodel.ClientAffinitySourceIP, + wantPorts: []agamodel.PortRange{ + { + FromPort: 53, + ToPort: 53, + }, + }, + wantErr: false, + }, + { + name: "with nil protocol (should default to TCP)", + listener: agaapi.GlobalAcceleratorListener{ + Protocol: nil, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + wantProtocol: agamodel.ProtocolTCP, + wantAffinity: agamodel.ClientAffinityNone, + wantPorts: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + wantErr: false, + }, + { + name: "with nil port ranges (should default to port 80)", + listener: agaapi.GlobalAcceleratorListener{ + Protocol: &protocolTCP, + PortRanges: nil, + ClientAffinity: agaapi.ClientAffinityNone, + }, + wantProtocol: agamodel.ProtocolTCP, + wantAffinity: agamodel.ClientAffinityNone, + wantPorts: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // Build listener spec + spec, err := buildListenerSpec(ctx, accelerator, tt.listener) + + // Check results + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantProtocol, spec.Protocol) + assert.Equal(t, tt.wantAffinity, spec.ClientAffinity) + assert.Equal(t, tt.wantPorts, spec.PortRanges) + // AcceleratorARN is a token that will be resolved later, not a direct string + assert.NotNil(t, spec.AcceleratorARN) + } + }) + } +} + +// Helper function to create a test accelerator +func createTestAccelerator(stack core.Stack) *agamodel.Accelerator { + spec := agamodel.AcceleratorSpec{ + Name: "test-accelerator", + Enabled: awssdk.Bool(true), + Tags: map[string]string{"Key": "Value"}, + } + + accelerator := agamodel.NewAccelerator(stack, "test-accelerator", spec, &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga", + Namespace: "default", + }, + }) + + // Set the accelerator status to simulate it being fulfilled + accelerator.SetStatus(agamodel.AcceleratorStatus{ + AcceleratorARN: "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd-abcd-1234-abcd-1234abcdefgh", + DNSName: "a1234abcd5678efghi.awsglobalaccelerator.com", + Status: "DEPLOYED", + }) + + return accelerator +} + +func TestBuildListenerProtocol(t *testing.T) { + // Protocol references for direct pointer usage + protocolTCP := agaapi.GlobalAcceleratorProtocolTCP + protocolUDP := agaapi.GlobalAcceleratorProtocolUDP + invalidProtocol := agaapi.GlobalAcceleratorProtocol("INVALID") + + tests := []struct { + name string + listener agaapi.GlobalAcceleratorListener + wantProtocol agamodel.Protocol + wantErr bool + wantErrString string + }{ + { + name: "with nil protocol (should default to TCP)", + listener: agaapi.GlobalAcceleratorListener{ + Protocol: nil, + }, + wantProtocol: agamodel.ProtocolTCP, + wantErr: false, + }, + { + name: "with TCP protocol", + listener: agaapi.GlobalAcceleratorListener{ + Protocol: &protocolTCP, + }, + wantProtocol: agamodel.ProtocolTCP, + wantErr: false, + }, + { + name: "with UDP protocol", + listener: agaapi.GlobalAcceleratorListener{ + Protocol: &protocolUDP, + }, + wantProtocol: agamodel.ProtocolUDP, + wantErr: false, + }, + { + name: "with invalid protocol", + listener: agaapi.GlobalAcceleratorListener{ + Protocol: &invalidProtocol, + }, + wantProtocol: "", + wantErr: true, + wantErrString: "unsupported protocol: INVALID", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test context + ctx := context.Background() + + // Call function + protocol, err := buildListenerProtocol(ctx, tt.listener) + + // Check results + if tt.wantErr { + assert.Error(t, err) + if tt.wantErrString != "" { + assert.Contains(t, err.Error(), tt.wantErrString) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantProtocol, protocol) + } + }) + } +} + +func TestBuildListenerPortRanges(t *testing.T) { + tests := []struct { + name string + listener agaapi.GlobalAcceleratorListener + wantPorts []agamodel.PortRange + wantErr bool + }{ + { + name: "with nil port ranges (should default to port 80)", + listener: agaapi.GlobalAcceleratorListener{ + PortRanges: nil, + }, + wantPorts: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + wantErr: false, + }, + { + name: "with single port range", + listener: agaapi.GlobalAcceleratorListener{ + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 443, + ToPort: 443, + }, + }, + }, + wantPorts: []agamodel.PortRange{ + { + FromPort: 443, + ToPort: 443, + }, + }, + wantErr: false, + }, + { + name: "with multiple port ranges", + listener: agaapi.GlobalAcceleratorListener{ + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8090, + }, + }, + }, + wantPorts: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8090, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test context + ctx := context.Background() + + // Call function + portRanges, err := buildListenerPortRanges(ctx, tt.listener) + + // Check results + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantPorts, portRanges) + } + }) + } +} + +func TestBuildListenerClientAffinity(t *testing.T) { + tests := []struct { + name string + listener agaapi.GlobalAcceleratorListener + wantAffinity agamodel.ClientAffinity + }{ + { + name: "with NONE client affinity", + listener: agaapi.GlobalAcceleratorListener{ + ClientAffinity: agaapi.ClientAffinityNone, + }, + wantAffinity: agamodel.ClientAffinityNone, + }, + { + name: "with SOURCE_IP client affinity", + listener: agaapi.GlobalAcceleratorListener{ + ClientAffinity: agaapi.ClientAffinitySourceIP, + }, + wantAffinity: agamodel.ClientAffinitySourceIP, + }, + { + name: "with invalid client affinity (should default to NONE)", + listener: agaapi.GlobalAcceleratorListener{ + ClientAffinity: "INVALID", + }, + wantAffinity: agamodel.ClientAffinityNone, + }, + { + name: "with empty client affinity (should default to NONE)", + listener: agaapi.GlobalAcceleratorListener{ + ClientAffinity: "", + }, + wantAffinity: agamodel.ClientAffinityNone, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test context + ctx := context.Background() + + // Call function + clientAffinity := buildListenerClientAffinity(ctx, tt.listener) + + // Check results + assert.Equal(t, tt.wantAffinity, clientAffinity) + }) + } +} diff --git a/pkg/aga/model_builder.go b/pkg/aga/model_builder.go index 7b8333667a..d4938ab291 100644 --- a/pkg/aga/model_builder.go +++ b/pkg/aga/model_builder.go @@ -62,7 +62,6 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccele // Create fresh builder instances for each reconciliation acceleratorBuilder := NewAcceleratorBuilder(b.trackingProvider, b.clusterName, b.clusterRegion, b.defaultTags, b.externalManagedTags, b.featureGates.Enabled(config.EnableDefaultTagsLowPriority)) // TODO - // listenerBuilder := NewListenerBuilder() // endpointGroupBuilder := NewEndpointGroupBuilder() // endpointBuilder := NewEndpointBuilder() @@ -72,8 +71,19 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccele return nil, nil, err } + // Build Listeners if specified + var listeners []*agamodel.Listener + if ga.Spec.Listeners != nil { + // Create builder for listeners and endpoints + listenerBuilder := NewListenerBuilder() + listeners, err = listenerBuilder.Build(ctx, stack, accelerator, *ga.Spec.Listeners) + if err != nil { + return nil, nil, err + } + } + + b.logger.V(1).Info("Listeners built", "listeners", listeners) // TODO: Add other resource builders - // listeners, err := listenerBuilder.Build(ctx, stack, accelerator, ga.Spec.Listeners) // endpointGroups, err := endpointGroupBuilder.Build(ctx, stack, listeners, ga.Spec.Listeners) // endpoints, err := endpointBuilder.Build(ctx, stack, endpointGroups, ga.Spec.Listeners) diff --git a/pkg/shared_utils/aga_utils.go b/pkg/aga/utils.go similarity index 97% rename from pkg/shared_utils/aga_utils.go rename to pkg/aga/utils.go index 15675e65c7..1f067e25e6 100644 --- a/pkg/shared_utils/aga_utils.go +++ b/pkg/aga/utils.go @@ -1,4 +1,4 @@ -package shared_utils +package aga import ( "strings" diff --git a/pkg/aga/utils_test.go b/pkg/aga/utils_test.go new file mode 100644 index 0000000000..6bfa40c5a0 --- /dev/null +++ b/pkg/aga/utils_test.go @@ -0,0 +1,95 @@ +package aga + +import ( + "testing" + + "sigs.k8s.io/aws-load-balancer-controller/pkg/config" +) + +func TestIsAGAControllerEnabled(t *testing.T) { + tests := []struct { + name string + featureGates config.FeatureGates + region string + want bool + }{ + { + name: "Feature gate disabled", + featureGates: func() config.FeatureGates { + fg := config.NewFeatureGates() + fg.Disable(config.AGAController) + return fg + }(), + region: "us-west-2", + want: false, + }, + { + name: "Feature gate enabled, standard region", + featureGates: func() config.FeatureGates { + fg := config.NewFeatureGates() + fg.Enable(config.AGAController) + return fg + }(), + region: "us-west-2", + want: true, + }, + { + name: "Feature gate enabled, China region", + featureGates: func() config.FeatureGates { + fg := config.NewFeatureGates() + fg.Enable(config.AGAController) + return fg + }(), + region: "cn-north-1", + want: false, + }, + { + name: "Feature gate enabled, GovCloud region", + featureGates: func() config.FeatureGates { + fg := config.NewFeatureGates() + fg.Enable(config.AGAController) + return fg + }(), + region: "us-gov-west-1", + want: false, + }, + { + name: "Feature gate enabled, ISO region", + featureGates: func() config.FeatureGates { + fg := config.NewFeatureGates() + fg.Enable(config.AGAController) + return fg + }(), + region: "us-iso-east-1", + want: false, + }, + { + name: "Feature gate enabled, ISO-E region", + featureGates: func() config.FeatureGates { + fg := config.NewFeatureGates() + fg.Enable(config.AGAController) + return fg + }(), + region: "eu-isoe-west-1", + want: false, + }, + { + name: "Feature gate enabled, upper case region", + featureGates: func() config.FeatureGates { + fg := config.NewFeatureGates() + fg.Enable(config.AGAController) + return fg + }(), + region: "US-WEST-2", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsAGAControllerEnabled(tt.featureGates, tt.region); got != tt.want { + t.Errorf("IsAGAControllerEnabled() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/aws/services/globalaccelerator.go b/pkg/aws/services/globalaccelerator.go index 6d388ce098..364f12d703 100644 --- a/pkg/aws/services/globalaccelerator.go +++ b/pkg/aws/services/globalaccelerator.go @@ -23,6 +23,24 @@ type GlobalAccelerator interface { // DeleteAccelerator deletes an accelerator. DeleteAcceleratorWithContext(ctx context.Context, input *globalaccelerator.DeleteAcceleratorInput) (*globalaccelerator.DeleteAcceleratorOutput, error) + // CreateListener creates a new listener. + CreateListenerWithContext(ctx context.Context, input *globalaccelerator.CreateListenerInput) (*globalaccelerator.CreateListenerOutput, error) + + // DescribeListener describes a listener. + DescribeListenerWithContext(ctx context.Context, input *globalaccelerator.DescribeListenerInput) (*globalaccelerator.DescribeListenerOutput, error) + + // UpdateListener updates a listener. + UpdateListenerWithContext(ctx context.Context, input *globalaccelerator.UpdateListenerInput) (*globalaccelerator.UpdateListenerOutput, error) + + // DeleteListener deletes a listener. + DeleteListenerWithContext(ctx context.Context, input *globalaccelerator.DeleteListenerInput) (*globalaccelerator.DeleteListenerOutput, error) + + // wrapper to ListListeners API, which aggregates paged results into list. + ListListenersAsList(ctx context.Context, input *globalaccelerator.ListListenersInput) ([]types.Listener, error) + + // ListListenersForAccelerator lists all listeners for an accelerator. + ListListenersForAcceleratorWithContext(ctx context.Context, input *globalaccelerator.ListListenersInput) (*globalaccelerator.ListListenersOutput, error) + // TagResource tags a resource. TagResourceWithContext(ctx context.Context, input *globalaccelerator.TagResourceInput) (*globalaccelerator.TagResourceOutput, error) @@ -117,3 +135,60 @@ func (c *defaultGlobalAccelerator) ListTagsForResourceWithContext(ctx context.Co } return client.ListTagsForResource(ctx, input) } + +func (c *defaultGlobalAccelerator) CreateListenerWithContext(ctx context.Context, input *globalaccelerator.CreateListenerInput) (*globalaccelerator.CreateListenerOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "CreateListener") + if err != nil { + return nil, err + } + return client.CreateListener(ctx, input) +} + +func (c *defaultGlobalAccelerator) DescribeListenerWithContext(ctx context.Context, input *globalaccelerator.DescribeListenerInput) (*globalaccelerator.DescribeListenerOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "DescribeListener") + if err != nil { + return nil, err + } + return client.DescribeListener(ctx, input) +} + +func (c *defaultGlobalAccelerator) UpdateListenerWithContext(ctx context.Context, input *globalaccelerator.UpdateListenerInput) (*globalaccelerator.UpdateListenerOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "UpdateListener") + if err != nil { + return nil, err + } + return client.UpdateListener(ctx, input) +} + +func (c *defaultGlobalAccelerator) DeleteListenerWithContext(ctx context.Context, input *globalaccelerator.DeleteListenerInput) (*globalaccelerator.DeleteListenerOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "DeleteListener") + if err != nil { + return nil, err + } + return client.DeleteListener(ctx, input) +} + +func (c *defaultGlobalAccelerator) ListListenersForAcceleratorWithContext(ctx context.Context, input *globalaccelerator.ListListenersInput) (*globalaccelerator.ListListenersOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "ListListeners") + if err != nil { + return nil, err + } + return client.ListListeners(ctx, input) +} + +func (c *defaultGlobalAccelerator) ListListenersAsList(ctx context.Context, input *globalaccelerator.ListListenersInput) ([]types.Listener, error) { + var result []types.Listener + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "ListListeners") + if err != nil { + return nil, err + } + paginator := globalaccelerator.NewListListenersPaginator(client, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return nil, err + } + result = append(result, output.Listeners...) + } + return result, nil +} diff --git a/pkg/aws/services/globalaccelerator_mocks.go b/pkg/aws/services/globalaccelerator_mocks.go index 3ccc9dfafd..e4989fa975 100644 --- a/pkg/aws/services/globalaccelerator_mocks.go +++ b/pkg/aws/services/globalaccelerator_mocks.go @@ -51,6 +51,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) CreateAcceleratorWithContext(arg0, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).CreateAcceleratorWithContext), arg0, arg1) } +// CreateListenerWithContext mocks base method. +func (m *MockGlobalAccelerator) CreateListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.CreateListenerInput) (*globalaccelerator.CreateListenerOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateListenerWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.CreateListenerOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateListenerWithContext indicates an expected call of CreateListenerWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) CreateListenerWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateListenerWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).CreateListenerWithContext), arg0, arg1) +} + // DeleteAcceleratorWithContext mocks base method. func (m *MockGlobalAccelerator) DeleteAcceleratorWithContext(arg0 context.Context, arg1 *globalaccelerator.DeleteAcceleratorInput) (*globalaccelerator.DeleteAcceleratorOutput, error) { m.ctrl.T.Helper() @@ -66,6 +81,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) DeleteAcceleratorWithContext(arg0, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DeleteAcceleratorWithContext), arg0, arg1) } +// DeleteListenerWithContext mocks base method. +func (m *MockGlobalAccelerator) DeleteListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.DeleteListenerInput) (*globalaccelerator.DeleteListenerOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteListenerWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.DeleteListenerOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteListenerWithContext indicates an expected call of DeleteListenerWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) DeleteListenerWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteListenerWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DeleteListenerWithContext), arg0, arg1) +} + // DescribeAcceleratorWithContext mocks base method. func (m *MockGlobalAccelerator) DescribeAcceleratorWithContext(arg0 context.Context, arg1 *globalaccelerator.DescribeAcceleratorInput) (*globalaccelerator.DescribeAcceleratorOutput, error) { m.ctrl.T.Helper() @@ -81,6 +111,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) DescribeAcceleratorWithContext(arg0 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DescribeAcceleratorWithContext), arg0, arg1) } +// DescribeListenerWithContext mocks base method. +func (m *MockGlobalAccelerator) DescribeListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.DescribeListenerInput) (*globalaccelerator.DescribeListenerOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DescribeListenerWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.DescribeListenerOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DescribeListenerWithContext indicates an expected call of DescribeListenerWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) DescribeListenerWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeListenerWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DescribeListenerWithContext), arg0, arg1) +} + // ListAcceleratorsAsList mocks base method. func (m *MockGlobalAccelerator) ListAcceleratorsAsList(arg0 context.Context, arg1 *globalaccelerator.ListAcceleratorsInput) ([]types.Accelerator, error) { m.ctrl.T.Helper() @@ -96,6 +141,36 @@ func (mr *MockGlobalAcceleratorMockRecorder) ListAcceleratorsAsList(arg0, arg1 i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAcceleratorsAsList", reflect.TypeOf((*MockGlobalAccelerator)(nil).ListAcceleratorsAsList), arg0, arg1) } +// ListListenersAsList mocks base method. +func (m *MockGlobalAccelerator) ListListenersAsList(arg0 context.Context, arg1 *globalaccelerator.ListListenersInput) ([]types.Listener, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListListenersAsList", arg0, arg1) + ret0, _ := ret[0].([]types.Listener) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListListenersAsList indicates an expected call of ListListenersAsList. +func (mr *MockGlobalAcceleratorMockRecorder) ListListenersAsList(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListListenersAsList", reflect.TypeOf((*MockGlobalAccelerator)(nil).ListListenersAsList), arg0, arg1) +} + +// ListListenersForAcceleratorWithContext mocks base method. +func (m *MockGlobalAccelerator) ListListenersForAcceleratorWithContext(arg0 context.Context, arg1 *globalaccelerator.ListListenersInput) (*globalaccelerator.ListListenersOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListListenersForAcceleratorWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.ListListenersOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListListenersForAcceleratorWithContext indicates an expected call of ListListenersForAcceleratorWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) ListListenersForAcceleratorWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListListenersForAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).ListListenersForAcceleratorWithContext), arg0, arg1) +} + // ListTagsForResourceWithContext mocks base method. func (m *MockGlobalAccelerator) ListTagsForResourceWithContext(arg0 context.Context, arg1 *globalaccelerator.ListTagsForResourceInput) (*globalaccelerator.ListTagsForResourceOutput, error) { m.ctrl.T.Helper() @@ -155,3 +230,18 @@ func (mr *MockGlobalAcceleratorMockRecorder) UpdateAcceleratorWithContext(arg0, mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).UpdateAcceleratorWithContext), arg0, arg1) } + +// UpdateListenerWithContext mocks base method. +func (m *MockGlobalAccelerator) UpdateListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.UpdateListenerInput) (*globalaccelerator.UpdateListenerOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateListenerWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.UpdateListenerOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateListenerWithContext indicates an expected call of UpdateListenerWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) UpdateListenerWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateListenerWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).UpdateListenerWithContext), arg0, arg1) +} diff --git a/pkg/deploy/aga/accelerator_manager.go b/pkg/deploy/aga/accelerator_manager.go index 13d96607a8..763826eb76 100644 --- a/pkg/deploy/aga/accelerator_manager.go +++ b/pkg/deploy/aga/accelerator_manager.go @@ -27,11 +27,12 @@ type AcceleratorManager interface { } // NewDefaultAcceleratorManager constructs new defaultAcceleratorManager. -func NewDefaultAcceleratorManager(gaService services.GlobalAccelerator, trackingProvider tracking.Provider, taggingManager TaggingManager, externalManagedTags []string, logger logr.Logger) *defaultAcceleratorManager { +func NewDefaultAcceleratorManager(gaService services.GlobalAccelerator, trackingProvider tracking.Provider, taggingManager TaggingManager, listenerManager ListenerManager, externalManagedTags []string, logger logr.Logger) *defaultAcceleratorManager { return &defaultAcceleratorManager{ gaService: gaService, trackingProvider: trackingProvider, taggingManager: taggingManager, + listenerManager: listenerManager, externalManagedTags: externalManagedTags, logger: logger, } @@ -44,6 +45,7 @@ type defaultAcceleratorManager struct { gaService services.GlobalAccelerator trackingProvider tracking.Provider taggingManager TaggingManager + listenerManager ListenerManager externalManagedTags []string logger logr.Logger } @@ -162,7 +164,29 @@ func (m *defaultAcceleratorManager) Delete(ctx context.Context, sdkAccelerator A } } - // Step 2: Delete the accelerator + // Step 2: Delete all listeners associated with this accelerator + // TODO: This will be enhanced to delete endpoint groups and endpoints + // before deleting listeners (when those features are implemented) + listeners, err := m.listListeners(ctx, acceleratorARN) + if err != nil { + var apiErr *agatypes.AcceleratorNotFoundException + if errors.As(err, &apiErr) { + m.logger.Info("Accelerator not found, assuming already deleted", "acceleratorARN", acceleratorARN) + return nil + } + return fmt.Errorf("failed to list listeners for accelerator: %w", err) + } + + for _, listener := range listeners { + listenerARN := awssdk.ToString(listener.ListenerArn) + m.logger.Info("Deleting listener for accelerator", "listenerARN", listenerARN, "acceleratorARN", acceleratorARN) + + if err := m.listenerManager.Delete(ctx, listenerARN); err != nil { + return fmt.Errorf("failed to delete listener %s: %w", listenerARN, err) + } + } + + // Step 3: Delete the accelerator deleteInput := &globalaccelerator.DeleteAcceleratorInput{ AcceleratorArn: aws.String(acceleratorARN), } @@ -176,6 +200,14 @@ func (m *defaultAcceleratorManager) Delete(ctx context.Context, sdkAccelerator A Message: "Accelerator is not fully disabled yet", } } + + // Check if accelerator was already deleted + var apiErr *agatypes.AcceleratorNotFoundException + if errors.As(err, &apiErr) { + m.logger.Info("Accelerator already deleted", "acceleratorARN", acceleratorARN) + return nil + } + return fmt.Errorf("failed to delete accelerator: %w", err) } @@ -249,6 +281,15 @@ func (m *defaultAcceleratorManager) getIdempotencyToken(resAccelerator *agamodel return resAccelerator.GetCRDUID() } +// listListeners lists all listeners for a given accelerator +func (m *defaultAcceleratorManager) listListeners(ctx context.Context, acceleratorARN string) ([]agatypes.Listener, error) { + listInput := &globalaccelerator.ListListenersInput{ + AcceleratorArn: aws.String(acceleratorARN), + } + + return m.gaService.ListListenersAsList(ctx, listInput) +} + func (m *defaultAcceleratorManager) buildAcceleratorStatus(accelerator *agatypes.Accelerator) agamodel.AcceleratorStatus { status := agamodel.AcceleratorStatus{ AcceleratorARN: *accelerator.AcceleratorArn, diff --git a/pkg/deploy/aga/listener_manager.go b/pkg/deploy/aga/listener_manager.go new file mode 100644 index 0000000000..9fb67c4b61 --- /dev/null +++ b/pkg/deploy/aga/listener_manager.go @@ -0,0 +1,247 @@ +package aga + +import ( + "context" + "fmt" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/pkg/errors" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +// ListenerManager is responsible for managing AWS Global Accelerator listeners. +type ListenerManager interface { + // Create creates a listener. + Create(ctx context.Context, resListener *agamodel.Listener) (agamodel.ListenerStatus, error) + + // Update updates a listener. + Update(ctx context.Context, resListener *agamodel.Listener, sdkListener *ListenerResource) (agamodel.ListenerStatus, error) + + // Delete deletes a listener. + Delete(ctx context.Context, listenerARN string) error +} + +// NewDefaultListenerManager constructs new defaultListenerManager. +func NewDefaultListenerManager(gaService services.GlobalAccelerator, logger logr.Logger) *defaultListenerManager { + return &defaultListenerManager{ + gaService: gaService, + logger: logger, + } +} + +var _ ListenerManager = &defaultListenerManager{} + +// defaultListenerManager is the default implementation for ListenerManager. +type defaultListenerManager struct { + gaService services.GlobalAccelerator + logger logr.Logger +} + +// convertPortRangesToSDK converts model port ranges to SDK port ranges +func convertPortRangesToSDK(modelPortRanges []agamodel.PortRange) []agatypes.PortRange { + sdkPortRanges := make([]agatypes.PortRange, 0, len(modelPortRanges)) + for _, pr := range modelPortRanges { + sdkPortRanges = append(sdkPortRanges, agatypes.PortRange{ + FromPort: aws.Int32(pr.FromPort), + ToPort: aws.Int32(pr.ToPort), + }) + } + return sdkPortRanges +} + +func (m *defaultListenerManager) buildSDKCreateListenerInput(_ context.Context, resListener *agamodel.Listener) (*globalaccelerator.CreateListenerInput, error) { + acceleratorARN, err := resListener.Spec.AcceleratorARN.Resolve(context.Background()) + if err != nil { + return nil, errors.Wrap(err, "failed to resolve accelerator ARN") + } + + // Convert port ranges to AWS SDK format + portRanges := convertPortRangesToSDK(resListener.Spec.PortRanges) + + // Build create input + createInput := &globalaccelerator.CreateListenerInput{ + AcceleratorArn: aws.String(acceleratorARN), + Protocol: agatypes.Protocol(resListener.Spec.Protocol), + PortRanges: portRanges, + } + + // Add client affinity if specified + if resListener.Spec.ClientAffinity != "" { + createInput.ClientAffinity = agatypes.ClientAffinity(resListener.Spec.ClientAffinity) + } + + return createInput, nil +} + +func (m *defaultListenerManager) Create(ctx context.Context, resListener *agamodel.Listener) (agamodel.ListenerStatus, error) { + // Build create input + createInput, err := m.buildSDKCreateListenerInput(ctx, resListener) + if err != nil { + return agamodel.ListenerStatus{}, err + } + + // Create listener + m.logger.Info("Creating listener", + "stackID", resListener.Stack().StackID(), + "resourceID", resListener.ID()) + createOutput, err := m.gaService.CreateListenerWithContext(ctx, createInput) + if err != nil { + return agamodel.ListenerStatus{}, fmt.Errorf("failed to create listener: %w", err) + } + + listener := createOutput.Listener + m.logger.Info("Successfully created listener", + "stackID", resListener.Stack().StackID(), + "resourceID", resListener.ID(), + "listenerARN", *listener.ListenerArn) + + return agamodel.ListenerStatus{ + ListenerARN: *listener.ListenerArn, + }, nil +} + +func (m *defaultListenerManager) buildSDKUpdateListenerInput(ctx context.Context, resListener *agamodel.Listener, sdkListener *ListenerResource) (*globalaccelerator.UpdateListenerInput, error) { + // Convert port ranges to AWS SDK format + portRanges := convertPortRangesToSDK(resListener.Spec.PortRanges) + + // Build update input + updateInput := &globalaccelerator.UpdateListenerInput{ + ListenerArn: sdkListener.Listener.ListenerArn, + Protocol: agatypes.Protocol(resListener.Spec.Protocol), + PortRanges: portRanges, + } + + // Add client affinity if specified + if resListener.Spec.ClientAffinity != "" { + updateInput.ClientAffinity = agatypes.ClientAffinity(resListener.Spec.ClientAffinity) + } + + return updateInput, nil +} + +func (m *defaultListenerManager) Update(ctx context.Context, resListener *agamodel.Listener, sdkListener *ListenerResource) (agamodel.ListenerStatus, error) { + // Check if the listener actually needs an update + if !m.isSDKListenerSettingsDrifted(resListener, sdkListener) { + m.logger.Info("No drift detected in listener settings, skipping update", + "stackID", resListener.Stack().StackID(), + "resourceID", resListener.ID(), + "listenerARN", *sdkListener.Listener.ListenerArn) + return agamodel.ListenerStatus{ + ListenerARN: *sdkListener.Listener.ListenerArn, + }, nil + } + + m.logger.Info("Drift detected in listener settings, updating", + "stackID", resListener.Stack().StackID(), + "resourceID", resListener.ID(), + "listenerARN", *sdkListener.Listener.ListenerArn) + + // Build update input + updateInput, err := m.buildSDKUpdateListenerInput(ctx, resListener, sdkListener) + if err != nil { + return agamodel.ListenerStatus{}, err + } + + // Update listener + updateOutput, err := m.gaService.UpdateListenerWithContext(ctx, updateInput) + if err != nil { + return agamodel.ListenerStatus{}, fmt.Errorf("failed to update listener: %w", err) + } + updatedListener := updateOutput.Listener + + m.logger.Info("Successfully updated listener", + "stackID", resListener.Stack().StackID(), + "resourceID", resListener.ID(), + "listenerARN", *updatedListener.ListenerArn) + + return agamodel.ListenerStatus{ + ListenerARN: *updatedListener.ListenerArn, + }, nil +} + +func (m *defaultListenerManager) Delete(ctx context.Context, listenerARN string) error { + // TODO: This will be enhanced to check for and delete endpoint groups + // before deleting the listener (when those features are implemented) + + m.logger.Info("Deleting listener", "listenerARN", listenerARN) + + // First check if the listener exists to avoid errors on already deleted resources + descInput := &globalaccelerator.DescribeListenerInput{ + ListenerArn: aws.String(listenerARN), + } + + _, err := m.gaService.DescribeListenerWithContext(ctx, descInput) + if err != nil { + // If the listener doesn't exist, consider it already deleted + var apiErr *agatypes.ListenerNotFoundException + if errors.As(err, &apiErr) { + m.logger.Info("Listener already deleted", "listenerARN", listenerARN) + return nil + } + return fmt.Errorf("failed to describe listener during deletion: %w", err) + } + + deleteInput := &globalaccelerator.DeleteListenerInput{ + ListenerArn: aws.String(listenerARN), + } + + if _, err := m.gaService.DeleteListenerWithContext(ctx, deleteInput); err != nil { + // Check if it's a not found error - the listener might have been deleted between our check and delete + var apiErr *agatypes.ListenerNotFoundException + if errors.As(err, &apiErr) { + m.logger.Info("Listener already deleted", "listenerARN", listenerARN) + return nil + } + return fmt.Errorf("failed to delete listener: %w", err) + } + + m.logger.Info("Successfully deleted listener", "listenerARN", listenerARN) + return nil +} + +// isSDKListenerSettingsDrifted checks if the listener configuration has drifted from the desired state +func (m *defaultListenerManager) isSDKListenerSettingsDrifted(resListener *agamodel.Listener, sdkListener *ListenerResource) bool { + // Check if protocol differs + if string(resListener.Spec.Protocol) != string(sdkListener.Listener.Protocol) { + return true + } + + // Check if client affinity differs + if string(resListener.Spec.ClientAffinity) != string(sdkListener.Listener.ClientAffinity) { + return true + } + + // Check if port ranges differ + if !m.arePortRangesEqual(resListener.Spec.PortRanges, sdkListener.Listener.PortRanges) { + return true + } + + return false +} + +// arePortRangesEqual compares port ranges from the resource model and SDK +func (m *defaultListenerManager) arePortRangesEqual(modelPortRanges []agamodel.PortRange, sdkPortRanges []agatypes.PortRange) bool { + if len(modelPortRanges) != len(sdkPortRanges) { + return false + } + + // Since port ranges are unordered, we need to compare them as sets + sdkPortMap := make(map[string]struct{}) + for _, portRange := range sdkPortRanges { + key := fmt.Sprintf("%d-%d", *portRange.FromPort, *portRange.ToPort) + sdkPortMap[key] = struct{}{} + } + + // Check if all model port ranges exist in the SDK port ranges + for _, portRange := range modelPortRanges { + key := fmt.Sprintf("%d-%d", portRange.FromPort, portRange.ToPort) + if _, exists := sdkPortMap[key]; !exists { + return false + } + } + + return true +} diff --git a/pkg/deploy/aga/listener_manager_mocks.go b/pkg/deploy/aga/listener_manager_mocks.go new file mode 100644 index 0000000000..b6c1d60f6d --- /dev/null +++ b/pkg/deploy/aga/listener_manager_mocks.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga (interfaces: ListenerManager) + +// Package aga is a generated GoMock package. +package aga + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + aga0 "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +// MockListenerManager is a mock of ListenerManager interface. +type MockListenerManager struct { + ctrl *gomock.Controller + recorder *MockListenerManagerMockRecorder +} + +// MockListenerManagerMockRecorder is the mock recorder for MockListenerManager. +type MockListenerManagerMockRecorder struct { + mock *MockListenerManager +} + +// NewMockListenerManager creates a new mock instance. +func NewMockListenerManager(ctrl *gomock.Controller) *MockListenerManager { + mock := &MockListenerManager{ctrl: ctrl} + mock.recorder = &MockListenerManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockListenerManager) EXPECT() *MockListenerManagerMockRecorder { + return m.recorder +} + +// Create mocks base method. +func (m *MockListenerManager) Create(arg0 context.Context, arg1 *aga0.Listener) (aga0.ListenerStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Create", arg0, arg1) + ret0, _ := ret[0].(aga0.ListenerStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockListenerManagerMockRecorder) Create(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockListenerManager)(nil).Create), arg0, arg1) +} + +// Delete mocks base method. +func (m *MockListenerManager) Delete(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockListenerManagerMockRecorder) Delete(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockListenerManager)(nil).Delete), arg0, arg1) +} + +// Update mocks base method. +func (m *MockListenerManager) Update(arg0 context.Context, arg1 *aga0.Listener, arg2 *ListenerResource) (aga0.ListenerStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", arg0, arg1, arg2) + ret0, _ := ret[0].(aga0.ListenerStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockListenerManagerMockRecorder) Update(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockListenerManager)(nil).Update), arg0, arg1, arg2) +} diff --git a/pkg/deploy/aga/listener_manager_test.go b/pkg/deploy/aga/listener_manager_test.go new file mode 100644 index 0000000000..a2c56ab178 --- /dev/null +++ b/pkg/deploy/aga/listener_manager_test.go @@ -0,0 +1,424 @@ +package aga + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +// ListenerResource is already defined in types.go, no need to redefine it here + +func Test_defaultListenerManager_buildSDKCreateListenerInput(t *testing.T) { + testAcceleratorARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd-abcd-1234-abcd-1234abcdefgh" + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + resListener *agamodel.Listener + want *globalaccelerator.CreateListenerInput + wantErr bool + }{ + { + name: "Standard TCP listener", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + AcceleratorARN: core.LiteralStringToken(testAcceleratorARN), + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + }, + want: &globalaccelerator.CreateListenerInput{ + AcceleratorArn: aws.String(testAcceleratorARN), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + }, + }, + wantErr: false, + }, + { + name: "UDP listener with client affinity", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-2"), + Spec: agamodel.ListenerSpec{ + AcceleratorARN: core.LiteralStringToken(testAcceleratorARN), + Protocol: agamodel.ProtocolUDP, + ClientAffinity: agamodel.ClientAffinitySourceIP, + PortRanges: []agamodel.PortRange{ + {FromPort: 10000, ToPort: 20000}, + }, + }, + }, + want: &globalaccelerator.CreateListenerInput{ + AcceleratorArn: aws.String(testAcceleratorARN), + Protocol: agatypes.ProtocolUdp, + ClientAffinity: agatypes.ClientAffinitySourceIp, + PortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(10000), ToPort: aws.Int32(20000)}, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create listener manager + m := &defaultListenerManager{ + gaService: nil, // Not needed for this test + logger: logr.Discard(), + } + + // Call the method being tested + got, err := m.buildSDKCreateListenerInput(context.Background(), tt.resListener) + + // Check if error status matches expected + if (err != nil) != tt.wantErr { + t.Errorf("buildSDKCreateListenerInput() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Check if the result matches expected + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_defaultListenerManager_buildSDKUpdateListenerInput(t *testing.T) { + testListenerARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd-abcd-1234-abcd-1234abcdefgh/listener/abcdef1234" + testAcceleratorARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd-abcd-1234-abcd-1234abcdefgh" + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + resListener *agamodel.Listener + sdkListener *ListenerResource + want *globalaccelerator.UpdateListenerInput + wantErr bool + }{ + { + name: "Standard TCP listener update", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + AcceleratorARN: core.LiteralStringToken(testAcceleratorARN), + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + ListenerArn: aws.String(testListenerARN), + }, + }, + want: &globalaccelerator.UpdateListenerInput{ + ListenerArn: aws.String(testListenerARN), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + }, + }, + wantErr: false, + }, + { + name: "UDP listener update with client affinity", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-2"), + Spec: agamodel.ListenerSpec{ + AcceleratorARN: core.LiteralStringToken(testAcceleratorARN), + Protocol: agamodel.ProtocolUDP, + ClientAffinity: agamodel.ClientAffinitySourceIP, + PortRanges: []agamodel.PortRange{ + {FromPort: 10000, ToPort: 20000}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + ListenerArn: aws.String(testListenerARN), + }, + }, + want: &globalaccelerator.UpdateListenerInput{ + ListenerArn: aws.String(testListenerARN), + Protocol: agatypes.ProtocolUdp, + ClientAffinity: agatypes.ClientAffinitySourceIp, + PortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(10000), ToPort: aws.Int32(20000)}, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create listener manager + m := &defaultListenerManager{ + gaService: nil, // Not needed for this test + logger: logr.Discard(), + } + + // Call the method being tested + got, err := m.buildSDKUpdateListenerInput(context.Background(), tt.resListener, tt.sdkListener) + + // Check if error status matches expected + if (err != nil) != tt.wantErr { + t.Errorf("buildSDKUpdateListenerInput() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Check if the result matches expected + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_defaultListenerManager_isSDKListenerSettingsDrifted(t *testing.T) { + tests := []struct { + name string + resListener *agamodel.Listener + sdkListener *ListenerResource + want bool + }{ + { + name: "No drift - exact match", + resListener: &agamodel.Listener{ + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + ClientAffinity: agamodel.ClientAffinityNone, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + ClientAffinity: agatypes.ClientAffinityNone, + PortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + }, + }, + }, + want: false, // No drift + }, + { + name: "Drift - different protocol", + resListener: &agamodel.Listener{ + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + ClientAffinity: agamodel.ClientAffinityNone, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolUdp, // Different protocol + ClientAffinity: agatypes.ClientAffinityNone, + PortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + }, + }, + }, + want: true, // Drift detected + }, + { + name: "Drift - different client affinity", + resListener: &agamodel.Listener{ + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + ClientAffinity: agamodel.ClientAffinitySourceIP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + ClientAffinity: agatypes.ClientAffinityNone, // Different client affinity + PortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + }, + }, + }, + want: true, // Drift detected + }, + { + name: "Drift - different port ranges", + resListener: &agamodel.Listener{ + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + ClientAffinity: agamodel.ClientAffinityNone, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + ClientAffinity: agatypes.ClientAffinityNone, + PortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + // Missing 443 port + }, + }, + }, + want: true, // Drift detected + }, + { + name: "No drift - same ports in different order", + resListener: &agamodel.Listener{ + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + ClientAffinity: agamodel.ClientAffinityNone, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + ClientAffinity: agatypes.ClientAffinityNone, + PortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + }, + }, + }, + want: false, // No drift - port orders don't matter + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create listener manager + m := &defaultListenerManager{ + gaService: nil, // Not needed for this test + logger: logr.Discard(), + } + + got := m.isSDKListenerSettingsDrifted(tt.resListener, tt.sdkListener) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_defaultListenerManager_arePortRangesEqual(t *testing.T) { + tests := []struct { + name string + modelPortRanges []agamodel.PortRange + sdkPortRanges []agatypes.PortRange + want bool + }{ + { + name: "Equal - exact match", + modelPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + sdkPortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + }, + want: true, + }, + { + name: "Equal - different order", + modelPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + sdkPortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + }, + want: true, + }, + { + name: "Not equal - different count", + modelPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + sdkPortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + }, + want: false, + }, + { + name: "Not equal - different range", + modelPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + sdkPortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + {FromPort: aws.Int32(8443), ToPort: aws.Int32(8443)}, + }, + want: false, + }, + { + name: "Equal - empty slices", + modelPortRanges: []agamodel.PortRange{}, + sdkPortRanges: []agatypes.PortRange{}, + want: true, + }, + { + name: "Not equal - one empty, one not", + modelPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + sdkPortRanges: []agatypes.PortRange{}, + want: false, + }, + { + name: "Equal - port ranges with ranges", + modelPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 100}, + {FromPort: 443, ToPort: 450}, + }, + sdkPortRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(100)}, + {FromPort: aws.Int32(443), ToPort: aws.Int32(450)}, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &defaultListenerManager{ + gaService: nil, + logger: logr.Discard(), + } + + got := m.arePortRangesEqual(tt.modelPortRanges, tt.sdkPortRanges) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/deploy/aga/listener_synthesizer.go b/pkg/deploy/aga/listener_synthesizer.go new file mode 100644 index 0000000000..df1c543f8d --- /dev/null +++ b/pkg/deploy/aga/listener_synthesizer.go @@ -0,0 +1,532 @@ +package aga + +import ( + "context" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/pkg/errors" + "k8s.io/apimachinery/pkg/util/sets" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" + "sort" +) + +// NewListenerSynthesizer constructs listenerSynthesizer +func NewListenerSynthesizer(gaClient services.GlobalAccelerator, listenerManager ListenerManager, + logger logr.Logger, stack core.Stack) *listenerSynthesizer { + return &listenerSynthesizer{ + gaClient: gaClient, + listenerManager: listenerManager, + logger: logger, + stack: stack, + } +} + +// listenerSynthesizer is responsible for synthesize Listener resources for a stack. +type listenerSynthesizer struct { + gaClient services.GlobalAccelerator + listenerManager ListenerManager + logger logr.Logger + stack core.Stack +} + +func (s *listenerSynthesizer) Synthesize(ctx context.Context) error { + var resListeners []*agamodel.Listener + s.stack.ListResources(&resListeners) + + // If no listeners are defined in the model, there's nothing to do + if len(resListeners) == 0 { + return nil + } + + // Get the accelerator resource from the stack + var resAccelerators []*agamodel.Accelerator + if err := s.stack.ListResources(&resAccelerators); err != nil { + return err + } + if len(resAccelerators) == 0 { + return errors.New("no accelerator resource found in stack") + } + accelerator := resAccelerators[0] + + // Get the accelerator ARN from the spec token + acceleratorARN, err := accelerator.AcceleratorARN().Resolve(ctx) + if err != nil { + return errors.Wrapf(err, "unable to resolve accelerator ARN for stack %s", s.stack.StackID()) + } + + // Process all listeners for this accelerator + if err := s.synthesizeListenersOnAccelerator(ctx, acceleratorARN, resListeners); err != nil { + return err + } + + return nil +} + +func (s *listenerSynthesizer) PostSynthesize(ctx context.Context) error { + // PostSynthesize is called after all resources in the stack have been synthesized. + // This is a good place to handle any cleanup or verification tasks. + // + // For listeners, we could use this to verify that all expected listeners + // are properly created and configured, but this is already handled in the + // main Synthesize method. + // + // Note: To minimize traffic disruption during reconciliation, we've already: + // 1. Deleted unneeded/conflicting listeners to free up capacity and avoid conflicts + // 2. Updated existing listeners to maintain their ARNs and associated resources + // 3. Created new listeners as needed + // + // This order ensures that we maintain maximum stability across reconciliations + // while also avoiding listener limit errors. + + return nil +} + +func (s *listenerSynthesizer) synthesizeListenersOnAccelerator(ctx context.Context, accARN string, resListeners []*agamodel.Listener) error { + // Get existing listeners for this accelerator + sdkListeners, err := s.findSDKListenersOnAccelerator(ctx, accARN) + if err != nil { + return err + } + + // Match resource listeners with existing SDK listeners + // - matchedResAndSDKListeners: pairs of resource and SDK listeners that will be updated + // - unmatchedResListeners: resource listeners that don't match any SDK listeners and will be created + // - unmatchedSDKListeners: SDK listeners that don't match any resource listeners and will be deleted + matchedResAndSDKListeners, unmatchedResListeners, unmatchedSDKListeners := s.matchResAndSDKListeners(resListeners, sdkListeners) + + // Improved operation order to minimize traffic disruption: + // 1. Delete only conflicting listeners (that would block updates) + // 2. Update matched listeners + // 3. Delete unneeded (non-conflicting) listeners + // 4. Create new listeners + + // STEP 1: Find SDK listeners that have port conflicts with planned updates + var conflictingListeners []*ListenerResource + var nonConflictingListeners []*ListenerResource + + // Track which listeners have port conflicts with our updates + conflictMap := make(map[string][]*ListenerResource) + + // For each update we're planning to do... + for _, pair := range matchedResAndSDKListeners { + var conflicts []*ListenerResource + + // Check against all unmatched SDK listeners for conflicts + for _, sdkListener := range unmatchedSDKListeners { + if s.hasPortRangeConflict(pair.resListener, sdkListener) { + conflicts = append(conflicts, sdkListener) + } + } + + // If there are conflicts, add them to our conflict map + if len(conflicts) > 0 { + conflictMap[pair.resListener.ID()] = conflicts + } + } + + // Build list of conflicting and non-conflicting listeners + listenerIsConflicting := make(map[string]bool) + + // Add all listeners with port conflicts to the conflicting list + for _, conflicts := range conflictMap { + for _, listener := range conflicts { + arn := *listener.Listener.ListenerArn + if !listenerIsConflicting[arn] { + conflictingListeners = append(conflictingListeners, listener) + listenerIsConflicting[arn] = true + } + } + } + + // Sort remaining unmatched listeners into non-conflicting + for _, sdkListener := range unmatchedSDKListeners { + arn := *sdkListener.Listener.ListenerArn + if !listenerIsConflicting[arn] { + nonConflictingListeners = append(nonConflictingListeners, sdkListener) + } + } + + // STEP 2: Execute operations in correct order + + // First, delete ONLY conflicting listeners (those that would block updates) + // TODO: When we implement endpoint groups, for a more comprehensive solution, we might also want to add the ability to + // migrate endpoint groups from these conflicting listeners to non-conflicting ones as much as possible. + for _, listener := range conflictingListeners { + s.logger.Info("Deleting conflicting listener to allow updates", + "listenerARN", *listener.Listener.ListenerArn, + "protocol", listener.Listener.Protocol) + + if err := s.listenerManager.Delete(ctx, *listener.Listener.ListenerArn); err != nil { + s.logger.Error(err, "Failed to delete conflicting listener", + "listenerARN", *listener.Listener.ListenerArn) + return err + } + } + + // Next, update existing matched listeners (now conflict-free) + for _, pair := range matchedResAndSDKListeners { + s.logger.Info("Updating existing listener", + "listenerARN", *pair.sdkListener.Listener.ListenerArn, + "protocol", pair.resListener.Spec.Protocol, + "portRanges", s.portRangesToString(pair.resListener.Spec.PortRanges)) + + listenerStatus, err := s.listenerManager.Update(ctx, pair.resListener, pair.sdkListener) + if err != nil { + s.logger.Error(err, "Failed to update listener", + "listenerARN", *pair.sdkListener.Listener.ListenerArn) + return err + } + pair.resListener.SetStatus(listenerStatus) + } + + // Then, delete non-conflicting but unneeded listeners to free up the space + for _, listener := range nonConflictingListeners { + s.logger.Info("Deleting unneeded listener", + "listenerARN", *listener.Listener.ListenerArn, + "protocol", listener.Listener.Protocol) + + if err := s.listenerManager.Delete(ctx, *listener.Listener.ListenerArn); err != nil { + s.logger.Error(err, "Failed to delete unneeded listener", + "listenerARN", *listener.Listener.ListenerArn) + return err + } + } + + // Finally, create any new listeners needed + for _, resListener := range unmatchedResListeners { + s.logger.Info("Creating new listener", + "protocol", resListener.Spec.Protocol, + "portRanges", s.portRangesToString(resListener.Spec.PortRanges)) + + listenerStatus, err := s.listenerManager.Create(ctx, resListener) + if err != nil { + // If we hit a listener limit error, log it clearly + var apiErr *agatypes.LimitExceededException + if errors.As(err, &apiErr) { + s.logger.Error(err, + "Reached listener limit on accelerator. Tried to create a listener after deleting unmatched ones.") + } + return err + } + resListener.SetStatus(listenerStatus) + } + + return nil +} + +// findSDKListenersOnAccelerator returns all listeners for the given accelerator +func (s *listenerSynthesizer) findSDKListenersOnAccelerator(ctx context.Context, accARN string) ([]*ListenerResource, error) { + // List listeners for the accelerator + listInput := &globalaccelerator.ListListenersInput{ + AcceleratorArn: awssdk.String(accARN), + } + sdkListeners, err := s.gaClient.ListListenersAsList(ctx, listInput) + if err != nil { + var apiErr *agatypes.AcceleratorNotFoundException + if errors.As(err, &apiErr) { + s.logger.Info("Accelerator not found in AWS, skipping listener listing", + "acceleratorARN", accARN) + return nil, nil + } + return nil, errors.Wrapf(err, "failed to list listeners for accelerator %s", accARN) + } + + var listeners []*ListenerResource + for _, listener := range sdkListeners { + // Clone the listener as the range variable address is reused + listenerCopy := listener + listeners = append(listeners, &ListenerResource{ + Listener: &listenerCopy, + }) + } + return listeners, nil +} + +// resAndSDKListenerPair holds a matched pair of resource and SDK listener +type resAndSDKListenerPair struct { + resListener *agamodel.Listener + sdkListener *ListenerResource +} + +// matchResAndSDKListeners matches resource listeners with SDK listeners using a similarity-based algorithm +// that finds the best match based on protocol and port range overlap. +// Returns three groups: +// - matchedResAndSDKListeners: pairs of resource and SDK listeners that will be updated +// - unmatchedResListeners: resource listeners that don't match any SDK listeners and will be created +// - unmatchedSDKListeners: SDK listeners that don't match any resource listeners and will be deleted +func (s *listenerSynthesizer) matchResAndSDKListeners(resListeners []*agamodel.Listener, sdkListeners []*ListenerResource) ( + []resAndSDKListenerPair, []*agamodel.Listener, []*ListenerResource) { + + // First, try to match by exact protocol and port ranges + exactMatches, remainingResListeners, remainingSDKListeners := s.findExactMatches(resListeners, sdkListeners) + + // For remaining listeners, use similarity-based matching + similarityMatches, unmatchedResListeners, unmatchedSDKListeners := s.findSimilarityMatches( + remainingResListeners, remainingSDKListeners) + + // Combine exact and similarity matches + matchedPairs := append(exactMatches, similarityMatches...) + + s.logger.V(1).Info("Matched listeners", + "exactMatches", len(exactMatches), + "similarityMatches", len(similarityMatches), + "unmatchedResListeners", len(unmatchedResListeners), + "unmatchedSDKListeners", len(unmatchedSDKListeners)) + + return matchedPairs, unmatchedResListeners, unmatchedSDKListeners +} + +// findExactMatches matches listeners that have identical protocol and port ranges +func (s *listenerSynthesizer) findExactMatches(resListeners []*agamodel.Listener, sdkListeners []*ListenerResource) ( + []resAndSDKListenerPair, []*agamodel.Listener, []*ListenerResource) { + + var matchedPairs []resAndSDKListenerPair + var unmatchedResListeners []*agamodel.Listener + var unmatchedSDKListeners []*ListenerResource + + // Create maps with protocol+portRanges as key + resListenerByKey := make(map[string]*agamodel.Listener) + sdkListenerByKey := make(map[string]*ListenerResource) + + // Map resource listeners + for _, resListener := range resListeners { + key := s.generateResListenerKey(resListener) + resListenerByKey[key] = resListener + } + + // Map SDK listeners + for _, sdkListener := range sdkListeners { + key := s.generateSDKListenerKey(sdkListener) + sdkListenerByKey[key] = sdkListener + } + + // Find matched and unmatched listeners + resListenerKeys := sets.StringKeySet(resListenerByKey) + sdkListenerKeys := sets.StringKeySet(sdkListenerByKey) + + // Find matches + for _, key := range resListenerKeys.Intersection(sdkListenerKeys).List() { + matchedPairs = append(matchedPairs, resAndSDKListenerPair{ + resListener: resListenerByKey[key], + sdkListener: sdkListenerByKey[key], + }) + } + + // Find unmatched resource listeners + for _, key := range resListenerKeys.Difference(sdkListenerKeys).List() { + unmatchedResListeners = append(unmatchedResListeners, resListenerByKey[key]) + } + + // Find unmatched SDK listeners + for _, key := range sdkListenerKeys.Difference(resListenerKeys).List() { + unmatchedSDKListeners = append(unmatchedSDKListeners, sdkListenerByKey[key]) + } + + return matchedPairs, unmatchedResListeners, unmatchedSDKListeners +} + +// listenerPairScore holds a potential match with its similarity score +type listenerPairScore struct { + resListener *agamodel.Listener + sdkListener *ListenerResource + score int +} + +// findSimilarityMatches matches remaining listeners based on similarity score +func (s *listenerSynthesizer) findSimilarityMatches(resListeners []*agamodel.Listener, sdkListeners []*ListenerResource) ( + []resAndSDKListenerPair, []*agamodel.Listener, []*ListenerResource) { + + // Define minimum similarity threshold - below this, we don't consider it a match + const minSimilarityThreshold = 15 // 15% + + var matchedPairs []resAndSDKListenerPair + + // Return early if either list is empty + if len(resListeners) == 0 || len(sdkListeners) == 0 { + return matchedPairs, resListeners, sdkListeners + } + + // Calculate similarity scores for all possible pairings + var scoredPairs []listenerPairScore + for _, resListener := range resListeners { + for _, sdkListener := range sdkListeners { + // Calculate similarity score for this pair + score := s.calculateSimilarityScore(resListener, sdkListener) + + // Only consider pairs with meaningful similarity (score >= minSimilarityThreshold) + if score >= minSimilarityThreshold { + scoredPairs = append(scoredPairs, listenerPairScore{ + resListener: resListener, + sdkListener: sdkListener, + score: score, + }) + } + } + } + + // Sort pairs by score (highest first) + sort.Slice(scoredPairs, func(i, j int) bool { + return scoredPairs[i].score > scoredPairs[j].score + }) + + // Track which listeners have been matched + matchedResListenerIDs := sets.NewString() + matchedSDKListenerARNs := sets.NewString() + + // Match greedily by highest score first + for _, pair := range scoredPairs { + resID := pair.resListener.ID() + sdkARN := awssdk.ToString(pair.sdkListener.Listener.ListenerArn) + + // Skip if either listener is already matched + if matchedResListenerIDs.Has(resID) || matchedSDKListenerARNs.Has(sdkARN) { + continue + } + + // Add this pair to matches + matchedPairs = append(matchedPairs, resAndSDKListenerPair{ + resListener: pair.resListener, + sdkListener: pair.sdkListener, + }) + + // Mark as matched + matchedResListenerIDs.Insert(resID) + matchedSDKListenerARNs.Insert(sdkARN) + + s.logger.V(1).Info("Matched listeners by similarity", + "resListenerID", resID, + "sdkListenerARN", sdkARN, + "similarityScore", pair.score) + } + + // Collect unmatched resource listeners + var unmatchedResListeners []*agamodel.Listener + for _, resListener := range resListeners { + if !matchedResListenerIDs.Has(resListener.ID()) { + unmatchedResListeners = append(unmatchedResListeners, resListener) + } + } + + // Collect unmatched SDK listeners + var unmatchedSDKListeners []*ListenerResource + for _, sdkListener := range sdkListeners { + if !matchedSDKListenerARNs.Has(awssdk.ToString(sdkListener.Listener.ListenerArn)) { + unmatchedSDKListeners = append(unmatchedSDKListeners, sdkListener) + } + } + + return matchedPairs, unmatchedResListeners, unmatchedSDKListeners +} + +// calculateSimilarityScore calculates how similar two listeners are +// Higher scores indicate better matches +func (s *listenerSynthesizer) calculateSimilarityScore(resListener *agamodel.Listener, sdkListener *ListenerResource) int { + // Start with base score + score := 0 + + // Protocol match is highly valuable - give significant bonus + if string(resListener.Spec.Protocol) == string(sdkListener.Listener.Protocol) { + score += 40 // Strong bonus for protocol match + } + + // Calculate port overlap + resPortSet := s.makeResPortSet(resListener.Spec.PortRanges) + sdkPortSet := s.makeSDKPortSet(sdkListener.Listener.PortRanges) + + // Find common ports (intersection) + commonPorts := 0 + for port := range resPortSet { + if sdkPortSet[port] { + commonPorts++ + } + } + + // Calculate total unique ports (union) + totalPorts := len(resPortSet) + len(sdkPortSet) - commonPorts + + // Jaccard similarity: intersection / union (as a percentage) + if totalPorts > 0 { + score += (commonPorts * 100) / totalPorts + } + + // If client affinity matches and is specified, add bonus points + resClientAffinity := string(resListener.Spec.ClientAffinity) + sdkClientAffinity := string(sdkListener.Listener.ClientAffinity) + + // Only add bonus if both have affinity set and they match + if resClientAffinity != "" && sdkClientAffinity != "" && resClientAffinity == sdkClientAffinity { + score += 10 + } + + return score +} + +// makeResPortSet converts resource model port ranges to a set of individual ports. +func (s *listenerSynthesizer) makeResPortSet(portRanges []agamodel.PortRange) map[int32]bool { + portSet := make(map[int32]bool) + ResPortRangesToSet(portRanges, portSet) + return portSet +} + +// makeSDKPortSet converts SDK port ranges to a set of individual ports. +func (s *listenerSynthesizer) makeSDKPortSet(portRanges []agatypes.PortRange) map[int32]bool { + portSet := make(map[int32]bool) + SDKPortRangesToSet(portRanges, portSet) + return portSet +} + +// generateResListenerKey creates a unique key for a resource listener based on protocol and port ranges +func (s *listenerSynthesizer) generateResListenerKey(listener *agamodel.Listener) string { + protocol := string(listener.Spec.Protocol) + + // Sort port ranges before generating key to ensure consistent matching + sortedPortRanges := make([]agamodel.PortRange, len(listener.Spec.PortRanges)) + copy(sortedPortRanges, listener.Spec.PortRanges) + SortModelPortRanges(sortedPortRanges) + + portRanges := ResPortRangesToString(sortedPortRanges) + return protocol + ":" + portRanges +} + +// generateSDKListenerKey creates a unique key for an SDK listener based on protocol and port ranges +func (s *listenerSynthesizer) generateSDKListenerKey(listener *ListenerResource) string { + protocol := string(listener.Listener.Protocol) + + // Sort port ranges before generating key to ensure consistent matching + sortedPortRanges := make([]agatypes.PortRange, len(listener.Listener.PortRanges)) + copy(sortedPortRanges, listener.Listener.PortRanges) + SortSDKPortRanges(sortedPortRanges) + + portRanges := SDKPortRangesToString(sortedPortRanges) + return protocol + ":" + portRanges +} + +// hasPortRangeConflict checks if there's any overlap between port ranges of two listeners +func (s *listenerSynthesizer) hasPortRangeConflict(resListener *agamodel.Listener, sdkListener *ListenerResource) bool { + // Different protocols can use the same ports without conflict + if string(resListener.Spec.Protocol) != string(sdkListener.Listener.Protocol) { + return false + } + + // Build port sets for both listeners + resPortSet := s.makeResPortSet(resListener.Spec.PortRanges) + sdkPortSet := s.makeSDKPortSet(sdkListener.Listener.PortRanges) + + // Check for any port overlap + for port := range resPortSet { + if sdkPortSet[port] { + return true // Found an overlapping port + } + } + + return false +} + +// portRangesToString serializes port ranges to a string - deprecated, use ResPortRangesToString instead +func (s *listenerSynthesizer) portRangesToString(portRanges []agamodel.PortRange) string { + return ResPortRangesToString(portRanges) +} diff --git a/pkg/deploy/aga/listener_synthesizer_test.go b/pkg/deploy/aga/listener_synthesizer_test.go new file mode 100644 index 0000000000..7c73f25430 --- /dev/null +++ b/pkg/deploy/aga/listener_synthesizer_test.go @@ -0,0 +1,1701 @@ +package aga + +import ( + "sort" + "testing" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +func Test_listenerSynthesizer_hasPortRangeConflict(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + resListener *agamodel.Listener + sdkListener *ListenerResource + want bool + }{ + { + name: "different protocols - no conflict", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + want: false, + }, + { + name: "same protocol, non-overlapping ports - no conflict", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + }, + }, + }, + want: false, + }, + { + name: "same protocol, same ports - conflict", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + want: true, + }, + { + name: "same protocol, overlapping port ranges - conflict", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 100}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(90), ToPort: awssdk.Int32(110)}, + }, + }, + }, + want: true, + }, + { + name: "same protocol, multiple port ranges with one overlap - conflict", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolUDP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + {FromPort: awssdk.Int32(8080), ToPort: awssdk.Int32(8080)}, + }, + }, + }, + want: true, + }, + { + name: "same protocol, adjacent port ranges - no conflict", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(91), ToPort: awssdk.Int32(100)}, + }, + }, + }, + want: false, + }, + { + name: "same protocol, one port at edge of range - conflict", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(90), ToPort: awssdk.Int32(100)}, + }, + }, + }, + want: true, + }, + { + name: "same protocol, complex multiple ranges with overlap - conflict", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + {FromPort: 8000, ToPort: 8010}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(22), ToPort: awssdk.Int32(22)}, + {FromPort: awssdk.Int32(5000), ToPort: awssdk.Int32(5010)}, + {FromPort: awssdk.Int32(8005), ToPort: awssdk.Int32(8015)}, + }, + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &listenerSynthesizer{ + logger: logr.Discard(), + } + got := s.hasPortRangeConflict(tt.resListener, tt.sdkListener) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_listenerSynthesizer_generateResListenerKey(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + listener *agamodel.Listener + want string + }{ + { + name: "TCP listener with single port range", + listener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + want: "TCP:80-80", + }, + { + name: "UDP listener with multiple port ranges - ordered", + listener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolUDP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + }, + want: "UDP:80-80,443-443", + }, + { + name: "TCP listener with multiple port ranges - unordered", + listener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + {FromPort: 80, ToPort: 80}, + }, + }, + }, + want: "TCP:80-80,443-443", // Should be sorted + }, + { + name: "UDP listener with complex port ranges - unordered", + listener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolUDP, + PortRanges: []agamodel.PortRange{ + {FromPort: 8000, ToPort: 8100}, + {FromPort: 443, ToPort: 443}, + {FromPort: 80, ToPort: 80}, + }, + }, + }, + want: "UDP:80-80,443-443,8000-8100", // Should be sorted by FromPort + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &listenerSynthesizer{ + logger: logr.Discard(), + } + got := s.generateResListenerKey(tt.listener) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_listenerSynthesizer_calculateSimilarityScore(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + resListener *agamodel.Listener + sdkListener *ListenerResource + want int + }{ + { + name: "exact match - protocol, ports, and client affinity", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + ClientAffinity: "SOURCE_IP", + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + }, + ClientAffinity: agatypes.ClientAffinitySourceIp, + }, + }, + want: 150, // 40 (protocol) + 100 (full port overlap) + 10 (client affinity) + }, + { + name: "protocol match, complete port overlap, no client affinity", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + ClientAffinity: agatypes.ClientAffinityNone, + }, + }, + want: 140, // 40 (protocol) + 100 (full port overlap) + }, + { + name: "protocol match, no port overlap", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + }, + }, + }, + want: 40, // 40 (protocol) + 0 (no port overlap) + }, + { + name: "protocol mismatch, partial port overlap", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + {FromPort: awssdk.Int32(8080), ToPort: awssdk.Int32(8080)}, + }, + }, + }, + want: 33, // 0 (protocol mismatch) + 33 (1 common port out of 3 total unique ports) + }, + { + name: "protocol match, partial port overlap with ranges", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(85), ToPort: awssdk.Int32(95)}, + }, + }, + }, + want: 77, // 40 (protocol) + 37 (port overlap) + }, + { + name: "protocol mismatch, no port overlap, client affinity match", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + ClientAffinity: "SOURCE_IP", + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + }, + ClientAffinity: agatypes.ClientAffinitySourceIp, + }, + }, + want: 10, // 0 (protocol mismatch) + 0 (no port overlap) + 10 (client affinity match) + }, + { + name: "protocol match, complete port overlap, client affinity mismatch", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + ClientAffinity: "SOURCE_IP", + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + ClientAffinity: agatypes.ClientAffinityNone, + }, + }, + want: 140, // 40 (protocol) + 100 (complete port overlap) + 0 (client affinity mismatch) + }, + { + name: "complex case - protocol match, multiple port ranges with partial overlap", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + {FromPort: 8000, ToPort: 8010}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + {FromPort: awssdk.Int32(8005), ToPort: awssdk.Int32(8015)}, + {FromPort: awssdk.Int32(9000), ToPort: awssdk.Int32(9010)}, + }, + }, + }, + want: 64, // 40 (protocol) + 24 (partial port overlap) + }, + { + name: "empty port ranges", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{}, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{}, + }, + }, + want: 40, // 40 (protocol) + 0 (no ports) + }, + { + name: "large port ranges with partial overlap", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 1000, ToPort: 2000}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(1500), ToPort: awssdk.Int32(2500)}, + }, + }, + }, + want: 73, // 40 (protocol) + 33 (port overlap) + }, + { + name: "nil and empty client affinity - no match bonus", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 80}}, + ClientAffinity: "", // Empty + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{{FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}}, + // ClientAffinity is nil or not set + }, + }, + want: 140, // 40 (protocol) + 100 (full port overlap) + 0 (no client affinity bonus) + }, + { + name: "protocol case sensitivity test (should still match)", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, // Upper case + PortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 80}}, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolTcp, // Title case + PortRanges: []agatypes.PortRange{{FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}}, + }, + }, + want: 140, // 40 (protocol) + 100 (full port overlap) + }, + { + name: "different port ranges but same total ports", + resListener: &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolUDP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 85}, + }, + }, + }, + sdkListener: &ListenerResource{ + Listener: &agatypes.Listener{ + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + {FromPort: awssdk.Int32(81), ToPort: awssdk.Int32(85)}, + }, + }, + }, + want: 140, // 40 (protocol) + 100 (full port overlap - different ranges but same ports) + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &listenerSynthesizer{ + logger: logr.Discard(), + } + got := s.calculateSimilarityScore(tt.resListener, tt.sdkListener) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_listenerSynthesizer_findExactMatches(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + resListeners []*agamodel.Listener + sdkListeners []*ListenerResource + wantMatchedPairs []struct { + resID string + sdkARN string + } + wantUnmatchedResIDs []string + wantUnmatchedSDKARNs []string + }{ + { + name: "empty lists", + resListeners: []*agamodel.Listener{}, + sdkListeners: []*ListenerResource{}, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{}, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "one exact match", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-80", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "one exact match among multiple listeners", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-443"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-53"), + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(53), ToPort: awssdk.Int32(53)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-80", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80", + }, + }, + wantUnmatchedResIDs: []string{"tcp-443"}, + wantUnmatchedSDKARNs: []string{"arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-53"}, + }, + { + name: "multiple exact matches", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "udp-53"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolUDP, + PortRanges: []agamodel.PortRange{ + {FromPort: 53, ToPort: 53}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-53"), + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(53), ToPort: awssdk.Int32(53)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-80", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80", + }, + { + resID: "udp-53", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-53", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "exact match with different port range ordering", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-multi-port"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, // Note the order - 443 first, then 80 + {FromPort: 80, ToPort: 80}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-multi"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, // Different order - 80 first, then 443 + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-multi-port", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-multi", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "no matches", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-80"), + Protocol: agatypes.ProtocolUdp, // Different protocol + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{}, + wantUnmatchedResIDs: []string{"tcp-80"}, + wantUnmatchedSDKARNs: []string{"arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-80"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &listenerSynthesizer{ + logger: logr.Discard(), + } + + // Run the function + matchedPairs, unmatchedResListeners, unmatchedSDKListeners := s.findExactMatches(tt.resListeners, tt.sdkListeners) + + // Collect the actual pairs and IDs for verification + var actualMatchedPairs []struct { + resID string + sdkARN string + } + + for _, pair := range matchedPairs { + actualMatchedPairs = append(actualMatchedPairs, struct { + resID string + sdkARN string + }{ + resID: pair.resListener.ID(), + sdkARN: awssdk.ToString(pair.sdkListener.Listener.ListenerArn), + }) + } + + var actualUnmatchedResIDs []string + for _, listener := range unmatchedResListeners { + actualUnmatchedResIDs = append(actualUnmatchedResIDs, listener.ID()) + } + + var actualUnmatchedSDKARNs []string + for _, listener := range unmatchedSDKListeners { + actualUnmatchedSDKARNs = append(actualUnmatchedSDKARNs, awssdk.ToString(listener.Listener.ListenerArn)) + } + + // Sort all slices to ensure consistent comparison + sort.Slice(actualMatchedPairs, func(i, j int) bool { + if actualMatchedPairs[i].resID != actualMatchedPairs[j].resID { + return actualMatchedPairs[i].resID < actualMatchedPairs[j].resID + } + return actualMatchedPairs[i].sdkARN < actualMatchedPairs[j].sdkARN + }) + + sort.Slice(tt.wantMatchedPairs, func(i, j int) bool { + if tt.wantMatchedPairs[i].resID != tt.wantMatchedPairs[j].resID { + return tt.wantMatchedPairs[i].resID < tt.wantMatchedPairs[j].resID + } + return tt.wantMatchedPairs[i].sdkARN < tt.wantMatchedPairs[j].sdkARN + }) + + sort.Strings(actualUnmatchedResIDs) + sort.Strings(tt.wantUnmatchedResIDs) + sort.Strings(actualUnmatchedSDKARNs) + sort.Strings(tt.wantUnmatchedSDKARNs) + + // Verify matched pairs + assert.Equal(t, len(tt.wantMatchedPairs), len(actualMatchedPairs), "matched pairs count") + for i := range tt.wantMatchedPairs { + if i < len(actualMatchedPairs) { + assert.Equal(t, tt.wantMatchedPairs[i].resID, actualMatchedPairs[i].resID, "matched pair resID at index %d", i) + assert.Equal(t, tt.wantMatchedPairs[i].sdkARN, actualMatchedPairs[i].sdkARN, "matched pair sdkARN at index %d", i) + } + } + + // Handle nil vs empty slices + if len(actualUnmatchedResIDs) == 0 && len(tt.wantUnmatchedResIDs) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched resource listeners + assert.ElementsMatch(t, tt.wantUnmatchedResIDs, actualUnmatchedResIDs, "unmatched resource listeners") + } + + if len(actualUnmatchedSDKARNs) == 0 && len(tt.wantUnmatchedSDKARNs) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched SDK listeners + assert.ElementsMatch(t, tt.wantUnmatchedSDKARNs, actualUnmatchedSDKARNs, "unmatched SDK listeners") + } + }) + } +} + +func Test_listenerSynthesizer_findSimilarityMatches(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + resListeners []*agamodel.Listener + sdkListeners []*ListenerResource + wantMatchedPairs []struct { + resID string + sdkARN string + } + wantUnmatchedResIDs []string + wantUnmatchedSDKARNs []string + }{ + { + name: "empty lists", + resListeners: []*agamodel.Listener{}, + sdkListeners: []*ListenerResource{}, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{}, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "empty resource listeners", + resListeners: []*agamodel.Listener{}, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list123"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{}, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{"arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list123"}, + }, + { + name: "empty sdk listeners", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{}, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{}, + wantUnmatchedResIDs: []string{"listener-1"}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "one exact similarity match", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "listener-1"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list123"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "listener-1", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list123", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "multiple listeners with some matches", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-443"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "udp-53"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolUDP, + PortRanges: []agamodel.PortRange{ + {FromPort: 53, ToPort: 53}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-8080"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(8080), ToPort: awssdk.Int32(8080)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-80", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80", + }, + { + resID: "tcp-443", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-8080", + }, + }, + wantUnmatchedResIDs: []string{"udp-53"}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "complex case with partial similarity matches - greedy algorithm test", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80-100"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 100}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-443"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-90-110"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(90), ToPort: awssdk.Int32(110)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-440-450"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(440), ToPort: awssdk.Int32(450)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-80-100", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-90-110", + }, + { + resID: "tcp-443", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-440-450", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + // The higher similarity will be between tcp-80-100 and tcp-90-110 due to more overlapping ports + // This verifies the greedy algorithm is matching highest scores first + }, + { + name: "no matches below threshold", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-80"), + Protocol: agatypes.ProtocolUdp, // Different protocol, similarity will be low + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(8080), ToPort: awssdk.Int32(8080)}, // Different port too + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{}, + wantUnmatchedResIDs: []string{"tcp-80"}, + wantUnmatchedSDKARNs: []string{"arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-80"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &listenerSynthesizer{ + logger: logr.Discard(), + } + + // Run the function + matchedPairs, unmatchedResListeners, unmatchedSDKListeners := s.findSimilarityMatches(tt.resListeners, tt.sdkListeners) + + // Collect the actual pairs and IDs for verification + var actualMatchedPairs []struct { + resID string + sdkARN string + } + + for _, pair := range matchedPairs { + actualMatchedPairs = append(actualMatchedPairs, struct { + resID string + sdkARN string + }{ + resID: pair.resListener.ID(), + sdkARN: awssdk.ToString(pair.sdkListener.Listener.ListenerArn), + }) + } + + var actualUnmatchedResIDs []string + for _, listener := range unmatchedResListeners { + actualUnmatchedResIDs = append(actualUnmatchedResIDs, listener.ID()) + } + + var actualUnmatchedSDKARNs []string + for _, listener := range unmatchedSDKListeners { + actualUnmatchedSDKARNs = append(actualUnmatchedSDKARNs, awssdk.ToString(listener.Listener.ListenerArn)) + } + + // Sort all slices to ensure consistent comparison + sort.Slice(actualMatchedPairs, func(i, j int) bool { + if actualMatchedPairs[i].resID != actualMatchedPairs[j].resID { + return actualMatchedPairs[i].resID < actualMatchedPairs[j].resID + } + return actualMatchedPairs[i].sdkARN < actualMatchedPairs[j].sdkARN + }) + + sort.Slice(tt.wantMatchedPairs, func(i, j int) bool { + if tt.wantMatchedPairs[i].resID != tt.wantMatchedPairs[j].resID { + return tt.wantMatchedPairs[i].resID < tt.wantMatchedPairs[j].resID + } + return tt.wantMatchedPairs[i].sdkARN < tt.wantMatchedPairs[j].sdkARN + }) + + sort.Strings(actualUnmatchedResIDs) + sort.Strings(tt.wantUnmatchedResIDs) + sort.Strings(actualUnmatchedSDKARNs) + sort.Strings(tt.wantUnmatchedSDKARNs) + + // Verify matched pairs + assert.Equal(t, len(tt.wantMatchedPairs), len(actualMatchedPairs), "matched pairs count") + for i := range tt.wantMatchedPairs { + if i < len(actualMatchedPairs) { + assert.Equal(t, tt.wantMatchedPairs[i].resID, actualMatchedPairs[i].resID, "matched pair resID at index %d", i) + assert.Equal(t, tt.wantMatchedPairs[i].sdkARN, actualMatchedPairs[i].sdkARN, "matched pair sdkARN at index %d", i) + } + } + + // Handle nil vs empty slices + if len(actualUnmatchedResIDs) == 0 && len(tt.wantUnmatchedResIDs) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched resource listeners + assert.ElementsMatch(t, tt.wantUnmatchedResIDs, actualUnmatchedResIDs, "unmatched resource listeners") + } + + if len(actualUnmatchedSDKARNs) == 0 && len(tt.wantUnmatchedSDKARNs) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched SDK listeners + assert.ElementsMatch(t, tt.wantUnmatchedSDKARNs, actualUnmatchedSDKARNs, "unmatched SDK listeners") + } + }) + } +} + +func Test_listenerSynthesizer_matchResAndSDKListeners(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + resListeners []*agamodel.Listener + sdkListeners []*ListenerResource + wantMatchedPairs []struct { + resID string + sdkARN string + } + wantUnmatchedResIDs []string + wantUnmatchedSDKARNs []string + }{ + { + name: "empty lists", + resListeners: []*agamodel.Listener{}, + sdkListeners: []*ListenerResource{}, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{}, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "exact match - should be identified in first pass", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-80", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "similarity match - should be identified in second pass", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80-90"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-85-95"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(85), ToPort: awssdk.Int32(95)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-80-90", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-85-95", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "mix of exact and similarity matches", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-443"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-8080-8090"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 8080, ToPort: 8090}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-8085-8095"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(8085), ToPort: awssdk.Int32(8095)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-80", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80", + }, + { + resID: "tcp-8080-8090", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-8085-8095", + }, + }, + wantUnmatchedResIDs: []string{"tcp-443"}, + wantUnmatchedSDKARNs: []string{}, + }, + { + name: "unmatched listeners - no similarities above threshold", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-53"), + Protocol: agatypes.ProtocolUdp, // Different protocol + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(53), ToPort: awssdk.Int32(53)}, // Different port + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{}, + wantUnmatchedResIDs: []string{"tcp-80"}, + wantUnmatchedSDKARNs: []string{"arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-53"}, + }, + { + name: "complex case with multiple matches of both types", + resListeners: []*agamodel.Listener{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-80"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "udp-53"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolUDP, + PortRanges: []agamodel.PortRange{ + {FromPort: 53, ToPort: 53}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-8080-8090"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 8080, ToPort: 8090}, + }, + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", "tcp-443"), + Spec: agamodel.ListenerSpec{ + Protocol: agamodel.ProtocolTCP, + PortRanges: []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }, + }, + }, + }, + sdkListeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-53"), + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(53), ToPort: awssdk.Int32(53)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-8085-8095"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(8085), ToPort: awssdk.Int32(8095)}, + }, + }, + }, + }, + wantMatchedPairs: []struct { + resID string + sdkARN string + }{ + { + resID: "tcp-80", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-80", + }, + { + resID: "udp-53", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-udp-53", + }, + { + resID: "tcp-8080-8090", + sdkARN: "arn:aws:globalaccelerator::123456789012:accelerator/acc123/listener/list-tcp-8085-8095", + }, + }, + wantUnmatchedResIDs: []string{"tcp-443"}, + wantUnmatchedSDKARNs: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &listenerSynthesizer{ + logger: logr.Discard(), + } + + // Run the function + matchedPairs, unmatchedResListeners, unmatchedSDKListeners := s.matchResAndSDKListeners(tt.resListeners, tt.sdkListeners) + + // Collect the actual pairs and IDs for verification + var actualMatchedPairs []struct { + resID string + sdkARN string + } + + for _, pair := range matchedPairs { + actualMatchedPairs = append(actualMatchedPairs, struct { + resID string + sdkARN string + }{ + resID: pair.resListener.ID(), + sdkARN: awssdk.ToString(pair.sdkListener.Listener.ListenerArn), + }) + } + + var actualUnmatchedResIDs []string + for _, listener := range unmatchedResListeners { + actualUnmatchedResIDs = append(actualUnmatchedResIDs, listener.ID()) + } + + var actualUnmatchedSDKARNs []string + for _, listener := range unmatchedSDKListeners { + actualUnmatchedSDKARNs = append(actualUnmatchedSDKARNs, awssdk.ToString(listener.Listener.ListenerArn)) + } + + // Sort all slices to ensure consistent comparison + sort.Slice(actualMatchedPairs, func(i, j int) bool { + if actualMatchedPairs[i].resID != actualMatchedPairs[j].resID { + return actualMatchedPairs[i].resID < actualMatchedPairs[j].resID + } + return actualMatchedPairs[i].sdkARN < actualMatchedPairs[j].sdkARN + }) + + sort.Slice(tt.wantMatchedPairs, func(i, j int) bool { + if tt.wantMatchedPairs[i].resID != tt.wantMatchedPairs[j].resID { + return tt.wantMatchedPairs[i].resID < tt.wantMatchedPairs[j].resID + } + return tt.wantMatchedPairs[i].sdkARN < tt.wantMatchedPairs[j].sdkARN + }) + + sort.Strings(actualUnmatchedResIDs) + sort.Strings(tt.wantUnmatchedResIDs) + sort.Strings(actualUnmatchedSDKARNs) + sort.Strings(tt.wantUnmatchedSDKARNs) + + // Verify matched pairs + assert.Equal(t, len(tt.wantMatchedPairs), len(actualMatchedPairs), "matched pairs count") + for i := range tt.wantMatchedPairs { + if i < len(actualMatchedPairs) { + assert.Equal(t, tt.wantMatchedPairs[i].resID, actualMatchedPairs[i].resID, "matched pair resID at index %d", i) + assert.Equal(t, tt.wantMatchedPairs[i].sdkARN, actualMatchedPairs[i].sdkARN, "matched pair sdkARN at index %d", i) + } + } + + // Handle nil vs empty slices + if len(actualUnmatchedResIDs) == 0 && len(tt.wantUnmatchedResIDs) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched resource listeners + assert.ElementsMatch(t, tt.wantUnmatchedResIDs, actualUnmatchedResIDs, "unmatched resource listeners") + } + + if len(actualUnmatchedSDKARNs) == 0 && len(tt.wantUnmatchedSDKARNs) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched SDK listeners + assert.ElementsMatch(t, tt.wantUnmatchedSDKARNs, actualUnmatchedSDKARNs, "unmatched SDK listeners") + } + }) + } +} + +func Test_listenerSynthesizer_generateSDKListenerKey(t *testing.T) { + tests := []struct { + name string + listener *ListenerResource + want string + }{ + { + name: "TCP listener with single port range", + listener: &ListenerResource{ + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/1234abcd-abcd-1234-abcd-1234abcdefgh/listener/abcdef1234"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + want: "TCP:80-80", + }, + { + name: "UDP listener with multiple port ranges - ordered", + listener: &ListenerResource{ + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/1234abcd-abcd-1234-abcd-1234abcdefgh/listener/abcdef1234"), + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + }, + }, + }, + want: "UDP:80-80,443-443", + }, + { + name: "TCP listener with multiple port ranges - unordered", + listener: &ListenerResource{ + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/1234abcd-abcd-1234-abcd-1234abcdefgh/listener/abcdef1234"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + want: "TCP:80-80,443-443", // Should be sorted + }, + { + name: "UDP listener with complex port ranges - unordered", + listener: &ListenerResource{ + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:accelerator/1234abcd-abcd-1234-abcd-1234abcdefgh/listener/abcdef1234"), + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(8000), ToPort: awssdk.Int32(8100)}, + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + want: "UDP:80-80,443-443,8000-8100", // Should be sorted by FromPort + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &listenerSynthesizer{ + logger: logr.Discard(), + } + got := s.generateSDKListenerKey(tt.listener) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/deploy/aga/stack_deployer.go b/pkg/deploy/aga/stack_deployer.go index 3bc38e13c2..f27122370c 100644 --- a/pkg/deploy/aga/stack_deployer.go +++ b/pkg/deploy/aga/stack_deployer.go @@ -32,9 +32,9 @@ func NewDefaultStackDeployer(cloud services.Cloud, config config.ControllerConfi // Create actual managers agaTaggingManager := NewDefaultTaggingManager(cloud.GlobalAccelerator(), cloud.RGT(), logger) - acceleratorManager := NewDefaultAcceleratorManager(cloud.GlobalAccelerator(), trackingProvider, agaTaggingManager, config.ExternalManagedTags, logger) + listenerManager := NewDefaultListenerManager(cloud.GlobalAccelerator(), logger) + acceleratorManager := NewDefaultAcceleratorManager(cloud.GlobalAccelerator(), trackingProvider, agaTaggingManager, listenerManager, config.ExternalManagedTags, logger) // TODO: Create other managers when they are implemented - // listenerManager := NewDefaultListenerManager(cloud.GlobalAccelerator(), trackingProvider, agaTaggingManager, config.ExternalManagedTags, logger) // endpointGroupManager := NewDefaultEndpointGroupManager(cloud.GlobalAccelerator(), trackingProvider, agaTaggingManager, config.ExternalManagedTags, logger) // endpointManager := NewDefaultEndpointManager(cloud.GlobalAccelerator(), logger) @@ -48,8 +48,8 @@ func NewDefaultStackDeployer(cloud services.Cloud, config config.ControllerConfi controllerName: controllerName, agaTaggingManager: agaTaggingManager, acceleratorManager: acceleratorManager, + listenerManager: listenerManager, // TODO: Set other managers when implemented - // listenerManager: listenerManager, // endpointGroupManager: endpointGroupManager, // endpointManager: endpointManager, } @@ -70,8 +70,8 @@ type defaultStackDeployer struct { // Actual managers agaTaggingManager TaggingManager acceleratorManager AcceleratorManager + listenerManager ListenerManager // TODO: Add other managers when implemented - // listenerManager ListenerManager // endpointGroupManager EndpointGroupManager // endpointManager EndpointManager } @@ -91,8 +91,8 @@ func (d *defaultStackDeployer) Deploy(ctx context.Context, stack core.Stack, met // Creation order: Accelerator first, then dependent resources synthesizers = append(synthesizers, NewAcceleratorSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.agaTaggingManager, d.acceleratorManager, d.logger, d.featureGates, stack), + NewListenerSynthesizer(d.cloud.GlobalAccelerator(), d.listenerManager, d.logger, stack), // TODO: Add other synthesizers when managers are implemented - // NewListenerSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.agaTaggingManager, d.listenerManager, d.logger, d.featureGates, stack), // NewEndpointGroupSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.agaTaggingManager, d.endpointGroupManager, d.logger, d.featureGates, stack), // NewEndpointSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.endpointManager, d.logger, d.featureGates, stack), ) diff --git a/pkg/deploy/aga/types.go b/pkg/deploy/aga/types.go index a6980f06d8..ae07815bf2 100644 --- a/pkg/deploy/aga/types.go +++ b/pkg/deploy/aga/types.go @@ -9,3 +9,8 @@ type AcceleratorWithTags struct { Accelerator *globalacceleratortypes.Accelerator Tags map[string]string } + +// ListenerResource represents an AWS Global Accelerator Listener. +type ListenerResource struct { + Listener *globalacceleratortypes.Listener +} diff --git a/pkg/deploy/aga/utils.go b/pkg/deploy/aga/utils.go new file mode 100644 index 0000000000..a75a187630 --- /dev/null +++ b/pkg/deploy/aga/utils.go @@ -0,0 +1,117 @@ +package aga + +import ( + "fmt" + "sort" + "strings" + + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +// SortModelPortRanges sorts port ranges by FromPort and then by ToPort +func SortModelPortRanges(portRanges []agamodel.PortRange) { + sort.Slice(portRanges, func(i, j int) bool { + if portRanges[i].FromPort != portRanges[j].FromPort { + return portRanges[i].FromPort < portRanges[j].FromPort + } + return portRanges[i].ToPort < portRanges[j].ToPort + }) +} + +// SortSDKPortRanges sorts port ranges by FromPort and then by ToPort +func SortSDKPortRanges(portRanges []agatypes.PortRange) { + sort.Slice(portRanges, func(i, j int) bool { + if *portRanges[i].FromPort != *portRanges[j].FromPort { + return *portRanges[i].FromPort < *portRanges[j].FromPort + } + return *portRanges[i].ToPort < *portRanges[j].ToPort + }) +} + +// PortRangeCompare is a generic comparison function for port ranges +// It takes two port ranges with their from and to values and compares them +// Returns -1 if the first range should sort before the second +// Returns 0 if they are equal +// Returns 1 if the first range should sort after the second +func PortRangeCompare(fromPort1, toPort1, fromPort2, toPort2 int32) int { + if fromPort1 != fromPort2 { + if fromPort1 < fromPort2 { + return -1 + } + return 1 + } + + if toPort1 != toPort2 { + if toPort1 < toPort2 { + return -1 + } + return 1 + } + + return 0 +} + +// PortRangesToSet adds all ports in a range (inclusive) to the provided portSet map +func PortRangesToSet(fromPort, toPort int32, portSet map[int32]bool) { + for port := fromPort; port <= toPort; port++ { + portSet[port] = true + } +} + +// SDKPortRangesToSet adds all ports from AWS SDK PortRange slices to the provided portSet map +func SDKPortRangesToSet(portRanges []agatypes.PortRange, portSet map[int32]bool) { + for _, pr := range portRanges { + PortRangesToSet(*pr.FromPort, *pr.ToPort, portSet) + } +} + +// ResPortRangesToSet adds all ports from resource model PortRange slices to the provided portSet map +func ResPortRangesToSet(portRanges []agamodel.PortRange, portSet map[int32]bool) { + for _, pr := range portRanges { + PortRangesToSet(pr.FromPort, pr.ToPort, portSet) + } +} + +// GetAWSInt32Value safely gets the value from an AWS SDK Int32 pointer +// Returns the value if pointer is not nil, or defaultValue otherwise +func GetAWSInt32Value(ptr *int32, defaultValue int32) int32 { + if ptr == nil { + return defaultValue + } + return *ptr +} + +// GetAWSStringValue safely gets the value from an AWS SDK String pointer +// Returns the value if pointer is not nil, or defaultValue otherwise +func GetAWSStringValue(ptr *string, defaultValue string) string { + if ptr == nil { + return defaultValue + } + return *ptr +} + +// FormatPortRangeToString converts an individual port range to string format +func FormatPortRangeToString(fromPort, toPort int32) string { + return fmt.Sprintf("%d-%d", fromPort, toPort) +} + +// ModelPortRangesToString converts model port ranges to a standardized string representation +// The port ranges should be sorted before calling this function +func ResPortRangesToString(portRanges []agamodel.PortRange) string { + var parts []string + for _, pr := range portRanges { + parts = append(parts, FormatPortRangeToString(pr.FromPort, pr.ToPort)) + } + return strings.Join(parts, ",") +} + +// SDKPortRangesToString converts SDK port ranges to a standardized string representation +// The port ranges should be sorted before calling this function +func SDKPortRangesToString(portRanges []agatypes.PortRange) string { + var parts []string + for _, pr := range portRanges { + parts = append(parts, FormatPortRangeToString(*pr.FromPort, *pr.ToPort)) + } + return strings.Join(parts, ",") +} diff --git a/pkg/deploy/aga/utils_test.go b/pkg/deploy/aga/utils_test.go new file mode 100644 index 0000000000..fae3a01bb8 --- /dev/null +++ b/pkg/deploy/aga/utils_test.go @@ -0,0 +1,446 @@ +package aga + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/stretchr/testify/assert" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +func TestSortModelPortRanges(t *testing.T) { + tests := []struct { + name string + portRanges []agamodel.PortRange + want []agamodel.PortRange + }{ + { + name: "already sorted by FromPort", + portRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + want: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + { + name: "unsorted by FromPort", + portRanges: []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + {FromPort: 80, ToPort: 80}, + }, + want: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + { + name: "same FromPort, different ToPort", + portRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 100}, + {FromPort: 80, ToPort: 90}, + }, + want: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + {FromPort: 80, ToPort: 100}, + }, + }, + { + name: "empty slice", + portRanges: []agamodel.PortRange{}, + want: []agamodel.PortRange{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + SortModelPortRanges(tt.portRanges) + assert.Equal(t, tt.want, tt.portRanges) + }) + } +} + +func TestSortSDKPortRanges(t *testing.T) { + tests := []struct { + name string + portRanges []agatypes.PortRange + want []agatypes.PortRange + }{ + { + name: "already sorted by FromPort", + portRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + }, + want: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + }, + }, + { + name: "unsorted by FromPort", + portRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + }, + want: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + }, + }, + { + name: "same FromPort, different ToPort", + portRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(100)}, + {FromPort: aws.Int32(80), ToPort: aws.Int32(90)}, + }, + want: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(90)}, + {FromPort: aws.Int32(80), ToPort: aws.Int32(100)}, + }, + }, + { + name: "empty slice", + portRanges: []agatypes.PortRange{}, + want: []agatypes.PortRange{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + SortSDKPortRanges(tt.portRanges) + assert.Equal(t, tt.want, tt.portRanges) + }) + } +} + +func TestPortRangeCompare(t *testing.T) { + tests := []struct { + name string + fromPort1 int32 + toPort1 int32 + fromPort2 int32 + toPort2 int32 + want int + }{ + { + name: "first range starts before second", + fromPort1: 80, + toPort1: 100, + fromPort2: 90, + toPort2: 110, + want: -1, + }, + { + name: "first range starts after second", + fromPort1: 90, + toPort1: 110, + fromPort2: 80, + toPort2: 100, + want: 1, + }, + { + name: "same start, first end before second", + fromPort1: 80, + toPort1: 100, + fromPort2: 80, + toPort2: 110, + want: -1, + }, + { + name: "same start, first end after second", + fromPort1: 80, + toPort1: 110, + fromPort2: 80, + toPort2: 100, + want: 1, + }, + { + name: "identical port ranges", + fromPort1: 80, + toPort1: 100, + fromPort2: 80, + toPort2: 100, + want: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := PortRangeCompare(tt.fromPort1, tt.toPort1, tt.fromPort2, tt.toPort2) + assert.Equal(t, tt.want, result) + }) + } +} + +func TestPortRangesToSet(t *testing.T) { + tests := []struct { + name string + fromPort int32 + toPort int32 + want map[int32]bool + }{ + { + name: "single port", + fromPort: 80, + toPort: 80, + want: map[int32]bool{80: true}, + }, + { + name: "port range", + fromPort: 80, + toPort: 82, + want: map[int32]bool{80: true, 81: true, 82: true}, + }, + { + name: "fromPort > toPort (invalid but shouldn't crash)", + fromPort: 82, + toPort: 80, + want: map[int32]bool{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + portSet := make(map[int32]bool) + PortRangesToSet(tt.fromPort, tt.toPort, portSet) + assert.Equal(t, tt.want, portSet) + }) + } +} + +func TestResPortRangesToSet(t *testing.T) { + tests := []struct { + name string + portRanges []agamodel.PortRange + want map[int32]bool + }{ + { + name: "single port range", + portRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 82}, + }, + want: map[int32]bool{80: true, 81: true, 82: true}, + }, + { + name: "multiple port ranges", + portRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 81}, + {FromPort: 443, ToPort: 444}, + }, + want: map[int32]bool{80: true, 81: true, 443: true, 444: true}, + }, + { + name: "empty slice", + portRanges: []agamodel.PortRange{}, + want: map[int32]bool{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + portSet := make(map[int32]bool) + ResPortRangesToSet(tt.portRanges, portSet) + assert.Equal(t, tt.want, portSet) + }) + } +} + +func TestSDKPortRangesToSet(t *testing.T) { + tests := []struct { + name string + portRanges []agatypes.PortRange + want map[int32]bool + }{ + { + name: "single port range", + portRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(82)}, + }, + want: map[int32]bool{80: true, 81: true, 82: true}, + }, + { + name: "multiple port ranges", + portRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(81)}, + {FromPort: aws.Int32(443), ToPort: aws.Int32(444)}, + }, + want: map[int32]bool{80: true, 81: true, 443: true, 444: true}, + }, + { + name: "empty slice", + portRanges: []agatypes.PortRange{}, + want: map[int32]bool{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + portSet := make(map[int32]bool) + SDKPortRangesToSet(tt.portRanges, portSet) + assert.Equal(t, tt.want, portSet) + }) + } +} + +func TestGetAWSInt32Value(t *testing.T) { + tests := []struct { + name string + ptr *int32 + defaultValue int32 + want int32 + }{ + { + name: "non-nil pointer", + ptr: aws.Int32(42), + defaultValue: 0, + want: 42, + }, + { + name: "nil pointer", + ptr: nil, + defaultValue: 42, + want: 42, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetAWSInt32Value(tt.ptr, tt.defaultValue) + assert.Equal(t, tt.want, result) + }) + } +} + +func TestGetAWSStringValue(t *testing.T) { + tests := []struct { + name string + ptr *string + defaultValue string + want string + }{ + { + name: "non-nil pointer", + ptr: aws.String("hello"), + defaultValue: "", + want: "hello", + }, + { + name: "nil pointer", + ptr: nil, + defaultValue: "default", + want: "default", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetAWSStringValue(tt.ptr, tt.defaultValue) + assert.Equal(t, tt.want, result) + }) + } +} + +func TestResPortRangesToString(t *testing.T) { + tests := []struct { + name string + portRanges []agamodel.PortRange + want string + }{ + { + name: "single port range", + portRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + want: "80-80", + }, + { + name: "multiple port ranges", + portRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + want: "80-80,443-443", + }, + { + name: "empty slice", + portRanges: []agamodel.PortRange{}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ResPortRangesToString(tt.portRanges) + assert.Equal(t, tt.want, result) + }) + } +} + +func TestSDKPortRangesToString(t *testing.T) { + tests := []struct { + name string + portRanges []agatypes.PortRange + want string + }{ + { + name: "single port range", + portRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + }, + want: "80-80", + }, + { + name: "multiple port ranges", + portRanges: []agatypes.PortRange{ + {FromPort: aws.Int32(80), ToPort: aws.Int32(80)}, + {FromPort: aws.Int32(443), ToPort: aws.Int32(443)}, + }, + want: "80-80,443-443", + }, + { + name: "empty slice", + portRanges: []agatypes.PortRange{}, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SDKPortRangesToString(tt.portRanges) + assert.Equal(t, tt.want, result) + }) + } +} + +func TestFormatPortRangeToString(t *testing.T) { + tests := []struct { + name string + fromPort int32 + toPort int32 + want string + }{ + { + name: "single port", + fromPort: 80, + toPort: 80, + want: "80-80", + }, + { + name: "port range", + fromPort: 80, + toPort: 100, + want: "80-100", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FormatPortRangeToString(tt.fromPort, tt.toPort) + assert.Equal(t, tt.want, result) + }) + } +} diff --git a/pkg/model/aga/listener.go b/pkg/model/aga/listener.go new file mode 100644 index 0000000000..f4e25986d8 --- /dev/null +++ b/pkg/model/aga/listener.go @@ -0,0 +1,111 @@ +package aga + +import ( + "context" + "github.com/pkg/errors" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +const ( + // ResourceTypeListener is the resource type for Global Accelerator Listener + ResourceTypeListener = "AWS::GlobalAccelerator::Listener" +) + +var _ core.Resource = &Listener{} + +// Listener represents an AWS Global Accelerator Listener. +type Listener struct { + core.ResourceMeta `json:"-"` + + // desired state of Listener + Spec ListenerSpec `json:"spec"` + + // observed state of Listener + // +optional + Status *ListenerStatus `json:"status,omitempty"` + + // reference to Accelerator resource + Accelerator *Accelerator `json:"-"` +} + +// NewListener constructs new Listener resource. +func NewListener(stack core.Stack, id string, spec ListenerSpec, accelerator *Accelerator) *Listener { + listener := &Listener{ + ResourceMeta: core.NewResourceMeta(stack, ResourceTypeListener, id), + Spec: spec, + Status: nil, + Accelerator: accelerator, + } + stack.AddResource(listener) + listener.registerDependencies(stack) + return listener +} + +// SetStatus sets the Listener's status +func (l *Listener) SetStatus(status ListenerStatus) { + l.Status = &status +} + +// ListenerARN returns The Amazon Resource Name (ARN) of the listener. +func (l *Listener) ListenerARN() core.StringToken { + return core.NewResourceFieldStringToken(l, "status/listenerARN", + func(ctx context.Context, res core.Resource, fieldPath string) (s string, err error) { + listener := res.(*Listener) + if listener.Status == nil { + return "", errors.Errorf("Listener is not fulfilled yet: %v", listener.ID()) + } + return listener.Status.ListenerARN, nil + }, + ) +} + +// register dependencies for Listener. +func (l *Listener) registerDependencies(stack core.Stack) { + // Listener depends on its Accelerator + stack.AddDependency(l, l.Accelerator) +} + +type Protocol string + +const ( + ProtocolTCP Protocol = "TCP" + ProtocolUDP Protocol = "UDP" +) + +type ClientAffinity string + +const ( + ClientAffinitySourceIP ClientAffinity = "SOURCE_IP" + ClientAffinityNone ClientAffinity = "NONE" +) + +// PortRange defines the port range for Global Accelerator listeners. +type PortRange struct { + // FromPort is the first port in the range of ports, inclusive. + FromPort int32 `json:"fromPort"` + + // ToPort is the last port in the range of ports, inclusive. + ToPort int32 `json:"toPort"` +} + +// ListenerSpec defines the desired state of Listener +type ListenerSpec struct { + // AcceleratorARN is the ARN of the accelerator to which the listener belongs + AcceleratorARN core.StringToken `json:"acceleratorARN"` + + // Protocol is the protocol for the connections from clients to the accelerator. + Protocol Protocol `json:"protocol"` + + // PortRanges is the list of port ranges for the connections from clients to the accelerator. + PortRanges []PortRange `json:"portRanges"` + + // ClientAffinity determines how to direct all requests from a specific client to the same endpoint + // +optional + ClientAffinity ClientAffinity `json:"clientAffinity,omitempty"` +} + +// ListenerStatus defines the observed state of Listener +type ListenerStatus struct { + // ListenerARN is the Amazon Resource Name (ARN) of the listener. + ListenerARN string `json:"listenerARN"` +} diff --git a/pkg/shared_utils/aga_utils_test.go b/pkg/shared_utils/aga_utils_test.go deleted file mode 100644 index 34fd686dd2..0000000000 --- a/pkg/shared_utils/aga_utils_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package shared_utils - -import ( - "testing" - - "github.com/spf13/pflag" - "github.com/stretchr/testify/assert" - "sigs.k8s.io/aws-load-balancer-controller/pkg/config" -) - -type mockFeatureGates struct { - enabled bool -} - -func (m *mockFeatureGates) Enabled(feature config.Feature) bool { - if feature == config.AGAController { - return m.enabled - } - return false -} - -func (m *mockFeatureGates) Enable(feature config.Feature) {} -func (m *mockFeatureGates) Disable(feature config.Feature) {} -func (m *mockFeatureGates) BindFlags(fs *pflag.FlagSet) {} - -func Test_IsAGAControllerEnabled(t *testing.T) { - tests := []struct { - name string - featureGate bool - region string - expectResult bool - }{ - { - name: "feature gate disabled", - featureGate: false, - region: "us-west-2", - expectResult: false, - }, - { - name: "feature gate enabled, standard region", - featureGate: true, - region: "us-west-2", - expectResult: true, - }, - { - name: "feature gate enabled, eu region", - featureGate: true, - region: "eu-west-1", - expectResult: true, - }, - { - name: "feature gate enabled, China region", - featureGate: true, - region: "cn-north-1", - expectResult: false, - }, - { - name: "feature gate enabled, GovCloud region", - featureGate: true, - region: "us-gov-west-1", - expectResult: false, - }, - { - name: "feature gate enabled, ap region", - featureGate: true, - region: "ap-southeast-1", - expectResult: true, - }, - { - name: "feature gate enabled, iso region", - featureGate: true, - region: "us-isof-east-1", - expectResult: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockFG := &mockFeatureGates{enabled: tt.featureGate} - result := IsAGAControllerEnabled(mockFG, tt.region) - assert.Equal(t, tt.expectResult, result) - }) - } -} diff --git a/scripts/gen_mocks.sh b/scripts/gen_mocks.sh index 5f0c871154..413cbec8d4 100755 --- a/scripts/gen_mocks.sh +++ b/scripts/gen_mocks.sh @@ -26,6 +26,7 @@ $MOCKGEN -package=networking -destination=./pkg/networking/vpc_info_provider_moc $MOCKGEN -package=networking -destination=./pkg/networking/backend_sg_provider_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking BackendSGProvider $MOCKGEN -package=networking -destination=./pkg/networking/security_group_resolver_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking SecurityGroupResolver $MOCKGEN -package=aga -destination=./pkg/deploy/aga/accelerator_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga AcceleratorManager +$MOCKGEN -package=aga -destination=./pkg/deploy/aga/listener_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga ListenerManager $MOCKGEN -package=aga -destination=./pkg/deploy/aga/tagging_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga TaggingManager $MOCKGEN -package=certs -destination=./pkg/certs/cert_discovery_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/certs CertDiscovery $MOCKGEN -package=elbv2 -destination=./pkg/deploy/elbv2/tagging_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/elbv2 TaggingManager diff --git a/webhooks/aga/globalaccelerator_validator.go b/webhooks/aga/globalaccelerator_validator.go new file mode 100644 index 0000000000..7adc218ccd --- /dev/null +++ b/webhooks/aga/globalaccelerator_validator.go @@ -0,0 +1,124 @@ +package aga + +import ( + "context" + + "github.com/go-logr/logr" + "github.com/pkg/errors" + "k8s.io/apimachinery/pkg/runtime" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + lbcmetrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/lbc" + "sigs.k8s.io/aws-load-balancer-controller/pkg/webhook" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" +) + +const ( + apiPathValidateAGAGlobalAccelerator = "/validate-aga-k8s-aws-v1beta1-globalaccelerator" +) + +// NewGlobalAcceleratorValidator returns a validator for GlobalAccelerator API. +func NewGlobalAcceleratorValidator(logger logr.Logger, metricsCollector lbcmetrics.MetricCollector) *globalAcceleratorValidator { + return &globalAcceleratorValidator{ + logger: logger, + metricsCollector: metricsCollector, + } +} + +var _ webhook.Validator = &globalAcceleratorValidator{} + +type globalAcceleratorValidator struct { + logger logr.Logger + metricsCollector lbcmetrics.MetricCollector +} + +func (v *globalAcceleratorValidator) Prototype(req admission.Request) (runtime.Object, error) { + return &agaapi.GlobalAccelerator{}, nil +} + +func (v *globalAcceleratorValidator) ValidateCreate(_ context.Context, obj runtime.Object) error { + ga := obj.(*agaapi.GlobalAccelerator) + + if err := v.checkForOverlappingPortRanges(ga); err != nil { + v.metricsCollector.ObserveWebhookValidationError(apiPathValidateAGAGlobalAccelerator, "checkForOverlappingPortRanges") + return err + } + + return nil +} + +func (v *globalAcceleratorValidator) ValidateUpdate(_ context.Context, obj runtime.Object, _ runtime.Object) error { + ga := obj.(*agaapi.GlobalAccelerator) + + if err := v.checkForOverlappingPortRanges(ga); err != nil { + v.metricsCollector.ObserveWebhookValidationError(apiPathValidateAGAGlobalAccelerator, "checkForOverlappingPortRanges") + return err + } + + return nil +} + +func (v *globalAcceleratorValidator) ValidateDelete(_ context.Context, _ runtime.Object) error { + return nil +} + +// checkForOverlappingPortRanges checks if there are overlapping port ranges across all listeners +// grouped by protocol +func (v *globalAcceleratorValidator) checkForOverlappingPortRanges(ga *agaapi.GlobalAccelerator) error { + if ga.Spec.Listeners == nil { + return nil + } + + // Group all port ranges by protocol + portRangesByProtocol := make(map[agaapi.GlobalAcceleratorProtocol][]agaapi.PortRange) + + // Process all listeners and collect port ranges by protocol + for _, listener := range *ga.Spec.Listeners { + if listener.PortRanges == nil || len(*listener.PortRanges) == 0 { + continue + } + + // Skip listeners with nil protocol, we will assign protocols based on endpoints + if listener.Protocol == nil { + continue + } + + // Add all port ranges from this listener to the appropriate protocol group + portRangesByProtocol[*listener.Protocol] = append(portRangesByProtocol[*listener.Protocol], *listener.PortRanges...) + } + + // Check each protocol group for overlapping port ranges + for protocol, portRanges := range portRangesByProtocol { + if hasOverlappingRangesInSlice(portRanges) { + return errors.Errorf( + "overlapping port ranges detected for protocol %s, which is not allowed", + protocol) + } + } + + return nil +} + +// hasOverlappingRangesInSlice checks if there are any overlapping ranges within a slice of port ranges +func hasOverlappingRangesInSlice(portRanges []agaapi.PortRange) bool { + for i := 0; i < len(portRanges); i++ { + for j := i + 1; j < len(portRanges); j++ { + if portRangesOverlap(portRanges[i], portRanges[j]) { + return true + } + } + } + return false +} + +// portRangesOverlap checks if two port ranges overlap +func portRangesOverlap(rangeA agaapi.PortRange, rangeB agaapi.PortRange) bool { + // Ranges overlap if start of A is before or at end of B AND end of A is after or at start of B + return rangeA.FromPort <= rangeB.ToPort && rangeA.ToPort >= rangeB.FromPort +} + +// +kubebuilder:webhook:path=/validate-aga-k8s-aws-v1beta1-globalaccelerator,mutating=false,failurePolicy=fail,groups=aga.k8s.aws,resources=globalaccelerators,verbs=create;update,versions=v1beta1,name=vglobalaccelerator.aga.k8s.aws,sideEffects=None,matchPolicy=Equivalent,webhookVersions=v1,admissionReviewVersions=v1beta1 + +func (v *globalAcceleratorValidator) SetupWithManager(mgr ctrl.Manager) { + mgr.GetWebhookServer().Register(apiPathValidateAGAGlobalAccelerator, webhook.ValidatingWebhookForValidator(v, mgr.GetScheme())) +} diff --git a/webhooks/aga/globalaccelerator_validator_test.go b/webhooks/aga/globalaccelerator_validator_test.go new file mode 100644 index 0000000000..fcc48d4a35 --- /dev/null +++ b/webhooks/aga/globalaccelerator_validator_test.go @@ -0,0 +1,928 @@ +package aga + +import ( + "context" + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + "testing" + + "github.com/stretchr/testify/assert" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + lbcmetrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/lbc" +) + +func Test_globalAcceleratorValidator_ValidateCreate(t *testing.T) { + // Protocol references for direct pointer usage + protocolTCP := agaapi.GlobalAcceleratorProtocolTCP + protocolUDP := agaapi.GlobalAcceleratorProtocolUDP + + tests := []struct { + name string + ga *agaapi.GlobalAccelerator + wantErr string + wantMetric bool + }{ + { + name: "valid global accelerator with no listeners", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: nil, + }, + }, + wantErr: "", + wantMetric: false, + }, + { + name: "valid global accelerator with single listener", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "", + wantMetric: false, + }, + { + name: "invalid global accelerator with single listener and overlapping ranges between listeners", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 8080, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "overlapping port ranges detected for protocol TCP, which is not allowed", + wantMetric: true, + }, + { + name: "valid global accelerator with multiple listeners with different protocols and non-overlapping ranges", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + { + Protocol: &protocolUDP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 443, + ToPort: 443, + }, + }, + ClientAffinity: agaapi.ClientAffinitySourceIP, + }, + }, + }, + }, + wantErr: "", + wantMetric: false, + }, + { + name: "valid global accelerator with with multiple listeners with different protocols and overlapping port ranges", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 90, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + { + Protocol: &protocolUDP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 90, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "", + wantMetric: false, + }, + { + name: "valid global accelerator with single listener having multiple non-overlapping port ranges", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8090, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "", + wantMetric: false, + }, + { + name: "valid global accelerator with multiple listeners having multiple non-overlapping port ranges", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + { + Protocol: &protocolUDP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 53, + ToPort: 53, + }, + { + FromPort: 123, + ToPort: 123, + }, + }, + ClientAffinity: agaapi.ClientAffinitySourceIP, + }, + }, + }, + }, + wantErr: "", + wantMetric: false, + }, + { + name: "valid global accelerator with multiple listeners having multiple port ranges of the same protocol but no overlap", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 8080, + ToPort: 8080, + }, + { + FromPort: 8443, + ToPort: 8443, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "", + wantMetric: false, + }, + { + name: "invalid global accelerator with multiple listeners having multiple port ranges with partial overlap", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8000, + ToPort: 9000, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 1000, + ToPort: 2000, + }, + { + FromPort: 8500, + ToPort: 8600, // Overlaps with 8000-9000 in first listener + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "overlapping port ranges detected for protocol TCP, which is not allowed", + wantMetric: true, + }, + { + name: "invalid global accelerator with wide port range overlapping with specific port", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 1000, + ToPort: 2000, // Wide range + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 1500, + ToPort: 1500, // Single port within the wide range + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "overlapping port ranges detected for protocol TCP, which is not allowed", + wantMetric: true, + }, + { + name: "valid global accelerator with touching but not overlapping port ranges", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 1000, + ToPort: 2000, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 2001, // Just after the previous range ends + ToPort: 3000, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "", + wantMetric: false, + }, + { + name: "invalid global accelerator with single listener having overlapping port ranges", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 1000, + ToPort: 2000, + }, + { + FromPort: 1500, // Overlaps with the first range + ToPort: 2500, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "overlapping port ranges detected for protocol TCP, which is not allowed", + wantMetric: true, + }, + { + name: "invalid global accelerator with single listener and overlapping port ranges within listener", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 8080, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 1000, + ToPort: 2000, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantErr: "overlapping port ranges detected for protocol TCP, which is not allowed", + wantMetric: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock dependencies + logger := logr.New(&log.NullLogSink{}) + mockMetricsCollector := lbcmetrics.NewMockCollector() + + // Create the validator + v := NewGlobalAcceleratorValidator(logger, mockMetricsCollector) + + // Run tests for both create and update + t.Run("create", func(t *testing.T) { + err := v.ValidateCreate(context.Background(), tt.ga) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + }) + + t.Run("update", func(t *testing.T) { + err := v.ValidateUpdate(context.Background(), tt.ga, &agaapi.GlobalAccelerator{}) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.NoError(t, err) + } + }) + + // Verify metrics collection + mockCollector := v.metricsCollector.(*lbcmetrics.MockCollector) + if tt.wantMetric { + // Should have 2 invocations, one for create and one for update + assert.Equal(t, 2, len(mockCollector.Invocations[lbcmetrics.MetricWebhookValidationFailure])) + } else { + assert.Equal(t, 0, len(mockCollector.Invocations[lbcmetrics.MetricWebhookValidationFailure])) + } + }) + } +} + +func Test_globalAcceleratorValidator_checkForOverlappingPortRanges(t *testing.T) { + // Protocol references for direct pointer usage + protocolTCP := agaapi.GlobalAcceleratorProtocolTCP + protocolUDP := agaapi.GlobalAcceleratorProtocolUDP + + tests := []struct { + name string + globalAccelerator *agaapi.GlobalAccelerator + wantError bool + errorContains string + }{ + { + name: "no listeners", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: nil, + }, + }, + wantError: false, + }, + { + name: "single listener", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + }, + }, + }, + }, + wantError: false, + }, + { + name: "two listeners with different protocols - no overlap", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + }, + { + Protocol: &protocolUDP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + }, + }, + }, + }, + wantError: false, + }, + { + name: "two TCP listeners with non-overlapping port ranges", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 443, + ToPort: 443, + }, + }, + }, + }, + }, + }, + wantError: false, + }, + { + name: "two TCP listeners with directly overlapping port ranges", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + }, + }, + }, + }, + wantError: true, + errorContains: "overlapping port ranges detected for protocol", + }, + { + name: "overlapping port ranges with nil protocol should be skipped", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: nil, // Will be skipped + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + }, + }, + }, + }, + wantError: false, // No error because nil protocol listeners are skipped + }, + { + name: "multiple port ranges with partial overlap", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + { + FromPort: 200, + ToPort: 300, + }, + }, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 90, + ToPort: 150, + }, + { + FromPort: 400, + ToPort: 500, + }, + }, + }, + }, + }, + }, + wantError: true, + errorContains: "overlapping port ranges detected for protocol", + }, + { + name: "port ranges with second range overlapping first", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 200, + ToPort: 300, + }, + }, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 250, + ToPort: 350, + }, + }, + }, + }, + }, + }, + wantError: true, + errorContains: "overlapping port ranges detected for protocol", + }, + { + name: "port ranges with edge case - touching but not overlapping", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 100, + ToPort: 200, + }, + }, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 201, + ToPort: 300, + }, + }, + }, + }, + }, + }, + wantError: false, + }, + { + name: "example from task description", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 78, // Likely a mistake in the example, but should be caught as overlapping with 80 + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantError: true, + errorContains: "overlapping port ranges detected for protocol", + }, + { + name: "single listener with multiple non-overlapping port ranges", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8090, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantError: false, + }, + { + name: "single listener with overlapping port ranges", + globalAccelerator: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + { + FromPort: 90, // Overlaps with previous range + ToPort: 120, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + }, + wantError: true, + errorContains: "overlapping port ranges detected for protocol", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := logr.New(&log.NullLogSink{}) + + // Create a mock metrics collector + mockMetricsCollector := lbcmetrics.NewMockCollector() + + validator := &globalAcceleratorValidator{ + logger: logger, + metricsCollector: mockMetricsCollector, + } + + err := validator.checkForOverlappingPortRanges(tt.globalAccelerator) + + if tt.wantError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_portRangesOverlap(t *testing.T) { + tests := []struct { + name string + rangeA agaapi.PortRange + rangeB agaapi.PortRange + want bool + }{ + { + name: "exactly matching ranges", + rangeA: agaapi.PortRange{ + FromPort: 80, + ToPort: 80, + }, + rangeB: agaapi.PortRange{ + FromPort: 80, + ToPort: 80, + }, + want: true, + }, + { + name: "completely non-overlapping ranges", + rangeA: agaapi.PortRange{ + FromPort: 80, + ToPort: 90, + }, + rangeB: agaapi.PortRange{ + FromPort: 100, + ToPort: 110, + }, + want: false, + }, + { + name: "A partially overlaps B (lower)", + rangeA: agaapi.PortRange{ + FromPort: 80, + ToPort: 100, + }, + rangeB: agaapi.PortRange{ + FromPort: 90, + ToPort: 110, + }, + want: true, + }, + { + name: "A partially overlaps B (higher)", + rangeA: agaapi.PortRange{ + FromPort: 90, + ToPort: 110, + }, + rangeB: agaapi.PortRange{ + FromPort: 80, + ToPort: 100, + }, + want: true, + }, + { + name: "A completely contains B", + rangeA: agaapi.PortRange{ + FromPort: 80, + ToPort: 120, + }, + rangeB: agaapi.PortRange{ + FromPort: 90, + ToPort: 110, + }, + want: true, + }, + { + name: "B completely contains A", + rangeA: agaapi.PortRange{ + FromPort: 90, + ToPort: 110, + }, + rangeB: agaapi.PortRange{ + FromPort: 80, + ToPort: 120, + }, + want: true, + }, + { + name: "Adjacent ranges (not overlapping)", + rangeA: agaapi.PortRange{ + FromPort: 80, + ToPort: 90, + }, + rangeB: agaapi.PortRange{ + FromPort: 91, + ToPort: 100, + }, + want: false, + }, + { + name: "Touching ranges (should be considered overlap)", + rangeA: agaapi.PortRange{ + FromPort: 80, + ToPort: 90, + }, + rangeB: agaapi.PortRange{ + FromPort: 90, + ToPort: 100, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := portRangesOverlap(tt.rangeA, tt.rangeB) + assert.Equal(t, tt.want, result) + }) + } +}