Skip to content

Commit 48dc43c

Browse files
authored
Remove extra newline after downloading/caching models locally (#1401)
1 parent fb74b7e commit 48dc43c

File tree

1 file changed

+53
-33
lines changed

1 file changed

+53
-33
lines changed

cli/local/model_cache.go

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,100 +41,120 @@ func CacheModels(apiSpec *spec.API, awsClient *aws.Client) ([]*spec.LocalModelCa
4141
modelPaths[i] = modelResource.ModelPath
4242
}
4343

44+
uncachedModelCount := 0
45+
4446
localModelCaches := make([]*spec.LocalModelCache, len(modelPaths))
4547
for i, modelPath := range modelPaths {
4648
var err error
47-
localModelCaches[i], err = CacheModel(modelPath, awsClient)
49+
modelCacheID, err := modelCacheID(modelPath, awsClient)
4850
if err != nil {
4951
if apiSpec.Predictor.ModelPath != nil {
5052
return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelPathKey)
5153
}
5254
return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelsKey, apiSpec.Predictor.Models[i].Name, userconfig.ModelPathKey)
5355
}
54-
localModelCaches[i].TargetPath = apiSpec.Predictor.Models[i].Name
56+
57+
localModelCache := spec.LocalModelCache{
58+
ID: modelCacheID,
59+
HostPath: filepath.Join(_modelCacheDir, modelCacheID),
60+
TargetPath: apiSpec.Predictor.Models[i].Name,
61+
}
62+
63+
if !files.IsFile(filepath.Join(localModelCache.HostPath, "_SUCCESS")) {
64+
err = cacheModel(modelPath, localModelCache, awsClient)
65+
if err != nil {
66+
if apiSpec.Predictor.ModelPath != nil {
67+
return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelPathKey)
68+
}
69+
return nil, errors.Wrap(err, apiSpec.Identify(), userconfig.PredictorKey, userconfig.ModelsKey, apiSpec.Predictor.Models[i].Name, userconfig.ModelPathKey)
70+
}
71+
uncachedModelCount++
72+
}
73+
74+
localModelCaches[i] = &localModelCache
5575
}
5676

57-
if len(localModelCaches) > 0 {
77+
if uncachedModelCount > 0 {
5878
fmt.Println("") // Newline to group all of the model information
5979
}
6080

6181
return localModelCaches, nil
6282
}
6383

64-
func CacheModel(modelPath string, awsClient *aws.Client) (*spec.LocalModelCache, error) {
65-
localModelCache := spec.LocalModelCache{}
66-
var awsClientForBucket *aws.Client
67-
var err error
68-
84+
func modelCacheID(modelPath string, awsClient *aws.Client) (string, error) {
6985
if strings.HasPrefix(modelPath, "s3://") {
70-
awsClientForBucket, err = aws.NewFromClientS3Path(modelPath, awsClient)
86+
awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient)
7187
if err != nil {
72-
return nil, err
88+
return "", err
7389
}
7490
bucket, prefix, err := aws.SplitS3Path(modelPath)
7591
if err != nil {
76-
return nil, err
92+
return "", err
7793
}
7894
hash, err := awsClientForBucket.HashS3Dir(bucket, prefix, nil)
7995
if err != nil {
80-
return nil, err
81-
}
82-
localModelCache.ID = hash
83-
} else {
84-
hash, err := localModelHash(modelPath)
85-
if err != nil {
86-
return nil, err
96+
return "", err
8797
}
88-
localModelCache.ID = hash
98+
return hash, nil
8999
}
90100

91-
modelDir := filepath.Join(_modelCacheDir, localModelCache.ID)
101+
hash, err := localModelHash(modelPath)
102+
if err != nil {
103+
return "", err
104+
}
105+
return hash, nil
106+
}
107+
108+
func cacheModel(modelPath string, localModelCache spec.LocalModelCache, awsClient *aws.Client) error {
109+
modelDir := localModelCache.HostPath
92110

93111
if files.IsFile(filepath.Join(modelDir, "_SUCCESS")) {
94-
localModelCache.HostPath = modelDir
95-
return &localModelCache, nil
112+
return nil
96113
}
97114

98-
err = ResetModelCacheDir(modelDir)
115+
err := ResetModelCacheDir(modelDir)
99116
if err != nil {
100-
return nil, err
117+
return err
101118
}
102119

103120
if strings.HasPrefix(modelPath, "s3://") {
104-
err := downloadModel(modelPath, modelDir, awsClientForBucket)
121+
awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient)
122+
if err != nil {
123+
return err
124+
}
125+
126+
err = downloadModel(modelPath, modelDir, awsClientForBucket)
105127
if err != nil {
106-
return nil, err
128+
return err
107129
}
108130
} else {
109131
if strings.HasSuffix(modelPath, ".zip") {
110132
err := unzipAndValidate(modelPath, modelPath, modelDir)
111133
if err != nil {
112-
return nil, err
134+
return err
113135
}
114136
} else if strings.HasSuffix(modelPath, ".onnx") {
115137
fmt.Println(fmt.Sprintf("○ caching model %s ...", modelPath))
116138
err := files.CopyFileOverwrite(modelPath, filepath.Join(modelDir, filepath.Base(modelPath)))
117139
if err != nil {
118-
return nil, err
140+
return err
119141
}
120142
} else {
121143
fmt.Println(fmt.Sprintf("○ caching model %s ...", modelPath))
122144
tfModelVersion := filepath.Base(modelPath)
123145
err := files.CopyDirOverwrite(strings.TrimSuffix(modelPath, "/"), s.EnsureSuffix(filepath.Join(modelDir, tfModelVersion), "/"))
124146
if err != nil {
125-
return nil, err
147+
return err
126148
}
127149
}
128150
}
129151

130152
err = files.MakeEmptyFile(filepath.Join(modelDir, "_SUCCESS"))
131153
if err != nil {
132-
return nil, err
154+
return err
133155
}
134156

135-
localModelCache.HostPath = modelDir
136-
137-
return &localModelCache, nil
157+
return nil
138158
}
139159

140160
func DeleteCachedModels(apiName string, modelsToDelete []string) error {

0 commit comments

Comments
 (0)