Skip to content

Commit d0b9d23

Browse files
tomasdembellidylanratcliffe
authored andcommitted
add apigateway model adapter
1 parent 2bdaf6b commit d0b9d23

File tree

9 files changed

+319
-1
lines changed

9 files changed

+319
-1
lines changed

adapters/apigateway-model.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package adapters
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strings"
7+
8+
"github.com/aws/aws-sdk-go-v2/service/apigateway"
9+
"github.com/aws/aws-sdk-go-v2/service/apigateway/types"
10+
"github.com/overmindtech/aws-source/adapterhelpers"
11+
"github.com/overmindtech/sdp-go"
12+
)
13+
14+
func convertGetModelOutputToModel(output *apigateway.GetModelOutput) *types.Model {
15+
return &types.Model{
16+
Id: output.Id,
17+
Name: output.Name,
18+
Description: output.Description,
19+
Schema: output.Schema,
20+
ContentType: output.ContentType,
21+
}
22+
}
23+
24+
func modelOutputMapper(query, scope string, awsItem *types.Model) (*sdp.Item, error) {
25+
attributes, err := adapterhelpers.ToAttributesWithExclude(awsItem, "tags")
26+
if err != nil {
27+
return nil, err
28+
}
29+
30+
restAPIID := strings.Split(query, "/")[0]
31+
32+
err = attributes.Set("UniqueAttribute", fmt.Sprintf("%s/%s", restAPIID, *awsItem.Name))
33+
34+
item := sdp.Item{
35+
Type: "apigateway-model",
36+
UniqueAttribute: "Name",
37+
Attributes: attributes,
38+
Scope: scope,
39+
}
40+
41+
item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{
42+
Query: &sdp.Query{
43+
Type: "apigateway-rest-api",
44+
Method: sdp.QueryMethod_GET,
45+
Query: restAPIID,
46+
Scope: scope,
47+
},
48+
BlastPropagation: &sdp.BlastPropagation{
49+
// They are tightly coupled, so we need to propagate the blast to the linked item
50+
In: true,
51+
Out: true,
52+
},
53+
})
54+
55+
return &item, nil
56+
}
57+
58+
func NewAPIGatewayModelAdapter(client *apigateway.Client, accountID string, region string) *adapterhelpers.GetListAdapter[*types.Model, *apigateway.Client, *apigateway.Options] {
59+
return &adapterhelpers.GetListAdapter[*types.Model, *apigateway.Client, *apigateway.Options]{
60+
ItemType: "apigateway-model",
61+
Client: client,
62+
AccountID: accountID,
63+
Region: region,
64+
AdapterMetadata: modelAdapterMetadata,
65+
GetFunc: func(ctx context.Context, client *apigateway.Client, scope, query string) (*types.Model, error) {
66+
f := strings.Split(query, "/")
67+
if len(f) != 2 {
68+
return nil, &sdp.QueryError{
69+
ErrorType: sdp.QueryError_NOTFOUND,
70+
ErrorString: fmt.Sprintf("query must be in the format of: the rest-api-id/model-name, but found: %s", query),
71+
}
72+
}
73+
out, err := client.GetModel(ctx, &apigateway.GetModelInput{
74+
RestApiId: &f[0],
75+
ModelName: &f[1],
76+
})
77+
if err != nil {
78+
return nil, err
79+
}
80+
return convertGetModelOutputToModel(out), nil
81+
},
82+
DisableList: true,
83+
SearchFunc: func(ctx context.Context, client *apigateway.Client, scope string, query string) ([]*types.Model, error) {
84+
out, err := client.GetModels(ctx, &apigateway.GetModelsInput{
85+
RestApiId: &query,
86+
})
87+
if err != nil {
88+
return nil, err
89+
}
90+
91+
var items []*types.Model
92+
for _, model := range out.Items {
93+
items = append(items, &model)
94+
}
95+
96+
return items, nil
97+
},
98+
ItemMapper: func(query, scope string, awsItem *types.Model) (*sdp.Item, error) {
99+
return modelOutputMapper(query, scope, awsItem)
100+
},
101+
}
102+
}
103+
104+
var modelAdapterMetadata = Metadata.Register(&sdp.AdapterMetadata{
105+
Type: "apigateway-model",
106+
DescriptiveName: "API Gateway Model",
107+
Category: sdp.AdapterCategory_ADAPTER_CATEGORY_CONFIGURATION,
108+
SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{
109+
Get: true,
110+
Search: true,
111+
GetDescription: "Get an API Gateway Model by its rest API ID and model name: rest-api-id/model-name",
112+
SearchDescription: "Search for API Gateway Models by their rest API ID: rest-api-id",
113+
},
114+
})

