Skip to content

Commit 8885fcb

Browse files
authored
Escape tilde and copy tf version for local files (#1011)
1 parent 54604e0 commit 8885fcb

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

cli/local/model_cache.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ func CacheModel(modelPath string, awsClient *aws.Client) (*spec.LocalModelCache,
8585
}
8686
} else {
8787
fmt.Println(fmt.Sprintf("caching model %s ...", modelPath))
88-
err := files.CopyDirOverwrite(strings.TrimSuffix(modelPath, "/"), s.EnsureSuffix(modelDir, "/"))
88+
tfModelVersion := filepath.Base(modelPath)
89+
err := files.CopyDirOverwrite(strings.TrimSuffix(modelPath, "/"), s.EnsureSuffix(filepath.Join(modelDir, tfModelVersion), "/"))
8990
if err != nil {
9091
return nil, err
9192
}

pkg/types/spec/validations.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,6 @@ func validateTensorFlowPredictor(predictor *userconfig.Predictor, providerType t
517517
}
518518

519519
model := *predictor.Model
520-
521520
if strings.HasPrefix(model, "s3://") {
522521
model, err := cr.S3PathValidator(model)
523522
if err != nil {
@@ -543,7 +542,16 @@ func validateTensorFlowPredictor(predictor *userconfig.Predictor, providerType t
543542
}
544543

545544
configFileDir := filepath.Dir(projectFiles.GetConfigFilePath())
546-
model := files.RelToAbsPath(*predictor.Model, configFileDir)
545+
546+
var err error
547+
if strings.HasPrefix(*predictor.Model, "~/") {
548+
model, err = files.EscapeTilde(model)
549+
if err != nil {
550+
return err
551+
}
552+
} else {
553+
model = files.RelToAbsPath(*predictor.Model, configFileDir)
554+
}
547555
if strings.HasSuffix(model, ".zip") {
548556
if err := files.CheckFile(model); err != nil {
549557
return errors.Wrap(err, userconfig.ModelKey)
@@ -571,7 +579,7 @@ func validateONNXPredictor(predictor *userconfig.Predictor, providerType types.P
571579
}
572580

573581
model := *predictor.Model
574-
582+
var err error
575583
if !strings.HasSuffix(model, ".onnx") {
576584
return errors.Wrap(ErrorInvalidONNXModelPath(), userconfig.ModelKey, model)
577585
}
@@ -591,7 +599,14 @@ func validateONNXPredictor(predictor *userconfig.Predictor, providerType types.P
591599
}
592600

593601
configFileDir := filepath.Dir(projectFiles.GetConfigFilePath())
594-
model := files.RelToAbsPath(*predictor.Model, configFileDir)
602+
if strings.HasPrefix(*predictor.Model, "~/") {
603+
model, err = files.EscapeTilde(model)
604+
if err != nil {
605+
return err
606+
}
607+
} else {
608+
model = files.RelToAbsPath(*predictor.Model, configFileDir)
609+
}
595610
if err := files.CheckFile(model); err != nil {
596611
return errors.Wrap(err, userconfig.ModelKey)
597612
}

0 commit comments

Comments
 (0)