Skip to content

Commit d5312b2

Browse files
authored
Remove prediction monitoring (#1758)
1 parent 9c2c79a commit d5312b2

File tree

24 files changed

+18
-670
lines changed

24 files changed

+18
-670
lines changed

cli/cmd/lib_realtime_apis.go

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"fmt"
2222
"io/ioutil"
2323
"net/http"
24-
"sort"
2524
"strconv"
2625
"strings"
2726
"time"
@@ -53,15 +52,6 @@ func realtimeAPITable(realtimeAPI schema.APIResponse, env cliconfig.Environment)
5352

5453
out += t.MustFormat()
5554

56-
if env.Provider != types.LocalProviderType && realtimeAPI.Spec.Monitoring != nil {
57-
switch realtimeAPI.Spec.Monitoring.ModelType {
58-
case userconfig.ClassificationModelType:
59-
out += "\n" + classificationMetricsStr(realtimeAPI.Metrics)
60-
case userconfig.RegressionModelType:
61-
out += "\n" + regressionMetricsStr(realtimeAPI.Metrics)
62-
}
63-
}
64-
6555
if realtimeAPI.DashboardURL != nil && *realtimeAPI.DashboardURL != "" {
6656
out += "\n" + console.Bold("metrics dashboard: ") + *realtimeAPI.DashboardURL + "\n"
6757
}
@@ -169,75 +159,6 @@ func code5XXStr(metrics *metrics.Metrics) string {
169159
return s.Int(metrics.NetworkStats.Code5XX)
170160
}
171161

172-
func regressionMetricsStr(metrics *metrics.Metrics) string {
173-
minStr := "-"
174-
maxStr := "-"
175-
avgStr := "-"
176-
177-
if metrics.RegressionStats != nil {
178-
if metrics.RegressionStats.Min != nil {
179-
minStr = fmt.Sprintf("%.9g", *metrics.RegressionStats.Min)
180-
}
181-
182-
if metrics.RegressionStats.Max != nil {
183-
maxStr = fmt.Sprintf("%.9g", *metrics.RegressionStats.Max)
184-
}
185-
186-
if metrics.RegressionStats.Avg != nil {
187-
avgStr = fmt.Sprintf("%.9g", *metrics.RegressionStats.Avg)
188-
}
189-
}
190-
191-
t := table.Table{
192-
Headers: []table.Header{
193-
{Title: "min", MaxWidth: 10},
194-
{Title: "max", MaxWidth: 10},
195-
{Title: "avg", MaxWidth: 10},
196-
},
197-
Rows: [][]interface{}{{minStr, maxStr, avgStr}},
198-
}
199-
200-
return t.MustFormat()
201-
}
202-
203-
func classificationMetricsStr(metrics *metrics.Metrics) string {
204-
classList := make([]string, 0, len(metrics.ClassDistribution))
205-
for inputName := range metrics.ClassDistribution {
206-
classList = append(classList, inputName)
207-
}
208-
sort.Strings(classList)
209-
210-
rows := make([][]interface{}, len(classList))
211-
for rowNum, className := range classList {
212-
rows[rowNum] = []interface{}{
213-
className,
214-
metrics.ClassDistribution[className],
215-
}
216-
}
217-
218-
if len(classList) == 0 {
219-
rows = append(rows, []interface{}{
220-
"-",
221-
"-",
222-
})
223-
}
224-
225-
t := table.Table{
226-
Headers: []table.Header{
227-
{Title: "class", MaxWidth: 40},
228-
{Title: "count", MaxWidth: 20},
229-
},
230-
Rows: rows,
231-
}
232-
233-
out := t.MustFormat()
234-
235-
if len(classList) == consts.MaxClassesPerMonitoringRequest {
236-
out += fmt.Sprintf("\nlisting at most %d classes, the complete list can be found in your cloudwatch dashboard\n", consts.MaxClassesPerMonitoringRequest)
237-
}
238-
return out
239-
}
240-
241162
func describeModelInput(status *status.Status, predictor *userconfig.Predictor, apiEndpoint string) string {
242163
if status.Updated.Ready+status.Stale.Ready == 0 {
243164
return "the models' metadata schema will be available when the api is live\n"

pkg/consts/consts.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,9 @@ var (
5050
DefaultImageONNXPredictorGPU,
5151
)
5252

53-
MaxClassesPerMonitoringRequest = 20 // cloudwatch.GeMetricData can get up to 100 metrics per request, avoid multiple requests and have room for other stats
54-
DashboardTitle = "# cortex monitoring dashboard"
55-
DefaultMaxReplicaConcurrency = int64(1024)
56-
NeuronCoresPerInf = int64(4)
53+
DashboardTitle = "# cortex monitoring dashboard"
54+
DefaultMaxReplicaConcurrency = int64(1024)
55+
NeuronCoresPerInf = int64(4)
5756
)
5857

5958
func defaultDockerImage(imageName string) string {

pkg/cortex/serve/cortex_internal/lib/api/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,4 @@
1313
# limitations under the License.
1414

1515
from cortex_internal.lib.api.predictor import Predictor
16-
from cortex_internal.lib.api.monitoring import Monitoring
1716
from cortex_internal.lib.api.api import API, get_api, get_spec

pkg/cortex/serve/cortex_internal/lib/api/api.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Dict, Optional, Tuple, Union
2121

2222
import datadog
23-
from cortex_internal.lib.api import Monitoring, Predictor
23+
from cortex_internal.lib.api import Predictor
2424
from cortex_internal.lib.exceptions import CortexException
2525
from cortex_internal.lib.storage import LocalStorage, S3, GCS
2626
from cortex_internal.lib.log import logger
@@ -49,10 +49,6 @@ def __init__(
4949
self.name = api_spec["name"]
5050
self.predictor = Predictor(provider, api_spec, model_dir)
5151

52-
self.monitoring = None
53-
if self.api_spec.get("monitoring") is not None:
54-
self.monitoring = Monitoring(**self.api_spec["monitoring"])
55-
5652
if provider != "local":
5753
host_ip = os.environ["HOST_IP"]
5854
datadog.initialize(statsd_host=host_ip, statsd_port="8125")
@@ -65,24 +61,6 @@ def __init__(
6561
def server_side_batching_enabled(self):
6662
return self.api_spec["predictor"].get("server_side_batching") is not None
6763

68-
def get_cached_classes(self):
69-
prefix = os.path.join(self.metadata_root, "classes") + "/"
70-
class_paths, _ = self.storage.search(prefix=prefix)
71-
class_set = set()
72-
for class_path in class_paths:
73-
encoded_class_name = class_path.split("/")[-1]
74-
class_set.add(base64.urlsafe_b64decode(encoded_class_name.encode()).decode())
75-
return class_set
76-
77-
def upload_class(self, class_name: str):
78-
try:
79-
ascii_encoded = class_name.encode("ascii") # cloudwatch only supports ascii
80-
encoded_class_name = base64.urlsafe_b64encode(ascii_encoded)
81-
key = os.path.join(self.metadata_root, "classes", encoded_class_name.decode())
82-
self.storage.put_json("", key)
83-
except Exception as e:
84-
raise ValueError("unable to store class {}".format(class_name)) from e
85-
8664
def metric_dimensions_with_id(self):
8765
return [
8866
{"Name": "APIName", "Value": self.name},
@@ -106,14 +84,6 @@ def post_request_metrics(self, status_code, total_time):
10684
]
10785
self.post_metrics(metrics)
10886

109-
def post_monitoring_metrics(self, prediction_value=None):
110-
if prediction_value is not None:
111-
metrics = [
112-
self.prediction_metrics(self.metric_dimensions(), prediction_value),
113-
self.prediction_metrics(self.metric_dimensions_with_id(), prediction_value),
114-
]
115-
self.post_metrics(metrics)
116-
11787
def post_metrics(self, metrics):
11888
try:
11989
if self.statsd is None:
@@ -168,22 +138,6 @@ def latency_metric(self, dimensions, total_time):
168138
"Value": total_time, # milliseconds
169139
}
170140

171-
def prediction_metrics(self, dimensions, prediction_value):
172-
if self.monitoring.model_type == "classification":
173-
dimensions_with_class = dimensions + [{"Name": "Class", "Value": str(prediction_value)}]
174-
return {
175-
"MetricName": "Prediction",
176-
"Dimensions": dimensions_with_class,
177-
"Unit": "Count",
178-
"Value": 1,
179-
}
180-
else:
181-
return {
182-
"MetricName": "Prediction",
183-
"Dimensions": dimensions,
184-
"Value": float(prediction_value),
185-
}
186-
187141

188142
def get_api(
189143
provider: str,

pkg/cortex/serve/cortex_internal/lib/api/monitoring.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

pkg/cortex/serve/cortex_internal/serve/serve.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from cortex_internal.lib.exceptions import UserRuntimeException
3737
from fastapi import FastAPI
3838
from fastapi.exceptions import RequestValidationError
39-
from starlette.background import BackgroundTasks
4039
from starlette.exceptions import HTTPException as StarletteHTTPException
4140
from starlette.requests import Request
4241
from starlette.responses import JSONResponse, PlainTextResponse, Response
@@ -62,7 +61,6 @@
6261
"dynamic_batcher": None,
6362
"predict_route": None,
6463
"client": None,
65-
"class_set": set(),
6664
}
6765

6866

@@ -191,8 +189,6 @@ async def parse_payload(request: Request, call_next):
191189

192190

193191
def predict(request: Request):
194-
tasks = BackgroundTasks()
195-
api = local_cache["api"]
196192
predictor_impl = local_cache["predictor_impl"]
197193
dynamic_batcher = local_cache["dynamic_batcher"]
198194
kwargs = build_predict_kwargs(request)
@@ -219,26 +215,10 @@ def predict(request: Request):
219215
) from e
220216
response = Response(content=json_string, media_type="application/json")
221217

222-
if local_cache["provider"] not in ["local", "gcp"] and api.monitoring is not None:
223-
try:
224-
predicted_value = api.monitoring.extract_predicted_value(prediction)
225-
api.post_monitoring_metrics(predicted_value)
226-
if (
227-
api.monitoring.model_type == "classification"
228-
and predicted_value not in local_cache["class_set"]
229-
):
230-
tasks.add_task(api.upload_class, class_name=predicted_value)
231-
local_cache["class_set"].add(predicted_value)
232-
except:
233-
logger.warn("unable to record prediction metric", exc_info=True)
234-
235218
if util.has_method(predictor_impl, "post_predict"):
236219
kwargs = build_post_predict_kwargs(prediction, request)
237220
request_thread_pool.submit(predictor_impl.post_predict, **kwargs)
238221

239-
if len(tasks.tasks) > 0:
240-
response.background = tasks
241-
242222
return response
243223

244224

@@ -355,16 +335,6 @@ def start_fn():
355335
logger.exception("failed to start api")
356336
sys.exit(1)
357337

358-
if (
359-
provider != "local"
360-
and api.monitoring is not None
361-
and api.monitoring.model_type == "classification"
362-
):
363-
try:
364-
local_cache["class_set"] = api.get_cached_classes()
365-
except:
366-
logger.warn("an error occurred while attempting to load classes", exc_info=True)
367-
368338
app.add_api_route(local_cache["predict_route"], predict, methods=["POST"])
369339
app.add_api_route(local_cache["predict_route"], get_summary, methods=["GET"])
370340

0 commit comments

Comments
 (0)