Skip to content

Commit e2ff778

Browse files
authored
Path based model format detection (#251)
1 parent a1254dc commit e2ff778

File tree

8 files changed

+46
-39
lines changed

8 files changed

+46
-39
lines changed

docs/apis/apis.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Serve models at scale.
88
- kind: api
99
name: <string> # API name (required)
1010
model: <string> # path to an exported model (e.g. s3://my-bucket/model.zip)
11-
model_format: <string> # model format, must be "tensorflow" or "onnx"
11+
model_format: <string> # model format, must be "tensorflow" or "onnx" (default: "onnx" if model path ends with .onnx, "tensorflow" if model path ends with .zip)
1212
request_handler: <string> # path to the request handler implementation file, relative to the cortex root
1313
compute:
1414
min_replicas: <int> # minimum number of replicas (default: 1)

docs/apis/packaging-models.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ Reference your `model` in an API:
2222
```yaml
2323
- kind: api
2424
name: my-api
25-
model_format: tensorflow
2625
model: s3://my-bucket/model.zip
2726
```
2827
@@ -65,5 +64,4 @@ Reference your `model` in an API:
6564
- kind: api
6665
name: my-api
6766
model: s3://my-bucket/model.onnx
68-
model_format: onnx
6967
```

docs/apis/tutorial.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ Cortex requires a `cortex.yaml` file which defines a `deployment` resource. An `
2626
- kind: api
2727
name: classifier
2828
model: s3://cortex-examples/iris/tensorflow.zip
29-
model_format: tensorflow
3029
```
3130
3231
Cortex is able to read from any S3 bucket that you have access to.

examples/iris/cortex.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,23 @@
44
- kind: api
55
name: tensorflow
66
model: s3://cortex-examples/iris/tensorflow.zip
7-
model_format: tensorflow
87

98
- kind: api
109
name: pytorch
1110
model: s3://cortex-examples/iris/pytorch.onnx
12-
model_format: onnx
1311
request_handler: pytorch/handler.py
1412

1513
- kind: api
1614
name: xgboost
1715
model: s3://cortex-examples/iris/xgboost.onnx
18-
model_format: onnx
1916
request_handler: xgboost/handler.py
2017

2118
- kind: api
2219
name: sklearn
2320
model: s3://cortex-examples/iris/sklearn.onnx
24-
model_format: onnx
2521
request_handler: sklearn/handler.py
2622

2723
- kind: api
2824
name: keras
2925
model: s3://cortex-examples/iris/keras.onnx
30-
model_format: onnx
3126
request_handler: keras/handler.py

examples/iris/keras/irises.json

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
{
2-
"samples": [
3-
[
4-
5.9,
5-
3.0,
6-
5.1,
7-
1.8
8-
],
9-
[
10-
5.6,
11-
2.5,
12-
3.9,
13-
1.1
14-
]
2+
"samples": [
3+
[
4+
5.9,
5+
3.0,
6+
5.1,
7+
1.8
8+
],
9+
[
10+
5.6,
11+
2.5,
12+
3.9,
13+
1.1
1514
]
15+
]
1616
}

examples/iris/xgboost/irises.json

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
{
2-
"samples": [
3-
[
4-
5.9,
5-
3.0,
6-
5.1,
7-
1.8
8-
],
9-
[
10-
5.6,
11-
2.5,
12-
3.9,
13-
1.1
14-
]
2+
"samples": [
3+
[
4+
5.9,
5+
3.0,
6+
5.1,
7+
1.8
8+
],
9+
[
10+
5.6,
11+
2.5,
12+
3.9,
13+
1.1
1514
]
15+
]
1616
}

pkg/operator/api/userconfig/apis.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,20 @@ func (api *API) Validate() error {
119119
if yaml.StartsWithEscapedAtSymbol(api.Model) {
120120
api.ModelFormat = TensorFlowModelFormat
121121
} else {
122-
if api.ModelFormat == UnknownModelFormat {
123-
return errors.Wrap(cr.ErrorMustBeDefined(), Identify(api), ModelFormatKey)
124-
}
125-
126122
if !aws.IsValidS3Path(api.Model) {
127123
return errors.Wrap(ErrorInvalidS3PathOrResourceReference(api.Model), Identify(api), ModelKey)
128124
}
129125

126+
if api.ModelFormat == UnknownModelFormat {
127+
if strings.HasSuffix(api.Model, ".onnx") {
128+
api.ModelFormat = ONNXModelFormat
129+
} else if strings.HasSuffix(api.Model, ".zip") {
130+
api.ModelFormat = TensorFlowModelFormat
131+
} else {
132+
return errors.Wrap(ErrorUnableToInferModelFormat(), Identify(api))
133+
}
134+
}
135+
130136
if ok, err := aws.IsS3PathFileExternal(api.Model); err != nil || !ok {
131137
return errors.Wrap(ErrorExternalNotFound(api.Model), Identify(api), ModelKey)
132138
}

pkg/operator/api/userconfig/errors.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ const (
7878
ErrExtraResourcesWithExternalAPIs
7979
ErrImplDoesNotExist
8080
ErrInvalidS3PathOrResourceReference
81+
ErrUnableToInferModelFormat
8182
ErrExternalNotFound
8283
)
8384

@@ -130,6 +131,7 @@ var errorKinds = []string{
130131
"err_extra_resources_with_external_apis",
131132
"err_impl_does_not_exist",
132133
"err_invalid_s3_path_or_resource_reference",
134+
"err_unable_to_infer_model_format",
133135
"err_external_not_found",
134136
}
135137

@@ -597,6 +599,13 @@ func ErrorExternalNotFound(path string) error {
597599
}
598600
}
599601

602+
func ErrorUnableToInferModelFormat() error {
603+
return Error{
604+
Kind: ErrUnableToInferModelFormat,
605+
message: "unable to infer " + ModelFormatKey + ": path to model should end in .zip for TensorFlow models, .onnx for ONNX models, or the " + ModelFormatKey + " key must be specified",
606+
}
607+
}
608+
600609
func ErrorInvalidS3PathOrResourceReference(provided string) error {
601610
s3ErrMsg := aws.ErrorInvalidS3Path(provided).Error()
602611
return Error{

0 commit comments

Comments
 (0)