adapters/apigateway-model_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package adapters
2+
3+
import (
4+
"github.com/overmindtech/sdp-go"
5+
"testing"
6+
"time"
7+
8+
"github.com/aws/aws-sdk-go-v2/aws"
9+
"github.com/aws/aws-sdk-go-v2/service/apigateway"
10+
"github.com/aws/aws-sdk-go-v2/service/apigateway/types"
11+
"github.com/overmindtech/aws-source/adapterhelpers"
12+
)
13+
14+
func TestModelOutputMapper(t *testing.T) {
15+
awsItem := &types.Model{
16+
Id: aws.String("model-id"),
17+
Name: aws.String("model-name"),
18+
Description: aws.String("description"),
19+
Schema: aws.String("{\"type\": \"object\"}"),
20+
ContentType: aws.String("application/json"),
21+
}
22+
23+
item, err := modelOutputMapper("rest-api-id/model-name", "scope", awsItem)
24+
if err != nil {
25+
t.Fatalf("unexpected error: %v", err)
26+
}
27+
28+
if err := item.Validate(); err != nil {
29+
t.Error(err)
30+
}
31+
32+
tests := adapterhelpers.QueryTests{
33+
{
34+
ExpectedType: "apigateway-rest-api",
35+
ExpectedMethod: sdp.QueryMethod_GET,
36+
ExpectedQuery: "rest-api-id",
37+
ExpectedScope: "scope",
38+
},
39+
}
40+
41+
tests.Execute(t, item)
42+
}
43+
44+
func TestNewAPIGatewayModelAdapter(t *testing.T) {
45+
config, account, region := adapterhelpers.GetAutoConfig(t)
46+
47+
client := apigateway.NewFromConfig(config)
48+
49+
adapter := NewAPIGatewayModelAdapter(client, account, region)
50+
51+
test := adapterhelpers.E2ETest{
52+
Adapter: adapter,
53+
Timeout: 10 * time.Second,
54+
SkipList: true,
55+
}
56+
57+
test.Run(t)
58+
}

adapters/apigateway-stage.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ func stageOutputMapper(query, scope string, awsItem *types.Stage) (*sdp.Item, er
7070
})
7171
}
7272

73+
item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{
74+
Query: &sdp.Query{
75+
Type: "apigateway-rest-api",
76+
Method: sdp.QueryMethod_GET,
77+
Query: restAPIID,
78+
Scope: scope,
79+
},
80+
BlastPropagation: &sdp.BlastPropagation{
81+
// They are tightly coupled, so we need to propagate the blast to the linked item
82+
In: true,
83+
Out: true,
84+
},
85+
})
86+
7387
return &item, nil
7488
}
7589

adapters/apigateway-stage_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ func TestStageOutputMapper(t *testing.T) {
5050
ExpectedQuery: "rest-api-id/deployment-id",
5151
ExpectedScope: "scope",
5252
},
53+
{
54+
ExpectedType: "apigateway-rest-api",
55+
ExpectedMethod: sdp.QueryMethod_GET,
56+
ExpectedQuery: "rest-api-id",
57+
ExpectedScope: "scope",
58+
},
5359
}
5460

5561
tests.Execute(t, item)

adapters/integration/apigateway/apigateway_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ func APIGateway(t *testing.T) {
9494
t.Fatalf("failed to validate APIGateway stage adapter: %v", err)
9595
}
9696

97+
modelSource := adapters.NewAPIGatewayModelAdapter(testClient, accountID, testAWSConfig.Region)
98+
99+
err = modelSource.Validate()
100+
if err != nil {
101+
t.Fatalf("failed to validate APIGateway model adapter: %v", err)
102+
}
103+
97104
// Tests ----------------------------------------------------------------------------------------------------------
98105

99106
scope := adapterhelpers.FormatScope(accountID, testAWSConfig.Region)
@@ -536,5 +543,50 @@ func APIGateway(t *testing.T) {
536543
t.Fatalf("expected stage ID %s, got %s", stageID, stageIDFromSearch)
537544
}
538545

