From 4ba49f41de1bd0fff3cd705d9bab9759dc97e2f2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 10 Nov 2025 16:17:49 +0100 Subject: [PATCH] poc router/proxy server --- common/arg.cpp | 8 +- common/common.h | 2 - tests/test-quantize-stats.cpp | 2 +- tools/server/router.h | 301 ++++++++++++++++++++++++++++++++++ tools/server/server.cpp | 271 +++++++++--------------------- tools/server/utils.hpp | 71 ++++++++ 6 files changed, 450 insertions(+), 205 deletions(-) create mode 100644 tools/server/router.h diff --git a/common/arg.cpp b/common/arg.cpp index 430ab45dfe26e..5ec3b5ae0c0b9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -404,8 +404,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context } // handle model and download - { - auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline); + if (!params.model.path.empty() || !params.model.url.empty() || !params.model.hf_repo.empty()) { + auto res = common_params_handle_model(params.model, params.hf_token, "", params.offline); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -2073,9 +2073,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-m", "--model"}, "FNAME", ex == LLAMA_EXAMPLE_EXPORT_LORA ? std::string("model path from which to load base model") - : string_format( + : std::string( "model path (default: `models/$filename` with filename from `--hf-file` " - "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH + "or `--model-url` if set, otherwise empty)" ), [](common_params & params, const std::string & value) { params.model.path = value; diff --git a/common/common.h b/common/common.h index f42c083faa254..6500e96ab1705 100644 --- a/common/common.h +++ b/common/common.h @@ -28,8 +28,6 @@ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) -#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" - struct common_adapter_lora_info { std::string path; float scale; diff --git a/tests/test-quantize-stats.cpp b/tests/test-quantize-stats.cpp index a284a1f0c5e31..0cb29cd41b3da 100644 --- a/tests/test-quantize-stats.cpp +++ b/tests/test-quantize-stats.cpp @@ -23,7 +23,7 @@ #endif struct quantize_stats_params { - std::string model = DEFAULT_MODEL_PATH; + std::string model = ""; bool verbose = false; bool per_layer_stats = false; bool print_histogram = false; diff --git a/tools/server/router.h b/tools/server/router.h new file mode 100644 index 0000000000000..79faf3409c28c --- /dev/null +++ b/tools/server/router.h @@ -0,0 +1,301 @@ +#pragma once + +#include "utils.hpp" +#include "download.h" + +#include +#include + +#if defined(__APPLE__) && defined(__MACH__) +// macOS: use _NSGetExecutablePath to get the executable path +#include +#include +#endif + +using router_callback_t = std::function; + +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { + // skip GH copilot requests when using default port + if (req.path == "/v1/health") { + return; + } + + // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch + + SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); + + SRV_DBG("request: %s\n", req.body.c_str()); + SRV_DBG("response: %s\n", res.body.c_str()); +} + +static std::unique_ptr create_http_server(const common_params & params) { + std::unique_ptr svr; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); + svr.reset( + new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) + ); + } else { + LOG_INF("Running without SSL\n"); + svr.reset(new httplib::Server()); + } +#else + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_ERR("Server is built without SSL support\n"); + return nullptr; + } + svr.reset(new httplib::Server()); +#endif + + svr->set_default_headers({{"Server", "llama.cpp"}}); + svr->set_logger(log_server_request); + + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { + std::string message; + try { + std::rethrow_exception(ep); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "Unknown Exception"; + } + + try { + json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); + LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); + res_error(res, formatted_error); + } catch (const std::exception & e) { + LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + } + }); + + svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { + if (res.status == 404) { + res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); + } + // for other error codes, we skip processing here because it's already done by res_error() + }); + + // set timeouts and change hostname and port + svr->set_read_timeout (params.timeout_read); + svr->set_write_timeout(params.timeout_write); + + int n_threads_http = params.n_threads_http; + if (n_threads_http < 1) { + // +2 threads for monitoring endpoints + n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); + } + svr->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); }; + + return svr; +} + +struct server_instance { + pid_t pid; + int port; +}; + +namespace router { + +std::function shutdown_handler; +std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + +inline void signal_handler(int signal) { + if (router::is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + router::shutdown_handler(signal); +} + +// https://gist.github.com/Jacob-Tate/7b326a086cf3f9d46e32315841101109 +static std::filesystem::path get_abs_exe_path() { + #if defined(_MSC_VER) + wchar_t path[FILENAME_MAX] = { 0 }; + GetModuleFileNameW(nullptr, path, FILENAME_MAX); + return std::filesystem::path(path); + #elif defined(__APPLE__) && defined(__MACH__) + char small_path[PATH_MAX]; + uint32_t size = sizeof(small_path); + + if (_NSGetExecutablePath(small_path, &size) == 0) { + // resolve any symlinks to get absolute path + try { + return std::filesystem::canonical(std::filesystem::path(small_path)); + } catch (...) { + return std::filesystem::path(small_path); + } + } else { + // buffer was too small, allocate required size and call again + std::vector buf(size); + if (_NSGetExecutablePath(buf.data(), &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(buf.data())); + } catch (...) { + return std::filesystem::path(buf.data()); + } + } + return std::filesystem::path(std::string(buf.data(), (size > 0) ? size : 0)); + } + #else + char path[FILENAME_MAX]; + ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); + return std::filesystem::path(std::string(path, (count > 0) ? count: 0)); + #endif +} + +static int create_router_server(common_params params, char ** envp) { + std::unique_ptr svr = create_http_server(params); + + std::mutex m; + std::map instances; + + auto add_instance = [&](const std::string & id, server_instance && inst) { + std::lock_guard lock(m); + instances.emplace(id, std::move(inst)); + LOG_INF("added instance id=%s, pid=%d, port=%d\n", id.c_str(), inst.pid, inst.port); + }; + + auto remove_instance = [&](const std::string & id) { + std::lock_guard lock(m); + instances.erase(id); + LOG_INF("removed instance id=%s\n", id.c_str()); + }; + + auto create_instance = [&](const std::string & id, const common_params &) { + server_instance inst; + inst.port = rand() % 10000 + 20000; // random port between 20000 and 29999 + + pid_t pid = 0; + { + // Prepare arguments (pass original or custom ones) using mutable storage for argv + std::filesystem::path exe_path = get_abs_exe_path(); + std::string path = exe_path.string(); + + std::vector arg_strs; + arg_strs.push_back(path); + arg_strs.push_back("-hf"); + arg_strs.push_back(id); + arg_strs.push_back("--port"); + arg_strs.push_back(std::to_string(inst.port)); + + std::vector child_argv; + child_argv.reserve(arg_strs.size() + 1); + for (auto &s : arg_strs) { + child_argv.push_back(const_cast(s.c_str())); + } + child_argv.push_back(nullptr); + + LOG_INF("spawning instance %s with hf=%s on port %d\n", path.c_str(), id.c_str(), inst.port); + if (posix_spawn(&pid, path.c_str(), NULL, NULL, child_argv.data(), envp) != 0) { + perror("posix_spawn"); + exit(1); + } else { + inst.pid = pid; + } + } + add_instance(id, std::move(inst)); + + std::thread th([id, pid, &remove_instance]() { + int status = 0; + waitpid(pid, &status, 0); + SRV_INF("instance with pid %d exited with status %d\n", pid, status); + remove_instance(id); + }); + if (th.joinable()) { + th.detach(); // for testing + } else { + SRV_ERR("failed to detach thread for instance pid %d\n", inst.pid); + } + return 0; + }; + + // just PoC, non-OAI compat + svr->Get("/models", [instances](const httplib::Request &, httplib::Response & res) { + auto models = common_list_cached_models(); + json models_json = json::array(); + for (const auto & model : models) { + models_json.push_back(json { + {"model", model.to_string()}, + {"loaded", instances.find(model.to_string()) != instances.end()}, // TODO: non-thread-safe here + }); + } + res.set_content(safe_json_to_str(json {{"models", models_json}}), MIMETYPE_JSON); + res.status = 200; + }); + + svr->Post("/models/load", [¶ms, &create_instance](const httplib::Request & req, httplib::Response & res) { + const json body = json::parse(req.body); + const std::string model_str = json_value(body, "model", std::string()); + if (model_str.empty()) { + res_error(res, format_error_response("model field is required", ERROR_TYPE_INVALID_REQUEST)); + return; + } + if (create_instance(model_str, params) == 0) { + res.set_content(safe_json_to_str(json {{"status", "loading"}, {"model", model_str}}), MIMETYPE_JSON); + res.status = 200; + } else { + res_error(res, format_error_response("failed to create model instance", ERROR_TYPE_SERVER)); + } + }); + + svr->set_error_handler([&instances](const httplib::Request & req, httplib::Response & res) { + bool is_unhandled = req.matched_route.empty(); + if (is_unhandled && req.method == "POST") { + // proxy to the right instance based on HF model id + const json body = json::parse(req.body); + const std::string model_str = json_value(body, "model", std::string()); + const auto it = instances.find(model_str); + if (it != instances.end()) { + const server_instance & inst = it->second; + + // TODO: support streaming and other methods + httplib::Client cli("127.0.0.1", inst.port); + auto cli_res = cli.Post( + req.path, + req.headers, + req.body, + MIMETYPE_JSON + ); + res.status = cli_res->status; + res.set_content(cli_res->body, cli_res->get_header_value("Content-Type")); + } + } + }); + + // run the HTTP server in a thread + svr->bind_to_port(params.hostname, params.port); + std::thread t([&]() { svr->listen_after_bind(); }); + svr->wait_until_ready(); + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = router::signal_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (router::signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + router::shutdown_handler = [&](int) { + svr->stop(); + for (const auto & inst : instances) { + LOG_INF("terminating instance id=%s, pid=%d\n", inst.first.c_str(), inst.second.pid); + kill(inst.second.pid, SIGTERM); + } + }; + t.join(); + + exit(0); +} + +} // namespace router diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6bd4be3cc17c4..51287ce150b31 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,5 +1,6 @@ #include "chat.h" #include "utils.hpp" +#include "router.h" #include "arg.h" #include "common.h" @@ -10,9 +11,6 @@ #include "speculative.h" #include "mtmd.h" -// mime type for sending response -#define MIMETYPE_JSON "application/json; charset=utf-8" - // auto generated files (see README.md for details) #include "index.html.gz.hpp" #include "loading.html.hpp" @@ -76,18 +74,6 @@ enum oaicompat_type { OAICOMPAT_TYPE_EMBEDDING, }; -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error - ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error -}; - static bool server_task_type_need_embd(server_task_type task_type) { switch (task_type) { case SERVER_TASK_TYPE_EMBEDDING: @@ -1232,51 +1218,6 @@ struct server_task_result_rerank : server_task_result { } }; -// this function maybe used outside of server_task_result_error -static json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - case ERROR_TYPE_EXCEED_CONTEXT_SIZE: - type_str = "exceed_context_size_error"; - code = 400; - break; - } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; -} - struct server_task_result_error : server_task_result { int index = 0; error_type err_type = ERROR_TYPE_SERVER; @@ -4418,20 +4359,6 @@ struct server_context { } }; -static void log_server_request(const httplib::Request & req, const httplib::Response & res) { - // skip GH copilot requests when using default port - if (req.path == "/v1/health") { - return; - } - - // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch - - SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); - - SRV_DBG("request: %s\n", req.body.c_str()); - SRV_DBG("response: %s\n", res.body.c_str()); -} - std::function shutdown_handler; std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; @@ -4446,102 +4373,23 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } -int main(int argc, char ** argv) { - // own arguments required by this example - common_params params; - - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { - return 1; - } - - // TODO: should we have a separate n_parallel parameter for the server? - // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 - // TODO: this is a common configuration that is suitable for most local use cases - // however, overriding the parameters is a bit confusing - figure out something more intuitive - if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { - LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); - - params.n_parallel = 4; - params.kv_unified = true; - } - - common_init(); - - // struct that contains llama context and inference - server_context ctx_server; - +static int create_model_server(common_params params) { llama_backend_init(); llama_numa_init(params.numa); + - LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); - LOG_INF("\n"); - LOG_INF("%s\n", common_params_get_system_info(params).c_str()); - LOG_INF("\n"); + LOG_INF("%s: starting model server on port %d\n", __func__, params.port); - std::unique_ptr svr; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); - svr.reset( - new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) - ); - } else { - LOG_INF("Running without SSL\n"); - svr.reset(new httplib::Server()); - } -#else - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_ERR("Server is built without SSL support\n"); + std::unique_ptr svr = create_http_server(params); + if (!svr) { + LOG_ERR("%s: failed to create HTTP server\n", __func__); return 1; } - svr.reset(new httplib::Server()); -#endif - - std::atomic state{SERVER_STATE_LOADING_MODEL}; - - svr->set_default_headers({{"Server", "llama.cpp"}}); - svr->set_logger(log_server_request); - - auto res_error = [](httplib::Response & res, const json & error_data) { - json final_response {{"error", error_data}}; - res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); - }; - - auto res_ok = [](httplib::Response & res, const json & data) { - res.set_content(safe_json_to_str(data), MIMETYPE_JSON); - res.status = 200; - }; - - svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { - std::string message; - try { - std::rethrow_exception(ep); - } catch (const std::exception & e) { - message = e.what(); - } catch (...) { - message = "Unknown Exception"; - } - - try { - json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); - LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); - res_error(res, formatted_error); - } catch (const std::exception & e) { - LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); - } - }); - svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { - if (res.status == 404) { - res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); - } - // for other error codes, we skip processing here because it's already done by res_error() - }); + // struct that contains llama context and inference + server_context ctx_server; - // set timeouts and change hostname and port - svr->set_read_timeout (params.timeout_read); - svr->set_write_timeout(params.timeout_write); + std::atomic state{SERVER_STATE_LOADING_MODEL}; std::unordered_map log_data; @@ -4562,7 +4410,7 @@ int main(int argc, char ** argv) { // Middlewares // - auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { + auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { static const std::unordered_set public_endpoints = { "/health", "/v1/health", @@ -4600,7 +4448,7 @@ int main(int argc, char ** argv) { return false; }; - auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { + auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); if (current_state == SERVER_STATE_LOADING_MODEL) { auto tmp = string_split(req.path, '.'); @@ -4788,7 +4636,7 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -4820,7 +4668,7 @@ int main(int argc, char ** argv) { res_ok(res, result->to_json()); }; - const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -4853,7 +4701,7 @@ int main(int argc, char ** argv) { res_ok(res, result->to_json()); }; - const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_SLOT_ERASE); @@ -4876,7 +4724,7 @@ int main(int argc, char ** argv) { res_ok(res, result->to_json()); }; - const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { if (params.slot_save_path.empty()) { res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -4905,7 +4753,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_props = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + const auto handle_props = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { json default_generation_settings_for_props; { @@ -4947,7 +4795,7 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params_base.endpoint_props) { res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -4960,7 +4808,7 @@ int main(int argc, char ** argv) { res_ok(res, {{ "success", true }}); }; - const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + const auto handle_api_show = [&ctx_server](const httplib::Request &, httplib::Response & res) { bool has_mtmd = ctx_server.mctx != nullptr; json data = { { @@ -4991,7 +4839,7 @@ int main(int argc, char ** argv) { // handle completion-like requests (completion, chat, infill) // we can optionally provide a custom format for partial results and final results - const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( + const auto handle_completions_impl = [&ctx_server]( server_task_type type, json & data, const std::vector & files, @@ -5139,7 +4987,7 @@ int main(int argc, char ** argv) { OAICOMPAT_TYPE_COMPLETION); }; - const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { // check model compatibility std::string err; if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { @@ -5238,7 +5086,7 @@ int main(int argc, char ** argv) { }; // same with handle_chat_completions, but without inference part - const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_apply_template = [&ctx_server](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); std::vector files; // dummy, unused json data = oaicompat_chat_params_parse( @@ -5248,7 +5096,7 @@ int main(int argc, char ** argv) { res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; - const auto handle_models = [¶ms, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) { + const auto handle_models = [¶ms, &ctx_server, &state](const httplib::Request &, httplib::Response & res) { server_state current_state = state.load(); json model_meta = nullptr; if (current_state == SERVER_STATE_READY) { @@ -5293,7 +5141,7 @@ int main(int argc, char ** argv) { res_ok(res, models); }; - const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); json tokens_response = json::array(); @@ -5334,7 +5182,7 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::string content; @@ -5347,7 +5195,7 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { + const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { if (!ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -5457,7 +5305,7 @@ int main(int argc, char ** argv) { handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); }; - const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; @@ -5665,12 +5513,6 @@ int main(int argc, char ** argv) { // // Start the server // - if (params.n_threads_http < 1) { - // +2 threads for monitoring endpoints - params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); - } - log_data["n_threads_http"] = std::to_string(params.n_threads_http); - svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; // clean up function, to be called before exit auto clean_up = [&svr, &ctx_server]() { @@ -5747,19 +5589,19 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.terminate(); }; -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = signal_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); - sigaction(SIGTERM, &sigint_action, NULL); -#elif defined (_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif +// #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +// struct sigaction sigint_action; +// sigint_action.sa_handler = signal_handler; +// sigemptyset (&sigint_action.sa_mask); +// sigint_action.sa_flags = 0; +// sigaction(SIGINT, &sigint_action, NULL); +// sigaction(SIGTERM, &sigint_action, NULL); +// #elif defined (_WIN32) +// auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { +// return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; +// }; +// SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +// #endif LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__, is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : @@ -5774,3 +5616,36 @@ int main(int argc, char ** argv) { return 0; } + +int main(int argc, char ** argv, char ** envp) { + // own arguments required by this example + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { + return 1; + } + + // TODO: should we have a separate n_parallel parameter for the server? + // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 + // TODO: this is a common configuration that is suitable for most local use cases + // however, overriding the parameters is a bit confusing - figure out something more intuitive + if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { + LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); + + params.n_parallel = 4; + params.kv_unified = true; + } + + common_init(); + + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + + if (params.model.path.empty()) { + return router::create_router_server(params, envp); + } else { + return create_model_server(params); + } +} diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 2bce2f4a47af9..4478454632873 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -31,6 +31,9 @@ #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" +// mime type for sending response +#define MIMETYPE_JSON "application/json; charset=utf-8" + using json = nlohmann::ordered_json; #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) @@ -48,6 +51,63 @@ using json = nlohmann::ordered_json; #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error + ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error +}; + +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + case ERROR_TYPE_EXCEED_CONTEXT_SIZE: + type_str = "exceed_context_size_error"; + code = 400; + break; + } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; +} + using raw_buffer = std::vector; template @@ -1555,3 +1615,14 @@ static server_tokens format_rerank(const struct llama_model * model, const struc return result; } + +static void res_error(httplib::Response & res, const json & error_data) { + json final_response {{"error", error_data}}; + res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); + res.status = json_value(error_data, "code", 500); +} + +static void res_ok(httplib::Response & res, const json & data) { + res.set_content(safe_json_to_str(data), MIMETYPE_JSON); + res.status = 200; +}