Skip to content

Commit 531a7ab

Browse files
committed
Added async AI agent support
1 parent 08700fd commit 531a7ab

File tree

13 files changed

+1644
-16
lines changed

13 files changed

+1644
-16
lines changed

src/select_ai/agent/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# -----------------------------------------------------------------------------
77

88

9-
from .core import Agent, AgentAttributes
10-
from .task import Task, TaskAttributes
11-
from .team import Team, TeamAttributes
9+
from .core import Agent, AgentAttributes, AsyncAgent
10+
from .task import AsyncTask, Task, TaskAttributes
11+
from .team import AsyncTeam, Team, TeamAttributes
1212
from .tool import (
13+
AsyncTool,
1314
EmailNotificationToolParams,
1415
HTTPToolParams,
1516
HumanToolParams,

src/select_ai/agent/core.py

Lines changed: 224 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,26 @@
55
# http://oss.oracle.com/licenses/upl.
66
# -----------------------------------------------------------------------------
77

8-
import json
98
from abc import ABC
109
from dataclasses import dataclass
1110
from typing import (
1211
Any,
1312
AsyncGenerator,
1413
Iterator,
15-
List,
16-
Mapping,
1714
Optional,
1815
Union,
1916
)
2017

2118
import oracledb
2219

23-
from select_ai import BaseProfile
2420
from select_ai._abc import SelectAIDataClass
25-
from select_ai._enums import StrEnum
2621
from select_ai.agent.sql import (
2722
GET_USER_AI_AGENT,
2823
GET_USER_AI_AGENT_ATTRIBUTES,
2924
LIST_USER_AI_AGENTS,
3025
)
31-
from select_ai.async_profile import AsyncProfile
3226
from select_ai.db import async_cursor, cursor
3327
from select_ai.errors import AgentNotFoundError
34-
from select_ai.profile import Profile
3528

3629

3730
@dataclass
@@ -292,3 +285,227 @@ def set_attribute(self, attribute_name: str, attribute_value: Any) -> None:
292285
keyword_parameters=parameters,
293286
)
294287
self.attributes = self._get_attributes(agent_name=self.agent_name)
288+
289+
290+
class AsyncAgent(BaseAgent):
291+
"""
292+
select_ai.agent.AsyncAgent class lets you create, delete, enable, disable
293+
and list AI agents asynchronously
294+
295+
:param str agent_name: The name of the AI Agent
296+
:param str description: Optional description of the AI agent
297+
:param select_ai.agent.AgentAttributes attributes: AI agent attributes
298+
299+
"""
300+
301+
@staticmethod
302+
async def _get_attributes(agent_name: str) -> AgentAttributes:
303+
async with async_cursor() as cr:
304+
await cr.execute(
305+
GET_USER_AI_AGENT_ATTRIBUTES, agent_name=agent_name.upper()
306+
)
307+
attributes = await cr.fetchall()
308+
if attributes:
309+
post_processed_attributes = {}
310+
for k, v in attributes:
311+
if isinstance(v, oracledb.AsyncLOB):
312+
post_processed_attributes[k] = await v.read()
313+
else:
314+
post_processed_attributes[k] = v
315+
return AgentAttributes(**post_processed_attributes)
316+
else:
317+
raise AgentNotFoundError(agent_name=agent_name)
318+
319+
@staticmethod
320+
async def _get_description(agent_name: str) -> Union[str, None]:
321+
async with async_cursor() as cr:
322+
await cr.execute(GET_USER_AI_AGENT, agent_name=agent_name.upper())
323+
agent = await cr.fetchone()
324+
if agent:
325+
if agent[1] is not None:
326+
return await agent[1].read()
327+
else:
328+
return None
329+
else:
330+
raise AgentNotFoundError(agent_name)
331+
332+
async def create(
333+
self, enabled: Optional[bool] = True, replace: Optional[bool] = False
334+
):
335+
"""
336+
Register a new AI Agent within the Select AI framework
337+
338+
:param bool enabled: Whether the AI Agent should be enabled.
339+
Default value is True.
340+
341+
:param bool replace: Whether the AI Agent should be replaced.
342+
Default value is False.
343+
344+
"""
345+
if self.agent_name is None:
346+
raise AttributeError("Agent must have a name")
347+
if self.attributes is None:
348+
raise AttributeError("Agent must have attributes")
349+
350+
parameters = {
351+
"agent_name": self.agent_name,
352+
"attributes": self.attributes.json(),
353+
}
354+
if self.description:
355+
parameters["description"] = self.description
356+
357+
if not enabled:
358+
parameters["status"] = "disabled"
359+
360+
async with async_cursor() as cr:
361+
try:
362+
await cr.callproc(
363+
"DBMS_CLOUD_AI_AGENT.CREATE_AGENT",
364+
keyword_parameters=parameters,
365+
)
366+
except oracledb.Error as err:
367+
(err_obj,) = err.args
368+
if err_obj.code in (20050, 20052) and replace:
369+
await self.delete(force=True)
370+
await cr.callproc(
371+
"DBMS_CLOUD_AI_AGENT.CREATE_AGENT",
372+
keyword_parameters=parameters,
373+
)
374+
else:
375+
raise
376+
377+
async def delete(self, force: Optional[bool] = False):
378+
"""
379+
Delete AI Agent from the database
380+
381+
:param bool force: Force the deletion. Default value is False.
382+
383+
"""
384+
async with async_cursor() as cr:
385+
await cr.callproc(
386+
"DBMS_CLOUD_AI_AGENT.DROP_AGENT",
387+
keyword_parameters={
388+
"agent_name": self.agent_name,
389+
"force": force,
390+
},
391+
)
392+
393+
async def disable(self):
394+
"""
395+
Disable AI Agent
396+
"""
397+
async with async_cursor as cr:
398+
await cr.callproc(
399+
"DBMS_CLOUD_AI_AGENT.DISABLE_AGENT",
400+
keyword_parameters={
401+
"agent_name": self.agent_name,
402+
},
403+
)
404+
405+
async def enable(self):
406+
"""
407+
Enable AI Agent
408+
"""
409+
async with async_cursor as cr:
410+
await cr.callproc(
411+
"DBMS_CLOUD_AI_AGENT.ENABLE_AGENT",
412+
keyword_parameters={
413+
"agent_name": self.agent_name,
414+
},
415+
)
416+
417+
@classmethod
418+
async def fetch(cls, agent_name: str) -> "AsyncAgent":
419+
"""
420+
Fetch AI Agent attributes from the Database and build a proxy object in
421+
the Python layer
422+
423+
:param str agent_name: The name of the AI Agent
424+
425+
:return: select_ai.agent.Agent
426+
427+
:raises select_ai.errors.AgentNotFoundError:
428+
If the AI Agent is not found
429+
430+
"""
431+
attributes = await cls._get_attributes(agent_name=agent_name)
432+
description = await cls._get_description(agent_name=agent_name)
433+
return cls(
434+
agent_name=agent_name,
435+
attributes=attributes,
436+
description=description,
437+
)
438+
439+
@classmethod
440+
async def list(
441+
cls, agent_name_pattern: Optional[str] = ".*"
442+
) -> AsyncGenerator["AsyncAgent", None]:
443+
"""
444+
List AI agents matching a pattern
445+
446+
:param str agent_name_pattern: Regular expressions can be used
447+
to specify a pattern. Function REGEXP_LIKE is used to perform the
448+
match. Default value is ".*" i.e. match all agent names.
449+
450+
:return: AsyncGenerator[AsyncAgent]
451+
"""
452+
async with async_cursor() as cr:
453+
await cr.execute(
454+
LIST_USER_AI_AGENTS,
455+
agent_name_pattern=agent_name_pattern,
456+
)
457+
rows = await cr.fetchall()
458+
for row in rows:
459+
agent_name = row[0]
460+
if row[1]:
461+
description = await row[1].read() # Oracle.AsyncLOB
462+
else:
463+
description = None
464+
attributes = await cls._get_attributes(agent_name=agent_name)
465+
yield cls(
466+
agent_name=agent_name,
467+
description=description,
468+
attributes=attributes,
469+
)
470+
471+
async def set_attributes(self, attributes: AgentAttributes) -> None:
472+
"""
473+
Set AI Agent attributes
474+
475+
:param select_ai.agent.AgentAttributes attributes: Multiple attributes
476+
can be specified by passing an AgentAttributes object
477+
"""
478+
parameters = {
479+
"object_name": self.agent_name,
480+
"object_type": "agent",
481+
"attributes": attributes.json(),
482+
}
483+
async with async_cursor() as cr:
484+
await cr.callproc(
485+
"DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTES",
486+
keyword_parameters=parameters,
487+
)
488+
self.attributes = await self._get_attributes(
489+
agent_name=self.agent_name
490+
)
491+
492+
async def set_attribute(
493+
self, attribute_name: str, attribute_value: Any
494+
) -> None:
495+
"""
496+
Set a single AI Agent attribute specified using name and value
497+
"""
498+
parameters = {
499+
"object_name": self.agent_name,
500+
"object_type": "agent",
501+
"attribute_name": attribute_name,
502+
"attribute_value": attribute_value,
503+
}
504+
async with async_cursor() as cr:
505+
await cr.callproc(
506+
"DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTE",
507+
keyword_parameters=parameters,
508+
)
509+
self.attributes = await self._get_attributes(
510+
agent_name=self.agent_name
511+
)

0 commit comments

Comments
 (0)