Skip to content

Commit f94e026

Browse files
committed
refactor: add interior mutability to tensorrt_llm_backend_t
Make `tensorrt_llm_backend_t` interior mutable by marking the `inner_` struct as a `mutable` field, so we can make the methods `const`. This makes the pointer accessible from multiple threads at the Rust side without wrapping a Mutex. The underlying tensorrt_llm::executor::Executor already contains a mutex.
1 parent fab395b commit f94e026

File tree

3 files changed

+11
-12
lines changed

3 files changed

+11
-12
lines changed

backends/trtllm/csrc/ffi.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ namespace huggingface::tgi::backends::trtllm {
107107

108108
class tensorrt_llm_backend_t {
109109
private:
110-
backend_t inner_;
110+
mutable backend_t inner_;
111111

112112
// m_created_time is a reference point to convert time from c++ time_point
113113
// to rust Instant.
@@ -131,7 +131,7 @@ namespace huggingface::tgi::backends::trtllm {
131131
float_t repetition_penalty,
132132
float_t frequency_penalty,
133133
uint64_t seed
134-
) {
134+
) const {
135135
// This is enabled only if using add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_TRACE)
136136
SPDLOG_TRACE(FMT_STRING("[FFI] Submitting {:d} prompt tokens to the executor"));
137137

@@ -152,7 +152,7 @@ namespace huggingface::tgi::backends::trtllm {
152152
}
153153
}
154154

155-
std::unique_ptr<std::vector<generation_step_t>> pull_tokens() noexcept {
155+
std::unique_ptr<std::vector<generation_step_t>> pull_tokens() const noexcept {
156156
if (num_tokens_ready() > 0) [[likely]] {
157157
const auto responses = inner_.pull_tokens();
158158

@@ -176,7 +176,7 @@ namespace huggingface::tgi::backends::trtllm {
176176
}
177177
}
178178

179-
void cancel(request_id_t request_id) noexcept {
179+
void cancel(request_id_t request_id) const noexcept {
180180
SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id);
181181
inner_.cancel(request_id);
182182
}

backends/trtllm/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ mod ffi {
8383
fn num_tokens_ready(self: &TensorRtLlmBackendImpl) -> usize;
8484

8585
fn submit(
86-
self: Pin<&mut TensorRtLlmBackendImpl>,
86+
self: &TensorRtLlmBackendImpl,
8787
tokens: &[u32],
8888
max_new_tokens: u32,
8989
top_k: u32,
@@ -95,10 +95,10 @@ mod ffi {
9595
) -> Result<u64>;
9696

9797
fn pull_tokens(
98-
self: Pin<&mut TensorRtLlmBackendImpl>,
98+
self: &TensorRtLlmBackendImpl,
9999
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
100100

101-
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
101+
fn cancel(self: &TensorRtLlmBackendImpl, request_id: u64);
102102
}
103103
}
104104

backends/trtllm/src/looper.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
8080
fn executor_status_looper(
8181
max_inflight_requests: usize,
8282
tokenizer: Tokenizer,
83-
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
83+
backend: UniquePtr<TensorRtLlmBackendImpl>,
8484
mut backlog: UnboundedReceiver<GenerationContext>,
8585
created_time: Instant,
8686
) {
@@ -111,7 +111,7 @@ fn executor_status_looper(
111111
};
112112

113113
// Submit to the TensorRT-LLM executor for scheduling
114-
match backend.pin_mut().submit(
114+
match backend.submit(
115115
&input_ids.unwrap(), // This is checked beforehand in validate()
116116
stopping_params.max_new_tokens,
117117
top_k,
@@ -143,8 +143,7 @@ fn executor_status_looper(
143143
}
144144

145145
if backend.num_tokens_ready() > 0 {
146-
let mut backend = backend.pin_mut();
147-
match backend.as_mut().pull_tokens() {
146+
match backend.pull_tokens() {
148147
Ok(responses) => {
149148
// Iterate through all the decoded token
150149
for step in responses.deref() {
@@ -183,7 +182,7 @@ fn executor_status_looper(
183182
"Client dropped - removing request {} from tracked requests",
184183
step.request_id
185184
);
186-
backend.as_mut().cancel(step.request_id);
185+
backend.cancel(step.request_id);
187186
let _ = in_flights.remove(&step.request_id);
188187
}
189188
} else {

0 commit comments

Comments
 (0)