@@ -1197,18 +1197,26 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
11971197
11981198 var modelWrapError func (error ) error
11991199 var modelResources []userconfig.ModelResource
1200+ var modelFileResources []userconfig.ModelResource
12001201
12011202 if hasSingleModel {
12021203 modelWrapError = func (err error ) error {
1203- return errors .Wrap (err , userconfig .ModelsPathKey )
1204+ return errors .Wrap (err , userconfig .ModelsKey , userconfig . ModelsPathKey )
12041205 }
1205- modelResources = []userconfig.ModelResource {
1206- {
1207- Name : consts .SingleModelName ,
1208- Path : * predictor .Models .Path ,
1209- },
1206+ modelResource := userconfig.ModelResource {
1207+ Name : consts .SingleModelName ,
1208+ Path : * predictor .Models .Path ,
1209+ }
1210+
1211+ if strings .HasSuffix (* predictor .Models .Path , ".onnx" ) && provider != types .LocalProviderType {
1212+ if err := validateONNXModelFilePath (* predictor .Models .Path , projectFiles .ProjectDir (), awsClient , gcpClient ); err != nil {
1213+ return modelWrapError (err )
1214+ }
1215+ modelFileResources = append (modelFileResources , modelResource )
1216+ } else {
1217+ modelResources = append (modelResources , modelResource )
1218+ * predictor .Models .Path = s .EnsureSuffix (* predictor .Models .Path , "/" )
12101219 }
1211- * predictor .Models .Path = s .EnsureSuffix (* predictor .Models .Path , "/" )
12121220 }
12131221 if hasMultiModels {
12141222 if len (predictor .Models .Paths ) > 0 {
@@ -1225,8 +1233,15 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
12251233 path .Name ,
12261234 )
12271235 }
1228- (* path ).Path = s .EnsureSuffix ((* path ).Path , "/" )
1229- modelResources = append (modelResources , * path )
1236+ if strings .HasSuffix ((* path ).Path , ".onnx" ) && provider != types .LocalProviderType {
1237+ if err := validateONNXModelFilePath ((* path ).Path , projectFiles .ProjectDir (), awsClient , gcpClient ); err != nil {
1238+ return errors .Wrap (modelWrapError (err ), path .Name )
1239+ }
1240+ modelFileResources = append (modelFileResources , * path )
1241+ } else {
1242+ (* path ).Path = s .EnsureSuffix ((* path ).Path , "/" )
1243+ modelResources = append (modelResources , * path )
1244+ }
12301245 }
12311246 }
12321247
@@ -1249,6 +1264,23 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
12491264 return modelWrapError (err )
12501265 }
12511266
1267+ for _ , modelFileResource := range modelFileResources {
1268+ s3Path := strings .HasPrefix (modelFileResource .Path , "s3://" )
1269+ gcsPath := strings .HasPrefix (modelFileResource .Path , "gs://" )
1270+ localPath := ! s3Path && ! gcsPath
1271+
1272+ * models = append (* models , CuratedModelResource {
1273+ ModelResource : & userconfig.ModelResource {
1274+ Name : modelFileResource .Name ,
1275+ Path : modelFileResource .Path ,
1276+ },
1277+ S3Path : s3Path ,
1278+ GCSPath : gcsPath ,
1279+ LocalPath : localPath ,
1280+ IsFilePath : true ,
1281+ })
1282+ }
1283+
12521284 if hasMultiModels {
12531285 for _ , model := range * models {
12541286 if model .Name == consts .SingleModelName {
@@ -1264,6 +1296,58 @@ func validateONNXPredictor(api *userconfig.API, models *[]CuratedModelResource,
12641296 return nil
12651297}
12661298
1299+ func validateONNXModelFilePath (modelPath string , projectDir string , awsClient * aws.Client , gcpClient * gcp.Client ) error {
1300+ s3Path := strings .HasPrefix (modelPath , "s3://" )
1301+ gcsPath := strings .HasPrefix (modelPath , "gs://" )
1302+ localPath := ! s3Path && ! gcsPath
1303+
1304+ if s3Path {
1305+ awsClientForBucket , err := aws .NewFromClientS3Path (modelPath , awsClient )
1306+ if err != nil {
1307+ return err
1308+ }
1309+
1310+ bucket , modelPrefix , err := aws .SplitS3Path (modelPath )
1311+ if err != nil {
1312+ return err
1313+ }
1314+
1315+ isS3File , err := awsClientForBucket .IsS3File (bucket , modelPrefix )
1316+ if err != nil {
1317+ return err
1318+ }
1319+
1320+ if ! isS3File {
1321+ return ErrorInvalidONNXModelFilePath (modelPrefix )
1322+ }
1323+ }
1324+
1325+ if gcsPath {
1326+ bucket , modelPrefix , err := gcp .SplitGCSPath (modelPath )
1327+ if err != nil {
1328+ return err
1329+ }
1330+
1331+ isGCSFile , err := gcpClient .IsGCSFile (bucket , modelPrefix )
1332+ if err != nil {
1333+ return err
1334+ }
1335+
1336+ if ! isGCSFile {
1337+ return ErrorInvalidONNXModelFilePath (modelPrefix )
1338+ }
1339+ }
1340+
1341+ if localPath {
1342+ expandedLocalPath := files .RelToAbsPath (modelPath , projectDir )
1343+ if err := files .CheckFile (expandedLocalPath ); err != nil {
1344+ return err
1345+ }
1346+ }
1347+
1348+ return nil
1349+ }
1350+
12671351func validatePythonPath (predictor * userconfig.Predictor , projectFiles ProjectFiles ) error {
12681352 if ! projectFiles .HasDir (* predictor .PythonPath ) {
12691353 return ErrorPythonPathNotFound (* predictor .PythonPath )
0 commit comments