@@ -23,9 +23,11 @@ import (
2323 "strings"
2424
2525 "github.com/cortexlabs/cortex/pkg/consts"
26+ "github.com/cortexlabs/cortex/pkg/lib/archive"
2627 "github.com/cortexlabs/cortex/pkg/lib/aws"
2728 "github.com/cortexlabs/cortex/pkg/lib/docker"
2829 "github.com/cortexlabs/cortex/pkg/lib/errors"
30+ "github.com/cortexlabs/cortex/pkg/lib/gcp"
2931 s "github.com/cortexlabs/cortex/pkg/lib/strings"
3032 "github.com/cortexlabs/cortex/pkg/types/spec"
3133 "github.com/cortexlabs/cortex/pkg/types/userconfig"
@@ -63,18 +65,18 @@ func (modelCaches ModelCaches) IDs() string {
6365 return strings .Join (ids , ", " )
6466}
6567
66- func DeployContainers (api * spec.API , awsClient * aws.Client ) error {
68+ func DeployContainers (api * spec.API , awsClient * aws.Client , gcpClient * gcp. Client ) error {
6769 switch api .Predictor .Type {
6870 case userconfig .TensorFlowPredictorType :
69- return deployTensorFlowContainers (api , awsClient )
71+ return deployTensorFlowContainers (api , awsClient , gcpClient )
7072 case userconfig .ONNXPredictorType :
71- return deployONNXContainer (api , awsClient )
73+ return deployONNXContainer (api , awsClient , gcpClient )
7274 default :
73- return deployPythonContainer (api , awsClient )
75+ return deployPythonContainer (api , awsClient , gcpClient )
7476 }
7577}
7678
77- func getAPIEnv (api * spec.API , awsClient * aws.Client ) []string {
79+ func getAPIEnv (api * spec.API , awsClient * aws.Client , gcpClient * gcp. Client ) []string {
7880 envs := []string {}
7981
8082 for envName , envVal := range api .Predictor .Env {
@@ -92,7 +94,6 @@ func getAPIEnv(api *spec.API, awsClient *aws.Client) []string {
9294 "CORTEX_PROCESSES_PER_REPLICA=" + s .Int32 (api .Predictor .ProcessesPerReplica ),
9395 "CORTEX_THREADS_PER_PROCESS=" + s .Int32 (api .Predictor .ThreadsPerProcess ),
9496 "CORTEX_MAX_REPLICA_CONCURRENCY=" + s .Int32 (api .Predictor .ProcessesPerReplica * api .Predictor .ThreadsPerProcess + 1024 ), // allow a queue of 1024
95- "AWS_REGION=" + awsClient .Region ,
9697 )
9798
9899 if api .Predictor .ModelPath != nil || api .Predictor .Models != nil {
@@ -105,21 +106,29 @@ func getAPIEnv(api *spec.API, awsClient *aws.Client) []string {
105106 }
106107 envs = append (envs , "CORTEX_PYTHON_PATH=" + cortexPythonPath )
107108
108- if awsAccessKeyID := awsClient .AccessKeyID (); awsAccessKeyID != nil {
109- envs = append (envs , "AWS_ACCESS_KEY_ID=" + * awsAccessKeyID )
110- }
109+ if awsClient != nil {
110+ envs = append (envs , "AWS_REGION=" + awsClient .Region )
111111
112- if awsSecretAccessKey := awsClient .SecretAccessKey (); awsSecretAccessKey != nil {
113- envs = append (envs , "AWS_SECRET_ACCESS_KEY =" + * awsSecretAccessKey )
114- }
112+ if awsAccessKeyID := awsClient .AccessKeyID (); awsAccessKeyID != nil {
113+ envs = append (envs , "AWS_ACCESS_KEY_ID =" + * awsAccessKeyID )
114+ }
115115
116- if _ , ok := api .Predictor .Env ["PYTHONDONTWRITEBYTECODE" ]; ! ok {
117- envs = append (envs , "PYTHONDONTWRITEBYTECODE=1" )
116+ if awsSecretAccessKey := awsClient .SecretAccessKey (); awsSecretAccessKey != nil {
117+ envs = append (envs , "AWS_SECRET_ACCESS_KEY=" + * awsSecretAccessKey )
118+ }
119+
120+ if _ , ok := api .Predictor .Env ["PYTHONDONTWRITEBYTECODE" ]; ! ok {
121+ envs = append (envs , "PYTHONDONTWRITEBYTECODE=1" )
122+ }
118123 }
124+ if gcpClient != nil {
125+ envs = append (envs , "GOOGLE_APPLICATION_CREDENTIALS=/var/google_key.json" )
126+ }
127+
119128 return envs
120129}
121130
122- func deployPythonContainer (api * spec.API , awsClient * aws.Client ) error {
131+ func deployPythonContainer (api * spec.API , awsClient * aws.Client , gcpClient * gcp. Client ) error {
123132 portBinding := nat.PortBinding {}
124133 if api .Networking .LocalPort != nil {
125134 portBinding .HostPort = s .Int (* api .Networking .LocalPort )
@@ -176,7 +185,7 @@ func deployPythonContainer(api *spec.API, awsClient *aws.Client) error {
176185 Image : api .Predictor .Image ,
177186 Tty : true ,
178187 Env : append (
179- getAPIEnv (api , awsClient ),
188+ getAPIEnv (api , awsClient , gcpClient ),
180189 ),
181190 ExposedPorts : nat.PortSet {
182191 _defaultPortStr + "/tcp" : struct {}{},
@@ -198,12 +207,23 @@ func deployPythonContainer(api *spec.API, awsClient *aws.Client) error {
198207 return errors .Wrap (err , api .Identify ())
199208 }
200209
210+ if gcpClient != nil {
211+ docker .CopyToContainer (containerInfo .ID , & archive.Input {
212+ Bytes : []archive.BytesInput {
213+ {
214+ Content : gcpClient .CredentialsJSON ,
215+ Dest : "/var/google_key.json" ,
216+ },
217+ },
218+ }, "/" )
219+ }
220+
201221 err = docker .MustDockerClient ().ContainerStart (context .Background (), containerInfo .ID , dockertypes.ContainerStartOptions {})
202222 if err != nil {
203223 if api .Compute .GPU == 0 {
204224 return errors .Wrap (err , api .Identify ())
205225 }
206- err := retryWithNvidiaRuntime (err , containerConfig , hostConfig )
226+ err := retryWithNvidiaRuntime (err , containerConfig , hostConfig , gcpClient )
207227 if err != nil {
208228 return errors .Wrap (err , api .Identify ())
209229 }
@@ -212,7 +232,7 @@ func deployPythonContainer(api *spec.API, awsClient *aws.Client) error {
212232 return nil
213233}
214234
215- func deployONNXContainer (api * spec.API , awsClient * aws.Client ) error {
235+ func deployONNXContainer (api * spec.API , awsClient * aws.Client , gcpClient * gcp. Client ) error {
216236 portBinding := nat.PortBinding {}
217237 if api .Networking .LocalPort != nil {
218238 portBinding .HostPort = s .Int (* api .Networking .LocalPort )
@@ -268,7 +288,7 @@ func deployONNXContainer(api *spec.API, awsClient *aws.Client) error {
268288 Image : api .Predictor .Image ,
269289 Tty : true ,
270290 Env : append (
271- getAPIEnv (api , awsClient ),
291+ getAPIEnv (api , awsClient , gcpClient ),
272292 ),
273293 ExposedPorts : nat.PortSet {
274294 _defaultPortStr + "/tcp" : struct {}{},
@@ -291,12 +311,23 @@ func deployONNXContainer(api *spec.API, awsClient *aws.Client) error {
291311 return errors .Wrap (err , api .Identify ())
292312 }
293313
314+ if gcpClient != nil {
315+ docker .CopyToContainer (containerInfo .ID , & archive.Input {
316+ Bytes : []archive.BytesInput {
317+ {
318+ Content : gcpClient .CredentialsJSON ,
319+ Dest : "/var/google_key.json" ,
320+ },
321+ },
322+ }, "/" )
323+ }
324+
294325 err = docker .MustDockerClient ().ContainerStart (context .Background (), containerInfo .ID , dockertypes.ContainerStartOptions {})
295326 if err != nil {
296327 if api .Compute .GPU == 0 {
297328 return errors .Wrap (err , api .Identify ())
298329 }
299- err := retryWithNvidiaRuntime (err , containerConfig , hostConfig )
330+ err := retryWithNvidiaRuntime (err , containerConfig , hostConfig , gcpClient )
300331 if err != nil {
301332 return errors .Wrap (err , api .Identify ())
302333 }
@@ -305,7 +336,7 @@ func deployONNXContainer(api *spec.API, awsClient *aws.Client) error {
305336 return nil
306337}
307338
308- func deployTensorFlowContainers (api * spec.API , awsClient * aws.Client ) error {
339+ func deployTensorFlowContainers (api * spec.API , awsClient * aws.Client , gcpClient * gcp. Client ) error {
309340 serveResources := container.Resources {}
310341 apiResources := container.Resources {}
311342
@@ -400,13 +431,12 @@ func deployTensorFlowContainers(api *spec.API, awsClient *aws.Client) error {
400431 }
401432 return errors .Wrap (err , api .Identify ())
402433 }
403-
404434 err = docker .MustDockerClient ().ContainerStart (context .Background (), containerCreateRequest .ID , dockertypes.ContainerStartOptions {})
405435 if err != nil {
406436 if api .Compute .GPU == 0 {
407437 return errors .Wrap (err , api .Identify ())
408438 }
409- err := retryWithNvidiaRuntime (err , serveContainerConfig , serveHostConfig )
439+ err := retryWithNvidiaRuntime (err , serveContainerConfig , serveHostConfig , nil )
410440 if err != nil {
411441 return errors .Wrap (err , api .Identify ())
412442 }
@@ -446,7 +476,7 @@ func deployTensorFlowContainers(api *spec.API, awsClient *aws.Client) error {
446476 Image : api .Predictor .Image ,
447477 Tty : true ,
448478 Env : append (
449- getAPIEnv (api , awsClient ),
479+ getAPIEnv (api , awsClient , gcpClient ),
450480 "CORTEX_TF_BASE_SERVING_PORT=" + _tfServingPortStr ,
451481 "CORTEX_TF_SERVING_HOST=" + tfContainerHost ,
452482 ),
@@ -471,6 +501,17 @@ func deployTensorFlowContainers(api *spec.API, awsClient *aws.Client) error {
471501 return errors .Wrap (err , api .Identify ())
472502 }
473503
504+ if gcpClient != nil {
505+ docker .CopyToContainer (containerCreateRequest .ID , & archive.Input {
506+ Bytes : []archive.BytesInput {
507+ {
508+ Content : gcpClient .CredentialsJSON ,
509+ Dest : "/var/google_key.json" ,
510+ },
511+ },
512+ }, "/" )
513+ }
514+
474515 err = docker .MustDockerClient ().ContainerStart (context .Background (), containerCreateRequest .ID , dockertypes.ContainerStartOptions {})
475516 if err != nil {
476517 return errors .Wrap (err , api .Identify ())
@@ -480,7 +521,7 @@ func deployTensorFlowContainers(api *spec.API, awsClient *aws.Client) error {
480521}
481522
482523// Retries deploying a container requiring GPU using nvidia runtime, returns original error if isn't relevant, nil if successful and new error if a retry was attempted but failed
483- func retryWithNvidiaRuntime (err error , containerConfig * container.Config , hostConfig * container.HostConfig ) error {
524+ func retryWithNvidiaRuntime (err error , containerConfig * container.Config , hostConfig * container.HostConfig , gcpClient * gcp. Client ) error {
484525 // error message if device driver may look like 'could not select device driver "" with capabilities: [[gpu]]'
485526 if ! (strings .Contains (err .Error (), "could not select device driver" ) && strings .Contains (err .Error (), "gpu" )) {
486527 return err
@@ -494,6 +535,16 @@ func retryWithNvidiaRuntime(err error, containerConfig *container.Config, hostCo
494535 if err != nil {
495536 return errors .Wrap (err , "failed to request a GPU" )
496537 }
538+ if gcpClient != nil {
539+ docker .CopyToContainer (containerCreateRequest .ID , & archive.Input {
540+ Bytes : []archive.BytesInput {
541+ {
542+ Content : gcpClient .CredentialsJSON ,
543+ Dest : "/var/google_key.json" ,
544+ },
545+ },
546+ }, "/" )
547+ }
497548 err = docker .MustDockerClient ().ContainerStart (context .Background (), containerCreateRequest .ID , dockertypes.ContainerStartOptions {})
498549 if err != nil {
499550 return errors .Wrap (err , "failed to run a container using nvidia runtime; it is recommended to use the latest Docker Engine (https://docs.docker.com/engine/install/) with nvidia-container-runtime or nvidia-container-toolkit (https://docs.docker.com/config/containers/resource_constraints/#gpu)" )
0 commit comments