Skip to content

Commit 959b93d

Browse files
feat: support append mode when updating network restrictions (#4420)
* remove confusing log from restrictions get * improve update logic * update test * move append logic behind flag * remove trailing whitespace * fix lint * try another lint fix * update to use PATCH endpoint * fix lint * chore: minor refactor --------- Co-authored-by: Qiao Han <qiao@supabase.io>
1 parent 28425f5 commit 959b93d

File tree

4 files changed

+89
-13
lines changed

4 files changed

+89
-13
lines changed

cmd/restrictions.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ var (
1616

1717
dbCidrsToAllow []string
1818
bypassCidrChecks bool
19+
appendMode bool
1920

2021
restrictionsUpdateCmd = &cobra.Command{
2122
Use: "update",
2223
Short: "Update network restrictions",
2324
RunE: func(cmd *cobra.Command, args []string) error {
24-
return update.Run(cmd.Context(), flags.ProjectRef, dbCidrsToAllow, bypassCidrChecks)
25+
return update.Run(cmd.Context(), flags.ProjectRef, dbCidrsToAllow, bypassCidrChecks, appendMode)
2526
},
2627
}
2728

@@ -38,6 +39,7 @@ func init() {
3839
restrictionsCmd.PersistentFlags().StringVar(&flags.ProjectRef, "project-ref", "", "Project ref of the Supabase project.")
3940
restrictionsUpdateCmd.Flags().StringSliceVar(&dbCidrsToAllow, "db-allow-cidr", []string{}, "CIDR to allow DB connections from.")
4041
restrictionsUpdateCmd.Flags().BoolVar(&bypassCidrChecks, "bypass-cidr-checks", false, "Bypass some of the CIDR validation checks.")
42+
restrictionsUpdateCmd.Flags().BoolVar(&appendMode, "append", false, "Append to existing restrictions instead of replacing them.")
4143
restrictionsCmd.AddCommand(restrictionsGetCmd)
4244
restrictionsCmd.AddCommand(restrictionsUpdateCmd)
4345
rootCmd.AddCommand(restrictionsCmd)

internal/restrictions/get/get.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/go-errors/errors"
88
"github.com/supabase/cli/internal/utils"
9+
"github.com/supabase/cli/pkg/api"
910
)
1011

1112
func Run(ctx context.Context, projectRef string) error {
@@ -19,6 +20,6 @@ func Run(ctx context.Context, projectRef string) error {
1920

2021
fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", resp.JSON200.Config.DbAllowedCidrs)
2122
fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", resp.JSON200.Config.DbAllowedCidrsV6)
22-
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == "applied")
23+
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == api.NetworkRestrictionsResponseStatusApplied)
2324
return nil
2425
}

internal/restrictions/update/update.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ import (
1010
"github.com/supabase/cli/pkg/api"
1111
)
1212

13-
func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypassCidrChecks bool) error {
13+
// Run updates the network restriction lists using the provided CIDRs.
14+
func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypassCidrChecks bool, appendMode bool) error {
1415
// 1. separate CIDR to v4 and v6
1516
body := api.V1UpdateNetworkRestrictionsJSONRequestBody{
1617
DbAllowedCidrs: &[]string{},
@@ -31,6 +32,10 @@ func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypass
3132
}
3233
}
3334

35+
if appendMode {
36+
return ApplyPatch(ctx, projectRef, body)
37+
}
38+
3439
// 2. update restrictions
3540
resp, err := utils.GetSupabase().V1UpdateNetworkRestrictionsWithResponse(ctx, projectRef, body)
3641
if err != nil {
@@ -42,6 +47,44 @@ func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypass
4247

4348
fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", resp.JSON201.Config.DbAllowedCidrs)
4449
fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", resp.JSON201.Config.DbAllowedCidrsV6)
45-
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON201.Status == "applied")
50+
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON201.Status == api.NetworkRestrictionsResponseStatusApplied)
51+
return nil
52+
}
53+
54+
// ApplyPatch submits a network restriction payload using PATCH (add/remove mode).
55+
func ApplyPatch(ctx context.Context, projectRef string, body api.V1UpdateNetworkRestrictionsJSONRequestBody) error {
56+
patchBody := api.V1PatchNetworkRestrictionsJSONRequestBody{
57+
Add: &struct {
58+
DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"`
59+
DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"`
60+
}{
61+
DbAllowedCidrs: body.DbAllowedCidrs,
62+
DbAllowedCidrsV6: body.DbAllowedCidrsV6,
63+
},
64+
}
65+
66+
resp, err := utils.GetSupabase().V1PatchNetworkRestrictionsWithResponse(ctx, projectRef, patchBody)
67+
if err != nil {
68+
return errors.Errorf("failed to apply network restrictions: %w", err)
69+
}
70+
if resp.JSON200 == nil {
71+
return errors.New("failed to apply network restrictions: " + string(resp.Body))
72+
}
73+
74+
var allowedIPv4, allowedIPv6 []string
75+
if allowed := resp.JSON200.Config.DbAllowedCidrs; allowed != nil {
76+
for _, cidr := range *allowed {
77+
switch cidr.Type {
78+
case api.NetworkRestrictionsV2ResponseConfigDbAllowedCidrsTypeV4:
79+
allowedIPv4 = append(allowedIPv4, cidr.Address)
80+
case api.NetworkRestrictionsV2ResponseConfigDbAllowedCidrsTypeV6:
81+
allowedIPv6 = append(allowedIPv6, cidr.Address)
82+
}
83+
}
84+
}
85+
86+
fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", &allowedIPv4)
87+
fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", &allowedIPv6)
88+
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == api.NetworkRestrictionsV2ResponseStatusApplied)
4689
return nil
4790
}

