diff --git a/.semversioner/0.3.3.json b/.semversioner/0.3.3.json index 99726199c..4b8251685 100644 --- a/.semversioner/0.3.3.json +++ b/.semversioner/0.3.3.json @@ -1,66 +1,66 @@ -{ - "changes": [ - { - "description": "Add entrypoints for incremental indexing", - "type": "patch" - }, - { - "description": "Clean up and organize run index code", - "type": "patch" - }, - { - "description": "Consistent config loading. Resolves #99 and Resolves #1049", - "type": "patch" - }, - { - "description": "Fix circular dependency when running prompt tune api directly", - "type": "patch" - }, - { - "description": "Fix default settings for embedding", - "type": "patch" - }, - { - "description": "Fix img for auto tune", - "type": "patch" - }, - { - "description": "Fix img width", - "type": "patch" - }, - { - "description": "Fixed a bug in prompt tuning process", - "type": "patch" - }, - { - "description": "Refactor text unit build at local search", - "type": "patch" - }, - { - "description": "Update Prompt Tuning docs", - "type": "patch" - }, - { - "description": "Update create_pipeline_config.py", - "type": "patch" - }, - { - "description": "Update prompt tune command in docs", - "type": "patch" - }, - { - "description": "add querying from azure blob storage", - "type": "patch" - }, - { - "description": "fix setting base_dir to full paths when not using file system.", - "type": "patch" - }, - { - "description": "fix strategy config in entity_extraction", - "type": "patch" - } - ], - "created_at": "2024-09-10T19:51:24+00:00", - "version": "0.3.3" +{ + "changes": [ + { + "description": "Add entrypoints for incremental indexing", + "type": "patch" + }, + { + "description": "Clean up and organize run index code", + "type": "patch" + }, + { + "description": "Consistent config loading. Resolves #99 and Resolves #1049", + "type": "patch" + }, + { + "description": "Fix circular dependency when running prompt tune api directly", + "type": "patch" + }, + { + "description": "Fix default settings for embedding", + "type": "patch" + }, + { + "description": "Fix img for auto tune", + "type": "patch" + }, + { + "description": "Fix img width", + "type": "patch" + }, + { + "description": "Fixed a bug in prompt tuning process", + "type": "patch" + }, + { + "description": "Refactor text unit build at local search", + "type": "patch" + }, + { + "description": "Update Prompt Tuning docs", + "type": "patch" + }, + { + "description": "Update create_pipeline_config.py", + "type": "patch" + }, + { + "description": "Update prompt tune command in docs", + "type": "patch" + }, + { + "description": "add querying from azure blob storage", + "type": "patch" + }, + { + "description": "fix setting base_dir to full paths when not using file system.", + "type": "patch" + }, + { + "description": "fix strategy config in entity_extraction", + "type": "patch" + } + ], + "created_at": "2024-09-10T19:51:24+00:00", + "version": "0.3.3" } \ No newline at end of file diff --git a/.semversioner/next-release/minor-20251016174711346097.json b/.semversioner/next-release/minor-20251016174711346097.json new file mode 100644 index 000000000..8bbc87b4e --- /dev/null +++ b/.semversioner/next-release/minor-20251016174711346097.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Add comprehensive LLM usage tracking (tokens, calls, retries) for indexing stage with per-workflow breakdown" +} diff --git a/graphrag/index/run/run_pipeline.py b/graphrag/index/run/run_pipeline.py index f652db7ac..dfd825ea4 100644 --- a/graphrag/index/run/run_pipeline.py +++ b/graphrag/index/run/run_pipeline.py @@ -19,6 +19,7 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult +from graphrag.index.utils.llm_context import inject_llm_context from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.api import create_cache_from_config, create_storage_from_config from graphrag.utils.storage import load_table_from_storage, write_table_to_storage @@ -116,6 +117,12 @@ async def _run_pipeline( logger.info("Executing pipeline...") for name, workflow_function in pipeline.run(): last_workflow = name + # Set current workflow for LLM usage tracking + context.current_workflow = name + + # Inject pipeline context for LLM usage tracking + inject_llm_context(context) + context.callbacks.workflow_start(name, None) work_time = time.time() result = await workflow_function(config, context) @@ -124,12 +131,48 @@ async def _run_pipeline( workflow=name, result=result.result, state=context.state, errors=None ) context.stats.workflows[name] = {"overall": time.time() - work_time} + + # Log LLM usage for this workflow if available + if name in context.stats.llm_usage_by_workflow: + usage = context.stats.llm_usage_by_workflow[name] + retry_part = ( + f", {usage['retries']} retries" + if usage.get("retries", 0) > 0 + else "" + ) + logger.info( + "Workflow %s LLM usage: %d calls, %d prompt tokens, %d completion tokens%s", + name, + usage["llm_calls"], + usage["prompt_tokens"], + usage["completion_tokens"], + retry_part, + ) + if result.stop: logger.info("Halting pipeline at workflow request") break + # Clear current workflow + context.current_workflow = None context.stats.total_runtime = time.time() - start_time logger.info("Indexing pipeline complete.") + + # Log total LLM usage + if context.stats.total_llm_calls > 0: + retry_part = ( + f", {context.stats.total_llm_retries} retries" + if context.stats.total_llm_retries > 0 + else "" + ) + logger.info( + "Total LLM usage: %d calls, %d prompt tokens, %d completion tokens%s", + context.stats.total_llm_calls, + context.stats.total_prompt_tokens, + context.stats.total_completion_tokens, + retry_part, + ) + await _dump_json(context) except Exception as e: diff --git a/graphrag/index/typing/context.py b/graphrag/index/typing/context.py index ef2e1f7ea..c3639ea9c 100644 --- a/graphrag/index/typing/context.py +++ b/graphrag/index/typing/context.py @@ -30,3 +30,66 @@ class PipelineRunContext: "Callbacks to be called during the pipeline run." state: PipelineState "Arbitrary property bag for runtime state, persistent pre-computes, or experimental features." + current_workflow: str | None = None + "Current workflow being executed (for LLM usage tracking)." + + def record_llm_usage( + self, + llm_calls: int = 0, + prompt_tokens: int = 0, + completion_tokens: int = 0, + ) -> None: + """ + Record LLM usage for the current workflow. + + Args + ---- + llm_calls: Number of LLM calls + prompt_tokens: Number of prompt tokens + completion_tokens: Number of completion tokens + """ + if self.current_workflow is None: + return + + # Update totals + self.stats.total_llm_calls += llm_calls + self.stats.total_prompt_tokens += prompt_tokens + self.stats.total_completion_tokens += completion_tokens + + # Update workflow-specific stats + if self.current_workflow not in self.stats.llm_usage_by_workflow: + self.stats.llm_usage_by_workflow[self.current_workflow] = { + "llm_calls": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "retries": 0, + } + + workflow_stats = self.stats.llm_usage_by_workflow[self.current_workflow] + workflow_stats["llm_calls"] += llm_calls + workflow_stats["prompt_tokens"] += prompt_tokens + workflow_stats["completion_tokens"] += completion_tokens + + def record_llm_retries(self, retry_count: int) -> None: + """Record LLM retry attempts for the current workflow. + + Args + ---- + retry_count: Number of retry attempts performed before a success. + """ + if self.current_workflow is None or retry_count <= 0: + return + + # Update totals + self.stats.total_llm_retries += retry_count + + # Update workflow-specific stats + if self.current_workflow not in self.stats.llm_usage_by_workflow: + self.stats.llm_usage_by_workflow[self.current_workflow] = { + "llm_calls": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "retries": 0, + } + workflow_stats = self.stats.llm_usage_by_workflow[self.current_workflow] + workflow_stats["retries"] += retry_count diff --git a/graphrag/index/typing/stats.py b/graphrag/index/typing/stats.py index 271773600..c3298a095 100644 --- a/graphrag/index/typing/stats.py +++ b/graphrag/index/typing/stats.py @@ -23,3 +23,28 @@ class PipelineRunStats: workflows: dict[str, dict[str, float]] = field(default_factory=dict) """A dictionary of workflows.""" + + total_llm_calls: int = field(default=0) + """Total number of LLM calls across all workflows.""" + + total_prompt_tokens: int = field(default=0) + """Total prompt tokens used across all workflows.""" + + total_completion_tokens: int = field(default=0) + """Total completion tokens generated across all workflows.""" + + total_llm_retries: int = field(default=0) + """Total number of LLM retry attempts across all workflows (sum of failed attempts before each success).""" + + llm_usage_by_workflow: dict[str, dict[str, int]] = field(default_factory=dict) + """LLM usage breakdown by workflow. Structure: + { + "extract_graph": { + "llm_calls": 10, + "prompt_tokens": 5000, + "completion_tokens": 2000, + "retries": 3 + }, + ... + } + """ diff --git a/graphrag/index/utils/llm_context.py b/graphrag/index/utils/llm_context.py new file mode 100644 index 000000000..dc8c9d5d0 --- /dev/null +++ b/graphrag/index/utils/llm_context.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Helper utilities for LLM context injection in workflows.""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def inject_llm_context(context: Any) -> None: + """Inject pipeline context into ModelManager for LLM usage tracking. + + This helper function sets up LLM usage tracking for the current workflow. + + Args + ---- + context: The PipelineRunContext containing stats and workflow information. + + Example + ------- + from graphrag.index.utils.llm_context import inject_llm_context + + async def run_workflow(config, context): + inject_llm_context(context) # Enable LLM usage tracking + + Notes + ----- + - This function is idempotent - calling it multiple times is safe + - Failures are logged but don't break workflows + - Only affects LLM models created via ModelManager + """ + from graphrag.language_model.manager import ModelManager + + try: + ModelManager().set_pipeline_context(context) + except Exception as e: # noqa: BLE001 + # Log warning but don't break workflow + logger.warning("Failed to inject LLM context into ModelManager: %s", e) diff --git a/graphrag/index/workflows/extract_graph.py b/graphrag/index/workflows/extract_graph.py index 592502f6d..62f13aa0f 100644 --- a/graphrag/index/workflows/extract_graph.py +++ b/graphrag/index/workflows/extract_graph.py @@ -31,6 +31,7 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: extract_graph") + text_units = await load_table_from_storage("text_units", context.output_storage) extract_graph_llm_settings = config.get_language_model_config( diff --git a/graphrag/language_model/manager.py b/graphrag/language_model/manager.py index bc41235dd..a25a39936 100644 --- a/graphrag/language_model/manager.py +++ b/graphrag/language_model/manager.py @@ -35,8 +35,20 @@ def __init__(self) -> None: if not hasattr(self, "_initialized"): self.chat_models: dict[str, ChatModel] = {} self.embedding_models: dict[str, EmbeddingModel] = {} + self._pipeline_context: Any = None # For LLM usage tracking self._initialized = True + def set_pipeline_context(self, context: Any) -> None: + """Set pipeline context for all models to enable LLM usage tracking.""" + self._pipeline_context = context + # Update existing models that support context injection + for model in self.chat_models.values(): + if hasattr(model, "set_pipeline_context"): + model.set_pipeline_context(context) # type: ignore[attr-defined] + for model in self.embedding_models.values(): + if hasattr(model, "set_pipeline_context"): + model.set_pipeline_context(context) # type: ignore[attr-defined] + @classmethod def get_instance(cls) -> ModelManager: """Return the singleton instance of LLMManager.""" @@ -54,9 +66,11 @@ def register_chat( **chat_kwargs: Additional parameters for instantiation. """ chat_kwargs["name"] = name - self.chat_models[name] = ModelFactory.create_chat_model( - model_type, **chat_kwargs - ) + model = ModelFactory.create_chat_model(model_type, **chat_kwargs) + # Inject pipeline context if available + if self._pipeline_context and hasattr(model, "set_pipeline_context"): + model.set_pipeline_context(self._pipeline_context) # type: ignore[attr-defined] + self.chat_models[name] = model return self.chat_models[name] def register_embedding( @@ -71,9 +85,11 @@ def register_embedding( **embedding_kwargs: Additional parameters for instantiation. """ embedding_kwargs["name"] = name - self.embedding_models[name] = ModelFactory.create_embedding_model( - model_type, **embedding_kwargs - ) + model = ModelFactory.create_embedding_model(model_type, **embedding_kwargs) + # Inject pipeline context if available + if self._pipeline_context and hasattr(model, "set_pipeline_context"): + model.set_pipeline_context(self._pipeline_context) # type: ignore[attr-defined] + self.embedding_models[name] = model return self.embedding_models[name] def get_chat_model(self, name: str) -> ChatModel | None: diff --git a/graphrag/language_model/providers/litellm/chat_model.py b/graphrag/language_model/providers/litellm/chat_model.py index 43cb7ece9..c7aaf003f 100644 --- a/graphrag/language_model/providers/litellm/chat_model.py +++ b/graphrag/language_model/providers/litellm/chat_model.py @@ -115,7 +115,7 @@ def _create_completions( model_config: "LanguageModelConfig", cache: "PipelineCache | None", cache_key_prefix: str, -) -> tuple[FixedModelCompletion, AFixedModelCompletion]: +) -> tuple[FixedModelCompletion, AFixedModelCompletion, Any | None]: """Wrap the base litellm completion function with the model configuration and additional features. Wrap the base litellm completion function with instance variables based on the model configuration. @@ -161,8 +161,9 @@ def _create_completions( tpm=tpm, ) + retry_service: Any | None = None if model_config.retry_strategy != "none": - completion, acompletion = with_retries( + completion, acompletion, retry_service = with_retries( sync_fn=completion, async_fn=acompletion, model_config=model_config, @@ -183,7 +184,7 @@ def _create_completions( async_fn=acompletion, ) - return (completion, acompletion) + return (completion, acompletion, retry_service) class LitellmModelOutput(BaseModel): @@ -221,9 +222,19 @@ def __init__( self.name = name self.config = config self.cache = cache.child(self.name) if cache else None - self.completion, self.acompletion = _create_completions( - config, self.cache, "chat" - ) + ( + self.completion, + self.acompletion, + self._retry_service, + ) = _create_completions(config, self.cache, "chat") + self._pipeline_context: Any = None # For LLM usage tracking + + def set_pipeline_context(self, context: Any) -> None: + """Set the pipeline context for LLM usage tracking.""" + self._pipeline_context = context + # Propagate into retry service if available + if self._retry_service is not None: + self._retry_service.set_pipeline_context(context) def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Get model arguments supported by litellm.""" @@ -285,6 +296,16 @@ async def achat( response = await self.acompletion(messages=messages, stream=False, **new_kwargs) # type: ignore + # Record LLM usage if pipeline context is available + if self._pipeline_context is not None and hasattr(response, "usage"): + usage = getattr(response, "usage", None) + if usage: + self._pipeline_context.record_llm_usage( + llm_calls=1, + prompt_tokens=getattr(usage, "prompt_tokens", 0), + completion_tokens=getattr(usage, "completion_tokens", 0), + ) + messages.append({ "role": "assistant", "content": response.choices[0].message.content or "", # type: ignore @@ -335,9 +356,30 @@ async def achat_stream( response = await self.acompletion(messages=messages, stream=True, **new_kwargs) # type: ignore + full_content = "" async for chunk in response: # type: ignore if chunk.choices and chunk.choices[0].delta.content: - yield chunk.choices[0].delta.content + content = chunk.choices[0].delta.content + full_content += content + yield content + + # Record LLM usage for streaming (estimate using tokenizer) + if self._pipeline_context is not None and full_content: + from graphrag.tokenizer.get_tokenizer import get_tokenizer + + tokenizer = get_tokenizer(model_config=self.config) + # Calculate prompt text + prompt_text = "\n".join([ + f"{msg['role']}: {msg['content']}" for msg in messages + ]) + prompt_tokens = len(tokenizer.encode(prompt_text)) + completion_tokens = len(tokenizer.encode(full_content)) + + self._pipeline_context.record_llm_usage( + llm_calls=1, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) def chat(self, prompt: str, history: list | None = None, **kwargs: Any) -> "MR": """ diff --git a/graphrag/language_model/providers/litellm/embedding_model.py b/graphrag/language_model/providers/litellm/embedding_model.py index 779222c37..34e506aaa 100644 --- a/graphrag/language_model/providers/litellm/embedding_model.py +++ b/graphrag/language_model/providers/litellm/embedding_model.py @@ -101,7 +101,7 @@ def _create_embeddings( model_config: "LanguageModelConfig", cache: "PipelineCache | None", cache_key_prefix: str, -) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]: +) -> tuple[FixedModelEmbedding, AFixedModelEmbedding, Any | None]: """Wrap the base litellm embedding function with the model configuration and additional features. Wrap the base litellm embedding function with instance variables based on the model configuration. @@ -147,8 +147,9 @@ def _create_embeddings( tpm=tpm, ) + retry_service: Any | None = None if model_config.retry_strategy != "none": - embedding, aembedding = with_retries( + embedding, aembedding, retry_service = with_retries( sync_fn=embedding, async_fn=aembedding, model_config=model_config, @@ -169,7 +170,7 @@ def _create_embeddings( async_fn=aembedding, ) - return (embedding, aembedding) + return (embedding, aembedding, retry_service) class LitellmEmbeddingModel: @@ -185,9 +186,19 @@ def __init__( self.name = name self.config = config self.cache = cache.child(self.name) if cache else None - self.embedding, self.aembedding = _create_embeddings( - config, self.cache, "embeddings" - ) + ( + self.embedding, + self.aembedding, + self._retry_service, + ) = _create_embeddings(config, self.cache, "embeddings") + self._pipeline_context: Any = None # For LLM usage tracking + + def set_pipeline_context(self, context: Any) -> None: + """Set the pipeline context for LLM usage tracking.""" + self._pipeline_context = context + # Propagate into retry service if available + if self._retry_service is not None: + self._retry_service.set_pipeline_context(context) def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]: """Get model arguments supported by litellm.""" @@ -218,6 +229,16 @@ async def aembed_batch( new_kwargs = self._get_kwargs(**kwargs) response = await self.aembedding(input=text_list, **new_kwargs) + # Record LLM usage if pipeline context is available + if self._pipeline_context is not None and hasattr(response, "usage"): + usage = getattr(response, "usage", None) + if usage: + self._pipeline_context.record_llm_usage( + llm_calls=1, + prompt_tokens=getattr(usage, "prompt_tokens", 0), + completion_tokens=0, # embeddings don't have completion tokens + ) + return [emb.get("embedding", []) for emb in response.data] async def aembed(self, text: str, **kwargs: Any) -> list[float]: @@ -235,6 +256,16 @@ async def aembed(self, text: str, **kwargs: Any) -> list[float]: new_kwargs = self._get_kwargs(**kwargs) response = await self.aembedding(input=[text], **new_kwargs) + # Record LLM usage if pipeline context is available + if self._pipeline_context is not None and hasattr(response, "usage"): + usage = getattr(response, "usage", None) + if usage: + self._pipeline_context.record_llm_usage( + llm_calls=1, + prompt_tokens=getattr(usage, "prompt_tokens", 0), + completion_tokens=0, # embeddings don't have completion tokens + ) + return ( response.data[0].get("embedding", []) if response.data and response.data[0] diff --git a/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py b/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py index 1279f9e82..fa3952e91 100644 --- a/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py +++ b/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py @@ -22,7 +22,7 @@ def with_retries( sync_fn: LitellmRequestFunc, async_fn: AsyncLitellmRequestFunc, model_config: "LanguageModelConfig", -) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]: +) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc, Any]: """ Wrap the synchronous and asynchronous request functions with retries. @@ -51,4 +51,4 @@ async def _wrapped_with_retries_async( ) -> Any: return await retry_service.aretry(func=async_fn, **kwargs) - return (_wrapped_with_retries, _wrapped_with_retries_async) + return (_wrapped_with_retries, _wrapped_with_retries_async, retry_service) diff --git a/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py b/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py index e008322be..9769ac128 100644 --- a/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/exponential_retry.py @@ -26,6 +26,7 @@ def __init__( jitter: bool = True, **kwargs: Any, ): + super().__init__(**kwargs) if max_retries <= 0: msg = "max_retries must be greater than 0." raise ValueError(msg) @@ -42,21 +43,24 @@ def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: """Retry a synchronous function.""" retries = 0 delay = 1.0 # Initial delay in seconds - while True: - try: - return func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + return func(**kwargs) + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay *= self._base_delay logger.exception( - f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay *= self._base_delay - logger.exception( - f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - time.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311 + time.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311 + finally: + self._record_retries(retries) async def aretry( self, @@ -66,18 +70,21 @@ async def aretry( """Retry an asynchronous function.""" retries = 0 delay = 1.0 # Initial delay in seconds - while True: - try: - return await func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + return await func(**kwargs) + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay *= self._base_delay logger.exception( - f"ExponentialRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay *= self._base_delay - logger.exception( - f"ExponentialRetry: Request failed, retrying, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - await asyncio.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311 + await asyncio.sleep(delay + (self._jitter * random.uniform(0, 1))) # noqa: S311 + finally: + self._record_retries(retries) diff --git a/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py b/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py index 97fbdbf9c..0a839fb17 100644 --- a/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/incremental_wait_retry.py @@ -24,6 +24,7 @@ def __init__( max_retries: int = 5, **kwargs: Any, ): + super().__init__(**kwargs) if max_retries <= 0: msg = "max_retries must be greater than 0." raise ValueError(msg) @@ -40,21 +41,24 @@ def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: """Retry a synchronous function.""" retries = 0 delay = 0.0 - while True: - try: - return func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + return func(**kwargs) + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay += self._increment logger.exception( - f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay += self._increment - logger.exception( - f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - time.sleep(delay) + time.sleep(delay) + finally: + self._record_retries(retries) async def aretry( self, @@ -64,18 +68,21 @@ async def aretry( """Retry an asynchronous function.""" retries = 0 delay = 0.0 - while True: - try: - return await func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + return await func(**kwargs) + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay += self._increment logger.exception( - f"IncrementalWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay += self._increment - logger.exception( - f"IncrementalWaitRetry: Request failed, retrying after incremental delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - await asyncio.sleep(delay) + await asyncio.sleep(delay) + finally: + self._record_retries(retries) diff --git a/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py b/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py index 088f45421..d4d84f9d4 100644 --- a/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/native_wait_retry.py @@ -21,6 +21,7 @@ def __init__( max_retries: int = 5, **kwargs: Any, ): + super().__init__(**kwargs) if max_retries <= 0: msg = "max_retries must be greater than 0." raise ValueError(msg) @@ -30,19 +31,22 @@ def __init__( def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: """Retry a synchronous function.""" retries = 0 - while True: - try: - return func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + return func(**kwargs) + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 logger.exception( - f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - logger.exception( - f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) + finally: + self._record_retries(retries) async def aretry( self, @@ -51,16 +55,19 @@ async def aretry( ) -> Any: """Retry an asynchronous function.""" retries = 0 - while True: - try: - return await func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + return await func(**kwargs) + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 logger.exception( - f"NativeRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - logger.exception( - f"NativeRetry: Request failed, immediately retrying, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) + finally: + self._record_retries(retries) diff --git a/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py b/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py index 603f439d1..b4759b970 100644 --- a/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/random_wait_retry.py @@ -25,6 +25,7 @@ def __init__( max_retries: int = 5, **kwargs: Any, ): + super().__init__(**kwargs) if max_retries <= 0: msg = "max_retries must be greater than 0." raise ValueError(msg) @@ -39,21 +40,24 @@ def __init__( def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: """Retry a synchronous function.""" retries = 0 - while True: - try: - return func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + return func(**kwargs) + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay = random.uniform(0, self._max_retry_wait) # noqa: S311 logger.exception( - f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay = random.uniform(0, self._max_retry_wait) # noqa: S311 - logger.exception( - f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - time.sleep(delay) + time.sleep(delay) + finally: + self._record_retries(retries) async def aretry( self, @@ -62,18 +66,21 @@ async def aretry( ) -> Any: """Retry an asynchronous function.""" retries = 0 - while True: - try: - return await func(**kwargs) - except Exception as e: - if retries >= self._max_retries: + try: + while True: + try: + return await func(**kwargs) + except Exception as e: + if retries >= self._max_retries: + logger.exception( + f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + ) + raise + retries += 1 + delay = random.uniform(0, self._max_retry_wait) # noqa: S311 logger.exception( - f"RandomWaitRetry: Max retries exceeded, retries={retries}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 + f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 ) - raise - retries += 1 - delay = random.uniform(0, self._max_retry_wait) # noqa: S311 - logger.exception( - f"RandomWaitRetry: Request failed, retrying after random delay, retries={retries}, delay={delay}, max_retries={self._max_retries}, exception={e}", # noqa: G004, TRY401 - ) - await asyncio.sleep(delay) + await asyncio.sleep(delay) + finally: + self._record_retries(retries) diff --git a/graphrag/language_model/providers/litellm/services/retry/retry.py b/graphrag/language_model/providers/litellm/services/retry/retry.py index 4f53e598c..360ef9375 100644 --- a/graphrag/language_model/providers/litellm/services/retry/retry.py +++ b/graphrag/language_model/providers/litellm/services/retry/retry.py @@ -9,12 +9,37 @@ class Retry(ABC): - """LiteLLM Retry Abstract Base Class.""" + """LiteLLM Retry Abstract Base Class. + + Added lightweight pipeline context support to allow retry implementations + to record retry counts. + """ - @abstractmethod def __init__(self, /, **kwargs: Any): - msg = "Retry subclasses must implement the __init__ method." - raise NotImplementedError(msg) + self._pipeline_context: Any | None = None + + def set_pipeline_context(self, context: Any) -> None: + """Inject pipeline context (optional). + + The context is expected to expose a `record_llm_retries(int)` method. If it does not, retry tracking + is silently skipped. + """ + self._pipeline_context = context + + def _record_retries(self, retry_count: int) -> None: + """Record retry attempts to pipeline context (if available). + + This is a protected method intended for use by subclasses in their + finally blocks to ensure retry counts are recorded regardless of + success or failure. + + Args + ---- + retry_count: Number of retry attempts performed. + + """ + if self._pipeline_context is not None and retry_count > 0: + self._pipeline_context.record_llm_retries(retry_count) @abstractmethod def retry(self, func: Callable[..., Any], **kwargs: Any) -> Any: