Skip to content

Commit 5908001

Browse files
DragonFivemaxiaolong.maxwell
authored andcommitted
feat: add rec proto,serivce and utils for rec framework[2/6].
1 parent 2e2a304 commit 5908001

16 files changed

+693
-58
lines changed

xllm/api_service/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ cc_library(
88
api_service_impl.h
99
call.h
1010
completion_service_impl.h
11+
rec_completion_service_impl.h
1112
chat_service_impl.h
1213
embedding_service_impl.h
1314
image_generation_service_impl.h
@@ -23,6 +24,7 @@ cc_library(
2324
api_service.cpp
2425
call.cpp
2526
completion_service_impl.cpp
27+
rec_completion_service_impl.cpp
2628
chat_service_impl.cpp
2729
embedding_service_impl.cpp
2830
image_generation_service_impl.cpp

xllm/api_service/api_service.cpp

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@ limitations under the License.
2727
#include "core/common/metrics.h"
2828
#include "core/runtime/dit_master.h"
2929
#include "core/runtime/llm_master.h"
30+
// TODO. add following when next pr.
31+
// #include "core/runtime/rec_master.h"
3032
#include "core/runtime/vlm_master.h"
3133
#include "core/util/closure_guard.h"
3234
#include "embedding.pb.h"
3335
#include "image_generation.pb.h"
3436
#include "models.pb.h"
37+
#include "rec_completion_service_impl.h"
3538
#include "service_impl_factory.h"
3639
#include "xllm_metrics.h"
3740
namespace xllm {
@@ -70,6 +73,11 @@ APIService::APIService(Master* master,
7073
image_generation_service_impl_ =
7174
std::make_unique<ImageGenerationServiceImpl>(
7275
dynamic_cast<DiTMaster*>(master), model_names);
76+
} else if (FLAGS_backend == "rec") {
77+
// TODO. delete this when next pr.
78+
using RecMaster = LLMMaster;
79+
rec_completion_service_impl_ = std::make_unique<RecCompletionServiceImpl>(
80+
dynamic_cast<RecMaster*>(master), model_names);
7381
}
7482
models_service_impl_ =
7583
ServiceImplFactory<ModelsServiceImpl>::create_service_impl(
@@ -80,13 +88,6 @@ void APIService::Completions(::google::protobuf::RpcController* controller,
8088
const proto::CompletionRequest* request,
8189
proto::CompletionResponse* response,
8290
::google::protobuf::Closure* done) {
83-
// TODO with xllm-service
84-
}
85-
86-
void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
87-
const proto::HttpRequest* request,
88-
proto::HttpResponse* response,
89-
::google::protobuf::Closure* done) {
9091
xllm::ClosureGuard done_guard(
9192
done,
9293
std::bind(request_in_metric, nullptr),
@@ -95,66 +96,89 @@ void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
9596
LOG(ERROR) << "brpc request | respose | controller is null";
9697
return;
9798
}
99+
auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
98100

99-
auto arena = response->GetArena();
101+
if (FLAGS_backend == "llm" || FLAGS_backend == "vlm") {
102+
CHECK(completion_service_impl_) << " completion service is invalid.";
103+
std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
104+
ctrl,
105+
done_guard.release(),
106+
const_cast<proto::CompletionRequest*>(request),
107+
response);
108+
completion_service_impl_->process_async(call);
109+
} else if (FLAGS_backend == "rec") {
110+
CHECK(rec_completion_service_impl_)
111+
<< " rec completion service is invalid.";
112+
std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
113+
ctrl,
114+
done_guard.release(),
115+
const_cast<proto::CompletionRequest*>(request),
116+
response);
117+
rec_completion_service_impl_->process_async(call);
118+
}
119+
}
120+
121+
namespace {
122+
template <typename Call, typename Service>
123+
void CommonCompletionsImpl(std::unique_ptr<Service>& service,
124+
xllm::ClosureGuard& guard,
125+
::google::protobuf::Arena* arena,
126+
brpc::Controller* ctrl) {
100127
auto req_pb =
101-
google::protobuf::Arena::CreateMessage<proto::CompletionRequest>(arena);
128+
google::protobuf::Arena::CreateMessage<typename Call::ReqType>(arena);
102129
auto resp_pb =
103-
google::protobuf::Arena::CreateMessage<proto::CompletionResponse>(arena);
130+
google::protobuf::Arena::CreateMessage<typename Call::ResType>(arena);
104131

105-
auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
106132
std::string error;
107133
json2pb::Json2PbOptions options;
108134
butil::IOBuf& buf = ctrl->request_attachment();
109135
butil::IOBufAsZeroCopyInputStream iobuf_stream(buf);
110136
auto st = json2pb::JsonToProtoMessage(&iobuf_stream, req_pb, options, &error);
111137
if (!st) {
112138
ctrl->SetFailed(error);
113-
LOG(ERROR) << "parse json to proto failed: " << error;
139+
LOG(ERROR) << "parse json to proto failed: " << buf.to_string();
114140
return;
115141
}
116142

117-
std::shared_ptr<Call> call = std::make_shared<CompletionCall>(
118-
ctrl, done_guard.release(), req_pb, resp_pb);
119-
completion_service_impl_->process_async(call);
143+
auto call = std::make_shared<Call>(ctrl, guard.release(), req_pb, resp_pb);
144+
service->process_async(call);
120145
}
146+
} // namespace
121147

122-
void APIService::ChatCompletions(::google::protobuf::RpcController* controller,
123-
const proto::ChatRequest* request,
124-
proto::ChatResponse* response,
148+
void APIService::CompletionsHttp(::google::protobuf::RpcController* controller,
149+
const proto::HttpRequest* request,
150+
proto::HttpResponse* response,
125151
::google::protobuf::Closure* done) {
126-
// TODO with xllm-service
127-
}
128-
129-
namespace {
130-
template <typename ChatCall, typename Service>
131-
void ChatCompletionsImpl(std::unique_ptr<Service>& service,
132-
xllm::ClosureGuard& guard,
133-
::google::protobuf::Arena* arena,
134-
brpc::Controller* ctrl) {
135-
auto req_pb =
136-
google::protobuf::Arena::CreateMessage<typename ChatCall::ReqType>(arena);
137-
auto resp_pb =
138-
google::protobuf::Arena::CreateMessage<typename ChatCall::ResType>(arena);
152+
xllm::ClosureGuard done_guard(
153+
done,
154+
std::bind(request_in_metric, nullptr),
155+
std::bind(request_out_metric, (void*)controller));
156+
if (!request || !response || !controller) {
157+
LOG(ERROR) << "brpc request | respose | controller is null";
158+
return;
159+
}
139160

140-
std::string attachment = std::move(ctrl->request_attachment().to_string());
141-
std::string error;
161+
auto arena = response->GetArena();
162+
auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
142163

143-
google::protobuf::util::JsonParseOptions options;
144-
options.ignore_unknown_fields = true;
145-
auto json_status =
146-
google::protobuf::util::JsonStringToMessage(attachment, req_pb, options);
147-
if (!json_status.ok()) {
148-
ctrl->SetFailed(json_status.ToString());
149-
LOG(ERROR) << "parse json to proto failed: " << json_status.ToString();
150-
return;
164+
if (FLAGS_backend == "llm" || FLAGS_backend == "vlm") {
165+
CHECK(completion_service_impl_) << " completion service is invalid.";
166+
CommonCompletionsImpl<CompletionCall, CompletionServiceImpl>(
167+
completion_service_impl_, done_guard, arena, ctrl);
168+
} else if (FLAGS_backend == "rec") {
169+
CHECK(rec_completion_service_impl_)
170+
<< " rec completion service is invalid.";
171+
CommonCompletionsImpl<CompletionCall, RecCompletionServiceImpl>(
172+
rec_completion_service_impl_, done_guard, arena, ctrl);
151173
}
174+
}
152175

153-
auto call = std::make_shared<ChatCall>(
154-
ctrl, guard.release(), req_pb, resp_pb, arena != nullptr /*use_arena*/);
155-
service->process_async(call);
176+
void APIService::ChatCompletions(::google::protobuf::RpcController* controller,
177+
const proto::ChatRequest* request,
178+
proto::ChatResponse* response,
179+
::google::protobuf::Closure* done) {
180+
// TODO with xllm-service
156181
}
157-
} // namespace
158182

159183
void APIService::ChatCompletionsHttp(
160184
::google::protobuf::RpcController* controller,
@@ -175,12 +199,11 @@ void APIService::ChatCompletionsHttp(
175199
if (FLAGS_backend == "llm") {
176200
auto arena = response->GetArena();
177201
CHECK(chat_service_impl_) << " chat service is invalid.";
178-
ChatCompletionsImpl<ChatCall, ChatServiceImpl>(
202+
CommonCompletionsImpl<ChatCall, ChatServiceImpl>(
179203
chat_service_impl_, done_guard, arena, ctrl);
180204
} else if (FLAGS_backend == "vlm") {
181205
CHECK(mm_chat_service_impl_) << " mm chat service is invalid.";
182-
// TODO: fix me - temporarily using heap allocation instead of arena
183-
ChatCompletionsImpl<MMChatCall, MMChatServiceImpl>(
206+
CommonCompletionsImpl<MMChatCall, MMChatServiceImpl>(
184207
mm_chat_service_impl_, done_guard, nullptr, ctrl);
185208
}
186209
}

xllm/api_service/api_service.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ class APIService : public proto::XllmAPIService {
124124
std::unique_ptr<ModelsServiceImpl> models_service_impl_;
125125
std::unique_ptr<ImageGenerationServiceImpl> image_generation_service_impl_;
126126
std::unique_ptr<RerankServiceImpl> rerank_service_impl_;
127+
std::unique_ptr<RecCompletionServiceImpl> rec_completion_service_impl_;
127128
};
128129

129130
} // namespace xllm

xllm/api_service/completion_service_impl.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626

2727
#include "common/instance_name.h"
2828
#include "completion.pb.h"
29+
#include "core/framework/request/mm_data.h"
2930
#include "core/framework/request/request_output.h"
3031
#include "core/runtime/llm_master.h"
3132
#include "core/util/utils.h"
@@ -126,6 +127,7 @@ bool send_result_to_client_brpc(std::shared_ptr<CompletionCall> call,
126127
response.set_created(created_time);
127128
response.set_model(model);
128129

130+
// add choices into response
129131
response.mutable_choices()->Reserve(req_output.outputs.size());
130132
for (const auto& output : req_output.outputs) {
131133
auto* choice = response.add_choices();
@@ -137,6 +139,7 @@ bool send_result_to_client_brpc(std::shared_ptr<CompletionCall> call,
137139
}
138140
}
139141

142+
// add usage statistics
140143
if (req_output.usage.has_value()) {
141144
const auto& usage = req_output.usage.value();
142145
auto* proto_usage = response.mutable_usage();
@@ -163,6 +166,7 @@ CompletionServiceImpl::CompletionServiceImpl(
163166
void CompletionServiceImpl::process_async_impl(
164167
std::shared_ptr<CompletionCall> call) {
165168
const auto& rpc_request = call->request();
169+
166170
// check if model is supported
167171
const auto& model = rpc_request.model();
168172
if (unlikely(!models_.contains(model))) {
@@ -196,20 +200,20 @@ void CompletionServiceImpl::process_async_impl(
196200
request_params.decode_address = rpc_request.routing().decode_name();
197201
}
198202

203+
// schedule the request
199204
auto saved_streaming = request_params.streaming;
200205
auto saved_request_id = request_params.request_id;
201-
// schedule the request
202206
master_->handle_request(
203-
std::move(rpc_request.prompt()),
207+
std::move(call->request().prompt()),
204208
std::move(prompt_tokens),
205209
std::move(request_params),
206210
call.get(),
207211
[call,
208212
model,
209213
master = master_,
210-
stream = std::move(saved_streaming),
211-
include_usage = include_usage,
212-
request_id = std::move(saved_request_id),
214+
stream = saved_streaming,
215+
include_usage,
216+
request_id = saved_request_id,
213217
created_time = absl::ToUnixSeconds(absl::Now())](
214218
const RequestOutput& req_output) -> bool {
215219
if (req_output.status.has_value()) {

0 commit comments

Comments
 (0)