Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 65 additions & 65 deletions .semversioner/0.3.3.json
Original file line number Diff line number Diff line change
@@ -1,66 +1,66 @@
{
"changes": [
{
"description": "Add entrypoints for incremental indexing",
"type": "patch"
},
{
"description": "Clean up and organize run index code",
"type": "patch"
},
{
"description": "Consistent config loading. Resolves #99 and Resolves #1049",
"type": "patch"
},
{
"description": "Fix circular dependency when running prompt tune api directly",
"type": "patch"
},
{
"description": "Fix default settings for embedding",
"type": "patch"
},
{
"description": "Fix img for auto tune",
"type": "patch"
},
{
"description": "Fix img width",
"type": "patch"
},
{
"description": "Fixed a bug in prompt tuning process",
"type": "patch"
},
{
"description": "Refactor text unit build at local search",
"type": "patch"
},
{
"description": "Update Prompt Tuning docs",
"type": "patch"
},
{
"description": "Update create_pipeline_config.py",
"type": "patch"
},
{
"description": "Update prompt tune command in docs",
"type": "patch"
},
{
"description": "add querying from azure blob storage",
"type": "patch"
},
{
"description": "fix setting base_dir to full paths when not using file system.",
"type": "patch"
},
{
"description": "fix strategy config in entity_extraction",
"type": "patch"
}
],
"created_at": "2024-09-10T19:51:24+00:00",
"version": "0.3.3"
{
"changes": [
{
"description": "Add entrypoints for incremental indexing",
"type": "patch"
},
{
"description": "Clean up and organize run index code",
"type": "patch"
},
{
"description": "Consistent config loading. Resolves #99 and Resolves #1049",
"type": "patch"
},
{
"description": "Fix circular dependency when running prompt tune api directly",
"type": "patch"
},
{
"description": "Fix default settings for embedding",
"type": "patch"
},
{
"description": "Fix img for auto tune",
"type": "patch"
},
{
"description": "Fix img width",
"type": "patch"
},
{
"description": "Fixed a bug in prompt tuning process",
"type": "patch"
},
{
"description": "Refactor text unit build at local search",
"type": "patch"
},
{
"description": "Update Prompt Tuning docs",
"type": "patch"
},
{
"description": "Update create_pipeline_config.py",
"type": "patch"
},
{
"description": "Update prompt tune command in docs",
"type": "patch"
},
{
"description": "add querying from azure blob storage",
"type": "patch"
},
{
"description": "fix setting base_dir to full paths when not using file system.",
"type": "patch"
},
{
"description": "fix strategy config in entity_extraction",
"type": "patch"
}
],
"created_at": "2024-09-10T19:51:24+00:00",
"version": "0.3.3"
}
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20251016174711346097.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Add comprehensive LLM usage tracking (tokens, calls, retries) for indexing stage with per-workflow breakdown"
}
43 changes: 43 additions & 0 deletions graphrag/index/run/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.pipeline import Pipeline
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.index.utils.llm_context import inject_llm_context
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.api import create_cache_from_config, create_storage_from_config
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
Expand Down Expand Up @@ -116,6 +117,12 @@ async def _run_pipeline(
logger.info("Executing pipeline...")
for name, workflow_function in pipeline.run():
last_workflow = name
# Set current workflow for LLM usage tracking
context.current_workflow = name

# Inject pipeline context for LLM usage tracking
inject_llm_context(context)

context.callbacks.workflow_start(name, None)
work_time = time.time()
result = await workflow_function(config, context)
Expand All @@ -124,12 +131,48 @@ async def _run_pipeline(
workflow=name, result=result.result, state=context.state, errors=None
)
context.stats.workflows[name] = {"overall": time.time() - work_time}

# Log LLM usage for this workflow if available
if name in context.stats.llm_usage_by_workflow:
usage = context.stats.llm_usage_by_workflow[name]
retry_part = (
f", {usage['retries']} retries"
if usage.get("retries", 0) > 0
else ""
)
logger.info(
"Workflow %s LLM usage: %d calls, %d prompt tokens, %d completion tokens%s",
name,
usage["llm_calls"],
usage["prompt_tokens"],
usage["completion_tokens"],
retry_part,
)

if result.stop:
logger.info("Halting pipeline at workflow request")
break

# Clear current workflow
context.current_workflow = None
context.stats.total_runtime = time.time() - start_time
logger.info("Indexing pipeline complete.")

# Log total LLM usage
if context.stats.total_llm_calls > 0:
retry_part = (
f", {context.stats.total_llm_retries} retries"
if context.stats.total_llm_retries > 0
else ""
)
logger.info(
"Total LLM usage: %d calls, %d prompt tokens, %d completion tokens%s",
context.stats.total_llm_calls,
context.stats.total_prompt_tokens,
context.stats.total_completion_tokens,
retry_part,
)

await _dump_json(context)

except Exception as e:
Expand Down
63 changes: 63 additions & 0 deletions graphrag/index/typing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,66 @@ class PipelineRunContext:
"Callbacks to be called during the pipeline run."
state: PipelineState
"Arbitrary property bag for runtime state, persistent pre-computes, or experimental features."
current_workflow: str | None = None
"Current workflow being executed (for LLM usage tracking)."

def record_llm_usage(
self,
llm_calls: int = 0,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> None:
"""
Record LLM usage for the current workflow.

Args
----
llm_calls: Number of LLM calls
prompt_tokens: Number of prompt tokens
completion_tokens: Number of completion tokens
"""
if self.current_workflow is None:
return

# Update totals
self.stats.total_llm_calls += llm_calls
self.stats.total_prompt_tokens += prompt_tokens
self.stats.total_completion_tokens += completion_tokens

# Update workflow-specific stats
if self.current_workflow not in self.stats.llm_usage_by_workflow:
self.stats.llm_usage_by_workflow[self.current_workflow] = {
"llm_calls": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"retries": 0,
}

workflow_stats = self.stats.llm_usage_by_workflow[self.current_workflow]
workflow_stats["llm_calls"] += llm_calls
workflow_stats["prompt_tokens"] += prompt_tokens
workflow_stats["completion_tokens"] += completion_tokens

def record_llm_retries(self, retry_count: int) -> None:
"""Record LLM retry attempts for the current workflow.

Args
----
retry_count: Number of retry attempts performed before a success.
"""
if self.current_workflow is None or retry_count <= 0:
return

# Update totals
self.stats.total_llm_retries += retry_count

# Update workflow-specific stats
if self.current_workflow not in self.stats.llm_usage_by_workflow:
self.stats.llm_usage_by_workflow[self.current_workflow] = {
"llm_calls": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"retries": 0,
}
workflow_stats = self.stats.llm_usage_by_workflow[self.current_workflow]
workflow_stats["retries"] += retry_count
25 changes: 25 additions & 0 deletions graphrag/index/typing/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,28 @@ class PipelineRunStats:

workflows: dict[str, dict[str, float]] = field(default_factory=dict)
"""A dictionary of workflows."""

total_llm_calls: int = field(default=0)
"""Total number of LLM calls across all workflows."""

total_prompt_tokens: int = field(default=0)
"""Total prompt tokens used across all workflows."""

total_completion_tokens: int = field(default=0)
"""Total completion tokens generated across all workflows."""

total_llm_retries: int = field(default=0)
"""Total number of LLM retry attempts across all workflows (sum of failed attempts before each success)."""

llm_usage_by_workflow: dict[str, dict[str, int]] = field(default_factory=dict)
"""LLM usage breakdown by workflow. Structure:
{
"extract_graph": {
"llm_calls": 10,
"prompt_tokens": 5000,
"completion_tokens": 2000,
"retries": 3
},
...
}
"""
40 changes: 40 additions & 0 deletions graphrag/index/utils/llm_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Helper utilities for LLM context injection in workflows."""

import logging
from typing import Any

logger = logging.getLogger(__name__)


def inject_llm_context(context: Any) -> None:
"""Inject pipeline context into ModelManager for LLM usage tracking.

This helper function sets up LLM usage tracking for the current workflow.

Args
----
context: The PipelineRunContext containing stats and workflow information.

Example
-------
from graphrag.index.utils.llm_context import inject_llm_context

async def run_workflow(config, context):
inject_llm_context(context) # Enable LLM usage tracking

Notes
-----
- This function is idempotent - calling it multiple times is safe
- Failures are logged but don't break workflows
- Only affects LLM models created via ModelManager
"""
from graphrag.language_model.manager import ModelManager

try:
ModelManager().set_pipeline_context(context)
except Exception as e: # noqa: BLE001
# Log warning but don't break workflow
logger.warning("Failed to inject LLM context into ModelManager: %s", e)
1 change: 1 addition & 0 deletions graphrag/index/workflows/extract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async def run_workflow(
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
logger.info("Workflow started: extract_graph")

text_units = await load_table_from_storage("text_units", context.output_storage)

extract_graph_llm_settings = config.get_language_model_config(
Expand Down
Loading
Loading