Skip to content

Commit 57aad48

Browse files
committed
bug fixes related to profile attributes and optional attributes
1 parent 0e3a748 commit 57aad48

15 files changed

+486
-29
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# -----------------------------------------------------------------------------
2+
# Copyright (c) 2025, Oracle and/or its affiliates.
3+
#
4+
# Licensed under the Universal Permissive License v 1.0 as shown at
5+
# http://oss.oracle.com/licenses/upl.
6+
# -----------------------------------------------------------------------------
7+
8+
# -----------------------------------------------------------------------------
9+
# async/create_ai_credential.py
10+
#
11+
# Async API to create credential
12+
# -----------------------------------------------------------------------------
13+
14+
import asyncio
15+
import os
16+
17+
import oci
18+
import select_ai
19+
20+
user = os.getenv("SELECT_AI_USER")
21+
password = os.getenv("SELECT_AI_PASSWORD")
22+
dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING")
23+
24+
25+
async def main():
26+
await select_ai.async_connect(user=user, password=password, dsn=dsn)
27+
default_config = oci.config.from_file()
28+
oci.config.validate_config(default_config)
29+
with open(default_config["key_file"]) as fp:
30+
key_contents = fp.read()
31+
credential = {
32+
"credential_name": "my_oci_ai_profile_key",
33+
"user_ocid": default_config["user"],
34+
"tenancy_ocid": default_config["tenancy"],
35+
"private_key": key_contents,
36+
"fingerprint": default_config["fingerprint"],
37+
}
38+
await select_ai.async_create_credential(
39+
credential=credential, replace=True
40+
)
41+
print("Created credential: ", credential["credential_name"])
42+
43+
44+
asyncio.run(main())
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# -----------------------------------------------------------------------------
2+
# Copyright (c) 2025, Oracle and/or its affiliates.
3+
#
4+
# Licensed under the Universal Permissive License v 1.0 as shown at
5+
# http://oss.oracle.com/licenses/upl.
6+
# -----------------------------------------------------------------------------
7+
8+
# -----------------------------------------------------------------------------
9+
# async/create_ai_credential.py
10+
#
11+
# Async API to create credential
12+
# -----------------------------------------------------------------------------
13+
14+
import asyncio
15+
import os
16+
17+
import select_ai
18+
19+
user = os.getenv("SELECT_AI_USER")
20+
password = os.getenv("SELECT_AI_PASSWORD")
21+
dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING")
22+
23+
24+
async def main():
25+
await select_ai.async_connect(user=user, password=password, dsn=dsn)
26+
await select_ai.async_delete_credential(
27+
credential_name="my_oci_ai_profile_key", force=True
28+
)
29+
print("Deleted credential: my_oci_ai_profile_key")
30+
31+
32+
asyncio.run(main())
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# -----------------------------------------------------------------------------
2+
# Copyright (c) 2025, Oracle and/or its affiliates.
3+
#
4+
# Licensed under the Universal Permissive License v 1.0 as shown at
5+
# http://oss.oracle.com/licenses/upl.
6+
# -----------------------------------------------------------------------------
7+
8+
# -----------------------------------------------------------------------------
9+
# async/disable_ai_provider.py
10+
#
11+
# Async API to disable AI provider for database users
12+
# -----------------------------------------------------------------------------
13+
14+
import asyncio
15+
import os
16+
17+
import select_ai
18+
19+
admin_user = os.getenv("SELECT_AI_ADMIN_USER")
20+
password = os.getenv("SELECT_AI_ADMIN_PASSWORD")
21+
dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING")
22+
select_ai_user = os.getenv("SELECT_AI_USER")
23+
24+
25+
async def main():
26+
await select_ai.async_connect(user=admin_user, password=password, dsn=dsn)
27+
await select_ai.async_disable_provider(
28+
users=select_ai_user, provider_endpoint="*.openai.azure.com"
29+
)
30+
print("Disabled AI provider for user: ", select_ai_user)
31+
32+
33+
asyncio.run(main())
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# -----------------------------------------------------------------------------
2+
# Copyright (c) 2025, Oracle and/or its affiliates.
3+
#
4+
# Licensed under the Universal Permissive License v 1.0 as shown at
5+
# http://oss.oracle.com/licenses/upl.
6+
# -----------------------------------------------------------------------------
7+
8+
# -----------------------------------------------------------------------------
9+
# async/enable_ai_provider.py
10+
#
11+
# Async API to enable AI provider for database users
12+
# -----------------------------------------------------------------------------
13+
14+
import asyncio
15+
import os
16+
17+
import select_ai
18+
19+
admin_user = os.getenv("SELECT_AI_ADMIN_USER")
20+
password = os.getenv("SELECT_AI_ADMIN_PASSWORD")
21+
dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING")
22+
select_ai_user = os.getenv("SELECT_AI_USER")
23+
24+
25+
async def main():
26+
await select_ai.async_connect(user=admin_user, password=password, dsn=dsn)
27+
await select_ai.async_enable_provider(
28+
users=select_ai_user, provider_endpoint="*.openai.azure.com"
29+
)
30+
print("Enabled AI provider for user: ", select_ai_user)
31+
32+
33+
asyncio.run(main())