internal/restrictions/update/update_test.go

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,52 @@ func TestUpdateRestrictionsCommand(t *testing.T) {
1919
token := apitest.RandomAccessToken(t)
2020
t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
2121

22-
t.Run("updates v4 and v6 CIDR", func(t *testing.T) {
22+
t.Run("replaces v4 and v6 CIDR", func(t *testing.T) {
2323
// Setup mock api
2424
defer gock.OffAll()
25+
expectedV4 := []string{"12.3.4.5/32", "1.2.3.1/24"}
26+
expectedV6 := []string{"2001:db8:abcd:0012::0/64"}
2527
gock.New(utils.DefaultApiHost).
2628
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
2729
MatchType("json").
2830
JSON(api.NetworkRestrictionsRequest{
29-
DbAllowedCidrs: &[]string{"12.3.4.5/32", "1.2.3.1/24"},
30-
DbAllowedCidrsV6: &[]string{"2001:db8:abcd:0012::0/64"},
31+
DbAllowedCidrs: &expectedV4,
32+
DbAllowedCidrsV6: &expectedV6,
3133
}).
3234
Reply(http.StatusCreated).
3335
JSON(api.NetworkRestrictionsResponse{
3436
Status: api.NetworkRestrictionsResponseStatus("applied"),
3537
})
3638
// Run test
37-
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false)
39+
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false, false)
40+
// Check error
41+
assert.NoError(t, err)
42+
assert.Empty(t, apitest.ListUnmatchedRequests())
43+
})
44+
45+
t.Run("appends v4 and v6 CIDR", func(t *testing.T) {
46+
// Setup mock api
47+
defer gock.OffAll()
48+
addV4 := []string{"12.3.4.5/32", "1.2.3.1/24"}
49+
addV6 := []string{"2001:db8:abcd:0012::0/64"}
50+
gock.New(utils.DefaultApiHost).
51+
Patch("/v1/projects/" + projectRef + "/network-restrictions").
52+
MatchType("json").
53+
JSON(api.NetworkRestrictionsPatchRequest{
54+
Add: &struct {
55+
DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"`
56+
DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"`
57+
}{
58+
DbAllowedCidrs: &addV4,
59+
DbAllowedCidrsV6: &addV6,
60+
},
61+
}).
62+
Reply(http.StatusOK).
63+
JSON(api.NetworkRestrictionsV2Response{
64+
Status: api.NetworkRestrictionsV2ResponseStatus("applied"),
65+
})
66+
// Run test
67+
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "1.2.3.1/24", "2001:db8:abcd:0012::0/64"}, false, true)
3868
// Check error
3969
assert.NoError(t, err)
4070
assert.Empty(t, apitest.ListUnmatchedRequests())
@@ -53,7 +83,7 @@ func TestUpdateRestrictionsCommand(t *testing.T) {
5383
}).
5484
ReplyError(errNetwork)
5585
// Run test
56-
err := Run(context.Background(), projectRef, []string{}, true)
86+
err := Run(context.Background(), projectRef, []string{}, true, false)
5787
// Check error
5888
assert.ErrorIs(t, err, errNetwork)
5989
assert.Empty(t, apitest.ListUnmatchedRequests())
@@ -71,7 +101,7 @@ func TestUpdateRestrictionsCommand(t *testing.T) {
71101
}).
72102
Reply(http.StatusServiceUnavailable)
73103
// Run test
74-
err := Run(context.Background(), projectRef, []string{}, true)
104+
err := Run(context.Background(), projectRef, []string{}, true, false)
75105
// Check error
76106
assert.ErrorContains(t, err, "failed to apply network restrictions:")
77107
assert.Empty(t, apitest.ListUnmatchedRequests())
@@ -99,22 +129,22 @@ func TestValidateCIDR(t *testing.T) {
99129
Status: api.NetworkRestrictionsResponseStatus("applied"),
100130
})
101131
// Run test
102-
err := Run(context.Background(), projectRef, []string{"10.0.0.0/8"}, true)
132+
err := Run(context.Background(), projectRef, []string{"10.0.0.0/8"}, true, false)
103133
// Check error
104134
assert.NoError(t, err)
105135
assert.Empty(t, apitest.ListUnmatchedRequests())
106136
})
107137

108138
t.Run("throws error on private subnet", func(t *testing.T) {
109139
// Run test
110-
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false)
140+
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false, false)
111141
// Check error
112142
assert.ErrorContains(t, err, "private IP provided: 10.0.0.0/8")
113143
})
114144

115145
t.Run("throws error on invalid subnet", func(t *testing.T) {
116146
// Run test
117-
err := Run(context.Background(), projectRef, []string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false)
147+
err := Run(context.Background(), projectRef, []string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false, false)
118148
// Check error
119149
assert.ErrorContains(t, err, "failed to parse IP: 12.3.4.5")
120150
})

0 commit comments

Comments
 (0)