@@ -517,18 +517,24 @@ func validateTensorFlowPredictor(predictor *userconfig.Predictor, providerType t
517517 }
518518
519519 model := * predictor .Model
520+
521+ awsClientForBucket , err := aws .NewFromClientS3Path (model , awsClient )
522+ if err != nil {
523+ return errors .Wrap (err , userconfig .ModelKey )
524+ }
525+
520526 if strings .HasPrefix (model , "s3://" ) {
521527 model , err := cr .S3PathValidator (model )
522528 if err != nil {
523529 return errors .Wrap (err , userconfig .ModelKey )
524530 }
525531
526532 if strings .HasSuffix (model , ".zip" ) {
527- if ok , err := awsClient .IsS3PathFile (model ); err != nil || ! ok {
533+ if ok , err := awsClientForBucket .IsS3PathFile (model ); err != nil || ! ok {
528534 return errors .Wrap (ErrorS3FileNotFound (model ), userconfig .ModelKey )
529535 }
530536 } else {
531- path , err := getTFServingExportFromS3Path (model , awsClient )
537+ path , err := getTFServingExportFromS3Path (model , awsClientForBucket )
532538 if err != nil {
533539 return errors .Wrap (err , userconfig .ModelKey )
534540 } else if path == "" {
@@ -584,13 +590,18 @@ func validateONNXPredictor(predictor *userconfig.Predictor, providerType types.P
584590 return errors .Wrap (ErrorInvalidONNXModelPath (), userconfig .ModelKey , model )
585591 }
586592
593+ awsClientForBucket , err := aws .NewFromClientS3Path (model , awsClient )
594+ if err != nil {
595+ return errors .Wrap (err , userconfig .ModelKey )
596+ }
597+
587598 if strings .HasPrefix (model , "s3://" ) {
588599 model , err := cr .S3PathValidator (model )
589600 if err != nil {
590601 return errors .Wrap (err , userconfig .ModelKey )
591602 }
592603
593- if ok , err := awsClient .IsS3PathFile (model ); err != nil || ! ok {
604+ if ok , err := awsClientForBucket .IsS3PathFile (model ); err != nil || ! ok {
594605 return errors .Wrap (ErrorS3FileNotFound (model ), userconfig .ModelKey )
595606 }
596607 } else {
@@ -620,8 +631,8 @@ func validateONNXPredictor(predictor *userconfig.Predictor, providerType types.P
620631 return nil
621632}
622633
623- func getTFServingExportFromS3Path (path string , awsClient * aws.Client ) (string , error ) {
624- if isValidTensorFlowS3Directory (path , awsClient ) {
634+ func getTFServingExportFromS3Path (path string , awsClientForBucket * aws.Client ) (string , error ) {
635+ if isValidTensorFlowS3Directory (path , awsClientForBucket ) {
625636 return path , nil
626637 }
627638
@@ -630,7 +641,7 @@ func getTFServingExportFromS3Path(path string, awsClient *aws.Client) (string, e
630641 return "" , err
631642 }
632643
633- objects , err := awsClient .ListS3PathDir (path , false , pointer .Int64 (1000 ))
644+ objects , err := awsClientForBucket .ListS3PathDir (path , false , pointer .Int64 (1000 ))
634645 if err != nil {
635646 return "" , err
636647 } else if len (objects ) == 0 {
@@ -652,7 +663,7 @@ func getTFServingExportFromS3Path(path string, awsClient *aws.Client) (string, e
652663 }
653664
654665 possiblePath := "s3://" + filepath .Join (bucket , filepath .Join (keyParts [:len (keyParts )- 1 ]... ))
655- if version >= highestVersion && isValidTensorFlowS3Directory (possiblePath , awsClient ) {
666+ if version >= highestVersion && isValidTensorFlowS3Directory (possiblePath , awsClientForBucket ) {
656667 highestVersion = version
657668 highestPath = possiblePath
658669 }
@@ -668,15 +679,15 @@ func getTFServingExportFromS3Path(path string, awsClient *aws.Client) (string, e
668679// - variables/
669680// - variables.index
670681// - variables.data-00000-of-00001 (there are a variable number of these files)
671- func isValidTensorFlowS3Directory (path string , awsClient * aws.Client ) bool {
672- if valid , err := awsClient .IsS3PathFile (
682+ func isValidTensorFlowS3Directory (path string , awsClientForBucket * aws.Client ) bool {
683+ if valid , err := awsClientForBucket .IsS3PathFile (
673684 aws .JoinS3Path (path , "saved_model.pb" ),
674685 aws .JoinS3Path (path , "variables/variables.index" ),
675686 ); err != nil || ! valid {
676687 return false
677688 }
678689
679- if valid , err := awsClient .IsS3PathPrefix (
690+ if valid , err := awsClientForBucket .IsS3PathPrefix (
680691 aws .JoinS3Path (path , "variables/variables.data-00000-of" ),
681692 ); err != nil || ! valid {
682693 return false
0 commit comments