diff --git a/README.md b/README.md index f8d1fa5..351da16 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ Released under the Universal Permissive License v1.0 as shown at . [contributing]: https://github.com/oracle/python-select-ai/blob/main/CONTRIBUTING.md -[documentation]: https://docs.oracle.com/en/cloud/paas/autonomous-database/serverless/pysai/ +[documentation]: https://oracle.github.io/python-select-ai/ [ghdiscussions]: https://github.com/oracle/python-select-ai/discussions [ghissues]: https://github.com/oracle/python-select-ai/issues [samples]: https://github.com/oracle/python-select-ai/tree/main/samples diff --git a/pyproject.toml b/pyproject.toml index 637806b..7758495 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ test = [ Homepage = "https://github.com/oracle/python-select-ai" Repository = "https://github.com/oracle/python-select-ai" Issues = "https://github.com/oracle/python-select-ai/issues" -Documentation = "https://docs.oracle.com/en/cloud/paas/autonomous-database/serverless/pysai/" +Documentation = "https://oracle.github.io/python-select-ai/" [tool.setuptools.packages.find] where = ["src"] diff --git a/samples/agent/task_create.py b/samples/agent/task_create.py index b71e37b..b0bd1de 100644 --- a/samples/agent/task_create.py +++ b/samples/agent/task_create.py @@ -25,7 +25,7 @@ task = Task( task_name="ANALYZE_MOVIE_TASK", - description="Movie task involving a human", + description="Search for movies in the database", attributes=TaskAttributes( instruction="Help the user with their request about movies. " "User question: {query}. " diff --git a/src/select_ai/_abc.py b/src/select_ai/_abc.py index b3a875b..08c9957 100644 --- a/src/select_ai/_abc.py +++ b/src/select_ai/_abc.py @@ -27,6 +27,17 @@ def _bool(value: Any) -> bool: raise ValueError(f"Invalid boolean value: {value}") +def _is_json(field, value) -> bool: + if field.type in ( + typing.List[Mapping], + typing.Optional[Mapping], + typing.Optional[List[str]], + typing.Optional[List[typing.Mapping]], + ) and isinstance(value, (str, bytes, bytearray)): + return True + return False + + @dataclass class SelectAIDataClass(ABC): """SelectAIDataClass is an abstract container for all data @@ -65,13 +76,7 @@ def __post_init__(self): setattr(self, field.name, _bool(value)) elif field.type is typing.Optional[float]: setattr(self, field.name, float(value)) - elif field.type is typing.Optional[Mapping] and isinstance( - value, (str, bytes, bytearray) - ): - setattr(self, field.name, json.loads(value)) - elif field.type is typing.Optional[ - List[typing.Mapping] - ] and isinstance(value, (str, bytes, bytearray)): + elif _is_json(field, value): setattr(self, field.name, json.loads(value)) else: setattr(self, field.name, value) diff --git a/src/select_ai/agent/__init__.py b/src/select_ai/agent/__init__.py index 343d096..e6f4bd4 100644 --- a/src/select_ai/agent/__init__.py +++ b/src/select_ai/agent/__init__.py @@ -6,10 +6,11 @@ # ----------------------------------------------------------------------------- -from .core import Agent, AgentAttributes -from .task import Task, TaskAttributes -from .team import Team, TeamAttributes +from .core import Agent, AgentAttributes, AsyncAgent +from .task import AsyncTask, Task, TaskAttributes +from .team import AsyncTeam, Team, TeamAttributes from .tool import ( + AsyncTool, EmailNotificationToolParams, HTTPToolParams, HumanToolParams, diff --git a/src/select_ai/agent/core.py b/src/select_ai/agent/core.py index b3171ac..7d5b216 100644 --- a/src/select_ai/agent/core.py +++ b/src/select_ai/agent/core.py @@ -5,33 +5,26 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -import json from abc import ABC from dataclasses import dataclass from typing import ( Any, AsyncGenerator, Iterator, - List, - Mapping, Optional, Union, ) import oracledb -from select_ai import BaseProfile from select_ai._abc import SelectAIDataClass -from select_ai._enums import StrEnum from select_ai.agent.sql import ( GET_USER_AI_AGENT, GET_USER_AI_AGENT_ATTRIBUTES, LIST_USER_AI_AGENTS, ) -from select_ai.async_profile import AsyncProfile from select_ai.db import async_cursor, cursor from select_ai.errors import AgentNotFoundError -from select_ai.profile import Profile @dataclass @@ -292,3 +285,227 @@ def set_attribute(self, attribute_name: str, attribute_value: Any) -> None: keyword_parameters=parameters, ) self.attributes = self._get_attributes(agent_name=self.agent_name) + + +class AsyncAgent(BaseAgent): + """ + select_ai.agent.AsyncAgent class lets you create, delete, enable, disable + and list AI agents asynchronously + + :param str agent_name: The name of the AI Agent + :param str description: Optional description of the AI agent + :param select_ai.agent.AgentAttributes attributes: AI agent attributes + + """ + + @staticmethod + async def _get_attributes(agent_name: str) -> AgentAttributes: + async with async_cursor() as cr: + await cr.execute( + GET_USER_AI_AGENT_ATTRIBUTES, agent_name=agent_name.upper() + ) + attributes = await cr.fetchall() + if attributes: + post_processed_attributes = {} + for k, v in attributes: + if isinstance(v, oracledb.AsyncLOB): + post_processed_attributes[k] = await v.read() + else: + post_processed_attributes[k] = v + return AgentAttributes(**post_processed_attributes) + else: + raise AgentNotFoundError(agent_name=agent_name) + + @staticmethod + async def _get_description(agent_name: str) -> Union[str, None]: + async with async_cursor() as cr: + await cr.execute(GET_USER_AI_AGENT, agent_name=agent_name.upper()) + agent = await cr.fetchone() + if agent: + if agent[1] is not None: + return await agent[1].read() + else: + return None + else: + raise AgentNotFoundError(agent_name) + + async def create( + self, enabled: Optional[bool] = True, replace: Optional[bool] = False + ): + """ + Register a new AI Agent within the Select AI framework + + :param bool enabled: Whether the AI Agent should be enabled. + Default value is True. + + :param bool replace: Whether the AI Agent should be replaced. + Default value is False. + + """ + if self.agent_name is None: + raise AttributeError("Agent must have a name") + if self.attributes is None: + raise AttributeError("Agent must have attributes") + + parameters = { + "agent_name": self.agent_name, + "attributes": self.attributes.json(), + } + if self.description: + parameters["description"] = self.description + + if not enabled: + parameters["status"] = "disabled" + + async with async_cursor() as cr: + try: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.CREATE_AGENT", + keyword_parameters=parameters, + ) + except oracledb.Error as err: + (err_obj,) = err.args + if err_obj.code in (20050, 20052) and replace: + await self.delete(force=True) + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.CREATE_AGENT", + keyword_parameters=parameters, + ) + else: + raise + + async def delete(self, force: Optional[bool] = False): + """ + Delete AI Agent from the database + + :param bool force: Force the deletion. Default value is False. + + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.DROP_AGENT", + keyword_parameters={ + "agent_name": self.agent_name, + "force": force, + }, + ) + + async def disable(self): + """ + Disable AI Agent + """ + async with async_cursor as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.DISABLE_AGENT", + keyword_parameters={ + "agent_name": self.agent_name, + }, + ) + + async def enable(self): + """ + Enable AI Agent + """ + async with async_cursor as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.ENABLE_AGENT", + keyword_parameters={ + "agent_name": self.agent_name, + }, + ) + + @classmethod + async def fetch(cls, agent_name: str) -> "AsyncAgent": + """ + Fetch AI Agent attributes from the Database and build a proxy object in + the Python layer + + :param str agent_name: The name of the AI Agent + + :return: select_ai.agent.Agent + + :raises select_ai.errors.AgentNotFoundError: + If the AI Agent is not found + + """ + attributes = await cls._get_attributes(agent_name=agent_name) + description = await cls._get_description(agent_name=agent_name) + return cls( + agent_name=agent_name, + attributes=attributes, + description=description, + ) + + @classmethod + async def list( + cls, agent_name_pattern: Optional[str] = ".*" + ) -> AsyncGenerator["AsyncAgent", None]: + """ + List AI agents matching a pattern + + :param str agent_name_pattern: Regular expressions can be used + to specify a pattern. Function REGEXP_LIKE is used to perform the + match. Default value is ".*" i.e. match all agent names. + + :return: AsyncGenerator[AsyncAgent] + """ + async with async_cursor() as cr: + await cr.execute( + LIST_USER_AI_AGENTS, + agent_name_pattern=agent_name_pattern, + ) + rows = await cr.fetchall() + for row in rows: + agent_name = row[0] + if row[1]: + description = await row[1].read() # Oracle.AsyncLOB + else: + description = None + attributes = await cls._get_attributes(agent_name=agent_name) + yield cls( + agent_name=agent_name, + description=description, + attributes=attributes, + ) + + async def set_attributes(self, attributes: AgentAttributes) -> None: + """ + Set AI Agent attributes + + :param select_ai.agent.AgentAttributes attributes: Multiple attributes + can be specified by passing an AgentAttributes object + """ + parameters = { + "object_name": self.agent_name, + "object_type": "agent", + "attributes": attributes.json(), + } + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTES", + keyword_parameters=parameters, + ) + self.attributes = await self._get_attributes( + agent_name=self.agent_name + ) + + async def set_attribute( + self, attribute_name: str, attribute_value: Any + ) -> None: + """ + Set a single AI Agent attribute specified using name and value + """ + parameters = { + "object_name": self.agent_name, + "object_type": "agent", + "attribute_name": attribute_name, + "attribute_value": attribute_value, + } + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTE", + keyword_parameters=parameters, + ) + self.attributes = await self._get_attributes( + agent_name=self.agent_name + ) diff --git a/src/select_ai/agent/sql.py b/src/select_ai/agent/sql.py index a46a795..b56cf8c 100644 --- a/src/select_ai/agent/sql.py +++ b/src/select_ai/agent/sql.py @@ -62,19 +62,21 @@ GET_USER_AI_AGENT_TEAM = """ -SELECT t.tool_name, t.description +SELECT t.agent_team_name as team_name, t.description FROM USER_AI_AGENT_TEAMS t -WHERE t.team_name = :team_name +WHERE t.agent_team_name = :team_name """ + GET_USER_AI_AGENT_TEAM_ATTRIBUTES = """ SELECT attribute_name, attribute_value FROM USER_AI_AGENT_TEAM_ATTRIBUTES -WHERE team_name = :team_name +WHERE agent_team_name = :team_name """ + LIST_USER_AI_AGENT_TEAMS = """ -SELECT t.tool_name, t.description +SELECT t.AGENT_TEAM_NAME as team_name, description FROM USER_AI_AGENT_TEAMS t -WHERE REGEXP_LIKE(t.team_name, :team_name_pattern, 'i') +WHERE REGEXP_LIKE(t.AGENT_TEAM_NAME, :team_name_pattern, 'i') """ diff --git a/src/select_ai/agent/task.py b/src/select_ai/agent/task.py index 534dd35..2fe8262 100644 --- a/src/select_ai/agent/task.py +++ b/src/select_ai/agent/task.py @@ -298,3 +298,224 @@ def set_attribute(self, attribute_name: str, attribute_value: Any): "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTE", keyword_parameters=parameters, ) + + +class AsyncTask(BaseTask): + """ + select_ai.agent.AsyncTask class lets you create, delete, enable, disable and + list AI Tasks asynchronously + + :param str task_name: The name of the AI task + :param str description: Optional description of the AI task + :param select_ai.agent.TaskAttributes attributes: AI task attributes + + """ + + @staticmethod + async def _get_attributes(task_name: str) -> TaskAttributes: + async with async_cursor() as cr: + await cr.execute( + GET_USER_AI_AGENT_TASK_ATTRIBUTES, task_name=task_name.upper() + ) + attributes = await cr.fetchall() + if attributes: + post_processed_attributes = {} + for k, v in attributes: + if isinstance(v, oracledb.AsyncLOB): + post_processed_attributes[k] = await v.read() + else: + post_processed_attributes[k] = v + return TaskAttributes(**post_processed_attributes) + else: + raise AgentTaskNotFoundError(task_name=task_name) + + @staticmethod + async def _get_description(task_name: str) -> Union[str, None]: + async with async_cursor() as cr: + await cr.execute( + GET_USER_AI_AGENT_TASK, task_name=task_name.upper() + ) + task = await cr.fetchone() + if task: + if task[1] is not None: + return await task[1].read() + else: + return None + else: + raise AgentTaskNotFoundError(task_name) + + async def create( + self, enabled: Optional[bool] = True, replace: Optional[bool] = False + ): + """ + Create a task that a Select AI agent can include in its + reasoning process + + :param bool enabled: Whether the AI Task should be enabled. + Default value is True. + + :param bool replace: Whether the AI Task should be replaced. + Default value is False. + + """ + if self.task_name is None: + raise AttributeError("Task must have a name") + if self.attributes is None: + raise AttributeError("Task must have attributes") + + parameters = { + "task_name": self.task_name, + "attributes": self.attributes.json(), + } + + if self.description: + parameters["description"] = self.description + + if not enabled: + parameters["status"] = "disabled" + + async with async_cursor() as cr: + try: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.CREATE_TASK", + keyword_parameters=parameters, + ) + except oracledb.Error as err: + (err_obj,) = err.args + if err_obj.code in (20051, 20052) and replace: + await self.delete(force=True) + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.CREATE_TASK", + keyword_parameters=parameters, + ) + else: + raise + + async def delete(self, force: bool = False): + """ + Delete AI Task from the database + + :param bool force: Force the deletion. Default value is False. + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.DROP_TASK", + keyword_parameters={ + "task_name": self.task_name, + "force": force, + }, + ) + + async def disable(self): + """ + Disable AI Task + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.DISABLE_TASK", + keyword_parameters={ + "task_name": self.task_name, + }, + ) + + async def enable(self): + """ + Enable AI Task + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.ENABLE_TASK", + keyword_parameters={ + "task_name": self.task_name, + }, + ) + + @classmethod + async def list( + cls, task_name_pattern: Optional[str] = ".*" + ) -> AsyncGenerator["AsyncTask", None]: + """List AI Tasks + + :param str task_name_pattern: Regular expressions can be used + to specify a pattern. Function REGEXP_LIKE is used to perform the + match. Default value is ".*" i.e. match all tasks. + + :return: AsyncGenerator[Task] + """ + async with async_cursor() as cr: + await cr.execute( + LIST_USER_AI_AGENT_TASKS, + task_name_pattern=task_name_pattern, + ) + rows = await cr.fetchall() + for row in rows: + task_name = row[0] + if row[1]: + description = await row[1].read() # Oracle.AsyncLOB + else: + description = None + attributes = await cls._get_attributes(task_name=task_name) + yield cls( + task_name=task_name, + description=description, + attributes=attributes, + ) + + @classmethod + async def fetch(cls, task_name: str) -> "AsyncTask": + """ + Fetch AI Task attributes from the Database and build a proxy object in + the Python layer + + :param str task_name: The name of the AI Task + + :return: select_ai.agent.Task + + :raises select_ai.errors.AgentTaskNotFoundError: + If the AI Task is not found + """ + attributes = await cls._get_attributes(task_name=task_name) + description = await cls._get_description(task_name=task_name) + return cls( + task_name=task_name, + description=description, + attributes=attributes, + ) + + async def set_attributes(self, attributes: TaskAttributes): + """ + Set AI Task attributes + + :param select_ai.agent.TaskAttributes attributes: Multiple attributes + can be specified by passing a TaskAttributes object + """ + parameters = { + "object_name": self.task_name, + "object_type": "task", + "attributes": attributes.json(), + } + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTES", + keyword_parameters=parameters, + ) + + async def set_attribute(self, attribute_name: str, attribute_value: Any): + """ + Set a single AI Task attribute specified using name and value + + :param str attribute_name: The name of the AI Task attribute + :param str attribute_value: The value of the AI Task attribute + + """ + parameters = { + "object_name": self.task_name, + "object_type": "task", + "attribute_name": attribute_name, + "attribute_value": attribute_value, + } + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTE", + keyword_parameters=parameters, + ) diff --git a/src/select_ai/agent/team.py b/src/select_ai/agent/team.py index c5ae938..a128a35 100644 --- a/src/select_ai/agent/team.py +++ b/src/select_ai/agent/team.py @@ -110,7 +110,7 @@ def _get_attributes(team_name: str) -> TeamAttributes: @staticmethod def _get_description(team_name: str) -> Union[str, None]: with cursor() as cr: - cr.execute(GET_USER_AI_AGENT_TEAM, task_name=team_name.upper()) + cr.execute(GET_USER_AI_AGENT_TEAM, team_name=team_name.upper()) team = cr.fetchone() if team: if team[1] is not None: @@ -328,3 +328,263 @@ def set_attribute(self, attribute_name: str, attribute_value: Any) -> None: "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTE", keyword_parameters=parameters, ) + + +class AsyncTeam(BaseTeam): + """ + A Team of AI agents work together to accomplish tasks + select_ai.agent.Team class lets you create, delete, enable, disable and + list AI Tasks. + + :param str team_name: The name of the AI team + :param str description: Optional description of the AI team + :param select_ai.agent.TeamAttributes attributes: AI team attributes + + """ + + @staticmethod + async def _get_attributes(team_name: str) -> TeamAttributes: + async with async_cursor() as cr: + await cr.execute( + GET_USER_AI_AGENT_TEAM_ATTRIBUTES, team_name=team_name.upper() + ) + attributes = await cr.fetchall() + if attributes: + post_processed_attributes = {} + for k, v in attributes: + if isinstance(v, oracledb.AsyncLOB): + post_processed_attributes[k] = await v.read() + else: + post_processed_attributes[k] = v + return TeamAttributes(**post_processed_attributes) + else: + raise AgentTeamNotFoundError(team_name=team_name) + + @staticmethod + async def _get_description(team_name: str) -> Union[str, None]: + async with async_cursor() as cr: + await cr.execute( + GET_USER_AI_AGENT_TEAM, team_name=team_name.upper() + ) + team = await cr.fetchone() + if team: + if team[1] is not None: + return await team[1].read() + else: + return None + else: + raise AgentTeamNotFoundError(team_name=team_name) + + async def create( + self, enabled: Optional[bool] = True, replace: Optional[bool] = False + ): + """ + Create a team of AI agents that work together to accomplish tasks. + + :param bool enabled: Whether the AI agent team should be enabled. + Default value is True. + + :param bool replace: Whether the AI agent team should be replaced. + Default value is False. + + """ + if self.team_name is None: + raise AttributeError("Team must have a name") + if self.attributes is None: + raise AttributeError("Team must have attributes") + + parameters = { + "team_name": self.team_name, + "attributes": self.attributes.json(), + } + if self.description: + parameters["description"] = self.description + + if not enabled: + parameters["status"] = "disabled" + + async with async_cursor() as cr: + try: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.CREATE_TEAM", + keyword_parameters=parameters, + ) + except oracledb.Error as err: + (err_obj,) = err.args + if err_obj.code in (20053, 20052) and replace: + await self.delete(force=True) + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.CREATE_TEAM", + keyword_parameters=parameters, + ) + else: + raise + + async def delete(self, force: Optional[bool] = False): + """ + Delete an AI agent team from the database + + :param bool force: Force the deletion. Default value is False. + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.DROP_TEAM", + keyword_parameters={ + "team_name": self.team_name, + "force": force, + }, + ) + + async def disable(self): + """ + Disable the AI agent team + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.DISABLE_TEAM", + keyword_parameters={ + "team_name": self.team_name, + }, + ) + + async def enable(self): + """ + Enable the AI agent team + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.ENABLE_TEAM", + keyword_parameters={ + "team_name": self.team_name, + }, + ) + + @classmethod + async def fetch(cls, team_name: str) -> "AsyncTeam": + """ + Fetch AI Team attributes from the Database and build a proxy object in + the Python layer + + :param str team_name: The name of the AI Team + + :return: select_ai.agent.Team + + :raises select_ai.errors.AgentTeamNotFoundError: + If the AI Team is not found + """ + attributes = await cls._get_attributes(team_name) + description = await cls._get_description(team_name) + return cls( + team_name=team_name, + attributes=attributes, + description=description, + ) + + @classmethod + async def list( + cls, team_name_pattern: Optional[str] = ".*" + ) -> AsyncGenerator["AsyncTeam", None]: + """ + List AI Agent Teams + + :param str team_name_pattern: Regular expressions can be used + to specify a pattern. Function REGEXP_LIKE is used to perform the + match. Default value is ".*" i.e. match all teams. + + :return: Iterator[Team] + + """ + async with async_cursor() as cr: + await cr.execute( + LIST_USER_AI_AGENT_TEAMS, + team_name_pattern=team_name_pattern, + ) + rows = await cr.fetchall() + for row in rows: + team_name = row[0] + if row[1]: + description = await row[1].read() # Oracle.AsyncLOB + else: + description = None + attributes = await cls._get_attributes(team_name=team_name) + yield cls( + team_name=team_name, + description=description, + attributes=attributes, + ) + + async def run(self, prompt: str = None, params: Mapping = None): + """ + Start a new AI agent team or resume a paused one that is waiting + for human input. If you provide an existing process ID and the + associated team process is in the WAITING_FOR_HUMAN state, the + function resumes the workflow using the input you provide as + the human response + + :param str prompt: Optional prompt for the user. If the task is + in the RUNNING state, the input acts as a placeholder for the + {query} in the task instruction. If the task is in the + WAITING_FOR_HUMAN state, the input serves as the human response. + + :param Mapping[str, str] params: Optional parameters for the task. + Currently, the following parameters are supported: + + - conversation_id: Identifies the conversation session associated + with the agent team + + - variables: key-value pairs that provide additional input to the agent team. + + """ + parameters = { + "team_name": self.team_name, + } + if prompt: + parameters["user_prompt"] = prompt + if params: + parameters["params"] = json.dumps(params) + + async with async_cursor() as cr: + data = await cr.callfunc( + "DBMS_CLOUD_AI_AGENT.RUN_TEAM", + oracledb.DB_TYPE_CLOB, + keyword_parameters=parameters, + ) + if data is not None: + result = await data.read() + else: + result = None + return result + + async def set_attributes(self, attributes: TeamAttributes) -> None: + """ + Set the attributes of the AI Agent team + """ + parameters = { + "object_name": self.team_name, + "object_type": "team", + "attributes": attributes.json(), + } + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTES", + keyword_parameters=parameters, + ) + + async def set_attribute( + self, attribute_name: str, attribute_value: Any + ) -> None: + """ + Set the attribute of the AI Agent team specified by + `attribute_name` and `attribute_value`. + """ + parameters = { + "object_name": self.team_name, + "object_type": "team", + "attribute_name": attribute_name, + "attribute_value": attribute_value, + } + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTE", + keyword_parameters=parameters, + ) diff --git a/src/select_ai/agent/tool.py b/src/select_ai/agent/tool.py index 0992bd2..cd85071 100644 --- a/src/select_ai/agent/tool.py +++ b/src/select_ai/agent/tool.py @@ -688,3 +688,442 @@ def set_attribute(self, attribute_name: str, attribute_value: Any) -> None: "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTE", keyword_parameters=parameters, ) + + +class AsyncTool(_BaseTool): + + @staticmethod + async def _get_attributes(tool_name: str) -> ToolAttributes: + """Get attributes of an AI tool + + :return: select_ai.agent.ToolAttributes + :raises: AgentToolNotFoundError + """ + async with async_cursor() as cr: + await cr.execute( + GET_USER_AI_AGENT_TOOL_ATTRIBUTES, tool_name=tool_name.upper() + ) + attributes = await cr.fetchall() + if attributes: + post_processed_attributes = {} + for k, v in attributes: + if isinstance(v, oracledb.AsyncLOB): + post_processed_attributes[k] = await v.read() + else: + post_processed_attributes[k] = v + return ToolAttributes.create(**post_processed_attributes) + else: + raise AgentToolNotFoundError(tool_name=tool_name) + + @staticmethod + async def _get_description(tool_name: str) -> Union[str, None]: + async with async_cursor() as cr: + await cr.execute( + GET_USER_AI_AGENT_TOOL, tool_name=tool_name.upper() + ) + tool = await cr.fetchone() + if tool: + if tool[1] is not None: + return await tool[1].read() + else: + return None + else: + raise AgentToolNotFoundError(tool_name=tool_name) + + async def create( + self, enabled: Optional[bool] = True, replace: Optional[bool] = False + ): + if self.tool_name is None: + raise AttributeError("Tool must have a name") + if self.attributes is None: + raise AttributeError("Tool must have attributes") + + parameters = { + "tool_name": self.tool_name, + "attributes": self.attributes.json(), + } + if self.description: + parameters["description"] = self.description + + if not enabled: + parameters["status"] = "disabled" + + async with async_cursor() as cr: + try: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.CREATE_TOOL", + keyword_parameters=parameters, + ) + except oracledb.Error as err: + (err_obj,) = err.args + if err_obj.code in (20050, 20052) and replace: + await self.delete(force=True) + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.CREATE_TOOL", + keyword_parameters=parameters, + ) + else: + raise + + @classmethod + async def create_built_in_tool( + cls, + tool_name: str, + tool_params: ToolParams, + tool_type: ToolType, + description: Optional[str] = None, + replace: Optional[bool] = False, + ) -> "AsyncTool": + """ + Register a built-in tool + + :param str tool_name: The name of the tool + :param select_ai.agent.ToolParams tool_params: + Parameters required by built-in tool + :param select_ai.agent.ToolType tool_type: The built-in tool type + :param str description: Description of the tool + :param bool replace: Whether to replace the existing tool. + Default value is False + + :return: select_ai.agent.Tool + """ + if not isinstance(tool_params, ToolParams): + raise TypeError( + "'tool_params' must be an object of " + "type select_ai.agent.ToolParams" + ) + attributes = ToolAttributes( + tool_params=tool_params, tool_type=tool_type + ) + tool = cls( + tool_name=tool_name, attributes=attributes, description=description + ) + await tool.create(replace=replace) + return tool + + @classmethod + async def create_email_notification_tool( + cls, + tool_name: str, + credential_name: str, + recipient: str, + sender: str, + smtp_host: str, + description: Optional[str], + replace: bool = False, + ) -> "AsyncTool": + """ + Register an email notification tool + + :param str tool_name: The name of the tool + :param str credential_name: The name of the credential + :param str recipient: The recipient of the email + :param str sender: The sender of the email + :param str smtp_host: The SMTP host of the email server + :param str description: The description of the tool + :param bool replace: Whether to replace the existing tool. + Default value is False + + :return: select_ai.agent.Tool + + """ + email_notification_tool_params = EmailNotificationToolParams( + credential_name=credential_name, + recipient=recipient, + sender=sender, + smtp_host=smtp_host, + ) + return await cls.create_built_in_tool( + tool_name=tool_name, + tool_type=ToolType.EMAIL, + tool_params=email_notification_tool_params, + description=description, + replace=replace, + ) + + @classmethod + async def create_http_tool( + cls, + tool_name: str, + credential_name: str, + endpoint: str, + description: Optional[str] = None, + replace: bool = False, + ) -> "AsyncTool": + http_tool_params = HTTPToolParams( + credential_name=credential_name, endpoint=endpoint + ) + return await cls.create_built_in_tool( + tool_name=tool_name, + tool_type=ToolType.HTTP, + tool_params=http_tool_params, + description=description, + replace=replace, + ) + + @classmethod + async def create_pl_sql_tool( + cls, + tool_name: str, + function: str, + description: Optional[str] = None, + replace: bool = False, + ) -> "AsyncTool": + """ + Create a custom tool to invoke PL/SQL procedure or function + + :param str tool_name: The name of the tool + :param str function: The name of the PL/SQL procedure or function + :param str description: The description of the tool + :param bool replace: Whether to replace existing tool. Default value + is False + + """ + tool_attributes = ToolAttributes(function=function) + tool = cls( + tool_name=tool_name, + attributes=tool_attributes, + description=description, + ) + await tool.create(replace=replace) + return tool + + @classmethod + async def create_rag_tool( + cls, + tool_name: str, + profile_name: str, + description: Optional[str] = None, + replace: bool = False, + ) -> "AsyncTool": + """ + Register a RAG tool, which will use a VectorIndex linked AI Profile + + :param str tool_name: The name of the tool + :param str profile_name: The name of the profile to + use for Vector Index based RAG + :param str description: The description of the tool + :param bool replace: Whether to replace existing tool. Default value + is False + """ + tool_params = RAGToolParams(profile_name=profile_name) + return await cls.create_built_in_tool( + tool_name=tool_name, + tool_type=ToolType.RAG, + tool_params=tool_params, + description=description, + replace=replace, + ) + + @classmethod + async def create_sql_tool( + cls, + tool_name: str, + profile_name: str, + description: Optional[str] = None, + replace: bool = False, + ) -> "AsyncTool": + """ + Register a SQL tool to perform natural language to SQL translation + + :param str tool_name: The name of the tool + :param str profile_name: The name of the profile to use for SQL + translation + :param str description: The description of the tool + :param bool replace: Whether to replace existing tool. Default value + is False + """ + tool_params = SQLToolParams(profile_name=profile_name) + return await cls.create_built_in_tool( + tool_name=tool_name, + tool_type=ToolType.SQL, + tool_params=tool_params, + description=description, + replace=replace, + ) + + @classmethod + async def create_slack_notification_tool( + cls, + tool_name: str, + credential_name: str, + slack_channel: str, + description: Optional[str] = None, + replace: bool = False, + ) -> "AsyncTool": + """ + Register a Slack notification tool + + :param str tool_name: The name of the Slack notification tool + :param str credential_name: The name of the Slack credential + :param str slack_channel: The name of the Slack channel + :param str description: The description of the Slack notification tool + :param bool replace: Whether to replace existing tool. Default value + is False + + """ + slack_notification_tool_params = SlackNotificationToolParams( + credential_name=credential_name, + slack_channel=slack_channel, + ) + return await cls.create_built_in_tool( + tool_name=tool_name, + tool_type=ToolType.SLACK, + tool_params=slack_notification_tool_params, + description=description, + replace=replace, + ) + + @classmethod + async def create_websearch_tool( + cls, + tool_name: str, + credential_name: str, + description: Optional[str], + replace: bool = False, + ) -> "AsyncTool": + """ + Register a built-in websearch tool to search information + on the web + + :param str tool_name: The name of the tool + :param str credential_name: The name of the credential object + storing OpenAI credentials + :param str description: The description of the tool + :param bool replace: Whether to replace the existing tool + + """ + web_search_tool_params = WebSearchToolParams( + credential_name=credential_name, + ) + return await cls.create_built_in_tool( + tool_name=tool_name, + tool_type=ToolType.WEBSEARCH, + tool_params=web_search_tool_params, + description=description, + replace=replace, + ) + + async def delete(self, force: bool = False): + """ + Delete AI Tool from the database + + :param bool force: Force the deletion. Default value is False. + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.DROP_TOOL", + keyword_parameters={ + "tool_name": self.tool_name, + "force": force, + }, + ) + + async def disable(self): + """ + Disable AI Tool + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.DISABLE_TOOL", + keyword_parameters={ + "tool_name": self.tool_name, + }, + ) + + async def enable(self): + """ + Enable AI Tool + """ + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.ENABLE_TOOL", + keyword_parameters={ + "tool_name": self.tool_name, + }, + ) + + @classmethod + async def fetch(cls, tool_name: str) -> "AsyncTool": + """ + Fetch AI Tool attributes from the Database and build a proxy object in + the Python layer + + :param str tool_name: The name of the AI Task + + :return: select_ai.agent.Tool + + :raises select_ai.errors.AgentToolNotFoundError: + If the AI Tool is not found + + """ + attributes = await cls._get_attributes(tool_name) + description = await cls._get_description(tool_name) + return cls( + tool_name=tool_name, attributes=attributes, description=description + ) + + @classmethod + async def list( + cls, tool_name_pattern: str = ".*" + ) -> AsyncGenerator["AsyncTool", None]: + """List AI Tools + + :param str tool_name_pattern: Regular expressions can be used + to specify a pattern. Function REGEXP_LIKE is used to perform the + match. Default value is ".*" i.e. match all tool name. + + :return: Iterator[Tool] + """ + async with async_cursor() as cr: + await cr.execute( + LIST_USER_AI_AGENT_TOOLS, + tool_name_pattern=tool_name_pattern, + ) + rows = await cr.fetchall() + for row in rows: + tool_name = row[0] + if row[1]: + description = await row[1].read() # Oracle.AsyncLOB + else: + description = None + attributes = await cls._get_attributes(tool_name=tool_name) + yield cls( + tool_name=tool_name, + description=description, + attributes=attributes, + ) + + async def set_attributes(self, attributes: ToolAttributes) -> None: + """ + Set the attributes of the AI Agent tool + """ + parameters = { + "object_name": self.tool_name, + "object_type": "tool", + "attributes": attributes.json(), + } + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTES", + keyword_parameters=parameters, + ) + + async def set_attribute( + self, attribute_name: str, attribute_value: Any + ) -> None: + """ + Set the attribute of the AI Agent tool specified by + `attribute_name` and `attribute_value`. + """ + parameters = { + "object_name": self.tool_name, + "object_type": "tool", + "attribute_name": attribute_name, + "attribute_value": attribute_value, + } + async with async_cursor() as cr: + await cr.callproc( + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTE", + keyword_parameters=parameters, + ) diff --git a/src/select_ai/version.py b/src/select_ai/version.py index 3c1bdbd..e773f7a 100644 --- a/src/select_ai/version.py +++ b/src/select_ai/version.py @@ -5,4 +5,4 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -__version__ = "1.2.0rc1" +__version__ = "1.2.0" diff --git a/tests/agents/conftest.py b/tests/agents/conftest.py new file mode 100644 index 0000000..386ab8f --- /dev/null +++ b/tests/agents/conftest.py @@ -0,0 +1,43 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest +import select_ai + + +@pytest.fixture(scope="module") +def provider(): + return select_ai.OCIGenAIProvider( + region="us-chicago-1", + oci_apiformat="GENERIC", + model="meta.llama-4-maverick-17b-128e-instruct-fp8", + ) + + +@pytest.fixture(scope="module") +def profile_attributes(provider, oci_credential): + return select_ai.ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[{"owner": "SH"}], + provider=provider, + ) + + +@pytest.fixture(scope="module") +def rag_profile_attributes(provider, oci_credential): + return select_ai.ProfileAttributes( + credential_name=oci_credential["credential_name"], + provider=provider, + ) + + +@pytest.fixture(scope="module") +def vector_index_attributes(provider, oci_credential): + return select_ai.OracleVectorIndexAttributes( + object_storage_credential_name=oci_credential["credential_name"], + location="https://objectstorage.us-ashburn-1.oraclecloud.com/n/dwcsdev/b/conda-environment/o/tenant1-pdb3/graph", + ) diff --git a/tests/agents/test_3000_tools.py b/tests/agents/test_3000_tools.py new file mode 100644 index 0000000..de60134 --- /dev/null +++ b/tests/agents/test_3000_tools.py @@ -0,0 +1,186 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3000 - Module for testing select_ai agent tools +""" + +import uuid + +import pytest +import select_ai +from select_ai.agent import Tool + +PYSAI_3000_PROFILE_NAME = f"PYSAI_3000_{uuid.uuid4().hex.upper()}" +PYSAI_3000_SQL_TOOL_NAME = f"PYSAI_3000_SQL_TOOL_{uuid.uuid4().hex.upper()}" +PYSAI_3000_SQL_TOOL_DESCRIPTION = f"SQL Tool for Python 3000" + +PYSAI_3000_RAG_PROFILE_NAME = f"PYSAI_3000_RAG_{uuid.uuid4().hex.upper()}" +PYSAI_3000_RAG_VECTOR_INDEX_NAME = ( + f"PYSAI_3000_RAG_VECTOR_{uuid.uuid4().hex.upper()}" +) +PYSAI_3000_RAG_TOOL_NAME = f"PYSAI_3000_RAG_TOOL_{uuid.uuid4().hex.upper()}" +PYSAI_3000_RAG_TOOL_DESCRIPTION = f"RAG Tool for Python 3000" + +PYSAI_3000_PL_SQL_TOOL_NAME = ( + f"PYSAI_3000_PL_SQL_TOOL_{uuid.uuid4().hex.upper()}" +) +PYSAI_3000_PL_SQL_TOOL_DESCRIPTION = f"PL/SQL Tool for Python 3000" +PYSAI_3000_PL_SQL_FUNC_NAME = ( + f"PYSAI_3000_PL_SQL_FUNC_{uuid.uuid4().hex.upper()}" +) + + +@pytest.fixture(scope="module") +def python_gen_ai_profile(profile_attributes): + profile = select_ai.Profile( + profile_name=PYSAI_3000_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, + ) + yield profile + profile.delete(force=True) + + +@pytest.fixture(scope="module") +def python_gen_rag_ai_profile(rag_profile_attributes): + profile = select_ai.Profile( + profile_name=PYSAI_3000_RAG_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=rag_profile_attributes, + ) + yield profile + profile.delete(force=True) + + +@pytest.fixture(scope="module") +def sql_tool(python_gen_ai_profile): + sql_tool = select_ai.agent.Tool.create_sql_tool( + tool_name=PYSAI_3000_SQL_TOOL_NAME, + description=PYSAI_3000_SQL_TOOL_DESCRIPTION, + profile_name=PYSAI_3000_PROFILE_NAME, + replace=True, + ) + yield sql_tool + sql_tool.delete(force=True) + + +@pytest.fixture(scope="module") +def vector_index(vector_index_attributes, python_gen_rag_ai_profile): + vector_index = select_ai.VectorIndex( + index_name=PYSAI_3000_RAG_VECTOR_INDEX_NAME, + attributes=vector_index_attributes, + description="Test vector index", + profile=python_gen_rag_ai_profile, + ) + vector_index.create(replace=True) + yield vector_index + vector_index.delete(force=True) + + +@pytest.fixture(scope="module") +def rag_tool(vector_index): + sql_tool = select_ai.agent.Tool.create_rag_tool( + tool_name=PYSAI_3000_RAG_TOOL_NAME, + description=PYSAI_3000_RAG_TOOL_DESCRIPTION, + profile_name=PYSAI_3000_RAG_PROFILE_NAME, + replace=True, + ) + yield sql_tool + sql_tool.delete(force=True) + + +@pytest.fixture(scope="module") +def pl_sql_function(): + create_function = f""" + CREATE OR REPLACE FUNCTION {PYSAI_3000_PL_SQL_FUNC_NAME} (p_birth_date IN DATE) + RETURN NUMBER + IS + v_age NUMBER; + BEGIN + -- Calculate the difference in years + v_age := TRUNC(MONTHS_BETWEEN(SYSDATE, p_birth_date) / 12); + + RETURN v_age; + END CALCULATE_AGE; + """ + with select_ai.cursor() as cr: + cr.execute(create_function) + yield create_function + with select_ai.cursor() as cr: + cr.execute(f"DROP FUNCTION {PYSAI_3000_PL_SQL_FUNC_NAME}") + + +@pytest.fixture(scope="module") +def pl_sql_tool(pl_sql_function): + pl_sql_tool = select_ai.agent.Tool.create_pl_sql_tool( + tool_name=PYSAI_3000_PL_SQL_TOOL_NAME, + function=PYSAI_3000_PL_SQL_FUNC_NAME, + description=PYSAI_3000_PL_SQL_TOOL_DESCRIPTION, + ) + yield pl_sql_tool + pl_sql_tool.delete(force=True) + + +def test_3000(sql_tool): + """test SQL tool creation and parameter validation""" + assert ( + sql_tool.attributes.tool_params.profile_name == PYSAI_3000_PROFILE_NAME + ) + assert sql_tool.tool_name == PYSAI_3000_SQL_TOOL_NAME + assert sql_tool.description == PYSAI_3000_SQL_TOOL_DESCRIPTION + assert isinstance( + sql_tool.attributes.tool_params, select_ai.agent.SQLToolParams + ) + + +def test_3001(rag_tool): + """test RAG tool creation and parameter validation""" + assert ( + rag_tool.attributes.tool_params.profile_name + == PYSAI_3000_RAG_PROFILE_NAME + ) + assert rag_tool.tool_name == PYSAI_3000_RAG_TOOL_NAME + assert rag_tool.description == PYSAI_3000_RAG_TOOL_DESCRIPTION + assert isinstance( + rag_tool.attributes.tool_params, select_ai.agent.RAGToolParams + ) + + +def test_3002(pl_sql_tool): + """test PL SQL tool creation and parameter validation""" + assert pl_sql_tool.tool_name == PYSAI_3000_PL_SQL_TOOL_NAME + assert pl_sql_tool.description == PYSAI_3000_PL_SQL_TOOL_DESCRIPTION + assert pl_sql_tool.attributes.function == PYSAI_3000_PL_SQL_FUNC_NAME + + +def test_3003(): + """list tools""" + tools = list(select_ai.agent.Tool.list()) + tool_names = set(tool.tool_name for tool in tools) + assert PYSAI_3000_RAG_TOOL_NAME in tool_names + assert PYSAI_3000_SQL_TOOL_NAME in tool_names + assert PYSAI_3000_PL_SQL_TOOL_NAME in tool_names + + +def test_3004(): + """list tools matching a REGEX pattern""" + tools = list(select_ai.agent.Tool.list(tool_name_pattern="^PYSAI_3000")) + tool_names = set(tool.tool_name for tool in tools) + assert PYSAI_3000_RAG_TOOL_NAME in tool_names + assert PYSAI_3000_SQL_TOOL_NAME in tool_names + assert PYSAI_3000_PL_SQL_TOOL_NAME in tool_names + + +def test_3005(): + """fetch tool""" + sql_tool = select_ai.agent.Tool.fetch(tool_name=PYSAI_3000_SQL_TOOL_NAME) + assert sql_tool.tool_name == PYSAI_3000_SQL_TOOL_NAME + assert sql_tool.description == PYSAI_3000_SQL_TOOL_DESCRIPTION + assert isinstance( + sql_tool.attributes.tool_params, select_ai.agent.SQLToolParams + ) diff --git a/tests/agents/test_3100_tasks.py b/tests/agents/test_3100_tasks.py new file mode 100644 index 0000000..ecaa174 --- /dev/null +++ b/tests/agents/test_3100_tasks.py @@ -0,0 +1,70 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3100 - Module for testing select_ai agent tasks +""" + +import uuid + +import pytest +import select_ai +from select_ai.agent import Task, TaskAttributes + +PYSAI_3100_TASK_NAME = f"PYSAI_3100_{uuid.uuid4().hex.upper()}" +PYSAI_3100_SQL_TASK_DESCRIPTION = "PYSAI_3100_SQL_TASK_DESCRIPTION" + + +@pytest.fixture(scope="module") +def task_attributes(): + return TaskAttributes( + instruction="Help the user with their request about movies. " + "User question: {query}. " + "You can use SQL tool to search the data from database", + tools=["MOVIE_SQL_TOOL"], + enable_human_tool=False, + ) + + +@pytest.fixture(scope="module") +def task(task_attributes): + task = Task( + task_name=PYSAI_3100_TASK_NAME, + description=PYSAI_3100_SQL_TASK_DESCRIPTION, + attributes=task_attributes, + ) + task.create() + yield task + task.delete(force=True) + + +def test_3100(task, task_attributes): + """simple task creation""" + assert task.task_name == PYSAI_3100_TASK_NAME + assert task.attributes == task_attributes + assert task.description == PYSAI_3100_SQL_TASK_DESCRIPTION + + +@pytest.mark.parametrize("task_name_pattern", [None, "^PYSAI_3100_"]) +def test_3101(task_name_pattern): + """task list""" + if task_name_pattern: + tasks = list(select_ai.agent.Task.list(task_name_pattern)) + else: + tasks = list(select_ai.agent.Task.list()) + task_names = set(task.task_name for task in tasks) + task_descriptions = set(task.description for task in tasks) + assert PYSAI_3100_TASK_NAME in task_names + assert PYSAI_3100_SQL_TASK_DESCRIPTION in task_descriptions + + +def test_3102(task_attributes): + """task fetch""" + task = select_ai.agent.Task.fetch(PYSAI_3100_TASK_NAME) + assert task.task_name == PYSAI_3100_TASK_NAME + assert task.attributes == task_attributes + assert task.description == PYSAI_3100_SQL_TASK_DESCRIPTION diff --git a/tests/agents/test_3200_agents.py b/tests/agents/test_3200_agents.py new file mode 100644 index 0000000..0a70b2e --- /dev/null +++ b/tests/agents/test_3200_agents.py @@ -0,0 +1,71 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3200 - Module for testing select_ai agents +""" +import uuid + +import pytest +import select_ai +from select_ai.agent import Agent, AgentAttributes + +PYSAI_3200_AGENT_NAME = f"PYSAI_3200_AGENT_{uuid.uuid4().hex.upper()}" +PYSAI_3200_AGENT_DESCRIPTION = "PYSAI_3200_AGENT_DESCRIPTION" +PYSAI_3200_PROFILE_NAME = f"PYSAI_3200_PROFILE_{uuid.uuid4().hex.upper()}" + + +@pytest.fixture(scope="module") +def python_gen_ai_profile(profile_attributes): + profile = select_ai.Profile( + profile_name=PYSAI_3200_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, + ) + yield profile + profile.delete(force=True) + + +@pytest.fixture(scope="module") +def agent_attributes(): + agent_attributes = AgentAttributes( + profile_name=PYSAI_3200_PROFILE_NAME, + role="You are an AI Movie Analyst. " + "Your can help answer a variety of questions related to movies. ", + enable_human_tool=False, + ) + return agent_attributes + + +@pytest.fixture(scope="module") +def agent(python_gen_ai_profile, agent_attributes): + agent = Agent( + agent_name=PYSAI_3200_AGENT_NAME, + attributes=agent_attributes, + description=PYSAI_3200_AGENT_DESCRIPTION, + ) + agent.create(enabled=True, replace=True) + yield agent + agent.delete(force=True) + + +def test_3200(agent, agent_attributes): + assert agent.agent_name == PYSAI_3200_AGENT_NAME + assert agent.attributes == agent_attributes + assert agent.description == PYSAI_3200_AGENT_DESCRIPTION + + +@pytest.mark.parametrize("agent_name_pattern", [None, "^PYSAI_3200_AGENT_"]) +def test_3201(agent_name_pattern): + if agent_name_pattern: + agents = list(select_ai.agent.Agent.list(agent_name_pattern)) + else: + agents = list(select_ai.agent.Agent.list()) + agent_names = set(agent.agent_name for agent in agents) + agent_descriptions = set(agent.description for agent in agents) + assert PYSAI_3200_AGENT_NAME in agent_names + assert PYSAI_3200_AGENT_DESCRIPTION in agent_descriptions diff --git a/tests/agents/test_3300_teams.py b/tests/agents/test_3300_teams.py new file mode 100644 index 0000000..8a61634 --- /dev/null +++ b/tests/agents/test_3300_teams.py @@ -0,0 +1,134 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3300 - Module for testing select_ai agent teams +""" + +import uuid + +import pytest +import select_ai +from select_ai.agent import ( + Agent, + AgentAttributes, + Task, + TaskAttributes, + Team, + TeamAttributes, +) + +PYSAI_3300_AGENT_NAME = f"PYSAI_3300_AGENT_{uuid.uuid4().hex.upper()}" +PYSAI_3300_AGENT_DESCRIPTION = "PYSAI_3300_AGENT_DESCRIPTION" +PYSAI_3300_PROFILE_NAME = f"PYSAI_3300_PROFILE_{uuid.uuid4().hex.upper()}" +PYSAI_3300_TASK_NAME = f"PYSAI_3300_{uuid.uuid4().hex.upper()}" +PYSAI_3300_TASK_DESCRIPTION = "PYSAI_3100_SQL_TASK_DESCRIPTION" +PYSAI_3300_TEAM_NAME = f"PYSAI_3300_TEAM_{uuid.uuid4().hex.upper()}" +PYSAI_3300_TEAM_DESCRIPTION = "PYSAI_3300_TEAM_DESCRIPTION" + + +@pytest.fixture(scope="module") +def python_gen_ai_profile(profile_attributes): + profile = select_ai.Profile( + profile_name=PYSAI_3300_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, + ) + yield profile + profile.delete(force=True) + + +@pytest.fixture(scope="module") +def task_attributes(): + return TaskAttributes( + instruction="Help the user with their request about movies. " + "User question: {query}. ", + enable_human_tool=False, + ) + + +@pytest.fixture(scope="module") +def task(task_attributes): + task = Task( + task_name=PYSAI_3300_TASK_NAME, + description=PYSAI_3300_TASK_DESCRIPTION, + attributes=task_attributes, + ) + task.create() + yield task + task.delete(force=True) + + +@pytest.fixture(scope="module") +def agent(python_gen_ai_profile): + agent = Agent( + agent_name=PYSAI_3300_AGENT_NAME, + description=PYSAI_3300_AGENT_DESCRIPTION, + attributes=AgentAttributes( + profile_name=PYSAI_3300_PROFILE_NAME, + role="You are an AI Movie Analyst. " + "Your can help answer a variety of questions related to movies. ", + enable_human_tool=False, + ), + ) + agent.create(enabled=True, replace=True) + yield agent + agent.delete(force=True) + + +@pytest.fixture(scope="module") +def team_attributes(agent, task): + return TeamAttributes( + agents=[{"name": agent.agent_name, "task": task.task_name}], + process="sequential", + ) + + +@pytest.fixture(scope="module") +def team(team_attributes): + team = Team( + team_name=PYSAI_3300_TEAM_NAME, + description=PYSAI_3300_TEAM_DESCRIPTION, + attributes=team_attributes, + ) + team.create() + yield team + team.delete(force=True) + + +def test_3300(team, team_attributes): + assert team.team_name == PYSAI_3300_TEAM_NAME + assert team.description == PYSAI_3300_TEAM_DESCRIPTION + assert team.attributes == team_attributes + + +@pytest.mark.parametrize("team_name_pattern", [None, "^PYSAI_3300_TEAM_"]) +def test_3301(team_name_pattern): + if team_name_pattern: + teams = list(Team.list(team_name_pattern)) + else: + teams = list(Team.list()) + team_names = set(team.team_name for team in teams) + team_descriptions = set(team.description for team in teams) + assert PYSAI_3300_TEAM_NAME in team_names + assert PYSAI_3300_TEAM_DESCRIPTION in team_descriptions + + +def test_3302(team_attributes): + team = Team.fetch(team_name=PYSAI_3300_TEAM_NAME) + assert team.team_name == PYSAI_3300_TEAM_NAME + assert team.description == PYSAI_3300_TEAM_DESCRIPTION + assert team.attributes == team_attributes + + +def test_3303(team): + response = team.run( + prompt="In the movie Titanic, was there enough space for Jack ? ", + params={"conversation_id": str(uuid.uuid4())}, + ) + assert isinstance(response, str) + assert len(response) > 0 diff --git a/tests/agents/test_3400_async_tools.py b/tests/agents/test_3400_async_tools.py new file mode 100644 index 0000000..3620a8b --- /dev/null +++ b/tests/agents/test_3400_async_tools.py @@ -0,0 +1,188 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3400 - Module for testing select_ai agent async tools +""" + +import uuid + +import pytest +import select_ai +from select_ai.agent import AsyncTool + +PYSAI_3400_PROFILE_NAME = f"PYSAI_3400_{uuid.uuid4().hex.upper()}" +PYSAI_3400_SQL_TOOL_NAME = f"PYSAI_3400_SQL_TOOL_{uuid.uuid4().hex.upper()}" +PYSAI_3400_SQL_TOOL_DESCRIPTION = f"SQL Tool for Python 3000" + +PYSAI_3400_RAG_PROFILE_NAME = f"PYSAI_3400_RAG_{uuid.uuid4().hex.upper()}" +PYSAI_3400_RAG_VECTOR_INDEX_NAME = ( + f"PYSAI_3400_RAG_VECTOR_{uuid.uuid4().hex.upper()}" +) +PYSAI_3400_RAG_TOOL_NAME = f"PYSAI_3400_RAG_TOOL_{uuid.uuid4().hex.upper()}" +PYSAI_3400_RAG_TOOL_DESCRIPTION = f"RAG Tool for Python 3000" + +PYSAI_3400_PL_SQL_TOOL_NAME = ( + f"PYSAI_3400_PL_SQL_TOOL_{uuid.uuid4().hex.upper()}" +) +PYSAI_3400_PL_SQL_TOOL_DESCRIPTION = f"PL/SQL Tool for Python 3000" +PYSAI_3400_PL_SQL_FUNC_NAME = ( + f"PYSAI_3400_PL_SQL_FUNC_{uuid.uuid4().hex.upper()}" +) + + +@pytest.fixture(scope="module") +async def python_gen_ai_profile(profile_attributes): + profile = await select_ai.AsyncProfile( + profile_name=PYSAI_3400_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, + ) + yield profile + await profile.delete(force=True) + + +@pytest.fixture(scope="module") +async def python_gen_rag_ai_profile(rag_profile_attributes): + profile = await select_ai.AsyncProfile( + profile_name=PYSAI_3400_RAG_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=rag_profile_attributes, + ) + yield profile + await profile.delete(force=True) + + +@pytest.fixture(scope="module") +async def sql_tool(python_gen_ai_profile): + sql_tool = await AsyncTool.create_sql_tool( + tool_name=PYSAI_3400_SQL_TOOL_NAME, + description=PYSAI_3400_SQL_TOOL_DESCRIPTION, + profile_name=PYSAI_3400_PROFILE_NAME, + replace=True, + ) + yield sql_tool + await sql_tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def vector_index(vector_index_attributes, python_gen_rag_ai_profile): + vector_index = select_ai.AsyncVectorIndex( + index_name=PYSAI_3400_RAG_VECTOR_INDEX_NAME, + attributes=vector_index_attributes, + description="Test vector index", + profile=python_gen_rag_ai_profile, + ) + await vector_index.create(replace=True) + yield vector_index + await vector_index.delete(force=True) + + +@pytest.fixture(scope="module") +async def rag_tool(vector_index): + sql_tool = await AsyncTool.create_rag_tool( + tool_name=PYSAI_3400_RAG_TOOL_NAME, + description=PYSAI_3400_RAG_TOOL_DESCRIPTION, + profile_name=PYSAI_3400_RAG_PROFILE_NAME, + replace=True, + ) + yield sql_tool + await sql_tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def pl_sql_function(): + create_function = f""" + CREATE OR REPLACE FUNCTION {PYSAI_3400_PL_SQL_FUNC_NAME} (p_birth_date IN DATE) + RETURN NUMBER + IS + v_age NUMBER; + BEGIN + -- Calculate the difference in years + v_age := TRUNC(MONTHS_BETWEEN(SYSDATE, p_birth_date) / 12); + + RETURN v_age; + END CALCULATE_AGE; + """ + async with select_ai.async_cursor() as cr: + await cr.execute(create_function) + yield create_function + async with select_ai.async_cursor() as cr: + await cr.execute(f"DROP FUNCTION {PYSAI_3400_PL_SQL_FUNC_NAME}") + + +@pytest.fixture(scope="module") +async def pl_sql_tool(pl_sql_function): + pl_sql_tool = await AsyncTool.create_pl_sql_tool( + tool_name=PYSAI_3400_PL_SQL_TOOL_NAME, + function=PYSAI_3400_PL_SQL_FUNC_NAME, + description=PYSAI_3400_PL_SQL_TOOL_DESCRIPTION, + ) + yield pl_sql_tool + await pl_sql_tool.delete(force=True) + + +def test_3400(sql_tool): + """test SQL tool creation and parameter validation""" + assert ( + sql_tool.attributes.tool_params.profile_name == PYSAI_3400_PROFILE_NAME + ) + assert sql_tool.tool_name == PYSAI_3400_SQL_TOOL_NAME + assert sql_tool.description == PYSAI_3400_SQL_TOOL_DESCRIPTION + assert isinstance( + sql_tool.attributes.tool_params, select_ai.agent.SQLToolParams + ) + + +def test_3401(rag_tool): + """test RAG tool creation and parameter validation""" + assert ( + rag_tool.attributes.tool_params.profile_name + == PYSAI_3400_RAG_PROFILE_NAME + ) + assert rag_tool.tool_name == PYSAI_3400_RAG_TOOL_NAME + assert rag_tool.description == PYSAI_3400_RAG_TOOL_DESCRIPTION + assert isinstance( + rag_tool.attributes.tool_params, select_ai.agent.RAGToolParams + ) + + +def test_3402(pl_sql_tool): + """test PL SQL tool creation and parameter validation""" + assert pl_sql_tool.tool_name == PYSAI_3400_PL_SQL_TOOL_NAME + assert pl_sql_tool.description == PYSAI_3400_PL_SQL_TOOL_DESCRIPTION + assert pl_sql_tool.attributes.function == PYSAI_3400_PL_SQL_FUNC_NAME + + +async def test_3403(): + """list tools""" + tools = [tool async for tool in AsyncTool.list()] + tool_names = set(tool.tool_name for tool in tools) + assert PYSAI_3400_RAG_TOOL_NAME in tool_names + assert PYSAI_3400_SQL_TOOL_NAME in tool_names + assert PYSAI_3400_PL_SQL_TOOL_NAME in tool_names + + +async def test_3404(): + """list tools matching a REGEX pattern""" + tools = [ + tool async for tool in AsyncTool.list(tool_name_pattern="^PYSAI_3400") + ] + tool_names = set(tool.tool_name for tool in tools) + assert PYSAI_3400_RAG_TOOL_NAME in tool_names + assert PYSAI_3400_SQL_TOOL_NAME in tool_names + assert PYSAI_3400_PL_SQL_TOOL_NAME in tool_names + + +async def test_3405(): + """fetch tool""" + sql_tool = await AsyncTool.fetch(tool_name=PYSAI_3400_SQL_TOOL_NAME) + assert sql_tool.tool_name == PYSAI_3400_SQL_TOOL_NAME + assert sql_tool.description == PYSAI_3400_SQL_TOOL_DESCRIPTION + assert isinstance( + sql_tool.attributes.tool_params, select_ai.agent.SQLToolParams + ) diff --git a/tests/agents/test_3500_async_tasks.py b/tests/agents/test_3500_async_tasks.py new file mode 100644 index 0000000..131720d --- /dev/null +++ b/tests/agents/test_3500_async_tasks.py @@ -0,0 +1,73 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3500 - Module for testing select_ai agent async tasks +""" + +import uuid + +import pytest +import select_ai +from select_ai.agent import AsyncTask, TaskAttributes + +PYSAI_3500_TASK_NAME = f"PYSAI_3500_TASK_{uuid.uuid4().hex.upper()}" +PYSAI_3500_SQL_TASK_DESCRIPTION = "PYSAI_3500_SQL_TASK_DESCRIPTION" + + +@pytest.fixture(scope="module") +def task_attributes(): + return TaskAttributes( + instruction="Help the user with their request about movies. " + "User question: {query}. " + "You can use SQL tool to search the data from database", + tools=["MOVIE_SQL_TOOL"], + enable_human_tool=False, + ) + + +@pytest.fixture(scope="module") +async def task(task_attributes): + task = AsyncTask( + task_name=PYSAI_3500_TASK_NAME, + description=PYSAI_3500_SQL_TASK_DESCRIPTION, + attributes=task_attributes, + ) + await task.create() + yield task + await task.delete(force=True) + + +async def test_3500(task, task_attributes): + """simple task creation""" + assert task.task_name == PYSAI_3500_TASK_NAME + assert task.attributes == task_attributes + assert task.description == PYSAI_3500_SQL_TASK_DESCRIPTION + + +@pytest.mark.parametrize("task_name_pattern", [None, "^PYSAI_3500_"]) +async def test_3501(task_name_pattern): + """task list""" + if task_name_pattern: + tasks = [ + task + async for task in select_ai.agent.AsyncTask.list(task_name_pattern) + ] + else: + tasks = [task async for task in select_ai.agent.AsyncTask.list()] + task_names = set(task.task_name for task in tasks) + task_descriptions = set(task.description for task in tasks) + assert PYSAI_3500_TASK_NAME in task_names + assert PYSAI_3500_SQL_TASK_DESCRIPTION in task_descriptions + + +async def test_3502(task_attributes): + """task fetch""" + task = await select_ai.agent.AsyncTask.fetch(PYSAI_3500_TASK_NAME) + assert task.task_name == PYSAI_3500_TASK_NAME + assert task.attributes == task_attributes + assert task.description == PYSAI_3500_SQL_TASK_DESCRIPTION diff --git a/tests/agents/test_3600_async_agents.py b/tests/agents/test_3600_async_agents.py new file mode 100644 index 0000000..17d48e6 --- /dev/null +++ b/tests/agents/test_3600_async_agents.py @@ -0,0 +1,77 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3600 - Module for testing select_ai async agents +""" + +import uuid + +import pytest +import select_ai +from select_ai.agent import AgentAttributes, AsyncAgent + +PYSAI_3600_AGENT_NAME = f"PYSAI_3600_AGENT_{uuid.uuid4().hex.upper()}" +PYSAI_3600_AGENT_DESCRIPTION = "PYSAI_3600_AGENT_DESCRIPTION" +PYSAI_3600_PROFILE_NAME = f"PYSAI_3600_PROFILE_{uuid.uuid4().hex.upper()}" + + +@pytest.fixture(scope="module") +async def async_python_gen_ai_profile(profile_attributes): + profile = await select_ai.AsyncProfile( + profile_name=PYSAI_3600_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, + ) + yield profile + await profile.delete(force=True) + + +@pytest.fixture(scope="module") +def agent_attributes(): + agent_attributes = AgentAttributes( + profile_name=PYSAI_3600_PROFILE_NAME, + role="You are an AI Movie Analyst. " + "Your can help answer a variety of questions related to movies. ", + enable_human_tool=False, + ) + return agent_attributes + + +@pytest.fixture(scope="module") +async def agent(async_python_gen_ai_profile, agent_attributes): + agent = AsyncAgent( + agent_name=PYSAI_3600_AGENT_NAME, + attributes=agent_attributes, + description=PYSAI_3600_AGENT_DESCRIPTION, + ) + await agent.create(enabled=True, replace=True) + yield agent + await agent.delete(force=True) + + +async def test_3200(agent, agent_attributes): + assert agent.agent_name == PYSAI_3600_AGENT_NAME + assert agent.attributes == agent_attributes + assert agent.description == PYSAI_3600_AGENT_DESCRIPTION + + +@pytest.mark.parametrize("agent_name_pattern", [None, "^PYSAI_3600_AGENT_"]) +async def test_3201(agent_name_pattern): + if agent_name_pattern: + agents = [ + agent + async for agent in select_ai.agent.AsyncAgent.list( + agent_name_pattern + ) + ] + else: + agents = [agent async for agent in select_ai.agent.AsyncAgent.list()] + agent_names = set(agent.agent_name for agent in agents) + agent_descriptions = set(agent.description for agent in agents) + assert PYSAI_3600_AGENT_NAME in agent_names + assert PYSAI_3600_AGENT_DESCRIPTION in agent_descriptions diff --git a/tests/agents/test_3700_async_teams.py b/tests/agents/test_3700_async_teams.py new file mode 100644 index 0000000..cc8bd3e --- /dev/null +++ b/tests/agents/test_3700_async_teams.py @@ -0,0 +1,134 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3700 - Module for testing select_ai agent async teams +""" + +import uuid + +import pytest +import select_ai +from select_ai.agent import ( + AgentAttributes, + AsyncAgent, + AsyncTask, + AsyncTeam, + TaskAttributes, + TeamAttributes, +) + +PYSAI_3700_AGENT_NAME = f"PYSAI_3700_AGENT_{uuid.uuid4().hex.upper()}" +PYSAI_3700_AGENT_DESCRIPTION = "PYSAI_3700_AGENT_DESCRIPTION" +PYSAI_3700_PROFILE_NAME = f"PYSAI_3700_PROFILE_{uuid.uuid4().hex.upper()}" +PYSAI_3700_TASK_NAME = f"PYSAI_3700_{uuid.uuid4().hex.upper()}" +PYSAI_3700_TASK_DESCRIPTION = "PYSAI_3100_SQL_TASK_DESCRIPTION" +PYSAI_3700_TEAM_NAME = f"PYSAI_3700_TEAM_{uuid.uuid4().hex.upper()}" +PYSAI_3700_TEAM_DESCRIPTION = "PYSAI_3700_TEAM_DESCRIPTION" + + +@pytest.fixture(scope="module") +async def python_gen_ai_profile(profile_attributes): + profile = await select_ai.AsyncProfile( + profile_name=PYSAI_3700_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, + ) + yield profile + await profile.delete(force=True) + + +@pytest.fixture(scope="module") +def task_attributes(): + return TaskAttributes( + instruction="Help the user with their request about movies. " + "User question: {query}. ", + enable_human_tool=False, + ) + + +@pytest.fixture(scope="module") +async def task(task_attributes): + task = AsyncTask( + task_name=PYSAI_3700_TASK_NAME, + description=PYSAI_3700_TASK_DESCRIPTION, + attributes=task_attributes, + ) + await task.create() + yield task + await task.delete(force=True) + + +@pytest.fixture(scope="module") +async def agent(python_gen_ai_profile): + agent = AsyncAgent( + agent_name=PYSAI_3700_AGENT_NAME, + description=PYSAI_3700_AGENT_DESCRIPTION, + attributes=AgentAttributes( + profile_name=PYSAI_3700_PROFILE_NAME, + role="You are an AI Movie Analyst. " + "Your can help answer a variety of questions related to movies. ", + enable_human_tool=False, + ), + ) + await agent.create(enabled=True, replace=True) + yield agent + await agent.delete(force=True) + + +@pytest.fixture(scope="module") +def team_attributes(agent, task): + return TeamAttributes( + agents=[{"name": agent.agent_name, "task": task.task_name}], + process="sequential", + ) + + +@pytest.fixture(scope="module") +async def team(team_attributes): + team = AsyncTeam( + team_name=PYSAI_3700_TEAM_NAME, + description=PYSAI_3700_TEAM_DESCRIPTION, + attributes=team_attributes, + ) + await team.create() + yield team + await team.delete(force=True) + + +def test_3300(team, team_attributes): + assert team.team_name == PYSAI_3700_TEAM_NAME + assert team.description == PYSAI_3700_TEAM_DESCRIPTION + assert team.attributes == team_attributes + + +@pytest.mark.parametrize("team_name_pattern", [None, "^PYSAI_3700_TEAM_"]) +async def test_3301(team_name_pattern): + if team_name_pattern: + teams = [team async for team in AsyncTeam.list(team_name_pattern)] + else: + teams = [team async for team in select_ai.agent.AsyncTeam.list()] + team_names = set(team.team_name for team in teams) + team_descriptions = set(team.description for team in teams) + assert PYSAI_3700_TEAM_NAME in team_names + assert PYSAI_3700_TEAM_DESCRIPTION in team_descriptions + + +async def test_3302(team_attributes): + team = await AsyncTeam.fetch(team_name=PYSAI_3700_TEAM_NAME) + assert team.team_name == PYSAI_3700_TEAM_NAME + assert team.description == PYSAI_3700_TEAM_DESCRIPTION + assert team.attributes == team_attributes + + +async def test_3303(team): + response = await team.run( + prompt="In the movie Titanic, was there enough space for Jack ? ", + params={"conversation_id": str(uuid.uuid4())}, + ) + assert isinstance(response, str) + assert len(response) > 0 diff --git a/tests/create_schema.py b/tests/create_schema.py index b61b62f..fbd63df 100644 --- a/tests/create_schema.py +++ b/tests/create_schema.py @@ -31,6 +31,31 @@ ) """ +CREATE_DIRECTOR_DDL = """ + CREATE TABLE Director ( + director_id INT PRIMARY KEY, + name VARCHAR(10) +) +""" + +CREATE_MOVIE_DDL = """ +CREATE TABLE Movie ( + movie_id INT PRIMARY KEY, + title VARCHAR(100), + release_date DATE, + genre VARCHAR(50), + director_id INT, + FOREIGN KEY (director_id) REFERENCES Director(director_id) +) +""" + +CREATE_ACTOR_DDL = """ + CREATE TABLE Actor ( + actor_id INT PRIMARY KEY, + name VARCHAR(100) +) +""" + INSERT_PEOPLE_DML = """ INSERT INTO people (id, name, age, height, hometown) VALUES (: 1, :2, :3, :4, :5) @@ -62,15 +87,22 @@ def test_create_schema(connection, cursor): - for tbl in ("gymnast", "people"): + for tbl in ("gymnast", "people", "director", "movie", "actor"): try: cursor.execute(f"DROP TABLE {tbl} CASCADE CONSTRAINTS") print(f"Dropped table {tbl}") except oracledb.Error: print(f"Table {tbl} does not exist, skipping") - cursor.execute(CREATE_PEOPLE_DDL) - cursor.execute(CREATE_GYMNAST_DDL) + for ddl in ( + CREATE_PEOPLE_DDL, + CREATE_GYMNAST_DDL, + CREATE_DIRECTOR_DDL, + CREATE_MOVIE_DDL, + CREATE_ACTOR_DDL, + ): + cursor.execute(ddl) + cursor.executemany(INSERT_PEOPLE_DML, PEOPLE_DATA) cursor.executemany(INSERT_GYMNAST_DML, GYMNAST_DATA) connection.commit() diff --git a/tests/profiles/test_1200_profile.py b/tests/profiles/test_1200_profile.py index 16c6429..0bd945d 100644 --- a/tests/profiles/test_1200_profile.py +++ b/tests/profiles/test_1200_profile.py @@ -145,7 +145,11 @@ def test_1208(oci_credential): profile = Profile(PYSAI_1200_PROFILE) profile_attrs = ProfileAttributes( credential_name=oci_credential["credential_name"], - provider=select_ai.OCIGenAIProvider(), + provider=select_ai.OCIGenAIProvider( + model="meta.llama-4-maverick-17b-128e-instruct-fp8", + region="us-chicago-1", + oci_apiformat="GENERIC", + ), object_list=[{"owner": "ADMIN", "name": "gymnasts"}], comments=True, ) @@ -155,6 +159,7 @@ def test_1208(oci_credential): ] assert profile.attributes.comments is True fetched_attributes = profile.get_attributes() + print(fetched_attributes.provider) assert fetched_attributes == profile_attrs diff --git a/tests/profiles/test_1300_profile_async.py b/tests/profiles/test_1300_profile_async.py index 9a9cc4a..ada04a0 100644 --- a/tests/profiles/test_1300_profile_async.py +++ b/tests/profiles/test_1300_profile_async.py @@ -154,7 +154,11 @@ async def test_1308(oci_credential): profile = await AsyncProfile(PYSAI_ASYNC_1300_PROFILE) profile_attrs = ProfileAttributes( credential_name=oci_credential["credential_name"], - provider=select_ai.OCIGenAIProvider(), + provider=select_ai.OCIGenAIProvider( + model="meta.llama-4-maverick-17b-128e-instruct-fp8", + region="us-chicago-1", + oci_apiformat="GENERIC", + ), object_list=[{"owner": "ADMIN", "name": "gymnasts"}], comments=True, )