Skip to content

Commit 099cc13

Browse files
committed
feat: respect context
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent 02cc8cb commit 099cc13

File tree

4 files changed

+79
-10
lines changed

4 files changed

+79
-10
lines changed

backend/cpp/llama-cpp/grpc-server.cpp

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,12 @@ class BackendServiceImpl final : public backend::Backend::Service {
822822
}
823823

824824
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
825+
// Check if context is cancelled before processing result
826+
if (context->IsCancelled()) {
827+
ctx_server.cancel_tasks(task_ids);
828+
return false;
829+
}
830+
825831
json res_json = result->to_json();
826832
if (res_json.is_array()) {
827833
for (const auto & res : res_json) {
@@ -875,13 +881,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
875881
reply.set_message(error_data.value("content", ""));
876882
writer->Write(reply);
877883
return true;
878-
}, [&]() {
879-
// NOTE: we should try to check when the writer is closed here
880-
return false;
884+
}, [&context]() {
885+
// Check if the gRPC context is cancelled
886+
return context->IsCancelled();
881887
});
882888

883889
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
884890

891+
// Check if context was cancelled during processing
892+
if (context->IsCancelled()) {
893+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
894+
}
895+
885896
return grpc::Status::OK;
886897
}
887898

@@ -1145,6 +1156,14 @@ class BackendServiceImpl final : public backend::Backend::Service {
11451156

11461157

11471158
std::cout << "[DEBUG] Waiting for results..." << std::endl;
1159+
1160+
// Check cancellation before waiting for results
1161+
if (context->IsCancelled()) {
1162+
ctx_server.cancel_tasks(task_ids);
1163+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
1164+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1165+
}
1166+
11481167
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
11491168
std::cout << "[DEBUG] Received " << results.size() << " results" << std::endl;
11501169
if (results.size() == 1) {
@@ -1176,13 +1195,20 @@ class BackendServiceImpl final : public backend::Backend::Service {
11761195
}, [&](const json & error_data) {
11771196
std::cout << "[DEBUG] Error in results: " << error_data.value("content", "") << std::endl;
11781197
reply->set_message(error_data.value("content", ""));
1179-
}, [&]() {
1180-
return false;
1198+
}, [&context]() {
1199+
// Check if the gRPC context is cancelled
1200+
// This is checked every HTTP_POLLING_SECONDS (1 second) during receive_multi_results
1201+
return context->IsCancelled();
11811202
});
11821203

11831204
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
11841205
std::cout << "[DEBUG] Predict request completed successfully" << std::endl;
11851206

1207+
// Check if context was cancelled during processing
1208+
if (context->IsCancelled()) {
1209+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1210+
}
1211+
11861212
return grpc::Status::OK;
11871213
}
11881214

@@ -1234,6 +1260,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
12341260
ctx_server.queue_tasks.post(std::move(tasks));
12351261
}
12361262

1263+
// Check cancellation before waiting for results
1264+
if (context->IsCancelled()) {
1265+
ctx_server.cancel_tasks(task_ids);
1266+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
1267+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1268+
}
1269+
12371270
// get the result
12381271
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
12391272
for (auto & res : results) {
@@ -1242,12 +1275,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
12421275
}
12431276
}, [&](const json & error_data) {
12441277
error = true;
1245-
}, [&]() {
1246-
return false;
1278+
}, [&context]() {
1279+
// Check if the gRPC context is cancelled
1280+
return context->IsCancelled();
12471281
});
12481282

12491283
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
12501284

1285+
// Check if context was cancelled during processing
1286+
if (context->IsCancelled()) {
1287+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1288+
}
1289+
12511290
if (error) {
12521291
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
12531292
}
@@ -1325,6 +1364,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
13251364
ctx_server.queue_tasks.post(std::move(tasks));
13261365
}
13271366

1367+
// Check cancellation before waiting for results
1368+
if (context->IsCancelled()) {
1369+
ctx_server.cancel_tasks(task_ids);
1370+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
1371+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1372+
}
1373+
13281374
// Get the results
13291375
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
13301376
for (auto & res : results) {
@@ -1333,12 +1379,18 @@ class BackendServiceImpl final : public backend::Backend::Service {
13331379
}
13341380
}, [&](const json & error_data) {
13351381
error = true;
1336-
}, [&]() {
1337-
return false;
1382+
}, [&context]() {
1383+
// Check if the gRPC context is cancelled
1384+
return context->IsCancelled();
13381385
});
13391386

13401387
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
13411388

1389+
// Check if context was cancelled during processing
1390+
if (context->IsCancelled()) {
1391+
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client");
1392+
}
1393+
13421394
if (error) {
13431395
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
13441396
}

core/http/endpoints/openai/chat.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
358358
LOOP:
359359
for {
360360
select {
361+
case <-input.Context.Done():
362+
// Context was cancelled (client disconnected or request cancelled)
363+
log.Debug().Msgf("Request context cancelled, stopping stream")
364+
input.Cancel()
365+
break LOOP
361366
case ev := <-responses:
362367
if len(ev.Choices) == 0 {
363368
log.Debug().Msgf("No choices in the response, skipping")

core/http/middleware/request.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
161161
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
162162
ctx.Set("X-Correlation-ID", correlationID)
163163

164-
c1, cancel := context.WithCancel(re.applicationConfig.Context)
164+
//c1, cancel := context.WithCancel(re.applicationConfig.Context)
165+
c1, cancel := context.WithCancel(ctx.Context())
165166
// Add the correlation ID to the new context
166167
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
167168

pkg/grpc/client.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,22 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
178178
}
179179

180180
for {
181+
// Check if context is cancelled before receiving
182+
select {
183+
case <-ctx.Done():
184+
return ctx.Err()
185+
default:
186+
}
187+
181188
reply, err := stream.Recv()
182189
if err == io.EOF {
183190
break
184191
}
185192
if err != nil {
193+
// Check if error is due to context cancellation
194+
if ctx.Err() != nil {
195+
return ctx.Err()
196+
}
186197
fmt.Println("Error", err)
187198

188199
return err

0 commit comments

Comments
 (0)