Skip to content

Commit b3aa326

Browse files
authored
Improve AWS instance parsing (#2109)
1 parent 243efaf commit b3aa326

File tree

10 files changed

+9531
-3075
lines changed

10 files changed

+9531
-3075
lines changed

pkg/lib/aws/ec2.go

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,91 @@ import (
3030
s "github.com/cortexlabs/cortex/pkg/lib/strings"
3131
)
3232

33-
// aws instance types take this form: (\w+)([0-9]+)(\w*).(\w+)
34-
// the first group is the instance series, e.g. "m", "t", "g", "inf", ...
35-
// the second group is a version number for that series, e.g. 3, 4, ...
36-
// the third group is optional, and is a set of single-character "flags"
33+
var _digitsRegex = regexp.MustCompile(`[0-9]+`)
34+
35+
type ParsedInstanceType struct {
36+
Family string
37+
Generation int
38+
Capabilities strset.Set
39+
Size string
40+
}
41+
42+
// Checks weather the input is an AWS instance type
43+
func IsValidInstanceType(instanceType string) bool {
44+
return AllInstanceTypes.Has(instanceType)
45+
}
46+
47+
// Checks whether the input is an AWS instance type
48+
func CheckValidInstanceType(instanceType string) error {
49+
if !IsValidInstanceType(instanceType) {
50+
return ErrorInvalidInstanceType(instanceType)
51+
}
52+
return nil
53+
}
54+
55+
// AWS instance types take the form of: [family][generation][capabilities].[size]
56+
// the first group is the instance family, e.g. "m", "t", "g", "inf", ...
57+
// the second group is a generation number for that series, e.g. 3, 4, ...
58+
// the third group is optional, and is a set of single-character capabilities
3759
// "g" represents ARM (graviton), "a" for AMD, "n" for fast networking, "d" for fast storage, etc.
3860
// the fourth and final group (after the dot) is the instance size, e.g. "large"
39-
var _armInstanceCapabilityRegex = regexp.MustCompile(`^\w+[0-9]+\w*g\w*\.\w+$`)
61+
func ParseInstanceType(instanceType string) (ParsedInstanceType, error) {
62+
if err := CheckValidInstanceType(instanceType); err != nil {
63+
return ParsedInstanceType{}, err
64+
}
4065

41-
// instanceType is assumed to be a valid instance type that exists in AWS, e.g. g4dn.xlarge
42-
func IsARMInstance(instanceType string) bool {
43-
if strings.HasPrefix(instanceType, "a") {
44-
return true
66+
parts := strings.Split(instanceType, ".")
67+
if len(parts) != 2 {
68+
return ParsedInstanceType{}, errors.ErrorUnexpected("unexpected invalid instance type: " + instanceType)
69+
}
70+
71+
prefix := parts[0]
72+
size := parts[1]
73+
74+
digitSets := _digitsRegex.FindAllString(prefix, -1)
75+
if len(digitSets) == 0 {
76+
return ParsedInstanceType{}, errors.ErrorUnexpected("unexpected invalid instance type: " + instanceType)
4577
}
46-
return _armInstanceCapabilityRegex.MatchString(instanceType)
78+
79+
prefixParts := _digitsRegex.Split(prefix, -1)
80+
capabilitiesStr := prefixParts[len(prefixParts)-1]
81+
capabilities := strset.FromSlice(strings.Split(capabilitiesStr, ""))
82+
83+
generationStr := digitSets[len(digitSets)-1]
84+
generation, ok := s.ParseInt(generationStr)
85+
if !ok {
86+
return ParsedInstanceType{}, errors.ErrorUnexpected("unexpected invalid instance type: " + instanceType)
87+
}
88+
89+
generationIndex := strings.LastIndex(prefix, generationStr)
90+
if generationIndex == -1 {
91+
return ParsedInstanceType{}, errors.ErrorUnexpected("unexpected invalid instance type: " + instanceType)
92+
}
93+
family := prefix[:generationIndex]
94+
95+
return ParsedInstanceType{
96+
Family: family,
97+
Generation: generation,
98+
Capabilities: capabilities,
99+
Size: size,
100+
}, nil
101+
}
102+
103+
func IsARMInstance(instanceType string) (bool, error) {
104+
parsedType, err := ParseInstanceType(instanceType)
105+
if err != nil {
106+
return false, err
107+
}
108+
109+
if parsedType.Family == "a" {
110+
return true, nil
111+
}
112+
113+
if parsedType.Capabilities.Has("g") {
114+
return true, nil
115+
}
116+
117+
return false, nil
47118
}
48119

49120
func (c *Client) SpotInstancePrice(instanceType string) (float64, error) {

pkg/lib/aws/ec2_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
Copyright 2021 Cortex Labs, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package aws
18+
19+
import (
20+
"fmt"
21+
"testing"
22+
23+
"github.com/cortexlabs/cortex/pkg/lib/sets/strset"
24+
"github.com/stretchr/testify/require"
25+
)
26+
27+
func TestParseInstanceType(t *testing.T) {
28+
var testcases = []struct {
29+
instanceType string
30+
expected ParsedInstanceType
31+
}{
32+
{"t3.small", ParsedInstanceType{"t", 3, strset.New(), "small"}},
33+
{"g4dn.xlarge", ParsedInstanceType{"g", 4, strset.New("d", "n"), "xlarge"}},
34+
{"inf1.24xlarge", ParsedInstanceType{"inf", 1, strset.New(), "24xlarge"}},
35+
{"u-9tb1.metal", ParsedInstanceType{"u-9tb", 1, strset.New(), "metal"}},
36+
}
37+
38+
invalidTypes := []string{
39+
"badtype",
40+
"badtype.large",
41+
"badtype1.large",
42+
"badtype2ad.large",
43+
}
44+
45+
for _, testcase := range testcases {
46+
parsed, err := ParseInstanceType(testcase.instanceType)
47+
require.NoError(t, err)
48+
require.Equal(t, testcase.expected.Family, parsed.Family, fmt.Sprintf("unexpected family for input: %s", testcase.instanceType))
49+
require.Equal(t, testcase.expected.Generation, parsed.Generation, fmt.Sprintf("unexpected generation for input: %s", testcase.instanceType))
50+
require.ElementsMatch(t, testcase.expected.Capabilities.Slice(), parsed.Capabilities.Slice(), fmt.Sprintf("unexpected capabilities for input: %s", testcase.instanceType))
51+
require.Equal(t, testcase.expected.Size, parsed.Size, fmt.Sprintf("unexpected size for input: %s", testcase.instanceType))
52+
}
53+
54+
for _, instanceType := range invalidTypes {
55+
_, err := ParseInstanceType(instanceType)
56+
require.Error(t, err)
57+
}
58+
59+
for instanceType := range AllInstanceTypes {
60+
_, err := ParseInstanceType(instanceType)
61+
require.NoError(t, err)
62+
}
63+
}

pkg/lib/aws/elb.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@ import (
2828
// https://docs.aws.amazon.com/elasticloadbalancing/latest/network/load-balancer-target-groups.html
2929
var _nlbUnsupportedInstancePrefixes = strset.New("c1", "cc1", "cc2", "cg1", "cg2", "cr1", "g1", "g2", "hi1", "hs1", "m1", "m2", "m3", "t1")
3030

31-
// instanceType must be a valid instance type that exists in AWS, e.g. g4dn.xlarge
32-
func IsInstanceSupportedByNLB(instanceType string) bool {
33-
instancePrefix := strings.Split(instanceType, ".")[0]
34-
return !_nlbUnsupportedInstancePrefixes.Has(instancePrefix)
31+
func IsInstanceSupportedByNLB(instanceType string) (bool, error) {
32+
if err := CheckValidInstanceType(instanceType); err != nil {
33+
return false, err
34+
}
35+
36+
for prefix := range _nlbUnsupportedInstancePrefixes {
37+
if strings.HasPrefix(instanceType, prefix) {
38+
return false, nil
39+
}
40+
}
41+
42+
return true, nil
3543
}
3644

3745
// returns the the first load balancer which has all of the specified tags, or nil if no load balancers match

pkg/lib/aws/errors.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
)
2929

3030
const (
31+
ErrInvalidInstanceType = "aws.invalid_instance_type"
3132
ErrInvalidAWSCredentials = "aws.invalid_aws_credentials"
3233
ErrInvalidS3aPath = "aws.invalid_s3a_path"
3334
ErrInvalidS3Path = "aws.invalid_s3_path"
@@ -92,6 +93,13 @@ func IsErrCode(err error, errorCode string) bool {
9293
return false
9394
}
9495

96+
func ErrorInvalidInstanceType(instanceType string) error {
97+
return errors.WithStack(&errors.Error{
98+
Kind: ErrInvalidInstanceType,
99+
Message: fmt.Sprintf("%s is not an AWS instance type (e.g. m5.large is a valid instance type)", s.UserStr(instanceType)),
100+
})
101+
}
102+
95103
func ErrorInvalidAWSCredentials(awsErr error) error {
96104
awsErrMsg := errors.Message(awsErr)
97105
return errors.WithStack(&errors.Error{

0 commit comments

Comments
 (0)