Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmd/restrictions.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ var (

dbCidrsToAllow []string
bypassCidrChecks bool
appendMode bool

restrictionsUpdateCmd = &cobra.Command{
Use: "update",
Short: "Update network restrictions",
RunE: func(cmd *cobra.Command, args []string) error {
return update.Run(cmd.Context(), flags.ProjectRef, dbCidrsToAllow, bypassCidrChecks)
return update.Run(cmd.Context(), flags.ProjectRef, dbCidrsToAllow, bypassCidrChecks, appendMode)
},
}

Expand All @@ -38,6 +39,7 @@ func init() {
restrictionsCmd.PersistentFlags().StringVar(&flags.ProjectRef, "project-ref", "", "Project ref of the Supabase project.")
restrictionsUpdateCmd.Flags().StringSliceVar(&dbCidrsToAllow, "db-allow-cidr", []string{}, "CIDR to allow DB connections from.")
restrictionsUpdateCmd.Flags().BoolVar(&bypassCidrChecks, "bypass-cidr-checks", false, "Bypass some of the CIDR validation checks.")
restrictionsUpdateCmd.Flags().BoolVar(&appendMode, "append", false, "Append to existing restrictions instead of replacing them.")
restrictionsCmd.AddCommand(restrictionsGetCmd)
restrictionsCmd.AddCommand(restrictionsUpdateCmd)
rootCmd.AddCommand(restrictionsCmd)
Expand Down
3 changes: 2 additions & 1 deletion internal/restrictions/get/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/go-errors/errors"
"github.com/supabase/cli/internal/utils"
"github.com/supabase/cli/pkg/api"
)

func Run(ctx context.Context, projectRef string) error {
Expand All @@ -19,6 +20,6 @@ func Run(ctx context.Context, projectRef string) error {

fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", resp.JSON200.Config.DbAllowedCidrs)
fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", resp.JSON200.Config.DbAllowedCidrsV6)
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == "applied")
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == api.NetworkRestrictionsResponseStatusApplied)
return nil
}
47 changes: 45 additions & 2 deletions internal/restrictions/update/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import (
"github.com/supabase/cli/pkg/api"
)

func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypassCidrChecks bool) error {
// Run updates the network restriction lists using the provided CIDRs.
func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypassCidrChecks bool, appendMode bool) error {
// 1. separate CIDR to v4 and v6
body := api.V1UpdateNetworkRestrictionsJSONRequestBody{
DbAllowedCidrs: &[]string{},
Expand All @@ -31,6 +32,10 @@ func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypass
}
}

if appendMode {
return ApplyPatch(ctx, projectRef, body)
}

// 2. update restrictions
resp, err := utils.GetSupabase().V1UpdateNetworkRestrictionsWithResponse(ctx, projectRef, body)
if err != nil {
Expand All @@ -42,6 +47,44 @@ func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypass

fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", resp.JSON201.Config.DbAllowedCidrs)
fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", resp.JSON201.Config.DbAllowedCidrsV6)
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON201.Status == "applied")
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON201.Status == api.NetworkRestrictionsResponseStatusApplied)
return nil
}

// ApplyPatch submits a network restriction payload using PATCH (add/remove mode).
func ApplyPatch(ctx context.Context, projectRef string, body api.V1UpdateNetworkRestrictionsJSONRequestBody) error {
patchBody := api.V1PatchNetworkRestrictionsJSONRequestBody{
Add: &struct {
DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"`
DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"`
}{
DbAllowedCidrs: body.DbAllowedCidrs,
DbAllowedCidrsV6: body.DbAllowedCidrsV6,
},
}

resp, err := utils.GetSupabase().V1PatchNetworkRestrictionsWithResponse(ctx, projectRef, patchBody)
if err != nil {
return errors.Errorf("failed to apply network restrictions: %w", err)
}
if resp.JSON200 == nil {
return errors.New("failed to apply network restrictions: " + string(resp.Body))
}

var allowedIPv4, allowedIPv6 []string
if allowed := resp.JSON200.Config.DbAllowedCidrs; allowed != nil {
for _, cidr := range *allowed {
switch cidr.Type {
case api.NetworkRestrictionsV2ResponseConfigDbAllowedCidrsTypeV4:
allowedIPv4 = append(allowedIPv4, cidr.Address)
case api.NetworkRestrictionsV2ResponseConfigDbAllowedCidrsTypeV6:
allowedIPv6 = append(allowedIPv6, cidr.Address)
}
}
}

fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", &allowedIPv4)
fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", &allowedIPv6)
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == api.NetworkRestrictionsV2ResponseStatusApplied)
return nil
}
48 changes: 39 additions & 9 deletions internal/restrictions/update/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,52 @@ func TestUpdateRestrictionsCommand(t *testing.T) {
token := apitest.RandomAccessToken(t)
t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))

