diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index 42241242d41..3dfeb91a9e6 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -318,6 +318,28 @@ std::string getEnvNixlInterface() return nixlInterface; } +std::string getEnvNixlBackend() +{ + static std::once_flag flag; + static std::string nixlBackend; + + std::call_once(flag, + [&]() + { + char const* nixl_backend = std::getenv("TRTLLM_NIXL_KVCACHE_BACKEND"); + if (nixl_backend) + { + nixlBackend = nixl_backend; + } + else + { + // Default to UCX if not specified + nixlBackend = "UCX"; + } + }); + return nixlBackend; +} + bool getEnvDisaggLayerwise() { static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE"); diff --git a/cpp/tensorrt_llm/common/envUtils.h b/cpp/tensorrt_llm/common/envUtils.h index 7296c84ef1d..6142781f6ac 100644 --- a/cpp/tensorrt_llm/common/envUtils.h +++ b/cpp/tensorrt_llm/common/envUtils.h @@ -88,6 +88,8 @@ std::string getEnvUCXInterface(); std::string getEnvNixlInterface(); +std::string getEnvNixlBackend(); + bool getEnvDisaggLayerwise(); bool getEnvParallelCacheSend(); diff --git a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp index 0cb93bc0a1a..eb3ecb6c05f 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -345,15 +346,27 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config) mRawAgent = std::make_unique(config.mName, std::move(nixlConfig)); } + std::string nixlBackend = common::getEnvNixlBackend(); + // List of supported backends - extend this list as new backends are added + static const std::set kSUPPORTED_BACKENDS = {"UCX"}; + + if (kSUPPORTED_BACKENDS.find(nixlBackend) == kSUPPORTED_BACKENDS.end()) + { + TLLM_LOG_ERROR("Unsupported NIXL backend: %s, fallback to UCX", nixlBackend.c_str()); + nixlBackend = "UCX"; + } + + TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent using NIXL backend: %s", nixlBackend.c_str()); + nixl_b_params_t init1; nixl_mem_list_t mems1; - status = mRawAgent->getPluginParams("UCX", mems1, init1); + status = mRawAgent->getPluginParams(nixlBackend.c_str(), mems1, init1); TLLM_CHECK(status == NIXL_SUCCESS); - status = mRawAgent->createBackend("UCX", init1, mRawBackend); + status = mRawAgent->createBackend(nixlBackend.c_str(), init1, mRawBackend); if (status != NIXL_SUCCESS || !mRawBackend) { - TLLM_THROW("Failed to create NIXL backend"); + TLLM_THROW("Failed to create NIXL backend: %s", nixlBackend.c_str()); } mExtraParams.backends.push_back(mRawBackend); TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent mAddress: %s", mAddress.c_str());