From e4fe69ebeb15abab01e2f3f7843beec056c3ed35 Mon Sep 17 00:00:00 2001 From: June Feng Date: Thu, 16 Oct 2025 10:22:37 -0700 Subject: [PATCH 1/4] Add LLM usage and retry tracking for indexing stage --- graphrag/index/run/run_pipeline.py | 40 ++++++++++++ graphrag/index/typing/context.py | 63 +++++++++++++++++++ graphrag/index/typing/stats.py | 25 ++++++++ graphrag/index/utils/llm_context.py | 40 ++++++++++++ graphrag/index/workflows/extract_graph.py | 1 + graphrag/language_model/manager.py | 38 +++++++++-- .../providers/litellm/chat_model.py | 57 ++++++++++++++--- .../providers/litellm/embedding_model.py | 46 ++++++++++++-- .../litellm/request_wrappers/with_retries.py | 4 +- .../services/retry/exponential_retry.py | 61 ++++++++++-------- .../services/retry/incremental_wait_retry.py | 61 ++++++++++-------- .../services/retry/native_wait_retry.py | 53 +++++++++------- .../services/retry/random_wait_retry.py | 61 ++++++++++-------- .../providers/litellm/services/retry/retry.py | 37 +++++++++-- 14 files changed, 462 insertions(+), 125 deletions(-) create mode 100644 graphrag/index/utils/llm_context.py diff --git a/graphrag/index/run/run_pipeline.py b/graphrag/index/run/run_pipeline.py index f652db7acd..017f7c1660 100644 --- a/graphrag/index/run/run_pipeline.py +++ b/graphrag/index/run/run_pipeline.py @@ -19,6 +19,7 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult +from graphrag.index.utils.llm_context import inject_llm_context from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.api import create_cache_from_config, create_storage_from_config from graphrag.utils.storage import load_table_from_storage, write_table_to_storage @@ -116,6 +117,12 @@ async def _run_pipeline( logger.info("Executing pipeline...") for name, workflow_function in pipeline.run(): last_workflow = name + # Set current workflow for LLM usage tracking + context.current_workflow = name + + # Inject pipeline context for LLM usage tracking + inject_llm_context(context) + context.callbacks.workflow_start(name, None) work_time = time.time() result = await workflow_function(config, context) @@ -124,12 +131,45 @@ async def _run_pipeline( workflow=name, result=result.result, state=context.state, errors=None ) context.stats.workflows[name] = {"overall": time.time() - work_time} + + # Log LLM usage for this workflow if available + if name in context.stats.llm_usage_by_workflow: + usage = context.stats.llm_usage_by_workflow[name] + retry_part = ( + f", {usage['retries']} retries" if usage.get("retries", 0) > 0 else "" + ) + logger.info( + f"Workflow {name} LLM usage: " + f"{usage['llm_calls']} calls, " + f"{usage['prompt_tokens']} prompt tokens, " + f"{usage['completion_tokens']} completion tokens" + f"{retry_part}" + ) + if result.stop: logger.info("Halting pipeline at workflow request") break + # Clear current workflow + context.current_workflow = None context.stats.total_runtime = time.time() - start_time logger.info("Indexing pipeline complete.") + + # Log total LLM usage + if context.stats.total_llm_calls > 0: + retry_part = ( + f", {context.stats.total_llm_retries} retries" + if context.stats.total_llm_retries > 0 + else "" + ) + logger.info( + f"Total LLM usage: " + f"{context.stats.total_llm_calls} calls, " + f"{context.stats.total_prompt_tokens} prompt tokens, " + f"{context.stats.total_completion_tokens} completion tokens" + f"{retry_part}" + ) + await _dump_json(context) except Exception as e: diff --git a/graphrag/index/typing/context.py b/graphrag/index/typing/context.py index ef2e1f7ea5..c3639ea9ce 100644 --- a/graphrag/index/typing/context.py +++ b/graphrag/index/typing/context.py @@ -30,3 +30,66 @@ class PipelineRunContext: "Callbacks to be called during the pipeline run." state: PipelineState "Arbitrary property bag for runtime state, persistent pre-computes, or experimental features." + current_workflow: str | None = None + "Current workflow being executed (for LLM usage tracking)." + + def record_llm_usage( + self, + llm_calls: int = 0, + prompt_tokens: int = 0, + completion_tokens: int = 0, + ) -> None: + """ + Record LLM usage for the current workflow. + + Args + ---- + llm_calls: Number of LLM calls + prompt_tokens: Number of prompt tokens + completion_tokens: Number of completion tokens + """ + if self.current_workflow is None: + return + + # Update totals + self.stats.total_llm_calls += llm_calls + self.stats.total_prompt_tokens += prompt_tokens + self.stats.total_completion_tokens += completion_tokens + + # Update workflow-specific stats + if self.current_workflow not in self.stats.llm_usage_by_workflow: + self.stats.llm_usage_by_workflow[self.current_workflow] = { + "llm_calls": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "retries": 0, + } + + workflow_stats = self.stats.llm_usage_by_workflow[self.current_workflow] + workflow_stats["llm_calls"] += llm_calls + workflow_stats["prompt_tokens"] += prompt_tokens + workflow_stats["completion_tokens"] += completion_tokens + + def record_llm_retries(self, retry_count: int) -> None: + """Record LLM retry attempts for the current workflow. + + Args + ---- + retry_count: Number of retry attempts performed before a success. + """ + if self.current_workflow is None or retry_count <= 0: + return + + # Update totals + self.stats.total_llm_retries += retry_count + + # Update workflow-specific stats + if self.current_workflow not in self.stats.llm_usage_by_workflow: + self.stats.llm_usage_by_workflow[self.current_workflow] = { + "llm_calls": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "retries": 0, + } + workflow_stats = self.stats.llm_usage_by_workflow[self.current_workflow] + workflow_stats["retries"] += retry_count diff --git a/graphrag/index/typing/stats.py b/graphrag/index/typing/stats.py index 271773600f..c3298a0957 100644 --- a/graphrag/index/typing/stats.py +++ b/graphrag/index/typing/stats.py @@ -23,3 +23,28 @@ class PipelineRunStats: workflows: dict[str, dict[str, float]] = field(default_factory=dict) """A dictionary of workflows.""" + + total_llm_calls: int = field(default=0) + """Total number of LLM calls across all workflows.""" + + total_prompt_tokens: int = field(default=0) + """Total prompt tokens used across all workflows.""" + + total_completion_tokens: int = field(default=0) + """Total completion tokens generated across all workflows.""" + + total_llm_retries: int = field(default=0) + """Total number of LLM retry attempts across all workflows (sum of failed attempts before each success).""" + + llm_usage_by_workflow: dict[str, dict[str, int]] = field(default_factory=dict) + """LLM usage breakdown by workflow. Structure: + { + "extract_graph": { + "llm_calls": 10, + "prompt_tokens": 5000, + "completion_tokens": 2000, + "retries": 3 + }, + ... + } + """ diff --git a/graphrag/index/utils/llm_context.py b/graphrag/index/utils/llm_context.py new file mode 100644 index 0000000000..765b273151 --- /dev/null +++ b/graphrag/index/utils/llm_context.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Helper utilities for LLM context injection in workflows.""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def inject_llm_context(context: Any) -> None: + """Inject pipeline context into ModelManager for LLM usage tracking. + + This helper function sets up LLM usage tracking for the current workflow. + + Args + ---- + context: The PipelineRunContext containing stats and workflow information. + + Example + ------- + from graphrag.index.utils.llm_context import inject_llm_context + + async def run_workflow(config, context): + inject_llm_context(context) # Enable LLM usage tracking + + Notes + ----- + - This function is idempotent - calling it multiple times is safe + - Failures are logged but don't break workflows + - Only affects LLM models created via ModelManager + """ + from graphrag.language_model.manager import ModelManager + + try: + ModelManager().set_pipeline_context(context) + except Exception as e: + # Log warning but don't break workflow + logger.warning("Failed to inject LLM context into ModelManager: %s", e) diff --git a/graphrag/index/workflows/extract_graph.py b/graphrag/index/workflows/extract_graph.py index 592502f6da..aff0c07e63 100644 --- a/graphrag/index/workflows/extract_graph.py +++ b/graphrag/index/workflows/extract_graph.py @@ -31,6 +31,7 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: extract_graph") + text_units = await load_table_from_storage("text_units", context.output_storage) extract_graph_llm_settings = config.get_language_model_config( diff --git a/graphrag/language_model/manager.py b/graphrag/language_model/manager.py index bc41235dda..baafcafe48 100644 --- a/graphrag/language_model/manager.py +++ b/graphrag/language_model/manager.py @@ -35,8 +35,24 @@ def __init__(self) -> None: if not hasattr(self, "_initialized"): self.chat_models: dict[str, ChatModel] = {} self.embedding_models: dict[str, EmbeddingModel] = {} + self._pipeline_context: Any = None # For LLM usage tracking self._initialized = True + def set_pipeline_context(self, context: Any) -> None: + """Set pipeline context for all models to enable LLM usage tracking.""" + self._pipeline_context = context + # Update existing models that support context injection + for model in self.chat_models.values(): + try: + getattr(model, "set_pipeline_context")(context) + except AttributeError: + pass + for model in self.embedding_models.values(): + try: + getattr(model, "set_pipeline_context")(context) + except AttributeError: + pass + @classmethod def get_instance(cls) -> ModelManager: """Return the singleton instance of LLMManager.""" @@ -54,9 +70,14 @@ def register_chat( **chat_kwargs: Additional parameters for instantiation. """ chat_kwargs["name"] = name - self.chat_models[name] = ModelFactory.create_chat_model( - model_type, **chat_kwargs - ) + model = ModelFactory.create_chat_model(model_type, **chat_kwargs) + # Inject pipeline context if available + if self._pipeline_context: + try: + getattr(model, "set_pipeline_context")(self._pipeline_context) + except AttributeError: + pass # Model doesn't support context injection + self.chat_models[name] = model return self.chat_models[name] def register_embedding( @@ -71,9 +92,14 @@ def register_embedding( **embedding_kwargs: Additional parameters for instantiation. """ embedding_kwargs["name"] = name - self.embedding_models[name] = ModelFactory.create_embedding_model( - model_type, **embedding_kwargs - ) + model = ModelFactory.create_embedding_model(model_type, **embedding_kwargs) + # Inject pipeline context if available + if self._pipeline_context: + try: + getattr(model, "set_pipeline_context")(self._pipeline_context) + except AttributeError: + pass # Model doesn't support context injection + self.embedding_models[name] = model return self.embedding_models[name] def get_chat_model(self, name: str) -> ChatModel | None: diff --git a/graphrag/language_model/providers/litellm/chat_model.py b/graphrag/language_model/providers/litellm/chat_model.py index 43cb7ece98..31555d7616 100644 --- a/graphrag/language_model/providers/litellm/chat_model.py +++ b/graphrag/language_model/providers/litellm/chat_model.py @@ -115,7 +115,7 @@ def _create_completions( model_config: "LanguageModelConfig", cache: "PipelineCache | None", cache_key_prefix: str, -) -> tuple[FixedModelCompletion, AFixedModelCompletion]: +) -> tuple[FixedModelCompletion, AFixedModelCompletion, Any | None]: """Wrap the base litellm completion function with the model configuration and additional features. Wrap the base litellm completion function with instance variables based on the model configuration. @@ -161,8 +161,9 @@ def _create_completions( tpm=tpm, ) + retry_service: Any | None = None if model_config.retry_strategy != "none": - completion, acompletion = with_retries( + completion, acompletion, retry_service = with_retries( sync_fn=completion, async_fn=acompletion, model_config=model_config, @@ -183,7 +184,7 @@ def _create_completions( async_fn=acompletion, ) - return (completion, acompletion) + return (completion, acompletion, retry_service) class LitellmModelOutput(BaseModel): @@ -221,9 +222,22 @@ def __init__( self.name = name self.config = config self.cache = cache.child(self.name) if cache else None - self.completion, self.acompletion = _create_completions( - config, self.cache, "chat" - ) + ( + self.completion, + self.acompletion, + self._retry_service, + ) = _create_completions(config, self.cache, "chat") + self._pipeline_context: Any = None # For LLM usage tracking + + def set_pipeline_context(self, context: Any) -> None: + """Set the pipeline context for LLM usage tracking.""" + self._pipeline_context = context + # Propagate into retry service if available + try: + if self._retry_service is not None: + getattr(self._retry_service, "set_pipeline_context")(context) + except AttributeError: + pass def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Get model arguments supported by litellm.""" @@ -285,6 +299,16 @@ async def achat( response = await self.acompletion(messages=messages, stream=False, **new_kwargs) # type: ignore + # Record LLM usage if pipeline context is available + if self._pipeline_context is not None and hasattr(response, "usage"): + usage = getattr(response, "usage", None) + if usage: + self._pipeline_context.record_llm_usage( + llm_calls=1, + prompt_tokens=getattr(usage, "prompt_tokens", 0), + completion_tokens=getattr(usage, "completion_tokens", 0), + ) + messages.append({ "role": "assistant", "content": response.choices[0].message.content or "", # type: ignore @@ -335,9 +359,28 @@ async def achat_stream( response = await self.acompletion(messages=messages, stream=True, **new_kwargs) # type: ignore + full_content = "" async for chunk in response: # type: ignore if chunk.choices and chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content + content = chunk.choices[0].delta.content + full_content += content + yield content + + # Record LLM usage for streaming (estimate using tokenizer) + if self._pipeline_context is not None and full_content: + from graphrag.tokenizer.get_tokenizer import get_tokenizer + + tokenizer = get_tokenizer(model_config=self.config) + # Calculate prompt text + prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + prompt_tokens = len(tokenizer.encode(prompt_text)) + completion_tokens = len(tokenizer.encode(full_content)) + + self._pipeline_context.record_llm_usage( + llm_calls=1, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) def chat(self, prompt: str, history: list | None = None, **kwargs: Any) -> "MR": """ diff --git a/graphrag/language_model/providers/litellm/embedding_model.py b/graphrag/language_model/providers/litellm/embedding_model.py index 779222c37e..9ec20f48ed 100644 --- a/graphrag/language_model/providers/litellm/embedding_model.py +++ b/graphrag/language_model/providers/litellm/embedding_model.py @@ -101,7 +101,7 @@ def _create_embeddings( model_config: "LanguageModelConfig", cache: "PipelineCache | None", cache_key_prefix: str, -) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]: +) -> tuple[FixedModelEmbedding, AFixedModelEmbedding, Any | None]: """Wrap the base litellm embedding function with the model configuration and additional features. Wrap the base litellm embedding function with instance variables based on the model configuration. @@ -147,8 +147,9 @@ def _create_embeddings( tpm=tpm, ) + retry_service: Any | None = None if model_config.retry_strategy != "none": - embedding, aembedding = with_retries( + embedding, aembedding, retry_service = with_retries( sync_fn=embedding, async_fn=aembedding, model_config=model_config, @@ -169,7 +170,7 @@ def _create_embeddings( async_fn=aembedding, ) - return (embedding, aembedding) + return (embedding, aembedding, retry_service) class LitellmEmbeddingModel: @@ -185,9 +186,22 @@ def __init__( self.name = name self.config = config self.cache = cache.child(self.name) if cache else None - self.embedding, self.aembedding = _create_embeddings( - config, self.cache, "embeddings" - ) + ( + self.embedding, + self.aembedding, + self._retry_service, + ) = _create_embeddings(config, self.cache, "embeddings") + self._pipeline_context: Any = None # For LLM usage tracking + + def set_pipeline_context(self, context: Any) -> None: + """Set the pipeline context for LLM usage tracking.""" + self._pipeline_context = context + # Propagate into retry service if available + try: + if self._retry_service is not None: + getattr(self._retry_service, "set_pipeline_context")(context) + except AttributeError: + pass def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Get model arguments supported by litellm.""" @@ -218,6 +232,16 @@ async def aembed_batch( new_kwargs = self._get_kwargs(**kwargs) response = await self.aembedding(input=text_list, **new_kwargs) + # Record LLM usage if pipeline context is available + if self._pipeline_context is not None and hasattr(response, "usage"): + usage = getattr(response, "usage", None) + if usage: + self._pipeline_context.record_llm_usage( + llm_calls=1, + prompt_tokens=getattr(usage, "prompt_tokens", 0), + completion_tokens=0, # embeddings don't have completion tokens + ) + return [emb.get("embedding", []) for emb in response.data] async def aembed(self, text: str, **kwargs: Any) -> list[float]: @@ -235,6 +259,16 @@ async def aembed(self, text: str, **kwargs: Any) -> list[float]: new_kwargs = self._get_kwargs(**kwargs) response = await self.aembedding(input=[text], **new_kwargs) + # Record LLM usage if pipeline context is available + if self._pipeline_context is not None and hasattr(response, "usage"): + usage = getattr(response, "usage", None) + if usage: + self._pipeline_context.record_llm_usage( + llm_calls=1, + prompt_tokens=getattr(usage, "prompt_tokens", 0), + completion_tokens=0, # embeddings don't have completion tokens + ) + return ( response.data[0].get("embedding", []) if response.data and response.data[0] diff --git a/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py b/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py index 1279f9e820..fa3952e912 100644 --- a/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py +++ b/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py @@ -22,7 +22,7 @@ def with_retries( sync_fn: LitellmRequestFunc, async_fn: AsyncLitellmRequestFunc, model_config: "LanguageModelConfig", -) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]: +) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc, Any]: """ Wrap the synchronous and asynchronous request functions with retries. @@ -51,4 +51,4 @@ async def _wrapped_with_retries_async( ) -> Any: return await retry_service.aretry(func=async_fn, **kwargs) - return (_wrapped_with_retries, _wrapped_with_retries_async) + return (_wrapped_with_retries, _wrapped_with_retries_async, retry_service) diff --git a/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py b/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py index e008322be0..36bb6c77f2 100644 --- a/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py @@ -26,6 +26,7 @@ def __init__( jitter: bool = True, **kwargs: Any, ): + super().__init__(**kwargs) if max_retries <= 0: msg = "max_retries must be greater than 0." raise ValueError(msg) @@ -42,21 +43,25 @@ def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: """Retry a synchronous function.""" retries = 0 delay = 1.0 # Initial delay in seconds - while True: - try: - return func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + result = func(**kwargs) + return result + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay *= self._base_delay logger.exception( - f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay *= self._base_delay - logger.exception( - f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - time.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311 + time.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311 + finally: + self._record_retries(retries) async def aretry( self, @@ -66,18 +71,22 @@ async def aretry( """Retry an asynchronous function.""" retries = 0 delay = 1.0 # Initial delay in seconds - while True: - try: - return await func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + result = await func(**kwargs) + return result + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay *= self._base_delay logger.exception( - f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay *= self._base_delay - logger.exception( - f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - await asyncio.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311 + await asyncio.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311 + finally: + self._record_retries(retries) diff --git a/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py b/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py index 97fbdbf9c9..97258e368e 100644 --- a/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py @@ -24,6 +24,7 @@ def __init__( max_retries: int = 5, **kwargs: Any, ): + super().__init__(**kwargs) if max_retries <= 0: msg = "max_retries must be greater than 0." raise ValueError(msg) @@ -40,21 +41,25 @@ def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: """Retry a synchronous function.""" retries = 0 delay = 0.0 - while True: - try: - return func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + result = func(**kwargs) + return result + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay += self._increment logger.exception( - f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay += self._increment - logger.exception( - f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - time.sleep(delay) + time.sleep(delay) + finally: + self._record_retries(retries) async def aretry( self, @@ -64,18 +69,22 @@ async def aretry( """Retry an asynchronous function.""" retries = 0 delay = 0.0 - while True: - try: - return await func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + result = await func(**kwargs) + return result + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay += self._increment logger.exception( - f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay += self._increment - logger.exception( - f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - await asyncio.sleep(delay) + await asyncio.sleep(delay) + finally: + self._record_retries(retries) diff --git a/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py b/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py index 088f454213..295f3b3832 100644 --- a/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py @@ -21,6 +21,7 @@ def __init__( max_retries: int = 5, **kwargs: Any, ): + super().__init__(**kwargs) if max_retries <= 0: msg = "max_retries must be greater than 0." raise ValueError(msg) @@ -30,19 +31,23 @@ def __init__( def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: """Retry a synchronous function.""" retries = 0 - while True: - try: - return func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + result = func(**kwargs) + return result + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 logger.exception( - f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - logger.exception( - f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) + finally: + self._record_retries(retries) async def aretry( self, @@ -51,16 +56,20 @@ async def aretry( ) -> Any: """Retry an asynchronous function.""" retries = 0 - while True: - try: - return await func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + result = await func(**kwargs) + return result + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 logger.exception( - f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - logger.exception( - f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) + finally: + self._record_retries(retries) diff --git a/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py b/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py index 603f439d1f..e931ddd01d 100644 --- a/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py @@ -25,6 +25,7 @@ def __init__( max_retries: int = 5, **kwargs: Any, ): + super().__init__(**kwargs) if max_retries <= 0: msg = "max_retries must be greater than 0." raise ValueError(msg) @@ -39,21 +40,25 @@ def __init__( def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: """Retry a synchronous function.""" retries = 0 - while True: - try: - return func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + result = func(**kwargs) + return result + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay = random.uniform(0, self._max_retry_wait) # noqa: S311 logger.exception( - f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay = random.uniform(0, self._max_retry_wait) # noqa: S311 - logger.exception( - f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - time.sleep(delay) + time.sleep(delay) + finally: + self._record_retries(retries) async def aretry( self, @@ -62,18 +67,22 @@ async def aretry( ) -> Any: """Retry an asynchronous function.""" retries = 0 - while True: - try: - return await func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + result = await func(**kwargs) + return result + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay = random.uniform(0, self._max_retry_wait) # noqa: S311 logger.exception( - f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay = random.uniform(0, self._max_retry_wait) # noqa: S311 - logger.exception( - f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - await asyncio.sleep(delay) + await asyncio.sleep(delay) + finally: + self._record_retries(retries) diff --git a/graphrag/language_model/providers/litellm/services/retry/retry.py b/graphrag/language_model/providers/litellm/services/retry/retry.py index 4f53e598c6..d1195ba30e 100644 --- a/graphrag/language_model/providers/litellm/services/retry/retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/retry.py @@ -9,12 +9,41 @@ class Retry(ABC): - """LiteLLM Retry Abstract Base Class.""" + """LiteLLM Retry Abstract Base Class. + + Added lightweight pipeline context support to allow retry implementations + to record retry counts. + """ - @abstractmethod def __init__(self, /, **kwargs: Any): - msg = "Retry subclasses must implement the __init__ method." - raise NotImplementedError(msg) + self._pipeline_context: Any | None = None + + def set_pipeline_context(self, context: Any) -> None: + """Inject pipeline context (optional). + + The context is expected to expose a `record_llm_retries(int)` method. If it does not, retry tracking + is silently skipped. + """ + self._pipeline_context = context + + def _record_retries(self, retry_count: int) -> None: + """Record retry attempts to pipeline context (if available). + + This is a protected method intended for use by subclasses in their + finally blocks to ensure retry counts are recorded regardless of + success or failure. + + Args + ---- + retry_count: Number of retry attempts performed. + + """ + if self._pipeline_context is not None and retry_count > 0: + try: + self._pipeline_context.record_llm_retries(retry_count) + except AttributeError: + # Context doesn't support retry tracking, skip silently + pass @abstractmethod def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: From 3021b3dad820ffac79e86cf2590ff70a90ae2555 Mon Sep 17 00:00:00 2001 From: June Feng Date: Thu, 16 Oct 2025 10:47:48 -0700 Subject: [PATCH 2/4] chore: Add semversioner changeset for LLM tracking feature --- .semversioner/next-release/minor-20251016174711346097.json | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .semversioner/next-release/minor-20251016174711346097.json diff --git a/.semversioner/next-release/minor-20251016174711346097.json b/.semversioner/next-release/minor-20251016174711346097.json new file mode 100644 index 0000000000..8bbc87b4ec --- /dev/null +++ b/.semversioner/next-release/minor-20251016174711346097.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Add comprehensive LLM usage tracking (tokens, calls, retries) for indexing stage with per-workflow breakdown" +} From 5f143a477dfbad3a8146ff3f4ea5021748ce9944 Mon Sep 17 00:00:00 2001 From: June Feng Date: Thu, 16 Oct 2025 11:40:23 -0700 Subject: [PATCH 3/4] code cleanup and fixing lint errors --- .semversioner/0.3.3.json | 130 +++++++++--------- graphrag/index/run/run_pipeline.py | 37 ++--- graphrag/index/utils/llm_context.py | 14 +- graphrag/index/workflows/extract_graph.py | 2 +- graphrag/language_model/manager.py | 20 +-- graphrag/language_model/protocol/base.py | 8 ++ .../providers/litellm/chat_model.py | 11 +- .../providers/litellm/embedding_model.py | 7 +- .../services/retry/exponential_retry.py | 6 +- .../services/retry/incremental_wait_retry.py | 6 +- .../services/retry/native_wait_retry.py | 6 +- .../services/retry/random_wait_retry.py | 6 +- .../providers/litellm/services/retry/retry.py | 6 +- 13 files changed, 121 insertions(+), 138 deletions(-) diff --git a/.semversioner/0.3.3.json b/.semversioner/0.3.3.json index 99726199c8..4b82516855 100644 --- a/.semversioner/0.3.3.json +++ b/.semversioner/0.3.3.json @@ -1,66 +1,66 @@ -{ - "changes": [ - { - "description": "Add entrypoints for incremental indexing", - "type": "patch" - }, - { - "description": "Clean up and organize run index code", - "type": "patch" - }, - { - "description": "Consistent config loading. Resolves #99 and Resolves #1049", - "type": "patch" - }, - { - "description": "Fix circular dependency when running prompt tune api directly", - "type": "patch" - }, - { - "description": "Fix default settings for embedding", - "type": "patch" - }, - { - "description": "Fix img for auto tune", - "type": "patch" - }, - { - "description": "Fix img width", - "type": "patch" - }, - { - "description": "Fixed a bug in prompt tuning process", - "type": "patch" - }, - { - "description": "Refactor text unit build at local search", - "type": "patch" - }, - { - "description": "Update Prompt Tuning docs", - "type": "patch" - }, - { - "description": "Update create_pipeline_config.py", - "type": "patch" - }, - { - "description": "Update prompt tune command in docs", - "type": "patch" - }, - { - "description": "add querying from azure blob storage", - "type": "patch" - }, - { - "description": "fix setting base_dir to full paths when not using file system.", - "type": "patch" - }, - { - "description": "fix strategy config in entity_extraction", - "type": "patch" - } - ], - "created_at": "2024-09-10T19:51:24+00:00", - "version": "0.3.3" +{ + "changes": [ + { + "description": "Add entrypoints for incremental indexing", + "type": "patch" + }, + { + "description": "Clean up and organize run index code", + "type": "patch" + }, + { + "description": "Consistent config loading. Resolves #99 and Resolves #1049", + "type": "patch" + }, + { + "description": "Fix circular dependency when running prompt tune api directly", + "type": "patch" + }, + { + "description": "Fix default settings for embedding", + "type": "patch" + }, + { + "description": "Fix img for auto tune", + "type": "patch" + }, + { + "description": "Fix img width", + "type": "patch" + }, + { + "description": "Fixed a bug in prompt tuning process", + "type": "patch" + }, + { + "description": "Refactor text unit build at local search", + "type": "patch" + }, + { + "description": "Update Prompt Tuning docs", + "type": "patch" + }, + { + "description": "Update create_pipeline_config.py", + "type": "patch" + }, + { + "description": "Update prompt tune command in docs", + "type": "patch" + }, + { + "description": "add querying from azure blob storage", + "type": "patch" + }, + { + "description": "fix setting base_dir to full paths when not using file system.", + "type": "patch" + }, + { + "description": "fix strategy config in entity_extraction", + "type": "patch" + } + ], + "created_at": "2024-09-10T19:51:24+00:00", + "version": "0.3.3" } \ No newline at end of file diff --git a/graphrag/index/run/run_pipeline.py b/graphrag/index/run/run_pipeline.py index 017f7c1660..dfd825ea4f 100644 --- a/graphrag/index/run/run_pipeline.py +++ b/graphrag/index/run/run_pipeline.py @@ -119,10 +119,10 @@ async def _run_pipeline( last_workflow = name # Set current workflow for LLM usage tracking context.current_workflow = name - + # Inject pipeline context for LLM usage tracking inject_llm_context(context) - + context.callbacks.workflow_start(name, None) work_time = time.time() result = await workflow_function(config, context) @@ -131,21 +131,24 @@ async def _run_pipeline( workflow=name, result=result.result, state=context.state, errors=None ) context.stats.workflows[name] = {"overall": time.time() - work_time} - + # Log LLM usage for this workflow if available if name in context.stats.llm_usage_by_workflow: usage = context.stats.llm_usage_by_workflow[name] retry_part = ( - f", {usage['retries']} retries" if usage.get("retries", 0) > 0 else "" + f", {usage['retries']} retries" + if usage.get("retries", 0) > 0 + else "" ) logger.info( - f"Workflow {name} LLM usage: " - f"{usage['llm_calls']} calls, " - f"{usage['prompt_tokens']} prompt tokens, " - f"{usage['completion_tokens']} completion tokens" - f"{retry_part}" + "Workflow %s LLM usage: %d calls, %d prompt tokens, %d completion tokens%s", + name, + usage["llm_calls"], + usage["prompt_tokens"], + usage["completion_tokens"], + retry_part, ) - + if result.stop: logger.info("Halting pipeline at workflow request") break @@ -154,7 +157,7 @@ async def _run_pipeline( context.current_workflow = None context.stats.total_runtime = time.time() - start_time logger.info("Indexing pipeline complete.") - + # Log total LLM usage if context.stats.total_llm_calls > 0: retry_part = ( @@ -163,13 +166,13 @@ async def _run_pipeline( else "" ) logger.info( - f"Total LLM usage: " - f"{context.stats.total_llm_calls} calls, " - f"{context.stats.total_prompt_tokens} prompt tokens, " - f"{context.stats.total_completion_tokens} completion tokens" - f"{retry_part}" + "Total LLM usage: %d calls, %d prompt tokens, %d completion tokens%s", + context.stats.total_llm_calls, + context.stats.total_prompt_tokens, + context.stats.total_completion_tokens, + retry_part, ) - + await _dump_json(context) except Exception as e: diff --git a/graphrag/index/utils/llm_context.py b/graphrag/index/utils/llm_context.py index 765b273151..dc8c9d5d03 100644 --- a/graphrag/index/utils/llm_context.py +++ b/graphrag/index/utils/llm_context.py @@ -11,20 +11,20 @@ def inject_llm_context(context: Any) -> None: """Inject pipeline context into ModelManager for LLM usage tracking. - + This helper function sets up LLM usage tracking for the current workflow. - + Args ---- context: The PipelineRunContext containing stats and workflow information. - + Example ------- from graphrag.index.utils.llm_context import inject_llm_context - + async def run_workflow(config, context): inject_llm_context(context) # Enable LLM usage tracking - + Notes ----- - This function is idempotent - calling it multiple times is safe @@ -32,9 +32,9 @@ async def run_workflow(config, context): - Only affects LLM models created via ModelManager """ from graphrag.language_model.manager import ModelManager - + try: ModelManager().set_pipeline_context(context) - except Exception as e: + except Exception as e: # noqa: BLE001 # Log warning but don't break workflow logger.warning("Failed to inject LLM context into ModelManager: %s", e) diff --git a/graphrag/index/workflows/extract_graph.py b/graphrag/index/workflows/extract_graph.py index aff0c07e63..62f13aa0f8 100644 --- a/graphrag/index/workflows/extract_graph.py +++ b/graphrag/index/workflows/extract_graph.py @@ -31,7 +31,7 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: extract_graph") - + text_units = await load_table_from_storage("text_units", context.output_storage) extract_graph_llm_settings = config.get_language_model_config( diff --git a/graphrag/language_model/manager.py b/graphrag/language_model/manager.py index baafcafe48..4b19bc6b4d 100644 --- a/graphrag/language_model/manager.py +++ b/graphrag/language_model/manager.py @@ -43,15 +43,9 @@ def set_pipeline_context(self, context: Any) -> None: self._pipeline_context = context # Update existing models that support context injection for model in self.chat_models.values(): - try: - getattr(model, "set_pipeline_context")(context) - except AttributeError: - pass + model.set_pipeline_context(context) for model in self.embedding_models.values(): - try: - getattr(model, "set_pipeline_context")(context) - except AttributeError: - pass + model.set_pipeline_context(context) @classmethod def get_instance(cls) -> ModelManager: @@ -73,10 +67,7 @@ def register_chat( model = ModelFactory.create_chat_model(model_type, **chat_kwargs) # Inject pipeline context if available if self._pipeline_context: - try: - getattr(model, "set_pipeline_context")(self._pipeline_context) - except AttributeError: - pass # Model doesn't support context injection + model.set_pipeline_context(self._pipeline_context) self.chat_models[name] = model return self.chat_models[name] @@ -95,10 +86,7 @@ def register_embedding( model = ModelFactory.create_embedding_model(model_type, **embedding_kwargs) # Inject pipeline context if available if self._pipeline_context: - try: - getattr(model, "set_pipeline_context")(self._pipeline_context) - except AttributeError: - pass # Model doesn't support context injection + model.set_pipeline_context(self._pipeline_context) self.embedding_models[name] = model return self.embedding_models[name] diff --git a/graphrag/language_model/protocol/base.py b/graphrag/language_model/protocol/base.py index 74cd38746e..f4c47b4760 100644 --- a/graphrag/language_model/protocol/base.py +++ b/graphrag/language_model/protocol/base.py @@ -24,6 +24,10 @@ class EmbeddingModel(Protocol): config: LanguageModelConfig """Passthrough of the config used to create the model instance.""" + def set_pipeline_context(self, context: Any) -> None: + """Set pipeline context for LLM usage tracking (optional).""" + ... + async def aembed_batch( self, text_list: list[str], **kwargs: Any ) -> list[list[float]]: @@ -94,6 +98,10 @@ class ChatModel(Protocol): config: LanguageModelConfig """Passthrough of the config used to create the model instance.""" + def set_pipeline_context(self, context: Any) -> None: + """Set pipeline context for LLM usage tracking (optional).""" + ... + async def achat( self, prompt: str, history: list | None = None, **kwargs: Any ) -> ModelResponse: diff --git a/graphrag/language_model/providers/litellm/chat_model.py b/graphrag/language_model/providers/litellm/chat_model.py index 31555d7616..c7aaf003fd 100644 --- a/graphrag/language_model/providers/litellm/chat_model.py +++ b/graphrag/language_model/providers/litellm/chat_model.py @@ -233,11 +233,8 @@ def set_pipeline_context(self, context: Any) -> None: """Set the pipeline context for LLM usage tracking.""" self._pipeline_context = context # Propagate into retry service if available - try: - if self._retry_service is not None: - getattr(self._retry_service, "set_pipeline_context")(context) - except AttributeError: - pass + if self._retry_service is not None: + self._retry_service.set_pipeline_context(context) def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Get model arguments supported by litellm.""" @@ -372,7 +369,9 @@ async def achat_stream( tokenizer = get_tokenizer(model_config=self.config) # Calculate prompt text - prompt_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + prompt_text = "\n".join([ + f"{msg['role']}: {msg['content']}" for msg in messages + ]) prompt_tokens = len(tokenizer.encode(prompt_text)) completion_tokens = len(tokenizer.encode(full_content)) diff --git a/graphrag/language_model/providers/litellm/embedding_model.py b/graphrag/language_model/providers/litellm/embedding_model.py index 9ec20f48ed..34e506aaa6 100644 --- a/graphrag/language_model/providers/litellm/embedding_model.py +++ b/graphrag/language_model/providers/litellm/embedding_model.py @@ -197,11 +197,8 @@ def set_pipeline_context(self, context: Any) -> None: """Set the pipeline context for LLM usage tracking.""" self._pipeline_context = context # Propagate into retry service if available - try: - if self._retry_service is not None: - getattr(self._retry_service, "set_pipeline_context")(context) - except AttributeError: - pass + if self._retry_service is not None: + self._retry_service.set_pipeline_context(context) def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Get model arguments supported by litellm.""" diff --git a/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py b/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py index 36bb6c77f2..9769ac128f 100644 --- a/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py @@ -46,8 +46,7 @@ def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: try: while True: try: - result = func(**kwargs) - return result + return func(**kwargs) except Exception as e: if retries >= self._max_retries: logger.exception( @@ -74,8 +73,7 @@ async def aretry( try: while True: try: - result = await func(**kwargs) - return result + return await func(**kwargs) except Exception as e: if retries >= self._max_retries: logger.exception( diff --git a/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py b/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py index 97258e368e..0a839fb170 100644 --- a/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py @@ -44,8 +44,7 @@ def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: try: while True: try: - result = func(**kwargs) - return result + return func(**kwargs) except Exception as e: if retries >= self._max_retries: logger.exception( @@ -72,8 +71,7 @@ async def aretry( try: while True: try: - result = await func(**kwargs) - return result + return await func(**kwargs) except Exception as e: if retries >= self._max_retries: logger.exception( diff --git a/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py b/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py index 295f3b3832..d4d84f9d45 100644 --- a/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py @@ -34,8 +34,7 @@ def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: try: while True: try: - result = func(**kwargs) - return result + return func(**kwargs) except Exception as e: if retries >= self._max_retries: logger.exception( @@ -59,8 +58,7 @@ async def aretry( try: while True: try: - result = await func(**kwargs) - return result + return await func(**kwargs) except Exception as e: if retries >= self._max_retries: logger.exception( diff --git a/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py b/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py index e931ddd01d..b4759b9700 100644 --- a/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py @@ -43,8 +43,7 @@ def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: try: while True: try: - result = func(**kwargs) - return result + return func(**kwargs) except Exception as e: if retries >= self._max_retries: logger.exception( @@ -70,8 +69,7 @@ async def aretry( try: while True: try: - result = await func(**kwargs) - return result + return await func(**kwargs) except Exception as e: if retries >= self._max_retries: logger.exception( diff --git a/graphrag/language_model/providers/litellm/services/retry/retry.py b/graphrag/language_model/providers/litellm/services/retry/retry.py index d1195ba30e..360ef9375b 100644 --- a/graphrag/language_model/providers/litellm/services/retry/retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/retry.py @@ -39,11 +39,7 @@ def _record_retries(self, retry_count: int) -> None: """ if self._pipeline_context is not None and retry_count > 0: - try: - self._pipeline_context.record_llm_retries(retry_count) - except AttributeError: - # Context doesn't support retry tracking, skip silently - pass + self._pipeline_context.record_llm_retries(retry_count) @abstractmethod def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: From e720edd620ef9a9c27c57cd840a9ee1bd3351f75 Mon Sep 17 00:00:00 2001 From: June Feng Date: Thu, 16 Oct 2025 11:53:13 -0700 Subject: [PATCH 4/4] fix hasattr access with linting requirements --- graphrag/language_model/manager.py | 14 ++++++++------ graphrag/language_model/protocol/base.py | 8 -------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/graphrag/language_model/manager.py b/graphrag/language_model/manager.py index 4b19bc6b4d..a25a399366 100644 --- a/graphrag/language_model/manager.py +++ b/graphrag/language_model/manager.py @@ -43,9 +43,11 @@ def set_pipeline_context(self, context: Any) -> None: self._pipeline_context = context # Update existing models that support context injection for model in self.chat_models.values(): - model.set_pipeline_context(context) + if hasattr(model, "set_pipeline_context"): + model.set_pipeline_context(context) # type: ignore[attr-defined] for model in self.embedding_models.values(): - model.set_pipeline_context(context) + if hasattr(model, "set_pipeline_context"): + model.set_pipeline_context(context) # type: ignore[attr-defined] @classmethod def get_instance(cls) -> ModelManager: @@ -66,8 +68,8 @@ def register_chat( chat_kwargs["name"] = name model = ModelFactory.create_chat_model(model_type, **chat_kwargs) # Inject pipeline context if available - if self._pipeline_context: - model.set_pipeline_context(self._pipeline_context) + if self._pipeline_context and hasattr(model, "set_pipeline_context"): + model.set_pipeline_context(self._pipeline_context) # type: ignore[attr-defined] self.chat_models[name] = model return self.chat_models[name] @@ -85,8 +87,8 @@ def register_embedding( embedding_kwargs["name"] = name model = ModelFactory.create_embedding_model(model_type, **embedding_kwargs) # Inject pipeline context if available - if self._pipeline_context: - model.set_pipeline_context(self._pipeline_context) + if self._pipeline_context and hasattr(model, "set_pipeline_context"): + model.set_pipeline_context(self._pipeline_context) # type: ignore[attr-defined] self.embedding_models[name] = model return self.embedding_models[name] diff --git a/graphrag/language_model/protocol/base.py b/graphrag/language_model/protocol/base.py index f4c47b4760..74cd38746e 100644 --- a/graphrag/language_model/protocol/base.py +++ b/graphrag/language_model/protocol/base.py @@ -24,10 +24,6 @@ class EmbeddingModel(Protocol): config: LanguageModelConfig """Passthrough of the config used to create the model instance.""" - def set_pipeline_context(self, context: Any) -> None: - """Set pipeline context for LLM usage tracking (optional).""" - ... - async def aembed_batch( self, text_list: list[str], **kwargs: Any ) -> list[list[float]]: @@ -98,10 +94,6 @@ class ChatModel(Protocol): config: LanguageModelConfig """Passthrough of the config used to create the model instance.""" - def set_pipeline_context(self, context: Any) -> None: - """Set pipeline context for LLM usage tracking (optional).""" - ... - async def achat( self, prompt: str, history: list | None = None, **kwargs: Any ) -> ModelResponse: