diff --git a/src/agents/extensions/memory/__init__.py b/src/agents/extensions/memory/__init__.py index 5d670c4ad..3964bb7ae 100644 --- a/src/agents/extensions/memory/__init__.py +++ b/src/agents/extensions/memory/__init__.py @@ -15,6 +15,7 @@ "RedisSession", "SQLAlchemySession", "AdvancedSQLiteSession", + "AdvancedSQLAlchemySession", ] @@ -60,4 +61,15 @@ def __getattr__(name: str) -> Any: except ModuleNotFoundError as e: raise ImportError(f"Failed to import AdvancedSQLiteSession: {e}") from e + if name == "AdvancedSQLAlchemySession": + try: + from .advanced_sqlalchemy_session import AdvancedSQLAlchemySession # noqa: F401 + + return AdvancedSQLAlchemySession + except ModuleNotFoundError as e: + raise ImportError( + "AdvancedSQLAlchemySession requires the 'sqlalchemy' extra. " + "Install it with: pip install openai-agents[sqlalchemy]" + ) from e + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/agents/extensions/memory/advanced_sqlalchemy_session.py b/src/agents/extensions/memory/advanced_sqlalchemy_session.py new file mode 100644 index 000000000..ac225579f --- /dev/null +++ b/src/agents/extensions/memory/advanced_sqlalchemy_session.py @@ -0,0 +1,1239 @@ +from __future__ import annotations + +import json +import logging +import time +from typing import Any + +from sqlalchemy import ( + TIMESTAMP, + Column, + ForeignKey, + Index, + Integer, + String, + Table, + Text, + UniqueConstraint, + and_, + case, + delete, + func, + insert, + select, + text as sql_text, + update, +) +from sqlalchemy.ext.asyncio import AsyncEngine + +from agents.result import RunResult +from agents.usage import Usage + +from ...items import TResponseInputItem +from .sqlalchemy_session import SQLAlchemySession + + +class AdvancedSQLAlchemySession(SQLAlchemySession): + """SQLAlchemy implementation of the advanced session with branching and usage tracking.""" + + _message_structure: Table + _turn_usage: Table + + def __init__( + self, + session_id: str, + *, + engine: AsyncEngine, + create_tables: bool = False, + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + structure_table: str = "message_structure", + turn_usage_table: str = "turn_usage", + logger: logging.Logger | None = None, + ): + """Initialize the AdvancedSQLAlchemySession.""" + super().__init__( + session_id, + engine=engine, + create_tables=create_tables, + sessions_table=sessions_table, + messages_table=messages_table, + ) + + self._message_structure = Table( + structure_table, + self._metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column( + "session_id", + String, + ForeignKey(f"{self._sessions.name}.session_id", ondelete="CASCADE"), + nullable=False, + ), + Column( + "message_id", + Integer, + ForeignKey(f"{self._messages.name}.id", ondelete="CASCADE"), + nullable=False, + ), + Column("branch_id", String, nullable=False, server_default="main"), + Column("message_type", String, nullable=False), + Column("sequence_number", Integer, nullable=False), + Column("user_turn_number", Integer), + Column("branch_turn_number", Integer), + Column("tool_name", String), + Column( + "created_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + ), + sqlite_autoincrement=True, + ) + + Index( + f"idx_{structure_table}_session_seq", + self._message_structure.c.session_id, + self._message_structure.c.sequence_number, + ) + Index( + f"idx_{structure_table}_branch", + self._message_structure.c.session_id, + self._message_structure.c.branch_id, + ) + Index( + f"idx_{structure_table}_branch_turn", + self._message_structure.c.session_id, + self._message_structure.c.branch_id, + self._message_structure.c.user_turn_number, + ) + Index( + f"idx_{structure_table}_branch_seq", + self._message_structure.c.session_id, + self._message_structure.c.branch_id, + self._message_structure.c.sequence_number, + ) + + self._turn_usage = Table( + turn_usage_table, + self._metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column( + "session_id", + String, + ForeignKey(f"{self._sessions.name}.session_id", ondelete="CASCADE"), + nullable=False, + ), + Column("branch_id", String, nullable=False, server_default="main"), + Column("user_turn_number", Integer, nullable=False), + Column("requests", Integer, nullable=False, server_default="0"), + Column("input_tokens", Integer, nullable=False, server_default="0"), + Column("output_tokens", Integer, nullable=False, server_default="0"), + Column("total_tokens", Integer, nullable=False, server_default="0"), + Column("input_tokens_details", Text), + Column("output_tokens_details", Text), + Column( + "created_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + ), + UniqueConstraint( + "session_id", + "branch_id", + "user_turn_number", + name=f"uq_{turn_usage_table}_turn", + ), + sqlite_autoincrement=True, + ) + + Index( + f"idx_{turn_usage_table}_session_turn", + self._turn_usage.c.session_id, + self._turn_usage.c.branch_id, + self._turn_usage.c.user_turn_number, + ) + + self._current_branch_id = "main" + self._logger = logger or logging.getLogger(__name__) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add items to the session. + + Args: + items: The items to add to the session + """ + if not items: + return + + await self._ensure_tables() + async with self._lock: + await super().add_items(items) + try: + await self._add_structure_metadata(items) + except Exception as exc: # pragma: no cover - defensive + self._logger.error( + "Failed to add structure metadata for session %s: %s", + self.session_id, + exc, + ) + try: + await self._cleanup_orphaned_messages() + except Exception as cleanup_error: # pragma: no cover - defensive + self._logger.error( + "Failed to cleanup orphaned messages for session %s: %s", + self.session_id, + cleanup_error, + ) + + async def get_items( + self, + limit: int | None = None, + branch_id: str | None = None, + ) -> list[TResponseInputItem]: + """Get items from current or specified branch. + + Args: + limit: Maximum number of items to return. If None, returns all items. + branch_id: Branch to get items from. If None, uses current branch. + + Returns: + List of conversation items from the specified branch. + """ + branch = branch_id or self._current_branch_id + await self._ensure_tables() + + async with self._session_factory() as sess: + if limit is None: + stmt = ( + select(self._messages.c.message_data) + .join( + self._message_structure, + and_( + self._messages.c.id == self._message_structure.c.message_id, + self._message_structure.c.branch_id == branch, + ), + ) + .where(self._messages.c.session_id == self.session_id) + .order_by(self._message_structure.c.sequence_number.asc()) + ) + else: + stmt = ( + select(self._messages.c.message_data) + .join( + self._message_structure, + and_( + self._messages.c.id == self._message_structure.c.message_id, + self._message_structure.c.branch_id == branch, + ), + ) + .where(self._messages.c.session_id == self.session_id) + .order_by(self._message_structure.c.sequence_number.desc()) + .limit(limit) + ) + + result = await sess.execute(stmt) + rows: list[str] = [row[0] for row in result.all()] + + if limit is not None: + rows.reverse() + + items: list[TResponseInputItem] = [] + for raw in rows: + try: + items.append(await self._deserialize_item(raw)) + except json.JSONDecodeError: + # Skip corrupted rows + continue + return items + + async def store_run_usage(self, result: RunResult) -> None: + """Store usage data for the current conversation turn. + + This is designed to be called after `Runner.run()` completes. + Session-level usage can be aggregated from turn data when needed. + + Args: + result: The result from the run + """ + usage = result.context_wrapper.usage + if usage is None: + return + + try: + current_turn = await self._get_current_turn_number() + if current_turn > 0: + await self._update_turn_usage_internal(current_turn, usage) + except Exception as exc: # pragma: no cover - defensive logging + self._logger.error("Failed to store usage for session %s: %s", self.session_id, exc) + + async def _get_next_turn_number(self, branch_id: str) -> int: + """Get the next turn number for a specific branch. + + Args: + branch_id: The branch ID to get the next turn number for. + + Returns: + The next available turn number for the specified branch. + """ + max_turn = await self._get_current_turn_number(branch_id) + return max_turn + 1 + + async def _get_next_branch_turn_number(self, branch_id: str) -> int: + """Get the next branch turn number for a specific branch. + + Args: + branch_id: The branch ID to get the next branch turn number for. + + Returns: + The next available branch turn number for the specified branch. + """ + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = select( + func.coalesce(func.max(self._message_structure.c.branch_turn_number), 0) + ).where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == branch_id, + ) + ) + value = await sess.scalar(stmt) + return int(value or 0) + 1 + + async def _get_current_turn_number(self, branch_id: str | None = None) -> int: + """Get the current turn number for the current branch. + + Returns: + The current turn number for the active branch. + """ + branch = branch_id or self._current_branch_id + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = select( + func.coalesce(func.max(self._message_structure.c.user_turn_number), 0) + ).where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == branch, + ) + ) + value = await sess.scalar(stmt) + return int(value or 0) + + async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None: + """Extract structure metadata with branch-aware turn tracking. + + This method: + - Assigns turn numbers per branch (not globally) + - Assigns explicit sequence numbers for precise ordering + - Links messages to their database IDs for structure tracking + - Handles multiple user messages in a single batch correctly + + Args: + items: The items to add to the session + """ + if not items: + return + + await self._ensure_tables() + + async with self._session_factory() as sess: + async with sess.begin(): + ids_stmt = ( + select(self._messages.c.id) + .where(self._messages.c.session_id == self.session_id) + .order_by(self._messages.c.id.desc()) + .limit(len(items)) + ) + id_rows = await sess.execute(ids_stmt) + message_ids = [row[0] for row in id_rows.all()] + message_ids.reverse() + + if len(message_ids) != len(items): + self._logger.warning( + "Mismatch retrieving message IDs for session %s. Expected %s got %s", + self.session_id, + len(items), + len(message_ids), + ) + return + + seq_stmt = select( + func.coalesce(func.max(self._message_structure.c.sequence_number), 0) + ).where(self._message_structure.c.session_id == self.session_id) + seq_start = await sess.scalar(seq_stmt) + seq_start = int(seq_start or 0) + + turn_stmt = select( + func.coalesce(func.max(self._message_structure.c.user_turn_number), 0), + func.coalesce(func.max(self._message_structure.c.branch_turn_number), 0), + ).where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == self._current_branch_id, + ) + ) + turn_row = await sess.execute(turn_stmt) + turn_values = turn_row.one_or_none() + current_turn = int(turn_values[0]) if turn_values and turn_values[0] else 0 + current_branch_turn = int(turn_values[1]) if turn_values and turn_values[1] else 0 + + structure_payload: list[dict[str, Any]] = [] + user_message_count = 0 + + for offset, (item, message_id) in enumerate(zip(items, message_ids)): + msg_type = self._classify_message_type(item) + tool_name = self._extract_tool_name(item) + + if self._is_user_message(item): + user_message_count += 1 + + turn_value = current_turn + user_message_count + branch_turn_value = current_branch_turn + user_message_count + + structure_payload.append( + { + "session_id": self.session_id, + "message_id": message_id, + "branch_id": self._current_branch_id, + "message_type": msg_type, + "sequence_number": seq_start + offset + 1, + "user_turn_number": turn_value, + "branch_turn_number": branch_turn_value, + "tool_name": tool_name, + } + ) + + if structure_payload: + await sess.execute(insert(self._message_structure), structure_payload) + + async def _cleanup_orphaned_messages(self) -> None: + """Remove messages that exist in agent_messages but not in message_structure. + + This can happen if _add_structure_metadata fails after super().add_items() succeeds. + Used for maintaining data consistency. + """ + await self._ensure_tables() + async with self._session_factory() as sess: + async with sess.begin(): + join_stmt = ( + select(self._messages.c.id) + .select_from( + self._messages.outerjoin( + self._message_structure, + self._messages.c.id == self._message_structure.c.message_id, + ) + ) + .where( + and_( + self._messages.c.session_id == self.session_id, + self._message_structure.c.message_id.is_(None), + ) + ) + ) + result = await sess.execute(join_stmt) + orphan_ids = [row[0] for row in result.all()] + + if orphan_ids: + await sess.execute( + delete(self._messages).where(self._messages.c.id.in_(orphan_ids)) + ) + self._logger.info( + "Cleaned up %s orphaned messages for session %s", + len(orphan_ids), + self.session_id, + ) + + def _classify_message_type(self, item: TResponseInputItem) -> str: + """Classify the type of a message item. + + Args: + item: The message item to classify. + + Returns: + String representing the message type (user, assistant, etc.). + """ + if isinstance(item, dict): + if item.get("role") == "user": + return "user" + elif item.get("role") == "assistant": + return "assistant" + elif item.get("type"): + return str(item.get("type")) + return "other" + + def _extract_tool_name(self, item: TResponseInputItem) -> str | None: + """Extract tool name if this is a tool call/output. + + Args: + item: The message item to extract tool name from. + + Returns: + Tool name if item is a tool call, None otherwise. + """ + if isinstance(item, dict): + item_type = item.get("type") + + if item_type in {"mcp_call", "mcp_approval_request"} and "server_label" in item: + server_label = item.get("server_label") + tool_name = item.get("name") + if tool_name and server_label: + return f"{server_label}.{tool_name}" + if server_label: + return str(server_label) + if tool_name: + return str(tool_name) + + if item_type in { + "computer_call", + "file_search_call", + "web_search_call", + "code_interpreter_call", + }: + return item_type + + if "name" in item and item.get("name") is not None: + return str(item.get("name")) + + return None + + def _is_user_message(self, item: TResponseInputItem) -> bool: + """Check if this is a user message. + + Args: + item: The message item to check. + + Returns: + True if the item is a user message, False otherwise. + """ + return isinstance(item, dict) and item.get("role") == "user" + + async def create_branch_from_turn( + self, + turn_number: int, + branch_name: str | None = None, + ) -> str: + """Create a new branch starting from a specific user message turn. + + Args: + turn_number: The branch turn number of the user message to branch from + branch_name: Optional name for the branch (auto-generated if None) + + Returns: + The branch_id of the newly created branch + + Raises: + ValueError: If turn doesn't exist or doesn't contain a user message + """ + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = ( + select(self._messages.c.message_data) + .join( + self._message_structure, + self._messages.c.id == self._message_structure.c.message_id, + ) + .where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == self._current_branch_id, + self._message_structure.c.branch_turn_number == turn_number, + self._message_structure.c.message_type == "user", + ) + ) + ) + row = await sess.execute(stmt) + message_row = row.first() + + if not message_row: + raise ValueError( + f"Turn {turn_number} does not contain a user message " + f"in branch '{self._current_branch_id}'" + ) + + try: + message_content = json.loads(message_row[0]).get("content", "") + except Exception: # pragma: no cover - defensive + message_content = "Unable to parse content" + + if branch_name is None: + branch_name = f"branch_from_turn_{turn_number}_{int(time.time())}" + + await self._copy_messages_to_new_branch(branch_name, turn_number) + + old_branch = self._current_branch_id + self._current_branch_id = branch_name + self._logger.debug( + "Created branch '%s' from turn %s ('%s') in '%s'", + branch_name, + turn_number, + message_content[:50] + ("..." if len(message_content) > 50 else ""), + old_branch, + ) + return branch_name + + async def create_branch_from_content( + self, + search_term: str, + branch_name: str | None = None, + ) -> str: + """Create branch from the first user turn matching the search term. + + Args: + search_term: Text to search for in user messages. + branch_name: Optional name for the branch (auto-generated if None). + + Returns: + The branch_id of the newly created branch. + + Raises: + ValueError: If no matching turns are found. + """ + matches = await self.find_turns_by_content(search_term) + if not matches: + raise ValueError(f"No user turns found containing '{search_term}'") + + return await self.create_branch_from_turn(matches[0]["turn"], branch_name) + + async def switch_to_branch(self, branch_id: str) -> None: + """Switch to a different branch. + + Args: + branch_id: The branch to switch to. + + Raises: + ValueError: If the branch doesn't exist. + """ + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = ( + select(func.count()) + .select_from(self._message_structure) + .where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == branch_id, + ) + ) + ) + exists = await sess.scalar(stmt) + + if not exists: + raise ValueError(f"Branch '{branch_id}' does not exist") + + old_branch = self._current_branch_id + self._current_branch_id = branch_id + self._logger.info("Switched from branch '%s' to '%s'", old_branch, branch_id) + + async def delete_branch(self, branch_id: str, *, force: bool = False) -> None: + """Delete a branch and all its associated data. + + Args: + branch_id: The branch to delete. + force: If True, allows deleting the current branch (will switch to 'main'). + + Raises: + ValueError: If branch doesn't exist, is 'main', or is current branch without force. + """ + if not branch_id or not branch_id.strip(): + raise ValueError("Branch ID cannot be empty") + + branch_id = branch_id.strip() + + if branch_id == "main": + raise ValueError("Cannot delete the 'main' branch") + + if branch_id == self._current_branch_id: + if not force: + raise ValueError( + f"Cannot delete current branch '{branch_id}'. " + "Use force=True or switch branches first" + ) + await self.switch_to_branch("main") + + await self._ensure_tables() + async with self._lock: + async with self._session_factory() as sess: + async with sess.begin(): + exists_stmt = ( + select(func.count()) + .select_from(self._message_structure) + .where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == branch_id, + ) + ) + ) + exists = await sess.scalar(exists_stmt) + if not exists: + raise ValueError(f"Branch '{branch_id}' does not exist") + + usage_result = await sess.execute( + delete(self._turn_usage).where( + and_( + self._turn_usage.c.session_id == self.session_id, + self._turn_usage.c.branch_id == branch_id, + ) + ) + ) + + structure_result = await sess.execute( + delete(self._message_structure).where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == branch_id, + ) + ) + ) + + self._logger.info( + "Deleted branch '%s': %s message entries, %s usage entries", + branch_id, + structure_result.rowcount if "structure_result" in locals() else 0, + usage_result.rowcount if "usage_result" in locals() else 0, + ) + + async def list_branches(self) -> list[dict[str, Any]]: + """List all branches in this session. + + Returns: + List of dicts with branch info containing: + - 'branch_id': Branch identifier + - 'message_count': Number of messages in branch + - 'user_turns': Number of user turns in branch + - 'is_current': Whether this is the current branch + - 'created_at': When the branch was first created + """ + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = ( + select( + self._message_structure.c.branch_id, + func.count().label("message_count"), + func.sum( + case( + (self._message_structure.c.message_type == "user", 1), + else_=0, + ) + ).label("user_turns"), + func.min(self._message_structure.c.created_at).label("created_at"), + ) + .where(self._message_structure.c.session_id == self.session_id) + .group_by(self._message_structure.c.branch_id) + .order_by(func.min(self._message_structure.c.created_at)) + ) + result = await sess.execute(stmt) + rows = result.all() + + branches: list[dict[str, Any]] = [] + for branch_id, message_count, user_turns, created_at in rows: + branches.append( + { + "branch_id": branch_id, + "message_count": int(message_count or 0), + "user_turns": int(user_turns or 0), + "is_current": branch_id == self._current_branch_id, + "created_at": created_at, + } + ) + return branches + + async def _copy_messages_to_new_branch( + self, + new_branch_id: str, + from_turn_number: int, + ) -> None: + """Copy messages before the branch point to the new branch. + + Args: + new_branch_id: The ID of the new branch to copy messages to. + from_turn_number: The turn number to copy messages up to (exclusive). + """ + await self._ensure_tables() + async with self._lock: + async with self._session_factory() as sess: + async with sess.begin(): + select_stmt = ( + select( + self._message_structure.c.message_id, + self._message_structure.c.message_type, + self._message_structure.c.sequence_number, + self._message_structure.c.user_turn_number, + self._message_structure.c.branch_turn_number, + self._message_structure.c.tool_name, + ) + .where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == self._current_branch_id, + self._message_structure.c.branch_turn_number < from_turn_number, + ) + ) + .order_by(self._message_structure.c.sequence_number) + ) + rows = await sess.execute(select_stmt) + messages_to_copy = rows.all() + + if not messages_to_copy: + return + + seq_stmt = select( + func.coalesce(func.max(self._message_structure.c.sequence_number), 0) + ).where(self._message_structure.c.session_id == self.session_id) + seq_start = await sess.scalar(seq_stmt) + seq_start = int(seq_start or 0) + + payload: list[dict[str, Any]] = [] + for idx, ( + message_id, + message_type, + _, + user_turn_number, + branch_turn_number, + tool_name, + ) in enumerate(messages_to_copy): + payload.append( + { + "session_id": self.session_id, + "message_id": message_id, + "branch_id": new_branch_id, + "message_type": message_type, + "sequence_number": seq_start + idx + 1, + "user_turn_number": user_turn_number, + "branch_turn_number": branch_turn_number, + "tool_name": tool_name, + } + ) + + await sess.execute(insert(self._message_structure), payload) + + async def get_conversation_turns( + self, + branch_id: str | None = None, + ) -> list[dict[str, Any]]: + """Get user turns with content for easy browsing and branching decisions. + + Args: + branch_id: Branch to get turns from (current branch if None). + + Returns: + List of dicts with turn info containing: + - 'turn': Branch turn number + - 'content': User message content (truncated) + - 'full_content': Full user message content + - 'timestamp': When the turn was created + - 'can_branch': Always True (all user messages can branch) + """ + branch = branch_id or self._current_branch_id + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = ( + select( + self._message_structure.c.branch_turn_number, + self._messages.c.message_data, + self._message_structure.c.created_at, + ) + .join( + self._messages, + self._messages.c.id == self._message_structure.c.message_id, + ) + .where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == branch, + self._message_structure.c.message_type == "user", + ) + ) + .order_by(self._message_structure.c.branch_turn_number) + ) + result = await sess.execute(stmt) + rows = result.all() + + turns: list[dict[str, Any]] = [] + for turn_number, message_data, created_at in rows: + try: + content = json.loads(message_data).get("content", "") + except (json.JSONDecodeError, AttributeError): + continue + turns.append( + { + "turn": int(turn_number), + "content": content[:100] + ("..." if len(content) > 100 else ""), + "full_content": content, + "timestamp": created_at, + "can_branch": True, + } + ) + return turns + + async def find_turns_by_content( + self, + search_term: str, + branch_id: str | None = None, + ) -> list[dict[str, Any]]: + """Find user turns containing specific content. + + Args: + search_term: Text to search for in user messages. + branch_id: Branch to search in (current branch if None). + + Returns: + List of matching turns with same format as get_conversation_turns(). + """ + branch = branch_id or self._current_branch_id + pattern = f"%{search_term}%" + await self._ensure_tables() + + async with self._session_factory() as sess: + stmt = ( + select( + self._message_structure.c.branch_turn_number, + self._messages.c.message_data, + self._message_structure.c.created_at, + ) + .join( + self._messages, + self._messages.c.id == self._message_structure.c.message_id, + ) + .where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == branch, + self._message_structure.c.message_type == "user", + self._messages.c.message_data.like(pattern), + ) + ) + .order_by(self._message_structure.c.branch_turn_number) + ) + result = await sess.execute(stmt) + rows = result.all() + + matches: list[dict[str, Any]] = [] + for turn_number, message_data, created_at in rows: + try: + content = json.loads(message_data).get("content", "") + except (json.JSONDecodeError, AttributeError): + continue + matches.append( + { + "turn": int(turn_number), + "content": content, + "full_content": content, + "timestamp": created_at, + "can_branch": True, + } + ) + return matches + + async def get_conversation_by_turns( + self, + branch_id: str | None = None, + ) -> dict[int, list[dict[str, str | None]]]: + """Get conversation grouped by user turns for specified branch. + + Args: + branch_id: Branch to get conversation from (current branch if None). + + Returns: + Dictionary mapping turn numbers to lists of message metadata. + """ + branch = branch_id or self._current_branch_id + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = ( + select( + self._message_structure.c.user_turn_number, + self._message_structure.c.message_type, + self._message_structure.c.tool_name, + ) + .where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == branch, + ) + ) + .order_by(self._message_structure.c.sequence_number) + ) + result = await sess.execute(stmt) + rows = result.all() + + turns: dict[int, list[dict[str, str | None]]] = {} + for turn_number, message_type, tool_name in rows: + key = int(turn_number or 0) + turns.setdefault(key, []).append({"type": message_type, "tool_name": tool_name}) + return turns + + async def get_tool_usage( + self, + branch_id: str | None = None, + ) -> list[tuple[str | None, int, int]]: + """Get all tool usage by turn for specified branch. + + Args: + branch_id: Branch to get tool usage from (current branch if None). + + Returns: + List of tuples containing (tool_name, usage_count, turn_number). + """ + branch = branch_id or self._current_branch_id + tool_types = { + "tool_call", + "function_call", + "computer_call", + "file_search_call", + "web_search_call", + "code_interpreter_call", + "custom_tool_call", + "mcp_call", + "mcp_approval_request", + } + + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = ( + select( + self._message_structure.c.tool_name, + func.count(), + self._message_structure.c.user_turn_number, + ) + .where( + and_( + self._message_structure.c.session_id == self.session_id, + self._message_structure.c.branch_id == branch, + self._message_structure.c.message_type.in_(tool_types), + ) + ) + .group_by( + self._message_structure.c.tool_name, + self._message_structure.c.user_turn_number, + ) + .order_by(self._message_structure.c.user_turn_number) + ) + result = await sess.execute(stmt) + rows = result.all() + + return [ + ( + tool_name, + int(count or 0), + int(user_turn_number or 0), + ) + for tool_name, count, user_turn_number in rows + ] + + async def get_session_usage( + self, + branch_id: str | None = None, + ) -> dict[str, int] | None: + """Get cumulative usage for session or specific branch. + + Args: + branch_id: If provided, only get usage for that branch. If None, get all branches. + + Returns: + Dictionary with usage statistics or None if no usage data found. + """ + await self._ensure_tables() + async with self._session_factory() as sess: + if branch_id: + stmt = select( + func.sum(self._turn_usage.c.requests), + func.sum(self._turn_usage.c.input_tokens), + func.sum(self._turn_usage.c.output_tokens), + func.sum(self._turn_usage.c.total_tokens), + func.count(), + ).where( + and_( + self._turn_usage.c.session_id == self.session_id, + self._turn_usage.c.branch_id == branch_id, + ) + ) + else: + stmt = select( + func.sum(self._turn_usage.c.requests), + func.sum(self._turn_usage.c.input_tokens), + func.sum(self._turn_usage.c.output_tokens), + func.sum(self._turn_usage.c.total_tokens), + func.count(), + ).where(self._turn_usage.c.session_id == self.session_id) + + result = await sess.execute(stmt) + row = result.first() + + if not row or row[0] is None: + return None + + requests, input_tokens, output_tokens, total_tokens, turns = row + return { + "requests": int(requests or 0), + "input_tokens": int(input_tokens or 0), + "output_tokens": int(output_tokens or 0), + "total_tokens": int(total_tokens or 0), + "total_turns": int(turns or 0), + } + + async def get_turn_usage( + self, + user_turn_number: int | None = None, + branch_id: str | None = None, + ) -> list[dict[str, Any]] | dict[str, Any]: + """Get usage statistics by turn with full JSON token details. + + Args: + user_turn_number: Specific turn to get usage for. If None, returns all turns. + branch_id: Branch to get usage from (current branch if None). + + Returns: + Dictionary with usage data for specific turn, or list of dictionaries for all turns. + """ + branch = branch_id or self._current_branch_id + await self._ensure_tables() + async with self._session_factory() as sess: + if user_turn_number is not None: + stmt = select( + self._turn_usage.c.requests, + self._turn_usage.c.input_tokens, + self._turn_usage.c.output_tokens, + self._turn_usage.c.total_tokens, + self._turn_usage.c.input_tokens_details, + self._turn_usage.c.output_tokens_details, + ).where( + and_( + self._turn_usage.c.session_id == self.session_id, + self._turn_usage.c.branch_id == branch, + self._turn_usage.c.user_turn_number == user_turn_number, + ) + ) + result = await sess.execute(stmt) + row = result.first() + if not row: + return {} + return { + "requests": int(row[0] or 0), + "input_tokens": int(row[1] or 0), + "output_tokens": int(row[2] or 0), + "total_tokens": int(row[3] or 0), + "input_tokens_details": self._loads_optional(row[4]), + "output_tokens_details": self._loads_optional(row[5]), + } + + stmt = ( + select( + self._turn_usage.c.user_turn_number, + self._turn_usage.c.requests, + self._turn_usage.c.input_tokens, + self._turn_usage.c.output_tokens, + self._turn_usage.c.total_tokens, + self._turn_usage.c.input_tokens_details, + self._turn_usage.c.output_tokens_details, + ) + .where( + and_( + self._turn_usage.c.session_id == self.session_id, + self._turn_usage.c.branch_id == branch, + ) + ) + .order_by(self._turn_usage.c.user_turn_number) + ) + result = await sess.execute(stmt) + rows = result.all() + + usage_rows: list[dict[str, Any]] = [] + for ( + turn_number, + requests, + input_tokens, + output_tokens, + total_tokens, + input_details, + output_details, + ) in rows: + usage_rows.append( + { + "user_turn_number": int(turn_number or 0), + "requests": int(requests or 0), + "input_tokens": int(input_tokens or 0), + "output_tokens": int(output_tokens or 0), + "total_tokens": int(total_tokens or 0), + "input_tokens_details": self._loads_optional(input_details), + "output_tokens_details": self._loads_optional(output_details), + } + ) + return usage_rows + + async def _update_turn_usage_internal( + self, + user_turn_number: int, + usage_data: Usage, + ) -> None: + """Internal method to update usage for a specific turn with full JSON details. + + Args: + user_turn_number: The turn number to update usage for. + usage_data: The usage data to store. + """ + await self._ensure_tables() + input_details = self._dumps_token_details(getattr(usage_data, "input_tokens_details", None)) + output_details = self._dumps_token_details( + getattr(usage_data, "output_tokens_details", None) + ) + + payload = { + "requests": usage_data.requests or 0, + "input_tokens": usage_data.input_tokens or 0, + "output_tokens": usage_data.output_tokens or 0, + "total_tokens": usage_data.total_tokens or 0, + "input_tokens_details": input_details, + "output_tokens_details": output_details, + } + + async with self._session_factory() as sess: + async with sess.begin(): + update_stmt = ( + update(self._turn_usage) + .where( + and_( + self._turn_usage.c.session_id == self.session_id, + self._turn_usage.c.branch_id == self._current_branch_id, + self._turn_usage.c.user_turn_number == user_turn_number, + ) + ) + .values(**payload) + ) + result = await sess.execute(update_stmt) + + if result.rowcount == 0: + insert_stmt = insert(self._turn_usage).values( + session_id=self.session_id, + branch_id=self._current_branch_id, + user_turn_number=user_turn_number, + **payload, + ) + await sess.execute(insert_stmt) + + def _dumps_token_details(self, details: Any) -> str | None: + """Serialize token detail objects to JSON.""" + if not details: + return None + + for attr in ("model_dump", "dict"): + if hasattr(details, attr): + try: + return json.dumps(getattr(details, attr)()) + except (TypeError, ValueError): + continue + + try: + return json.dumps(details.__dict__) + except (TypeError, ValueError) as exc: # pragma: no cover - defensive + self._logger.warning("Failed to serialize token details: %s", exc) + return None + + def _loads_optional(self, payload: str | None) -> Any: + """Deserialize optional JSON payloads.""" + if not payload: + return None + try: + return json.loads(payload) + except json.JSONDecodeError: + return None diff --git a/tests/extensions/memory/test_advanced_sqlalchemy_session.py b/tests/extensions/memory/test_advanced_sqlalchemy_session.py new file mode 100644 index 000000000..d9b3bd612 --- /dev/null +++ b/tests/extensions/memory/test_advanced_sqlalchemy_session.py @@ -0,0 +1,812 @@ +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any, cast + +import pytest +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed +from agents import Agent, Runner, TResponseInputItem, function_tool +from agents.extensions.memory import AdvancedSQLAlchemySession +from agents.result import RunResult +from agents.run_context import RunContextWrapper +from agents.usage import Usage +from tests.fake_model import FakeModel +from tests.test_responses import get_text_message + +pytestmark = pytest.mark.asyncio + +DB_URL = "sqlite+aiosqlite:///:memory:" + + +@asynccontextmanager +async def managed_session(session_id: str) -> AsyncIterator[AdvancedSQLAlchemySession]: + """Create an AdvancedSQLAlchemySession and ensure the engine is disposed afterwards.""" + session = AdvancedSQLAlchemySession.from_url( + session_id, + url=DB_URL, + create_tables=True, + ) + try: + yield session + finally: + await session._engine.dispose() + + +@function_tool +async def test_tool(query: str) -> str: + """A test tool for verifying tool tracking.""" + return f"Tool result for: {query}" + + +@pytest.fixture +def agent() -> Agent: + """Fixture for a basic agent with a fake model and tooling.""" + return Agent(name="advanced-sqlalchemy", model=FakeModel(), tools=[test_tool]) + + +@pytest.fixture +def usage_data() -> Usage: + """Fixture providing sample usage data.""" + return Usage( + requests=1, + input_tokens=50, + output_tokens=30, + total_tokens=80, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + ) + + +def create_mock_run_result( + usage: Usage | None = None, + agent: Agent | None = None, +) -> RunResult: + """Helper function to create a RunResult carrying usage information.""" + if agent is None: + agent = Agent(name="test", model=FakeModel()) + + if usage is None: + usage = Usage( + requests=1, + input_tokens=50, + output_tokens=30, + total_tokens=80, + input_tokens_details=InputTokensDetails(cached_tokens=10), + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + ) + + context_wrapper = RunContextWrapper(context=None, usage=usage) + + return RunResult( + input="test input", + new_items=[], + raw_responses=[], + final_output="test output", + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + context_wrapper=context_wrapper, + _last_agent=agent, + ) + + +async def test_advanced_session_basic_functionality(): + async with managed_session("advanced_test") as session: + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + await session.add_items(items) + + retrieved = await session.get_items() + assert len(retrieved) == 2 + assert retrieved[0].get("content") == "Hello" + assert retrieved[1].get("content") == "Hi there!" + + +async def test_message_structure_tracking(): + async with managed_session("structure_test") as session: + items: list[TResponseInputItem] = [ + {"role": "user", "content": "What's 2+2?"}, + {"type": "function_call", "name": "calculator", "arguments": '{"expression": "2+2"}'}, # type: ignore + {"type": "function_call_output", "output": "4"}, # type: ignore + {"role": "assistant", "content": "The answer is 4"}, + {"type": "reasoning", "summary": [{"text": "Simple math", "type": "summary_text"}]}, # type: ignore + ] + await session.add_items(items) + + conversation_turns = await session.get_conversation_by_turns() + assert len(conversation_turns) == 1 + + turn_1_items = conversation_turns[1] + assert len(turn_1_items) == 5 + + item_types = [item["type"] for item in turn_1_items] + assert "user" in item_types + assert "function_call" in item_types + assert "function_call_output" in item_types + assert "assistant" in item_types + assert "reasoning" in item_types + + +async def test_tool_usage_tracking(): + async with managed_session("tools_test") as session: + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Search for cats"}, + {"type": "function_call", "name": "web_search", "arguments": '{"query": "cats"}'}, # type: ignore + {"type": "function_call_output", "output": "Found cat information"}, # type: ignore + {"type": "function_call", "name": "calculator", "arguments": '{"expression": "1+1"}'}, # type: ignore + {"type": "function_call_output", "output": "2"}, # type: ignore + {"role": "assistant", "content": "I found information and calculated 1+1=2"}, + ] + await session.add_items(items) + + tool_usage = await session.get_tool_usage() + assert len(tool_usage) == 2 + + tool_names = {usage[0] for usage in tool_usage} + assert "web_search" in tool_names + assert "calculator" in tool_names + + +async def test_branch_listing_and_search(): + async with managed_session("advanced_listing") as session: + await session.add_items( + [ + {"role": "user", "content": "Initial turn"}, + {"role": "assistant", "content": "Reply"}, + ] + ) + + await session.add_items( + [ + {"role": "user", "content": "Second turn with tool"}, + { + "type": "mcp_call", + "server_label": "search", + "name": "lookup", + "role": "tool", + "content": [], + }, + ] + ) + + branches = await session.list_branches() + assert len(branches) == 1 + assert branches[0]["branch_id"] == "main" + assert branches[0]["user_turns"] == 2 + + turns = await session.get_conversation_turns() + assert [turn["content"] for turn in turns] == ["Initial turn", "Second turn with tool"] + + matches = await session.find_turns_by_content("Second") + assert len(matches) == 1 + assert matches[0]["turn"] == 2 + + tool_usage = await session.get_tool_usage() + assert tool_usage == [("search.lookup", 1, 2)] + + +async def test_branching_functionality(): + async with managed_session("branching_test") as session: + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "First question"}, + {"role": "assistant", "content": "First answer"}, + ] + await session.add_items(turn_1_items) + + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Second question"}, + {"role": "assistant", "content": "Second answer"}, + ] + await session.add_items(turn_2_items) + + turn_3_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Third question"}, + {"role": "assistant", "content": "Third answer"}, + ] + await session.add_items(turn_3_items) + + all_items = await session.get_items() + assert len(all_items) == 6 + + branch_name = await session.create_branch_from_turn(2, "test_branch") + assert branch_name == "test_branch" + assert session._current_branch_id == "test_branch" + + branch_items = await session.get_items() + assert len(branch_items) == 2 + assert branch_items[0].get("content") == "First question" + assert branch_items[1].get("content") == "First answer" + + await session.switch_to_branch("main") + assert session._current_branch_id == "main" + + main_items = await session.get_items() + assert len(main_items) == 6 + + branches = await session.list_branches() + assert len(branches) == 2 + branch_ids = [branch["branch_id"] for branch in branches] + assert "main" in branch_ids + assert "test_branch" in branch_ids + + await session.delete_branch("test_branch") + branches_after_delete = await session.list_branches() + assert len(branches_after_delete) == 1 + assert branches_after_delete[0]["branch_id"] == "main" + + +async def test_get_conversation_turns(): + async with managed_session("conversation_turns_test") as session: + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello there"}, + {"role": "assistant", "content": "Hi!"}, + ] + await session.add_items(turn_1_items) + + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well, thanks!"}, + ] + await session.add_items(turn_2_items) + + turns = await session.get_conversation_turns() + assert len(turns) == 2 + + assert turns[0]["turn"] == 1 + assert turns[0]["content"] == "Hello there" + assert turns[0]["full_content"] == "Hello there" + assert turns[0]["can_branch"] is True + assert "timestamp" in turns[0] + + assert turns[1]["turn"] == 2 + assert turns[1]["content"] == "How are you doing today?" + assert turns[1]["full_content"] == "How are you doing today?" + assert turns[1]["can_branch"] is True + + +async def test_find_turns_by_content(): + async with managed_session("find_turns_test") as session: + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Tell me about cats"}, + {"role": "assistant", "content": "Cats are great pets"}, + ] + await session.add_items(turn_1_items) + + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "What about dogs?"}, + {"role": "assistant", "content": "Dogs are also great pets"}, + ] + await session.add_items(turn_2_items) + + turn_3_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Tell me about cats again"}, + {"role": "assistant", "content": "Cats are wonderful companions"}, + ] + await session.add_items(turn_3_items) + + cat_turns = await session.find_turns_by_content("cats") + assert len(cat_turns) == 2 + assert cat_turns[0]["turn"] == 1 + assert cat_turns[1]["turn"] == 3 + + dog_turns = await session.find_turns_by_content("dogs") + assert len(dog_turns) == 1 + assert dog_turns[0]["turn"] == 2 + + no_turns = await session.find_turns_by_content("elephants") + assert len(no_turns) == 0 + + +async def test_create_branch_from_content(): + async with managed_session("branch_from_content_test") as session: + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "First question about math"}, + {"role": "assistant", "content": "Math answer"}, + ] + await session.add_items(turn_1_items) + + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Second question about science"}, + {"role": "assistant", "content": "Science answer"}, + ] + await session.add_items(turn_2_items) + + turn_3_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Another math question"}, + {"role": "assistant", "content": "Another math answer"}, + ] + await session.add_items(turn_3_items) + + branch_name = await session.create_branch_from_content("math", "math_branch") + assert branch_name == "math_branch" + assert session._current_branch_id == "math_branch" + + branch_items = await session.get_items() + assert len(branch_items) == 0 + + with pytest.raises(ValueError, match="No user turns found containing 'nonexistent'"): + await session.create_branch_from_content("nonexistent", "error_branch") + + +async def test_branch_specific_operations(): + async with managed_session("branch_specific_test") as session: + turn_1_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Main branch question"}, + {"role": "assistant", "content": "Main branch answer"}, + ] + await session.add_items(turn_1_items) + + usage_main = Usage(requests=1, input_tokens=50, output_tokens=30, total_tokens=80) + run_result_main = create_mock_run_result(usage_main) + await session.store_run_usage(run_result_main) + + await session.create_branch_from_turn(1, "test_branch") + + turn_2_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Branch question"}, + {"role": "assistant", "content": "Branch answer"}, + ] + await session.add_items(turn_2_items) + + usage_branch = Usage(requests=1, input_tokens=40, output_tokens=20, total_tokens=60) + run_result_branch = create_mock_run_result(usage_branch) + await session.store_run_usage(run_result_branch) + + main_items = await session.get_items(branch_id="main") + assert len(main_items) == 2 + assert main_items[0].get("content") == "Main branch question" + + current_items = await session.get_items() + assert len(current_items) == 2 + + main_turns = await session.get_conversation_turns(branch_id="main") + assert len(main_turns) == 1 + assert main_turns[0]["content"] == "Main branch question" + + current_turns = await session.get_conversation_turns() + assert len(current_turns) == 1 + + main_usage = await session.get_session_usage(branch_id="main") + assert main_usage is not None + assert main_usage["total_turns"] == 1 + + all_usage = await session.get_session_usage() + assert all_usage is not None + assert all_usage["total_turns"] == 2 + + +async def test_branch_error_handling(): + async with managed_session("branch_error_test") as session: + with pytest.raises(ValueError, match="Turn 5 does not contain a user message"): + await session.create_branch_from_turn(5, "error_branch") + + with pytest.raises(ValueError, match="Branch 'nonexistent' does not exist"): + await session.switch_to_branch("nonexistent") + + with pytest.raises(ValueError, match="Branch 'nonexistent' does not exist"): + await session.delete_branch("nonexistent") + + with pytest.raises(ValueError, match="Cannot delete the 'main' branch"): + await session.delete_branch("main") + + with pytest.raises(ValueError, match="Branch ID cannot be empty"): + await session.delete_branch("") + + with pytest.raises(ValueError, match="Branch ID cannot be empty"): + await session.delete_branch(" ") + + +async def test_branch_deletion_with_force(): + async with managed_session("force_delete_test") as session: + await session.add_items([{"role": "user", "content": "Main question"}]) + await session.add_items([{"role": "user", "content": "Second question"}]) + + await session.create_branch_from_turn(2, "temp_branch") + assert session._current_branch_id == "temp_branch" + + await session.add_items([{"role": "user", "content": "Branch question"}]) + + branches = await session.list_branches() + branch_ids = [branch["branch_id"] for branch in branches] + assert "temp_branch" in branch_ids + + with pytest.raises(ValueError, match="Cannot delete current branch"): + await session.delete_branch("temp_branch") + + await session.delete_branch("temp_branch", force=True) + assert session._current_branch_id == "main" + + branches_after = await session.list_branches() + assert len(branches_after) == 1 + assert branches_after[0]["branch_id"] == "main" + + +async def test_get_items_with_parameters(): + async with managed_session("get_items_params_test") as session: + items: list[TResponseInputItem] = [ + {"role": "user", "content": "First question"}, + {"role": "assistant", "content": "First answer"}, + {"role": "user", "content": "Second question"}, + {"role": "assistant", "content": "Second answer"}, + ] + await session.add_items(items) + + limited_items = await session.get_items(limit=2) + assert len(limited_items) == 2 + assert limited_items[0].get("content") == "Second question" + assert limited_items[1].get("content") == "Second answer" + + main_items = await session.get_items(branch_id="main") + assert len(main_items) == 4 + + all_items = await session.get_items() + assert len(all_items) == 4 + + await session.create_branch_from_turn(2, "test_branch") + + branch_items: list[TResponseInputItem] = [ + {"role": "user", "content": "Branch question"}, + {"role": "assistant", "content": "Branch answer"}, + ] + await session.add_items(branch_items) + + branch_items_result = await session.get_items(branch_id="test_branch") + assert len(branch_items_result) == 4 + + main_items_from_branch = await session.get_items(branch_id="main") + assert len(main_items_from_branch) == 4 + + +async def test_usage_tracking_storage(agent: Agent, usage_data: Usage): + async with managed_session("usage_test") as session: + await session.add_items([{"role": "user", "content": "First turn"}]) + run_result_1 = create_mock_run_result(usage_data) + await session.store_run_usage(run_result_1) + + usage_data_2 = Usage( + requests=2, + input_tokens=75, + output_tokens=45, + total_tokens=120, + input_tokens_details=InputTokensDetails(cached_tokens=20), + output_tokens_details=OutputTokensDetails(reasoning_tokens=15), + ) + + await session.add_items([{"role": "user", "content": "Second turn"}]) + run_result_2 = create_mock_run_result(usage_data_2) + await session.store_run_usage(run_result_2) + + session_usage = await session.get_session_usage() + assert session_usage is not None + assert session_usage["requests"] == 3 + assert session_usage["total_tokens"] == 200 + assert session_usage["input_tokens"] == 125 + assert session_usage["output_tokens"] == 75 + assert session_usage["total_turns"] == 2 + + turn_1_usage = await session.get_turn_usage(1) + assert isinstance(turn_1_usage, dict) + assert turn_1_usage["requests"] == 1 + assert turn_1_usage["total_tokens"] == 80 + assert turn_1_usage["input_tokens_details"]["cached_tokens"] == 10 + assert turn_1_usage["output_tokens_details"]["reasoning_tokens"] == 5 + + turn_2_usage = await session.get_turn_usage(2) + assert isinstance(turn_2_usage, dict) + assert turn_2_usage["requests"] == 2 + assert turn_2_usage["total_tokens"] == 120 + assert turn_2_usage["input_tokens_details"]["cached_tokens"] == 20 + assert turn_2_usage["output_tokens_details"]["reasoning_tokens"] == 15 + + all_turn_usage = await session.get_turn_usage() + assert isinstance(all_turn_usage, list) + assert len(all_turn_usage) == 2 + assert all_turn_usage[0]["user_turn_number"] == 1 + assert all_turn_usage[1]["user_turn_number"] == 2 + + +async def test_runner_integration_with_usage_tracking(agent: Agent): + async with managed_session("integration_test") as session: + + async def store_session_usage(result: Any, session: AdvancedSQLAlchemySession): + try: + await session.store_run_usage(result) + except Exception: + pass + + assert isinstance(agent.model, FakeModel) + fake_model = agent.model + fake_model.set_next_output([get_text_message("San Francisco")]) + + result1 = await Runner.run( + agent, + "What city is the Golden Gate Bridge in?", + session=session, + ) + assert result1.final_output == "San Francisco" + await store_session_usage(result1, session) + + fake_model.set_next_output([get_text_message("California")]) + result2 = await Runner.run( + agent, + "What state is it in?", + session=session, + ) + assert result2.final_output == "California" + await store_session_usage(result2, session) + + conversation_turns = await session.get_conversation_by_turns() + assert len(conversation_turns) == 2 + + session_usage = await session.get_session_usage() + assert session_usage is not None + assert session_usage["total_turns"] == 2 + assert "requests" in session_usage + assert "total_tokens" in session_usage + + +async def test_sequence_ordering(): + async with managed_session("sequence_test") as session: + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Message 1"}, + {"role": "assistant", "content": "Response 1"}, + {"role": "user", "content": "Message 2"}, + {"role": "assistant", "content": "Response 2"}, + ] + await session.add_items(items) + + retrieved = await session.get_items() + assert len(retrieved) == 4 + assert retrieved[0].get("content") == "Message 1" + assert retrieved[1].get("content") == "Response 1" + assert retrieved[2].get("content") == "Message 2" + assert retrieved[3].get("content") == "Response 2" + + +async def test_conversation_structure_with_multiple_turns(): + async with managed_session("multi_turn_test") as session: + turn_1: list[TResponseInputItem] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + await session.add_items(turn_1) + + turn_2: list[TResponseInputItem] = [ + {"role": "user", "content": "How are you?"}, + {"type": "function_call", "name": "mood_check", "arguments": "{}"}, # type: ignore + {"type": "function_call_output", "output": "I'm good"}, # type: ignore + {"role": "assistant", "content": "I'm doing well!"}, + ] + await session.add_items(turn_2) + + turn_3: list[TResponseInputItem] = [ + {"role": "user", "content": "Goodbye"}, + {"role": "assistant", "content": "See you later!"}, + ] + await session.add_items(turn_3) + + conversation_turns = await session.get_conversation_by_turns() + assert len(conversation_turns) == 3 + + assert len(conversation_turns[1]) == 2 + assert conversation_turns[1][0]["type"] == "user" + assert conversation_turns[1][1]["type"] == "assistant" + + assert len(conversation_turns[2]) == 4 + turn_2_types = [item["type"] for item in conversation_turns[2]] + assert "user" in turn_2_types + assert "function_call" in turn_2_types + assert "function_call_output" in turn_2_types + assert "assistant" in turn_2_types + + assert len(conversation_turns[3]) == 2 + + +async def test_empty_session_operations(): + async with managed_session("empty_test") as session: + items = await session.get_items() + assert len(items) == 0 + + conversation = await session.get_conversation_by_turns() + assert len(conversation) == 0 + + tool_usage = await session.get_tool_usage() + assert len(tool_usage) == 0 + + session_usage = await session.get_session_usage() + assert session_usage is None + + turns = await session.get_conversation_turns() + assert len(turns) == 0 + + +async def test_json_serialization_edge_cases(usage_data: Usage): + async with managed_session("json_test") as session: + await session.add_items([{"role": "user", "content": "First test"}]) + run_result_1 = create_mock_run_result(usage_data) + await session.store_run_usage(run_result_1) + + run_result_none = create_mock_run_result(None) + await session.store_run_usage(run_result_none) + + minimal_usage = Usage( + requests=1, + input_tokens=10, + output_tokens=5, + total_tokens=15, + ) + await session.add_items([{"role": "user", "content": "Second test"}]) + run_result_2 = create_mock_run_result(minimal_usage) + await session.store_run_usage(run_result_2) + + turn_1_usage = await session.get_turn_usage(1) + assert isinstance(turn_1_usage, dict) + assert turn_1_usage["requests"] == 1 + assert turn_1_usage["input_tokens_details"]["cached_tokens"] == 10 + + turn_2_usage = await session.get_turn_usage(2) + assert isinstance(turn_2_usage, dict) + assert turn_2_usage["requests"] == 1 + assert turn_2_usage["input_tokens_details"]["cached_tokens"] == 0 + assert turn_2_usage["output_tokens_details"]["reasoning_tokens"] == 0 + + +async def test_session_isolation(): + async with managed_session("session_1") as session1, managed_session("session_2") as session2: + await session1.add_items([{"role": "user", "content": "Session 1 message"}]) + await session2.add_items([{"role": "user", "content": "Session 2 message"}]) + + session1_items = await session1.get_items() + session2_items = await session2.get_items() + + assert len(session1_items) == 1 + assert len(session2_items) == 1 + assert session1_items[0].get("content") == "Session 1 message" + assert session2_items[0].get("content") == "Session 2 message" + + session1_turns = await session1.get_conversation_by_turns() + session2_turns = await session2.get_conversation_by_turns() + + assert len(session1_turns) == 1 + assert len(session2_turns) == 1 + + +async def test_error_handling_in_usage_tracking(usage_data: Usage): + async with managed_session("error_test") as session: + run_result = create_mock_run_result(usage_data) + await session.store_run_usage(run_result) + + await session._engine.dispose() + await session.store_run_usage(run_result) + + +async def test_advanced_tool_name_extraction(): + async with managed_session("advanced_tool_names_test") as session: + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Use various tools"}, + { + "type": "mcp_call", + "server_label": "filesystem", + "name": "read_file", + "arguments": "{}", + }, # type: ignore + { + "type": "mcp_approval_request", + "server_label": "database", + "name": "execute_query", + "arguments": "{}", + }, # type: ignore + {"type": "computer_call", "arguments": "{}"}, # type: ignore + {"type": "file_search_call", "arguments": "{}"}, # type: ignore + {"type": "web_search_call", "arguments": "{}"}, # type: ignore + {"type": "code_interpreter_call", "arguments": "{}"}, # type: ignore + {"type": "function_call", "name": "calculator", "arguments": "{}"}, # type: ignore + {"type": "custom_tool_call", "name": "custom_tool", "arguments": "{}"}, # type: ignore + ] + await session.add_items(items) + + conversation_turns = await session.get_conversation_by_turns() + turn_items = conversation_turns[1] + + tool_items = [item for item in turn_items if item["tool_name"]] + tool_names = [item["tool_name"] for item in tool_items] + + assert "filesystem.read_file" in tool_names + assert "database.execute_query" in tool_names + assert "computer_call" in tool_names + assert "file_search_call" in tool_names + assert "web_search_call" in tool_names + assert "code_interpreter_call" in tool_names + assert "calculator" in tool_names + assert "custom_tool" in tool_names + + +async def test_branch_usage_tracking(): + async with managed_session("branch_usage_test") as session: + await session.add_items([{"role": "user", "content": "Main question"}]) + usage_main = Usage(requests=1, input_tokens=50, output_tokens=30, total_tokens=80) + run_result_main = create_mock_run_result(usage_main) + await session.store_run_usage(run_result_main) + + await session.create_branch_from_turn(1, "usage_branch") + await session.add_items([{"role": "user", "content": "Branch question"}]) + usage_branch = Usage(requests=2, input_tokens=100, output_tokens=60, total_tokens=160) + run_result_branch = create_mock_run_result(usage_branch) + await session.store_run_usage(run_result_branch) + + main_usage = await session.get_session_usage(branch_id="main") + assert main_usage is not None + assert main_usage["requests"] == 1 + assert main_usage["total_tokens"] == 80 + assert main_usage["total_turns"] == 1 + + branch_usage = await session.get_session_usage(branch_id="usage_branch") + assert branch_usage is not None + assert branch_usage["requests"] == 2 + assert branch_usage["total_tokens"] == 160 + assert branch_usage["total_turns"] == 1 + + total_usage = await session.get_session_usage() + assert total_usage is not None + assert total_usage["requests"] == 3 + assert total_usage["total_tokens"] == 240 + assert total_usage["total_turns"] == 2 + + branch_turn_usage = await session.get_turn_usage(branch_id="usage_branch") + assert isinstance(branch_turn_usage, list) + assert len(branch_turn_usage) == 1 + assert branch_turn_usage[0]["requests"] == 2 + + +async def test_tool_name_extraction(): + async with managed_session("tool_names_test") as session: + items: list[TResponseInputItem] = [ + {"role": "user", "content": "Use tools please"}, + {"type": "function_call", "name": "search_web", "arguments": "{}"}, # type: ignore + {"type": "function_call_output", "tool_name": "search_web", "output": "result"}, # type: ignore + {"type": "function_call", "name": "calculator", "arguments": "{}"}, # type: ignore + ] + await session.add_items(items) + + conversation_turns = await session.get_conversation_by_turns() + turn_items = conversation_turns[1] + + tool_items = [item for item in turn_items if item["tool_name"]] + tool_names = [item["tool_name"] for item in tool_items] + + assert "search_web" in tool_names + assert "calculator" in tool_names + + +async def test_tool_execution_integration(agent: Agent): + async with managed_session("tool_integration_test") as session: + fake_model = cast(FakeModel, agent.model) + fake_model.set_next_output( + [ + { # type: ignore + "type": "function_call", + "name": "test_tool", + "arguments": '{"query": "test query"}', + "call_id": "call_123", + } + ] + ) + + fake_model.set_next_output([get_text_message("Tool executed successfully")]) + + result = await Runner.run( + agent, + "Please use the test tool", + session=session, + ) + + assert "Tool result for: test query" in str(result.new_items) + + tool_usage = await session.get_tool_usage() + assert len(tool_usage) > 0