@@ -41,100 +41,120 @@ func CacheModels(apiSpec *spec.API, awsClient *aws.Client) ([]*spec.LocalModelCa
4141 modelPaths [i ] = modelResource .ModelPath
4242 }
4343
44+ uncachedModelCount := 0
45+
4446 localModelCaches := make ([]* spec.LocalModelCache , len (modelPaths ))
4547 for i , modelPath := range modelPaths {
4648 var err error
47- localModelCaches [ i ] , err = CacheModel (modelPath , awsClient )
49+ modelCacheID , err := modelCacheID (modelPath , awsClient )
4850 if err != nil {
4951 if apiSpec .Predictor .ModelPath != nil {
5052 return nil , errors .Wrap (err , apiSpec .Identify (), userconfig .PredictorKey , userconfig .ModelPathKey )
5153 }
5254 return nil , errors .Wrap (err , apiSpec .Identify (), userconfig .PredictorKey , userconfig .ModelsKey , apiSpec .Predictor .Models [i ].Name , userconfig .ModelPathKey )
5355 }
54- localModelCaches [i ].TargetPath = apiSpec .Predictor .Models [i ].Name
56+
57+ localModelCache := spec.LocalModelCache {
58+ ID : modelCacheID ,
59+ HostPath : filepath .Join (_modelCacheDir , modelCacheID ),
60+ TargetPath : apiSpec .Predictor .Models [i ].Name ,
61+ }
62+
63+ if ! files .IsFile (filepath .Join (localModelCache .HostPath , "_SUCCESS" )) {
64+ err = cacheModel (modelPath , localModelCache , awsClient )
65+ if err != nil {
66+ if apiSpec .Predictor .ModelPath != nil {
67+ return nil , errors .Wrap (err , apiSpec .Identify (), userconfig .PredictorKey , userconfig .ModelPathKey )
68+ }
69+ return nil , errors .Wrap (err , apiSpec .Identify (), userconfig .PredictorKey , userconfig .ModelsKey , apiSpec .Predictor .Models [i ].Name , userconfig .ModelPathKey )
70+ }
71+ uncachedModelCount ++
72+ }
73+
74+ localModelCaches [i ] = & localModelCache
5575 }
5676
57- if len ( localModelCaches ) > 0 {
77+ if uncachedModelCount > 0 {
5878 fmt .Println ("" ) // Newline to group all of the model information
5979 }
6080
6181 return localModelCaches , nil
6282}
6383
64- func CacheModel (modelPath string , awsClient * aws.Client ) (* spec.LocalModelCache , error ) {
65- localModelCache := spec.LocalModelCache {}
66- var awsClientForBucket * aws.Client
67- var err error
68-
84+ func modelCacheID (modelPath string , awsClient * aws.Client ) (string , error ) {
6985 if strings .HasPrefix (modelPath , "s3://" ) {
70- awsClientForBucket , err = aws .NewFromClientS3Path (modelPath , awsClient )
86+ awsClientForBucket , err : = aws .NewFromClientS3Path (modelPath , awsClient )
7187 if err != nil {
72- return nil , err
88+ return "" , err
7389 }
7490 bucket , prefix , err := aws .SplitS3Path (modelPath )
7591 if err != nil {
76- return nil , err
92+ return "" , err
7793 }
7894 hash , err := awsClientForBucket .HashS3Dir (bucket , prefix , nil )
7995 if err != nil {
80- return nil , err
81- }
82- localModelCache .ID = hash
83- } else {
84- hash , err := localModelHash (modelPath )
85- if err != nil {
86- return nil , err
96+ return "" , err
8797 }
88- localModelCache . ID = hash
98+ return hash , nil
8999 }
90100
91- modelDir := filepath .Join (_modelCacheDir , localModelCache .ID )
101+ hash , err := localModelHash (modelPath )
102+ if err != nil {
103+ return "" , err
104+ }
105+ return hash , nil
106+ }
107+
108+ func cacheModel (modelPath string , localModelCache spec.LocalModelCache , awsClient * aws.Client ) error {
109+ modelDir := localModelCache .HostPath
92110
93111 if files .IsFile (filepath .Join (modelDir , "_SUCCESS" )) {
94- localModelCache .HostPath = modelDir
95- return & localModelCache , nil
112+ return nil
96113 }
97114
98- err = ResetModelCacheDir (modelDir )
115+ err : = ResetModelCacheDir (modelDir )
99116 if err != nil {
100- return nil , err
117+ return err
101118 }
102119
103120 if strings .HasPrefix (modelPath , "s3://" ) {
104- err := downloadModel (modelPath , modelDir , awsClientForBucket )
121+ awsClientForBucket , err := aws .NewFromClientS3Path (modelPath , awsClient )
122+ if err != nil {
123+ return err
124+ }
125+
126+ err = downloadModel (modelPath , modelDir , awsClientForBucket )
105127 if err != nil {
106- return nil , err
128+ return err
107129 }
108130 } else {
109131 if strings .HasSuffix (modelPath , ".zip" ) {
110132 err := unzipAndValidate (modelPath , modelPath , modelDir )
111133 if err != nil {
112- return nil , err
134+ return err
113135 }
114136 } else if strings .HasSuffix (modelPath , ".onnx" ) {
115137 fmt .Println (fmt .Sprintf ("○ caching model %s ..." , modelPath ))
116138 err := files .CopyFileOverwrite (modelPath , filepath .Join (modelDir , filepath .Base (modelPath )))
117139 if err != nil {
118- return nil , err
140+ return err
119141 }
120142 } else {
121143 fmt .Println (fmt .Sprintf ("○ caching model %s ..." , modelPath ))
122144 tfModelVersion := filepath .Base (modelPath )
123145 err := files .CopyDirOverwrite (strings .TrimSuffix (modelPath , "/" ), s .EnsureSuffix (filepath .Join (modelDir , tfModelVersion ), "/" ))
124146 if err != nil {
125- return nil , err
147+ return err
126148 }
127149 }
128150 }
129151
130152 err = files .MakeEmptyFile (filepath .Join (modelDir , "_SUCCESS" ))
131153 if err != nil {
132- return nil , err
154+ return err
133155 }
134156
135- localModelCache .HostPath = modelDir
136-
137- return & localModelCache , nil
157+ return nil
138158}
139159
140160func DeleteCachedModels (apiName string , modelsToDelete []string ) error {
0 commit comments