Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions rdagent/app/data_science/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
"""Number of failures tolerated before escalating to next timeout level (stage width). Every 'patience' failures, timeout increases by 'runner_timeout_increase_stage'"""
show_hard_limit: bool = True

#### hypothesis critique and rewrite
enable_hypo_critique_rewrite: bool = True
"""Enable hypothesis critique and rewrite stages for improving hypothesis quality"""
enable_scale_check: bool = False

#### mcp in coder
enable_context7: bool = True
"""enable the use of context7 as mcp to search for relevant documents of current implementation errors"""

#### enable runner code change summary
runner_enable_code_change_summary: bool = True

Expand Down
60 changes: 31 additions & 29 deletions rdagent/components/coder/data_science/pipeline/eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# tess successfully running.
# (GPT) if it aligns with the spec & rationality of the spec.
import json
import asyncio
import concurrent.futures
import re
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -20,6 +21,7 @@
from rdagent.components.coder.data_science.conf import get_clear_ws_cmd, get_ds_env
from rdagent.components.coder.data_science.share.notebook import NotebookConverter
from rdagent.components.coder.data_science.utils import remove_eda_part
from rdagent.components.mcp import query_context7
from rdagent.core.experiment import FBWorkspace, Task
from rdagent.log import rdagent_logger as logger
from rdagent.scenarios.data_science.test_eval import get_test_eval
Expand Down Expand Up @@ -76,7 +78,7 @@ def __str__(self) -> str:

