Skip to content

Commit 4ad6d53

Browse files
authored
Allow local models to be loaded (#1075)
1 parent 1bf7c9b commit 4ad6d53

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

cli/local/model_cache.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ import (
3535

3636
func CacheModel(modelPath string, awsClient *aws.Client) (*spec.LocalModelCache, error) {
3737
localModelCache := spec.LocalModelCache{}
38-
39-
awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient)
40-
if err != nil {
41-
return nil, err
42-
}
38+
var awsClientForBucket *aws.Client
39+
var err error
4340

4441
if strings.HasPrefix(modelPath, "s3://") {
42+
awsClientForBucket, err = aws.NewFromClientS3Path(modelPath, awsClient)
43+
if err != nil {
44+
return nil, err
45+
}
4546
bucket, prefix, err := aws.SplitS3Path(modelPath)
4647
if err != nil {
4748
return nil, err

pkg/types/spec/validations.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -523,12 +523,12 @@ func validateTensorFlowPredictor(predictor *userconfig.Predictor, providerType t
523523

524524
model := *predictor.Model
525525

526-
awsClientForBucket, err := aws.NewFromClientS3Path(model, awsClient)
527-
if err != nil {
528-
return errors.Wrap(err, userconfig.ModelKey)
529-
}
530-
531526
if strings.HasPrefix(model, "s3://") {
527+
awsClientForBucket, err := aws.NewFromClientS3Path(model, awsClient)
528+
if err != nil {
529+
return errors.Wrap(err, userconfig.ModelKey)
530+
}
531+
532532
model, err := cr.S3PathValidator(model)
533533
if err != nil {
534534
return errors.Wrap(err, userconfig.ModelKey)
@@ -595,12 +595,12 @@ func validateONNXPredictor(predictor *userconfig.Predictor, providerType types.P
595595
return errors.Wrap(ErrorInvalidONNXModelPath(), userconfig.ModelKey, model)
596596
}
597597

598-
awsClientForBucket, err := aws.NewFromClientS3Path(model, awsClient)
599-
if err != nil {
600-
return errors.Wrap(err, userconfig.ModelKey)
601-
}
602-
603598
if strings.HasPrefix(model, "s3://") {
599+
awsClientForBucket, err := aws.NewFromClientS3Path(model, awsClient)
600+
if err != nil {
601+
return errors.Wrap(err, userconfig.ModelKey)
602+
}
603+
604604
model, err := cr.S3PathValidator(model)
605605
if err != nil {
606606
return errors.Wrap(err, userconfig.ModelKey)

0 commit comments

Comments
 (0)