546+
// Search models by restApiID
547+
models, err := modelSource.Search(ctx, scope, restApiID, true)
548+
if err != nil {
549+
t.Fatalf("failed to search APIGateway models: %v", err)
550+
}
551+
552+
if len(models) == 0 {
553+
t.Fatalf("no models found")
554+
}
555+
556+
modelUniqueAttribute := models[0].GetUniqueAttribute()
557+
558+
modelID, err := integration.GetUniqueAttributeValueBySignificantAttribute(
559+
modelUniqueAttribute,
560+
"Name",
561+
"testModel",
562+
models,
563+
true,
564+
)
565+
if err != nil {
566+
t.Fatalf("failed to get model ID: %v", err)
567+
}
568+
569+
// Get model
570+
query = fmt.Sprintf("%s/testModel", restApiID)
571+
model, err := modelSource.Get(ctx, scope, query, true)
572+
if err != nil {
573+
t.Fatalf("failed to get APIGateway model: %v", err)
574+
}
575+
576+
modelIDFromGet, err := integration.GetUniqueAttributeValueBySignificantAttribute(
577+
modelUniqueAttribute,
578+
"Name",
579+
"testModel",
580+
[]*sdp.Item{model},
581+
true,
582+
)
583+
if err != nil {
584+
t.Fatalf("failed to get model ID from get: %v", err)
585+
}
586+
587+
if modelID != modelIDFromGet {
588+
t.Fatalf("expected model ID %s, got %s", modelID, modelIDFromGet)
589+
}
590+
539591
t.Log("APIGateway integration test completed")
540592
}

adapters/integration/apigateway/create.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,34 @@ func createStage(ctx context.Context, logger *slog.Logger, client *apigateway.Cl
284284

285285
return nil
286286
}
287+
288+
func createModel(ctx context.Context, logger *slog.Logger, client *apigateway.Client, restAPIID string) error {
289+
modelName := "testModel"
290+
291+
// check if a model with the same testID already exists
292+
err := findModelByName(ctx, client, restAPIID, modelName)
293+
if err != nil {
294+
if errors.As(err, new(integration.NotFoundError)) {
295+
logger.InfoContext(ctx, "Creating model")
296+
} else {
297+
return err
298+
}
299+
}
300+
301+
if err == nil {
302+
logger.InfoContext(ctx, "Model already exists")
303+
return nil
304+
}
305+
306+
_, err = client.CreateModel(ctx, &apigateway.CreateModelInput{
307+
RestApiId: &restAPIID,
308+
Name: &modelName,
309+
Schema: adapterhelpers.PtrString("{}"),
310+
ContentType: adapterhelpers.PtrString("application/json"),
311+
})
312+
if err != nil {
313+
return err
314+
}
315+
316+
return nil
317+
}

adapters/integration/apigateway/find.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,45 @@ func findStageByName(ctx context.Context, client *apigateway.Client, restAPIID,
184184
name,
185185
))
186186
}
187+
188+
return err
189+
}
190+
191+
if result == nil {
192+
return integration.NewNotFoundError(integration.ResourceName(
193+
integration.APIGateway,
194+
stageSrc,
195+
name,
196+
))
197+
}
198+
199+
return nil
200+
}
201+
202+
func findModelByName(ctx context.Context, client *apigateway.Client, restAPIID, name string) error {
203+
result, err := client.GetModel(ctx, &apigateway.GetModelInput{
204+
RestApiId: &restAPIID,
205+
ModelName: &name,
206+
})
207+
if err != nil {
208+
var notFoundErr *types.NotFoundException
209+
if errors.As(err, &notFoundErr) {
210+
return integration.NewNotFoundError(integration.ResourceName(
211+
integration.APIGateway,
212+
stageSrc,
213+
name,
214+
))
215+
}
216+
217+
return err
187218
}
188219

189220
if result == nil {
190-
return integration.NewNotFoundError(name)
221+
return integration.NewNotFoundError(integration.ResourceName(
222+
integration.APIGateway,
223+
stageSrc,
224+
name,
225+
))
191226
}
192227

193228
return nil

adapters/integration/apigateway/setup.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const (
1818
authorizerSrc = "authorizer"
1919
deploymentSrc = "deployment"
2020
stageSrc = "stage"
21+
modelSrc = "model"
2122
)
2223

2324
func setup(ctx context.Context, logger *slog.Logger, client *apigateway.Client) error {
@@ -83,5 +84,11 @@ func setup(ctx context.Context, logger *slog.Logger, client *apigateway.Client)
8384
return err
8485
}
8586

87+
// Create Model
88+
err = createModel(ctx, logger, client, *restApiID)
89+
if err != nil {
90+
return err
91+
}
92+
8693
return nil
8794
}

proc/proc.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ func InitializeAwsSourceEngine(ctx context.Context, ec *discovery.EngineConfig,
487487
adapters.NewAPIGatewayAuthorizerAdapter(apigatewayClient, *callerID.Account, cfg.Region),
488488
adapters.NewAPIGatewayDeploymentAdapter(apigatewayClient, *callerID.Account, cfg.Region),
489489
adapters.NewAPIGatewayStageAdapter(apigatewayClient, *callerID.Account, cfg.Region),
490+
adapters.NewAPIGatewayModelAdapter(apigatewayClient, *callerID.Account, cfg.Region),
490491

491492
// SSM
492493
adapters.NewSSMParameterAdapter(ssmClient, *callerID.Account, cfg.Region),

0 commit comments

Comments
 (0)