if self.error_message is not None:
# Check if error_message contains Context7 documentation results
if "### API Documentation Reference:" in self.error_message:
if "### Relevant Documentation Reference:" in self.error_message:
base_str += f"-------------------Error Analysis & Documentation Search Results ------------------\n{self.error_message}\n"
else:
base_str += f"-------------------Error Message------------------\n{self.error_message}\n"
Expand Down Expand Up @@ -270,8 +272,8 @@ def evaluate(
else:
eda_output = implementation.file_dict.get("EDA.md", None)

# extract enable_mcp_documentation_search from data science configuration
enable_mcp_documentation_search = DS_RD_SETTING.enable_mcp_documentation_search
# extract enable_context7 from setting
enable_context7 = DS_RD_SETTING.enable_context7

queried_similar_successful_knowledge = (
queried_knowledge.task_to_similar_task_successful_knowledge[target_task.get_task_information()]
Expand All @@ -282,7 +284,7 @@ def evaluate(
system_prompt = T(".prompts:pipeline_eval.system").r(
is_sub_enabled=test_eval.is_sub_enabled(self.scen.competition),
debug_mode=DS_RD_SETTING.sample_data_by_LLM,
enable_mcp_documentation_search=enable_mcp_documentation_search,
enable_context7=enable_context7,
mle_check=DS_RD_SETTING.sample_data_by_LLM,
queried_similar_successful_knowledge=queried_similar_successful_knowledge,
)
Expand All @@ -303,33 +305,33 @@ def evaluate(
init_kwargs_update_func=PipelineSingleFeedback.val_and_update_init_dict,
)

# judge whether we should perform documentation search
do_documentation_search = enable_mcp_documentation_search and wfb.requires_documentation_search

if do_documentation_search:
# Use MCPAgent for clean, user-friendly interface
if enable_context7 and wfb.requires_documentation_search is True:
try:
# Create agent targeting Context7 service - model config comes from mcp_config.json
doc_agent = DocAgent()

# Synchronous query - perfect for evaluation context
if wfb.error_message: # Type safety check
context7_result = doc_agent.query(query=wfb.error_message)

if context7_result:
logger.info("Context7: Documentation search completed successfully")
wfb.error_message += f"\n\n### API Documentation Reference:\nThe following API documentation was retrieved based on the error. This provides factual information about API changes or parameter specifications only:\n\n{context7_result}"
else:
logger.warning("Context7: Documentation search failed or no results found")
else:
logger.warning("Context7: No error message to search for")

# TODO: confirm what exception will be raised when timeout
# except concurrent.futures.TimeoutError:
# logger.error("Context7: Query timed out after 180 seconds")
def run_context7_sync():
"""Run Context7 query in a new event loop"""
# Create new event loop to avoid conflicts with existing loop
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
query_context7(error_message=wfb.error_message, full_code=implementation.all_codes)
)
finally:
new_loop.close()

# Execute in thread pool to avoid event loop conflicts
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_context7_sync)
context7_result = future.result(timeout=120) # 120s timeout, sufficient time for retry mechanism

if context7_result:
logger.info("Context7: Documentation search completed successfully")
wfb.error_message += f"\n\n### API Documentation Reference:\nThe following API documentation was retrieved based on the error. This provides factual information about API changes or parameter specifications only:\n\n{context7_result}"
else:
logger.warning("Context7: Documentation search failed or no results found")
except Exception as e:
error_msg = str(e) if str(e) else type(e).__name__
logger.error(f"Context7: Query failed - {error_msg}")
logger.error(f"Context7: Query failed - {str(e)}")

if score_ret_code != 0 and wfb.final_decision is True:
wfb.final_decision = False
Expand Down
13 changes: 7 additions & 6 deletions rdagent/components/coder/data_science/pipeline/prompts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,12 @@ pipeline_eval:
- Notes:
- Model performance is not evaluated in this step; focus solely on successful execution.
- Warnings are acceptable if they do not interfere with successful code execution.
- **Environment Constraint**: The coding environment is fixed and pre-configured. No package installation or modification is allowed. Code must use only existing pre-installed packages.
- If the code execute successfully:
- Proceed to Step 2.
- Proceed to Step 2 and overlook the remaining steps in Step 1.
- If the code does not execute successfully:
- Set the "final_decision" to false.
{% if enable_mcp_documentation_search %}
{% if enable_context7 %}
- Given that my package/environment is fixed and unchangeable, first you should go through the code and the execution output,if the problem could be solved by looking up the official documentation to confirm feature/API availability, compatible usage, or official alternatives in the fixed environment, set the "requires_documentation_search" to true.
{% endif %}
- Write complete analysis in the "execution" field.
Expand Down Expand Up @@ -314,14 +315,14 @@ pipeline_eval:
Please respond with your feedback in the following JSON format without anything else.
```json
{
{% if enable_mcp_documentation_search %}
{% if enable_context7 %}
"requires_documentation_search": <true/false>,
{% endif %}"execution": "Describe whether the code executed successfully. Include any errors or issues encountered, and append all error messages and full traceback details without summarizing or omitting any information. If errors occurred, analyze the root causes: (1) Are they fundamental algorithmic/approach issues, or (2) Implementation details that can be easily fixed, or (3) Environment/dependency problems?",
{% endif %}
"execution": "Describe whether the code executed successfully. Include any errors or issues encountered, and append all error messages and full traceback details without summarizing or omitting any information. If errors occurred, analyze the root causes: (1) Are they fundamental algorithmic/approach issues, or (2) Implementation details that can be easily fixed, or (3) Environment/dependency problems?",
"return_checking": "Examine the generated files by cross-referencing the code logic and stdout output. Verify: (1) Format matches required submission format (index, column names, CSV content); (2) **File generation authenticity**: Is the file genuinely produced by successful model execution, or is it a result of exception handling/fallback mechanisms? Cite specific code sections and stdout evidence.",
"code": "Begin explicitly with [Code analysis] or [Evaluation error]. Provide structured analysis: (1) **Technical Appropriateness**: Does the chosen approach (algorithms, data processing, validation strategy) match this problem's data characteristics and competition requirements? (2) **Effective Components**: What specific parts work well and why are they effective for this problem type? (3) **Issues & Improvements**: Identify concrete problems and suggest actionable improvement directions (without providing actual code). (4) **Code Quality**: Assess readability, structure, and adherence to specifications.",
{% if enable_mcp_documentation_search %}
"error_message": "If the code execution has problems, extract the error information in the following format, otherwise set to empty string: ### TRACEBACK: <full relevant traceback extracted from execution output> ### SUPPLEMENTARY_INFO: <only if TRACEBACK is unclear - copy exact code fragments: import statements, variable=value assignments, function calls with parameters as they appear in code>",
{% endif %}"final_decision": <true/false>
"final_decision": <true/false>
}
```

Expand Down
8 changes: 8 additions & 0 deletions rdagent/components/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""MCP (Model Context Protocol) integration for RD-Agent.

This module provides context7 functionality for documentation search.
"""

from .context7 import query_context7

__all__ = ["query_context7"]
225 changes: 225 additions & 0 deletions rdagent/components/mcp/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""MCP cache management module.

Provides general caching functionality for MCP tools and query result caching.
Reuses RD-Agent's existing SQLite cache system with permanent caching strategy.
"""

import hashlib
from typing import Any, Dict, Optional

from rdagent.log import rdagent_logger as logger
from rdagent.oai.backend.base import SQliteLazyCache
from rdagent.oai.llm_conf import LLM_SETTINGS


class MCPCache:
"""MCP cache manager based on existing SQLite cache system.

Uses permanent caching strategy, consistent with LITELLM.
"""

def __init__(self):
"""Initialize cache manager.

Uses permanent caching without expiration time.
"""
self._cache = SQliteLazyCache(cache_location=LLM_SETTINGS.prompt_cache_path)
self._stats = {"tools_hits": 0, "tools_misses": 0, "query_hits": 0, "query_misses": 0}

def _get_cached_result(self, cache_key: str) -> Optional[str]:
"""Get result from SQLite cache."""
return self._cache.chat_get(cache_key)

def _set_cached_result(self, cache_key: str, result: str):
"""Set SQLite cache result."""
self._cache.chat_set(cache_key, result)

def get_tools(self, mcp_url: str) -> Optional[Any]:
"""Get cached tools.

Args:
mcp_url: MCP service URL

Returns:
Cached tools list, returns None if cache miss
"""
# Tool object serialization is complex, temporarily not implementing tool caching
self._stats["tools_misses"] += 1
logger.info(f"Tools cache miss for URL: {mcp_url} (tools caching disabled)")
return None

def set_tools(self, mcp_url: str, tools: Any):
"""Set tools cache.

Args:
mcp_url: MCP service URL
tools: Tools list to cache (currently unused)
"""
# Temporarily not caching tool objects as they contain complex objects that are difficult to serialize
logger.info(f"Tools caching skipped for URL: {mcp_url} (complex objects)")

def get_query_result(self, error_message: str) -> Optional[str]:
"""Get cached query result.

Args:
error_message: Error message

Returns:
Cached query result, returns None if cache miss
"""
cache_key = f"mcp_query:{hashlib.md5(error_message.encode('utf-8')).hexdigest()}"
cached_result = self._get_cached_result(cache_key)

if cached_result:
self._stats["query_hits"] += 1
logger.info(f"Query cache hit for key: {cache_key[-8:]}...")
return cached_result

self._stats["query_misses"] += 1
logger.info(f"Query cache miss for key: {cache_key[-8:]}...")
return None

def set_query_result(self, error_message: str, result: str):
"""Set query result cache.

Args:
error_message: Error message
result: Query result
"""
cache_key = f"mcp_query:{hashlib.md5(error_message.encode('utf-8')).hexdigest()}"
self._set_cached_result(cache_key, result)
logger.info(f"Query result cached for key: {cache_key[-8:]}...")

def clear_cache(self):
"""Clear all MCP cache."""
cleared_count = 0

# Clear all cache keys with mcp_ prefix
# Note: This requires traversing the entire database, performance may be poor
logger.warning("Clearing all MCP cache entries...")

# Due to SQLite interface limitations, we cannot directly traverse keys, so provide hints
logger.info(
"To completely clear MCP cache, please delete the SQLite cache file or use clear_mcp_cache_by_pattern()"
)

return cleared_count

def clear_query_cache(self, error_message: str = None):
"""Clear query cache.

Args:
error_message: If specified, only clear cache for specific error message; otherwise clear all query cache
"""
if error_message:
# Clear cache for specific query
cache_key = f"mcp_query:{hashlib.md5(error_message.encode('utf-8')).hexdigest()}"
# SQLite has no direct delete method, we set to None to "delete"
self._set_cached_result(cache_key, "") # Set to empty string to indicate deletion
logger.info(f"Cleared cache for specific query: {cache_key[-8:]}...")
else:
logger.info("To clear all query cache, please use clear_all_mcp_cache() or delete the cache file")

def get_cache_info(self):
"""Get cache information."""
stats = self.get_cache_stats()
cache_file = getattr(self._cache, "cache_location", "unknown")

info = {"cache_file": cache_file, "stats": stats, "cache_type": "SQLite (shared with LITELLM)"}

logger.info(f"Cache info: {info}")
return info

def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
total_tools = self._stats["tools_hits"] + self._stats["tools_misses"]
total_queries = self._stats["query_hits"] + self._stats["query_misses"]

return {
"tools_cache": {
"hits": self._stats["tools_hits"],
"misses": self._stats["tools_misses"],
"hit_rate": self._stats["tools_hits"] / max(total_tools, 1),
"size": "N/A (SQLite)",
},
"query_cache": {
"hits": self._stats["query_hits"],
"misses": self._stats["query_misses"],
"hit_rate": self._stats["query_hits"] / max(total_queries, 1),
"size": "N/A (SQLite)",
},
}

def log_cache_stats(self):
"""Log cache statistics to log."""
stats = self.get_cache_stats()
logger.info(
f"Cache stats - Tools: {stats['tools_cache']['hits']}/{stats['tools_cache']['hits'] + stats['tools_cache']['misses']} hits "
f"({stats['tools_cache']['hit_rate']:.2%}), "
f"Queries: {stats['query_cache']['hits']}/{stats['query_cache']['hits'] + stats['query_cache']['misses']} hits "
f"({stats['query_cache']['hit_rate']:.2%})"
)


# Global cache instance
_global_cache: Optional[MCPCache] = None


def get_mcp_cache() -> MCPCache:
"""Get global MCP cache instance.

Returns:
MCP cache instance (permanent cache)
"""
global _global_cache
if _global_cache is None:
_global_cache = MCPCache()
return _global_cache


def clear_mcp_cache_by_file():
"""Clear all cache by deleting SQLite cache file.

Note: This will clear all cache, including LITELLM cache!
"""
import os

from rdagent.oai.llm_conf import LLM_SETTINGS

cache_file = LLM_SETTINGS.prompt_cache_path
if os.path.exists(cache_file):
try:
os.remove(cache_file)
logger.info(f"Successfully deleted cache file: {cache_file}")

# Reset global cache instance
global _global_cache
_global_cache = None

return True
except Exception as e:
logger.error(f"Failed to delete cache file {cache_file}: {e}")
return False
else:
logger.info(f"Cache file does not exist: {cache_file}")
return True


def get_cache_file_info():
"""Get cache file information."""
import os

from rdagent.oai.llm_conf import LLM_SETTINGS

cache_file = LLM_SETTINGS.prompt_cache_path

if os.path.exists(cache_file):
stat = os.stat(cache_file)
size_mb = stat.st_size / (1024 * 1024)

info = {"file_path": cache_file, "exists": True, "size_mb": round(size_mb, 2), "modified_time": stat.st_mtime}
else:
info = {"file_path": cache_file, "exists": False, "size_mb": 0, "modified_time": None}

logger.info(f"Cache file info: {info}")
return info
Loading