|
5 | 5 | # http://oss.oracle.com/licenses/upl. |
6 | 6 | # ----------------------------------------------------------------------------- |
7 | 7 |
|
8 | | -import json |
9 | 8 | from abc import ABC |
10 | 9 | from dataclasses import dataclass |
11 | 10 | from typing import ( |
12 | 11 | Any, |
13 | 12 | AsyncGenerator, |
14 | 13 | Iterator, |
15 | | - List, |
16 | | - Mapping, |
17 | 14 | Optional, |
18 | 15 | Union, |
19 | 16 | ) |
20 | 17 |
|
21 | 18 | import oracledb |
22 | 19 |
|
23 | | -from select_ai import BaseProfile |
24 | 20 | from select_ai._abc import SelectAIDataClass |
25 | | -from select_ai._enums import StrEnum |
26 | 21 | from select_ai.agent.sql import ( |
27 | 22 | GET_USER_AI_AGENT, |
28 | 23 | GET_USER_AI_AGENT_ATTRIBUTES, |
29 | 24 | LIST_USER_AI_AGENTS, |
30 | 25 | ) |
31 | | -from select_ai.async_profile import AsyncProfile |
32 | 26 | from select_ai.db import async_cursor, cursor |
33 | 27 | from select_ai.errors import AgentNotFoundError |
34 | | -from select_ai.profile import Profile |
35 | 28 |
|
36 | 29 |
|
37 | 30 | @dataclass |
@@ -292,3 +285,227 @@ def set_attribute(self, attribute_name: str, attribute_value: Any) -> None: |
292 | 285 | keyword_parameters=parameters, |
293 | 286 | ) |
294 | 287 | self.attributes = self._get_attributes(agent_name=self.agent_name) |
| 288 | + |
| 289 | + |
| 290 | +class AsyncAgent(BaseAgent): |
| 291 | + """ |
| 292 | + select_ai.agent.AsyncAgent class lets you create, delete, enable, disable |
| 293 | + and list AI agents asynchronously |
| 294 | +
|
| 295 | + :param str agent_name: The name of the AI Agent |
| 296 | + :param str description: Optional description of the AI agent |
| 297 | + :param select_ai.agent.AgentAttributes attributes: AI agent attributes |
| 298 | +
|
| 299 | + """ |
| 300 | + |
| 301 | + @staticmethod |
| 302 | + async def _get_attributes(agent_name: str) -> AgentAttributes: |
| 303 | + async with async_cursor() as cr: |
| 304 | + await cr.execute( |
| 305 | + GET_USER_AI_AGENT_ATTRIBUTES, agent_name=agent_name.upper() |
| 306 | + ) |
| 307 | + attributes = await cr.fetchall() |
| 308 | + if attributes: |
| 309 | + post_processed_attributes = {} |
| 310 | + for k, v in attributes: |
| 311 | + if isinstance(v, oracledb.AsyncLOB): |
| 312 | + post_processed_attributes[k] = await v.read() |
| 313 | + else: |
| 314 | + post_processed_attributes[k] = v |
| 315 | + return AgentAttributes(**post_processed_attributes) |
| 316 | + else: |
| 317 | + raise AgentNotFoundError(agent_name=agent_name) |
| 318 | + |
| 319 | + @staticmethod |
| 320 | + async def _get_description(agent_name: str) -> Union[str, None]: |
| 321 | + async with async_cursor() as cr: |
| 322 | + await cr.execute(GET_USER_AI_AGENT, agent_name=agent_name.upper()) |
| 323 | + agent = await cr.fetchone() |
| 324 | + if agent: |
| 325 | + if agent[1] is not None: |
| 326 | + return await agent[1].read() |
| 327 | + else: |
| 328 | + return None |
| 329 | + else: |
| 330 | + raise AgentNotFoundError(agent_name) |
| 331 | + |
| 332 | + async def create( |
| 333 | + self, enabled: Optional[bool] = True, replace: Optional[bool] = False |
| 334 | + ): |
| 335 | + """ |
| 336 | + Register a new AI Agent within the Select AI framework |
| 337 | +
|
| 338 | + :param bool enabled: Whether the AI Agent should be enabled. |
| 339 | + Default value is True. |
| 340 | +
|
| 341 | + :param bool replace: Whether the AI Agent should be replaced. |
| 342 | + Default value is False. |
| 343 | +
|
| 344 | + """ |
| 345 | + if self.agent_name is None: |
| 346 | + raise AttributeError("Agent must have a name") |
| 347 | + if self.attributes is None: |
| 348 | + raise AttributeError("Agent must have attributes") |
| 349 | + |
| 350 | + parameters = { |
| 351 | + "agent_name": self.agent_name, |
| 352 | + "attributes": self.attributes.json(), |
| 353 | + } |
| 354 | + if self.description: |
| 355 | + parameters["description"] = self.description |
| 356 | + |
| 357 | + if not enabled: |
| 358 | + parameters["status"] = "disabled" |
| 359 | + |
| 360 | + async with async_cursor() as cr: |
| 361 | + try: |
| 362 | + await cr.callproc( |
| 363 | + "DBMS_CLOUD_AI_AGENT.CREATE_AGENT", |
| 364 | + keyword_parameters=parameters, |
| 365 | + ) |
| 366 | + except oracledb.Error as err: |
| 367 | + (err_obj,) = err.args |
| 368 | + if err_obj.code in (20050, 20052) and replace: |
| 369 | + await self.delete(force=True) |
| 370 | + await cr.callproc( |
| 371 | + "DBMS_CLOUD_AI_AGENT.CREATE_AGENT", |
| 372 | + keyword_parameters=parameters, |
| 373 | + ) |
| 374 | + else: |
| 375 | + raise |
| 376 | + |
| 377 | + async def delete(self, force: Optional[bool] = False): |
| 378 | + """ |
| 379 | + Delete AI Agent from the database |
| 380 | +
|
| 381 | + :param bool force: Force the deletion. Default value is False. |
| 382 | +
|
| 383 | + """ |
| 384 | + async with async_cursor() as cr: |
| 385 | + await cr.callproc( |
| 386 | + "DBMS_CLOUD_AI_AGENT.DROP_AGENT", |
| 387 | + keyword_parameters={ |
| 388 | + "agent_name": self.agent_name, |
| 389 | + "force": force, |
| 390 | + }, |
| 391 | + ) |
| 392 | + |
| 393 | + async def disable(self): |
| 394 | + """ |
| 395 | + Disable AI Agent |
| 396 | + """ |
| 397 | + async with async_cursor as cr: |
| 398 | + await cr.callproc( |
| 399 | + "DBMS_CLOUD_AI_AGENT.DISABLE_AGENT", |
| 400 | + keyword_parameters={ |
| 401 | + "agent_name": self.agent_name, |
| 402 | + }, |
| 403 | + ) |
| 404 | + |
| 405 | + async def enable(self): |
| 406 | + """ |
| 407 | + Enable AI Agent |
| 408 | + """ |
| 409 | + async with async_cursor as cr: |
| 410 | + await cr.callproc( |
| 411 | + "DBMS_CLOUD_AI_AGENT.ENABLE_AGENT", |
| 412 | + keyword_parameters={ |
| 413 | + "agent_name": self.agent_name, |
| 414 | + }, |
| 415 | + ) |
| 416 | + |
| 417 | + @classmethod |
| 418 | + async def fetch(cls, agent_name: str) -> "AsyncAgent": |
| 419 | + """ |
| 420 | + Fetch AI Agent attributes from the Database and build a proxy object in |
| 421 | + the Python layer |
| 422 | +
|
| 423 | + :param str agent_name: The name of the AI Agent |
| 424 | +
|
| 425 | + :return: select_ai.agent.Agent |
| 426 | +
|
| 427 | + :raises select_ai.errors.AgentNotFoundError: |
| 428 | + If the AI Agent is not found |
| 429 | +
|
| 430 | + """ |
| 431 | + attributes = await cls._get_attributes(agent_name=agent_name) |
| 432 | + description = await cls._get_description(agent_name=agent_name) |
| 433 | + return cls( |
| 434 | + agent_name=agent_name, |
| 435 | + attributes=attributes, |
| 436 | + description=description, |
| 437 | + ) |
| 438 | + |
| 439 | + @classmethod |
| 440 | + async def list( |
| 441 | + cls, agent_name_pattern: Optional[str] = ".*" |
| 442 | + ) -> AsyncGenerator["AsyncAgent", None]: |
| 443 | + """ |
| 444 | + List AI agents matching a pattern |
| 445 | +
|
| 446 | + :param str agent_name_pattern: Regular expressions can be used |
| 447 | + to specify a pattern. Function REGEXP_LIKE is used to perform the |
| 448 | + match. Default value is ".*" i.e. match all agent names. |
| 449 | +
|
| 450 | + :return: AsyncGenerator[AsyncAgent] |
| 451 | + """ |
| 452 | + async with async_cursor() as cr: |
| 453 | + await cr.execute( |
| 454 | + LIST_USER_AI_AGENTS, |
| 455 | + agent_name_pattern=agent_name_pattern, |
| 456 | + ) |
| 457 | + rows = await cr.fetchall() |
| 458 | + for row in rows: |
| 459 | + agent_name = row[0] |
| 460 | + if row[1]: |
| 461 | + description = await row[1].read() # Oracle.AsyncLOB |
| 462 | + else: |
| 463 | + description = None |
| 464 | + attributes = await cls._get_attributes(agent_name=agent_name) |
| 465 | + yield cls( |
| 466 | + agent_name=agent_name, |
| 467 | + description=description, |
| 468 | + attributes=attributes, |
| 469 | + ) |
| 470 | + |
| 471 | + async def set_attributes(self, attributes: AgentAttributes) -> None: |
| 472 | + """ |
| 473 | + Set AI Agent attributes |
| 474 | +
|
| 475 | + :param select_ai.agent.AgentAttributes attributes: Multiple attributes |
| 476 | + can be specified by passing an AgentAttributes object |
| 477 | + """ |
| 478 | + parameters = { |
| 479 | + "object_name": self.agent_name, |
| 480 | + "object_type": "agent", |
| 481 | + "attributes": attributes.json(), |
| 482 | + } |
| 483 | + async with async_cursor() as cr: |
| 484 | + await cr.callproc( |
| 485 | + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTES", |
| 486 | + keyword_parameters=parameters, |
| 487 | + ) |
| 488 | + self.attributes = await self._get_attributes( |
| 489 | + agent_name=self.agent_name |
| 490 | + ) |
| 491 | + |
| 492 | + async def set_attribute( |
| 493 | + self, attribute_name: str, attribute_value: Any |
| 494 | + ) -> None: |
| 495 | + """ |
| 496 | + Set a single AI Agent attribute specified using name and value |
| 497 | + """ |
| 498 | + parameters = { |
| 499 | + "object_name": self.agent_name, |
| 500 | + "object_type": "agent", |
| 501 | + "attribute_name": attribute_name, |
| 502 | + "attribute_value": attribute_value, |
| 503 | + } |
| 504 | + async with async_cursor() as cr: |
| 505 | + await cr.callproc( |
| 506 | + "DBMS_CLOUD_AI_AGENT.SET_ATTRIBUTE", |
| 507 | + keyword_parameters=parameters, |
| 508 | + ) |
| 509 | + self.attributes = await self._get_attributes( |
| 510 | + agent_name=self.agent_name |
| 511 | + ) |
0 commit comments