@@ -42,7 +42,7 @@ def start(args):
4242 assert_api_version ()
4343 storage = S3 (bucket = os .environ ["CORTEX_BUCKET" ], region = os .environ ["AWS_REGION" ])
4444 try :
45- raw_api_spec = get_spec (args .cache_dir , args .spec )
45+ raw_api_spec = get_spec (storage , args .cache_dir , args .spec )
4646 api = API (storage = storage , cache_dir = args .cache_dir , ** raw_api_spec )
4747 client = api .predictor .initialize_client (args )
4848 cx_logger ().info ("loading the predictor from {}" .format (api .predictor .path ))
@@ -122,7 +122,7 @@ def after_request(response):
122122 try :
123123 api .post_latency_metrics (response .status_code , g .start_time )
124124
125- if api .tracker is not None :
125+ if int ( response . status_code / 100 ) == 2 and api .tracker is not None :
126126 predicted_value = api .tracker .extract_predicted_value (prediction )
127127 api .post_tracker_metrics (predicted_value )
128128 if predicted_value is not None and predicted_value not in local_cache ["class_set" ]:
@@ -164,10 +164,10 @@ def assert_api_version():
164164 )
165165
166166
167- def get_spec (cache_dir , s3_path ):
167+ def get_spec (storage , cache_dir , s3_path ):
168168 local_spec_path = os .path .join (cache_dir , "api_spec.msgpack" )
169- bucket , key = S3 .deconstruct_s3_path (s3_path )
170- S3 ( bucket , client_config = {}) .download_file (key , local_spec_path )
169+ _ , key = S3 .deconstruct_s3_path (s3_path )
170+ storage .download_file (key , local_spec_path )
171171 return util .read_msgpack (local_spec_path )
172172
173173
0 commit comments