Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 7351124

Browse files
tikikunhiento09
authored andcommitted
simpler stop word for efficiency
1 parent 83e0564 commit 7351124

File tree

1 file changed

+12
-25
lines changed

1 file changed

+12
-25
lines changed

cpp/tensorrt_llm/nitro/controllers/tensorrtllm.cc

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,8 @@ GenerationInput::TensorPtr tensorrtllm::getTensorSingleStopWordList(int stopToke
119119

120120
GenerationInput::TensorPtr tensorrtllm::getTensorChatMLStopWordList()
121121
{
122-
std::vector<int32_t> stopWordsTokens = { 28766, 321, 28730, 416, 28766, 28767, 2, 32000, 6, 7, 8, -1, -1, -1,
123-
-1, -1}; // Extend with -1 for increased length
124-
return gptSession->getBufferManager().copyFrom(stopWordsTokens, ITensor::makeShape({1, 2, 8}), MemoryType::kGPU);
122+
std::vector<int32_t> stopWordsTokens = {321, 28730, 416, 2, 32000, 3, 4, 5, -1, -1}; // Extend with -1 for increased length
123+
return gptSession->getBufferManager().copyFrom(stopWordsTokens, ITensor::makeShape({1, 2, 5}), MemoryType::kGPU);
125124
}
126125

127126
GenerationInput tensorrtllm::createGenerationInput(std::vector<int32_t> inputIdsHost)
@@ -148,19 +147,8 @@ GenerationOutput tensorrtllm::createGenerationOutput()
148147
}
149148

150149
void inferenceThread(std::shared_ptr<inferenceState> inferState, std::vector<int32_t> inputIdsHost,
151-
std::function<void(const HttpResponsePtr&)> callback, tensorrtllm* self)
150+
std::function<void(const HttpResponsePtr&)> callback, tensorrtllm* self,SamplingConfig samplingConfig,int inputLen, int outputLen)
152151
{
153-
const int inputLen = inputIdsHost.size();
154-
const int outputLen = 2048 - inputLen;
155-
156-
// Create sampling config
157-
SamplingConfig samplingConfig{1};
158-
samplingConfig.temperature = std::vector{0.0f};
159-
samplingConfig.randomSeed = std::vector{static_cast<uint64_t>(42ull)};
160-
samplingConfig.topK = std::vector{40};
161-
samplingConfig.topP = std::vector{0.0f};
162-
samplingConfig.minLength = std::vector{outputLen};
163-
samplingConfig.repetitionPenalty = std::vector{1.3f};
164152

165153
// Input preparation
166154

@@ -216,12 +204,11 @@ void tensorrtllm::chat_completion(
216204

217205
nlohmann::json data;
218206

219-
data["stream"] = completion.stream;
220-
data["n_predict"] = completion.max_tokens;
221-
data["top_p"] = completion.top_p;
222-
data["temperature"] = completion.temperature;
223-
data["frequency_penalty"] = completion.frequency_penalty;
207+
//data["stream"] = completion.stream;
208+
//data["n_predict"] = completion.max_tokens;
224209
data["presence_penalty"] = completion.presence_penalty;
210+
211+
225212
const Json::Value& messages = completion.messages;
226213

227214
// Format the input from user
@@ -261,20 +248,20 @@ void tensorrtllm::chat_completion(
261248

262249
std::vector<int32_t> inputIdsHost = nitro_tokenizer->encode(formatted_input);
263250
const int inputLen = inputIdsHost.size();
264-
const int outputLen = 2048 - inputLen;
251+
const int outputLen = completion.max_tokens - inputLen;
265252

266253
// Create sampling config
267254
SamplingConfig samplingConfig{1};
268-
samplingConfig.temperature = std::vector{0.0f};
255+
samplingConfig.temperature = std::vector{completion.temperature};
269256
samplingConfig.randomSeed = std::vector{static_cast<uint64_t>(42ull)};
270257
samplingConfig.topK = std::vector{40};
271-
samplingConfig.topP = std::vector{0.0f};
258+
samplingConfig.topP = std::vector{completion.top_p};
272259
samplingConfig.minLength = std::vector{outputLen};
273-
samplingConfig.repetitionPenalty = std::vector{1.3f};
260+
samplingConfig.repetitionPenalty = std::vector{completion.frequency_penalty};
274261

275262
// Input preparation
276263

277-
std::thread infThread(inferenceThread, inferState, inputIdsHost, callback, this);
264+
std::thread infThread(inferenceThread, inferState, inputIdsHost, callback, this,samplingConfig,inputLen,outputLen);
278265
infThread.detach(); // Detach the thread to allow it to run independently
279266

280267
auto chunked_content_provider = [this,inferState](char* pBuffer, std::size_t nBuffSize) -> std::size_t

0 commit comments

Comments
 (0)