samples/delete_ai_credential.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# -----------------------------------------------------------------------------
2+
# Copyright (c) 2025, Oracle and/or its affiliates.
3+
#
4+
# Licensed under the Universal Permissive License v 1.0 as shown at
5+
# http://oss.oracle.com/licenses/upl.
6+
# -----------------------------------------------------------------------------
7+
8+
# -----------------------------------------------------------------------------
9+
# delete_ai_credential.py
10+
#
11+
# Create a Database credential storing OCI Gen AI's credentials
12+
# -----------------------------------------------------------------------------
13+
import os
14+
15+
import select_ai
16+
17+
user = os.getenv("SELECT_AI_USER")
18+
password = os.getenv("SELECT_AI_PASSWORD")
19+
dsn = os.getenv("SELECT_AI_DB_CONNECT_STRING")
20+
21+
select_ai.connect(user=user, password=password, dsn=dsn)
22+
select_ai.delete_credential(
23+
credential_name="my_oci_ai_profile_key", force=True
24+
)
25+
print("Deleted credential: my_oci_ai_profile_key")

src/select_ai/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
# -----------------------------------------------------------------------------
77

88
from .action import Action
9-
from .admin import (
10-
create_credential,
11-
disable_provider,
12-
enable_provider,
13-
)
149
from .async_profile import AsyncProfile
1510
from .base_profile import BaseProfile, ProfileAttributes
1611
from .conversation import (
1712
AsyncConversation,
1813
Conversation,
1914
ConversationAttributes,
2015
)
16+
from .credential import (
17+
async_create_credential,
18+
async_delete_credential,
19+
create_credential,
20+
delete_credential,
21+
)
2122
from .db import (
2223
async_connect,
2324
async_cursor,
@@ -39,6 +40,10 @@
3940
OCIGenAIProvider,
4041
OpenAIProvider,
4142
Provider,
43+
async_disable_provider,
44+
async_enable_provider,
45+
disable_provider,
46+
enable_provider,
4247
)
4348
from .synthetic_data import (
4449
SyntheticDataAttributes,

src/select_ai/async_profile.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ async def _init_profile(self):
5555
:return: None
5656
:raises: oracledb.DatabaseError
5757
"""
58-
if self.profile_name is not None:
58+
if self.profile_name:
5959
profile_exists = False
6060
try:
6161
saved_attributes = await self._get_attributes(
@@ -75,7 +75,7 @@ async def _init_profile(self):
7575
profile_name=self.profile_name
7676
)
7777
except ProfileNotFoundError:
78-
if self.attributes is None:
78+
if self.attributes is None and self.description is None:
7979
raise
8080
else:
8181
if self.attributes is None:
@@ -91,10 +91,13 @@ async def _init_profile(self):
9191
await self.create(
9292
replace=self.replace, description=self.description
9393
)
94+
else: # profile name is None:
95+
if self.attributes is not None or self.description is not None:
96+
raise ValueError("'profile_name' cannot be empty or None")
9497
return self
9598

9699
@staticmethod
97-
async def _get_profile_description(profile_name) -> str:
100+
async def _get_profile_description(profile_name) -> Union[str, None]:
98101
"""Get description of profile from USER_CLOUD_AI_PROFILES
99102
100103
:param str profile_name: Name of profile
@@ -110,7 +113,10 @@ async def _get_profile_description(profile_name) -> str:
110113
)
111114
profile = await cr.fetchone()
112115
if profile:
113-
return await profile[1].read()
116+
if profile[1] is not None:
117+
return await profile[1].read()
118+
else:
119+
return None
114120
else:
115121
raise ProfileNotFoundError(profile_name)
116122

