@@ -29,11 +29,12 @@ class KubernetesPatchNetworkPolicyState(Enum):
2929
3030class KubernetesPatchNetworkPolicy (BaseAction ):
3131
32- def __init__ (self , namespace : str , target : str , network_enabled : bool , match_label : tuple , within_cluster : bool ):
32+ def __init__ (self , namespace : str , target : str , ingress_enabled : bool , egress_enabled : bool , match_label : tuple , within_cluster : bool ):
3333 super ().__init__ ()
3434 self .namespace = namespace
3535 self .target = target
36- self .network_enabled = network_enabled
36+ self .ingress_enabled = ingress_enabled
37+ self .egress_enabled = egress_enabled
3738 self .within_cluster = within_cluster
3839 if not isinstance (match_label , dict ) or not "key" in match_label or not "value" in match_label :
3940 raise ValueError ("match_label expected to be key-value pair." )
@@ -53,7 +54,7 @@ def setup(self, **kwargs):
5354 def update (self ) -> py_trees .common .Status : # pylint: disable=too-many-return-statements
5455 if self .current_state == KubernetesPatchNetworkPolicyState .IDLE :
5556 self .current_request = self .network_client .patch_namespaced_network_policy (self .target , body = self .get_network_policy (
56- policy_name = self .target , enable = self .network_enabled , match_label = self .match_label ), namespace = self .namespace , async_req = True )
57+ policy_name = self .target , enable_ingress = self .ingress_enabled , enable_egress = self . egress_enabled , match_label = self .match_label ), namespace = self .namespace , async_req = True )
5758 self .current_state = KubernetesPatchNetworkPolicyState .REQUEST_SENT
5859 self .feedback_message = f"Requested patching '{ self .target } ' in namespace '{ self .namespace } '" # pylint: disable= attribute-defined-outside-init
5960 return py_trees .common .Status .RUNNING
@@ -76,14 +77,16 @@ def update(self) -> py_trees.common.Status: # pylint: disable=too-many-return-s
7677 return py_trees .common .Status .FAILURE
7778 return py_trees .common .Status .FAILURE
7879
79- def get_network_policy (self , policy_name , match_label , enable ):
80+ def get_network_policy (self , policy_name , match_label , enable_ingress , enable_egress ):
8081 body = client .V1NetworkPolicy ()
8182 body .metadata = client .V1ObjectMeta (name = f"{ policy_name } " )
8283 body .spec = client .V1NetworkPolicySpec (pod_selector = client .V1LabelSelector (match_labels = {match_label ["key" ]: match_label ["value" ]}))
83- if enable :
84- body .spec .egress = [client .V1NetworkPolicyEgressRule ()]
84+ if enable_ingress :
8585 body .spec .ingress = [client .V1NetworkPolicyIngressRule ()]
8686 else :
87- body .spec .egress = []
8887 body .spec .ingress = []
88+ if enable_egress :
89+ body .spec .egress = [client .V1NetworkPolicyEgressRule ()]
90+ else :
91+ body .spec .egress = []
8992 return body
0 commit comments