t.Run("updates v4 and v6 CIDR", func(t *testing.T) {
t.Run("replaces v4 and v6 CIDR", func(t *testing.T) {
// Setup mock api
defer gock.OffAll()
expectedV4 := []string{"12.3.4.5/32", "1.2.3.1/24"}
expectedV6 := []string{"2001:db8:abcd:0012::0/64"}
gock.New(utils.DefaultApiHost).
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
MatchType("json").
JSON(api.NetworkRestrictionsRequest{
DbAllowedCidrs: &[]string{"12.3.4.5/32", "1.2.3.1/24"},
DbAllowedCidrsV6: &[]string{"2001:db8:abcd:0012::0/64"},
DbAllowedCidrs: &expectedV4,
DbAllowedCidrsV6: &expectedV6,
}).
Reply(http.StatusCreated).
JSON(api.NetworkRestrictionsResponse{
Status: api.NetworkRestrictionsResponseStatus("applied"),
})
// Run test
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false)
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false, false)
// Check error
assert.NoError(t, err)
assert.Empty(t, apitest.ListUnmatchedRequests())
})

t.Run("appends v4 and v6 CIDR", func(t *testing.T) {
// Setup mock api
defer gock.OffAll()
addV4 := []string{"12.3.4.5/32", "1.2.3.1/24"}
addV6 := []string{"2001:db8:abcd:0012::0/64"}
gock.New(utils.DefaultApiHost).
Patch("/v1/projects/" + projectRef + "/network-restrictions").
MatchType("json").
JSON(api.NetworkRestrictionsPatchRequest{
Add: &struct {
DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"`
DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"`
}{
DbAllowedCidrs: &addV4,
DbAllowedCidrsV6: &addV6,
},
}).
Reply(http.StatusOK).
JSON(api.NetworkRestrictionsV2Response{
Status: api.NetworkRestrictionsV2ResponseStatus("applied"),
})
// Run test
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "1.2.3.1/24", "2001:db8:abcd:0012::0/64"}, false, true)
// Check error
assert.NoError(t, err)
assert.Empty(t, apitest.ListUnmatchedRequests())
Expand All @@ -53,7 +83,7 @@ func TestUpdateRestrictionsCommand(t *testing.T) {
}).
ReplyError(errNetwork)
// Run test
err := Run(context.Background(), projectRef, []string{}, true)
err := Run(context.Background(), projectRef, []string{}, true, false)
// Check error
assert.ErrorIs(t, err, errNetwork)
assert.Empty(t, apitest.ListUnmatchedRequests())
Expand All @@ -71,7 +101,7 @@ func TestUpdateRestrictionsCommand(t *testing.T) {
}).
Reply(http.StatusServiceUnavailable)
// Run test
err := Run(context.Background(), projectRef, []string{}, true)
err := Run(context.Background(), projectRef, []string{}, true, false)
// Check error
assert.ErrorContains(t, err, "failed to apply network restrictions:")
assert.Empty(t, apitest.ListUnmatchedRequests())
Expand Down Expand Up @@ -99,22 +129,22 @@ func TestValidateCIDR(t *testing.T) {
Status: api.NetworkRestrictionsResponseStatus("applied"),
})
// Run test
err := Run(context.Background(), projectRef, []string{"10.0.0.0/8"}, true)
err := Run(context.Background(), projectRef, []string{"10.0.0.0/8"}, true, false)
// Check error
assert.NoError(t, err)
assert.Empty(t, apitest.ListUnmatchedRequests())
})

t.Run("throws error on private subnet", func(t *testing.T) {
// Run test
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false)
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false, false)
// Check error
assert.ErrorContains(t, err, "private IP provided: 10.0.0.0/8")
})

t.Run("throws error on invalid subnet", func(t *testing.T) {
// Run test
err := Run(context.Background(), projectRef, []string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false)
err := Run(context.Background(), projectRef, []string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false, false)
// Check error
assert.ErrorContains(t, err, "failed to parse IP: 12.3.4.5")
})
Expand Down
Loading