@@ -186,6 +192,12 @@ async def set_attributes(self, attributes: ProfileAttributes):
186192
attributes
187193
:return: None
188194
"""
195+
if not isinstance(attributes, ProfileAttributes):
196+
raise TypeError(
197+
"'attributes' must be an object of type "
198+
"select_ai.ProfileAttributes"
199+
)
200+
189201
self.attributes = attributes
190202
parameters = {
191203
"profile_name": self.profile_name,

src/select_ai/base_profile.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ class ProfileAttributes(SelectAIDataClass):
7373
vector_index_name: Optional[str] = None
7474

7575
def __post_init__(self):
76-
if not isinstance(self.provider, Provider):
76+
if self.provider and not isinstance(self.provider, Provider):
7777
raise ValueError(
78-
f"The arg `provider` must be an object of "
79-
f"type select_ai.Provider"
78+
f"'provider' must be an object of " f"type select_ai.Provider"
8079
)
8180

8281
def json(self, exclude_null=True):
@@ -166,6 +165,11 @@ def __init__(
166165
):
167166
"""Initialize a base profile"""
168167
self.profile_name = profile_name
168+
if attributes and not isinstance(attributes, ProfileAttributes):
169+
raise TypeError(
170+
"'attributes' must be an object of type "
171+
"select_ai.ProfileAttributes"
172+
)
169173
self.attributes = attributes
170174
self.description = description
171175
self.merge = merge

src/select_ai/conversation.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ def get_attributes(self) -> ConversationAttributes:
129129
attributes = cr.fetchone()
130130
if attributes:
131131
conversation_title = attributes[0]
132-
description = attributes[1].read() # Oracle.LOB
132+
if attributes[1]:
133+
description = attributes[1].read() # Oracle.LOB
134+
else:
135+
description = None
133136
retention_days = attributes[2]
134137
return ConversationAttributes(
135138
title=conversation_title,
@@ -154,7 +157,10 @@ def list(cls) -> Iterator["Conversation"]:
154157
for row in cr.fetchall():
155158
conversation_id = row[0]
156159
conversation_title = row[1]
157-
description = row[2].read() # Oracle.LOB
160+
if row[2]:
161+
description = row[2].read() # Oracle.LOB
162+
else:
163+
description = None
158164
retention_days = row[3]
159165
attributes = ConversationAttributes(
160166
title=conversation_title,
@@ -224,7 +230,10 @@ async def get_attributes(self) -> ConversationAttributes:
224230
attributes = await cr.fetchone()
225231
if attributes:
226232
conversation_title = attributes[0]
227-
description = await attributes[1].read() # Oracle.AsyncLOB
233+
if attributes[1]:
234+
description = await attributes[1].read() # Oracle.AsyncLOB
235+
else:
236+
description = None
228237
retention_days = attributes[2]
229238
return ConversationAttributes(
230239
title=conversation_title,
@@ -250,7 +259,10 @@ async def list(cls) -> AsyncGenerator["AsyncConversation", None]:
250259
for row in rows:
251260
conversation_id = row[0]
252261
conversation_title = row[1]
253-
description = await row[2].read() # Oracle.AsyncLOB
262+
if row[2]:
263+
description = await row[2].read() # Oracle.AsyncLOB
264+
else:
265+
description = None
254266
retention_days = row[3]
255267
attributes = ConversationAttributes(
256268
title=conversation_title,

0 commit comments

Comments
 (0)