Skip to content

Commit cc51aab

Browse files
feat: add energy consumption for each request
1 parent 06d9d88 commit cc51aab

File tree

4 files changed

+56
-1
lines changed

4 files changed

+56
-1
lines changed

router/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ csv = "1.3.0"
6565
ureq = "=2.9"
6666
pyo3 = { workspace = true }
6767
chrono = "0.4.39"
68+
nvml-wrapper = "0.11.0"
6869

6970

7071
[build-dependencies]

router/src/chat.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ mod tests {
412412
generated_tokens: 10,
413413
seed: None,
414414
finish_reason: FinishReason::Length,
415+
energy_mj: None,
415416
}),
416417
});
417418
if let ChatEvent::Events(events) = events {

router/src/lib.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ use tracing::warn;
2323
use utoipa::ToSchema;
2424
use uuid::Uuid;
2525
use validation::Validation;
26+
use nvml_wrapper::Nvml;
27+
use std::sync::OnceLock;
28+
29+
static NVML: OnceLock<Option<Nvml>> = OnceLock::new();
2630

2731
#[allow(clippy::large_enum_variant)]
2832
#[derive(Clone)]
@@ -1468,6 +1472,9 @@ pub(crate) struct Details {
14681472
pub best_of_sequences: Option<Vec<BestOfSequence>>,
14691473
#[serde(skip_serializing_if = "Vec::is_empty")]
14701474
pub top_tokens: Vec<Vec<Token>>,
1475+
#[serde(skip_serializing_if = "Option::is_none")]
1476+
#[schema(nullable = true, example = 152)]
1477+
pub energy_mj: Option<u64>,
14711478
}
14721479

14731480
#[derive(Serialize, ToSchema)]
@@ -1498,6 +1505,9 @@ pub(crate) struct StreamDetails {
14981505
pub seed: Option<u64>,
14991506
#[schema(example = 1)]
15001507
pub input_length: u32,
1508+
#[serde(skip_serializing_if = "Option::is_none")]
1509+
#[schema(nullable = true, example = 152)]
1510+
pub energy_mj: Option<u64>,
15011511
}
15021512

15031513
#[derive(Serialize, ToSchema, Clone)]
@@ -1546,6 +1556,34 @@ impl Default for ModelsInfo {
15461556
}
15471557
}
15481558

1559+
pub struct EnergyMonitor;
1560+
1561+
impl EnergyMonitor {
1562+
fn nvml() -> Option<&'static Nvml> {
1563+
NVML.get_or_init(|| Nvml::init().ok()).as_ref()
1564+
}
1565+
1566+
pub fn energy_mj(gpu_index: u32) -> Option<u64> {
1567+
let nvml = Self::nvml()?;
1568+
let device = nvml.device_by_index(gpu_index).ok()?;
1569+
device.total_energy_consumption().ok()
1570+
}
1571+
1572+
pub fn total_energy_mj() -> Option<u64> {
1573+
let nvml = Self::nvml()?;
1574+
let count = nvml.device_count().ok()?;
1575+
let mut total = 0;
1576+
for i in 0..count {
1577+
if let Ok(device) = nvml.device_by_index(i) {
1578+
if let Ok(energy) = device.total_energy_consumption() {
1579+
total += energy;
1580+
}
1581+
}
1582+
}
1583+
Some(total)
1584+
}
1585+
}
1586+
15491587
#[cfg(test)]
15501588
mod tests {
15511589
use super::*;

router/src/server.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::{
2626
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
2727
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
2828
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
29-
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
29+
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, EnergyMonitor,
3030
};
3131
use crate::{ChatTokenizeResponse, JsonSchemaConfig};
3232
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
@@ -293,6 +293,7 @@ pub(crate) async fn generate_internal(
293293
span: tracing::Span,
294294
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
295295
let start_time = Instant::now();
296+
let start_energy = EnergyMonitor::total_energy_mj();
296297
metrics::counter!("tgi_request_count").increment(1);
297298

298299
// Do not long ultra long inputs, like image payloads.
@@ -317,6 +318,12 @@ pub(crate) async fn generate_internal(
317318
}
318319
_ => (infer.generate(req).await?, None),
319320
};
321+
322+
let end_energy = EnergyMonitor::total_energy_mj();
323+
let energy_mj = match (start_energy, end_energy) {
324+
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
325+
_ => None,
326+
};
320327

321328
// Token details
322329
let input_length = response._input_length;
@@ -354,6 +361,7 @@ pub(crate) async fn generate_internal(
354361
seed: response.generated_text.seed,
355362
best_of_sequences,
356363
top_tokens: response.top_tokens,
364+
energy_mj,
357365
})
358366
}
359367
false => None,
@@ -515,6 +523,7 @@ async fn generate_stream_internal(
515523
impl Stream<Item = Result<StreamResponse, InferError>>,
516524
) {
517525
let start_time = Instant::now();
526+
let start_energy = EnergyMonitor::total_energy_mj();
518527
metrics::counter!("tgi_request_count").increment(1);
519528

520529
tracing::debug!("Input: {}", req.inputs);
@@ -590,13 +599,19 @@ async fn generate_stream_internal(
590599
queued,
591600
top_tokens,
592601
} => {
602+
let end_energy = EnergyMonitor::total_energy_mj();
603+
let energy_mj = match (start_energy, end_energy) {
604+
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
605+
_ => None,
606+
};
593607
// Token details
594608
let details = match details {
595609
true => Some(StreamDetails {
596610
finish_reason: generated_text.finish_reason,
597611
generated_tokens: generated_text.generated_tokens,
598612
seed: generated_text.seed,
599613
input_length,
614+
energy_mj,
600615
}),
601616
false => None,
602617
};

0 commit comments

Comments
 (0)