Skip to content

Commit 9055445

Browse files
feat: Introduce SingleStoreChatFactory and SingleStoreEmbeddingsFactory factories that instantiate various flavors of chat and embedding instances. (#94)
* feat: Introduce SingleStoreChat wrapper that uses interchangeably OpenAI or AmazonBedrockConverse protocol. * Provide option for getting 'X-S2-OBO' token for every request. * Provide headers that indicate passthrough Amazon bedrock requests. * Hardcode dummy credentials and region info for ChatBedrockConverse client. * Rename 'region' parameter to 'region_name'. * Expose 'streaming' parameter setting its oposite value to 'disable_streaming' ChatBedrockConverse client. * Set the default value for 'streaming' paramter to True. * Remove the cache option. * Remove unsupported kargs from Bedrock calls. * Replace composition wrapper with a factory method. * Pass bedrock runtime client as client parameter. * Pass also the 'X-BEDROCK-CONVERSE' headers that indicate that the requets should be handled as passthrough from UMG. * Remove some amazon specific headers, along with validation, remove X-BEDROCK headers as well. * Remove commented out code; set max retries to 1 for Amazon Bedrock models. * Use 'Union' return type to satisfy pre-commit checks for python version 3.9. * Expose also the hostingPlatform for InferenceAPIInfo. * Do not use model prefix, rely on hosting platform. * Introduce SingleStoreEmbeddingsFactory; small fixes. * Fix openai langchain library. * Remove any comments that expose internal implementation details. * Minor fixes.
1 parent a9d6ae8 commit 9055445

File tree

4 files changed

+233
-9
lines changed

4 files changed

+233
-9
lines changed

singlestoredb/ai/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
1+
from .chat import SingleStoreChat # noqa: F401
2+
from .chat import SingleStoreChatFactory # noqa: F401
13
from .chat import SingleStoreChatOpenAI # noqa: F401
24
from .embeddings import SingleStoreEmbeddings # noqa: F401
5+
from .embeddings import SingleStoreEmbeddingsFactory # noqa: F401

singlestoredb/ai/chat.py

Lines changed: 126 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import os
22
from typing import Any
3+
from typing import Callable
4+
from typing import Optional
5+
from typing import Union
36

4-
from singlestoredb.fusion.handlers.utils import get_workspace_manager
7+
import httpx
8+
9+
from singlestoredb import manage_workspaces
510

611
try:
712
from langchain_openai import ChatOpenAI
@@ -11,30 +16,144 @@
1116
'Please install it with `pip install langchain_openai`.',
1217
)
1318

19+
try:
20+
from langchain_aws import ChatBedrockConverse
21+
except ImportError:
22+
raise ImportError(
23+
'Could not import langchain-aws python package. '
24+
'Please install it with `pip install langchain-aws`.',
25+
)
26+
27+
import boto3
28+
from botocore import UNSIGNED
29+
from botocore.config import Config
30+
1431

1532
class SingleStoreChatOpenAI(ChatOpenAI):
16-
def __init__(self, model_name: str, **kwargs: Any):
33+
def __init__(self, model_name: str, api_key: Optional[str] = None, **kwargs: Any):
1734
inference_api_manger = (
18-
get_workspace_manager().organizations.current.inference_apis
35+
manage_workspaces().organizations.current.inference_apis
1936
)
2037
info = inference_api_manger.get(model_name=model_name)
38+
token = (
39+
api_key
40+
if api_key is not None
41+
else os.environ.get('SINGLESTOREDB_USER_TOKEN')
42+
)
2143
super().__init__(
2244
base_url=info.connection_url,
23-
api_key=os.environ.get('SINGLESTOREDB_USER_TOKEN'),
45+
api_key=token,
2446
model=model_name,
2547
**kwargs,
2648
)
2749

2850

2951
class SingleStoreChat(ChatOpenAI):
30-
def __init__(self, model_name: str, **kwargs: Any):
52+
def __init__(self, model_name: str, api_key: Optional[str] = None, **kwargs: Any):
3153
inference_api_manger = (
32-
get_workspace_manager().organizations.current.inference_apis
54+
manage_workspaces().organizations.current.inference_apis
3355
)
3456
info = inference_api_manger.get(model_name=model_name)
57+
token = (
58+
api_key
59+
if api_key is not None
60+
else os.environ.get('SINGLESTOREDB_USER_TOKEN')
61+
)
3562
super().__init__(
3663
base_url=info.connection_url,
37-
api_key=os.environ.get('SINGLESTOREDB_USER_TOKEN'),
64+
api_key=token,
3865
model=model_name,
3966
**kwargs,
4067
)
68+
69+
70+
def SingleStoreChatFactory(
71+
model_name: str,
72+
api_key: Optional[str] = None,
73+
streaming: bool = True,
74+
http_client: Optional[httpx.Client] = None,
75+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
76+
**kwargs: Any,
77+
) -> Union[ChatOpenAI, ChatBedrockConverse]:
78+
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
79+
"""
80+
inference_api_manager = (
81+
manage_workspaces().organizations.current.inference_apis
82+
)
83+
info = inference_api_manager.get(model_name=model_name)
84+
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
85+
token = api_key if api_key is not None else token_env
86+
87+
if info.hosting_platform == 'Amazon':
88+
# Instantiate Bedrock client
89+
cfg_kwargs = {
90+
'signature_version': UNSIGNED,
91+
'retries': {'max_attempts': 1, 'mode': 'standard'},
92+
}
93+
if http_client is not None and http_client.timeout is not None:
94+
cfg_kwargs['read_timeout'] = http_client.timeout
95+
cfg_kwargs['connect_timeout'] = http_client.timeout
96+
97+
cfg = Config(**cfg_kwargs)
98+
client = boto3.client(
99+
'bedrock-runtime',
100+
endpoint_url=info.connection_url,
101+
region_name='us-east-1',
102+
aws_access_key_id='placeholder',
103+
aws_secret_access_key='placeholder',
104+
config=cfg,
105+
)
106+
107+
def _inject_headers(request: Any, **_ignored: Any) -> None:
108+
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
109+
if obo_token_getter is not None:
110+
obo_val = obo_token_getter()
111+
if obo_val:
112+
request.headers['X-S2-OBO'] = obo_val
113+
if token:
114+
request.headers['Authorization'] = f'Bearer {token}'
115+
request.headers.pop('X-Amz-Date', None)
116+
request.headers.pop('X-Amz-Security-Token', None)
117+
118+
emitter = client._endpoint._event_emitter
119+
emitter.register_first(
120+
'before-send.bedrock-runtime.Converse',
121+
_inject_headers,
122+
)
123+
emitter.register_first(
124+
'before-send.bedrock-runtime.ConverseStream',
125+
_inject_headers,
126+
)
127+
emitter.register_first(
128+
'before-send.bedrock-runtime.InvokeModel',
129+
_inject_headers,
130+
)
131+
emitter.register_first(
132+
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
133+
_inject_headers,
134+
)
135+
136+
return ChatBedrockConverse(
137+
model_id=model_name,
138+
endpoint_url=info.connection_url,
139+
region_name='us-east-1',
140+
aws_access_key_id='placeholder',
141+
aws_secret_access_key='placeholder',
142+
disable_streaming=not streaming,
143+
client=client,
144+
**kwargs,
145+
)
146+
147+
# OpenAI / Azure OpenAI path
148+
openai_kwargs = dict(
149+
base_url=info.connection_url,
150+
api_key=token,
151+
model=model_name,
152+
streaming=streaming,
153+
)
154+
if http_client is not None:
155+
openai_kwargs['http_client'] = http_client
156+
return ChatOpenAI(
157+
**openai_kwargs,
158+
**kwargs,
159+
)

singlestoredb/ai/embeddings.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import os
22
from typing import Any
3+
from typing import Callable
4+
from typing import Optional
5+
from typing import Union
36

4-
from singlestoredb.fusion.handlers.utils import get_workspace_manager
7+
import httpx
8+
9+
from singlestoredb import manage_workspaces
510

611
try:
712
from langchain_openai import OpenAIEmbeddings
@@ -11,12 +16,24 @@
1116
'Please install it with `pip install langchain_openai`.',
1217
)
1318

19+
try:
20+
from langchain_aws import BedrockEmbeddings
21+
except ImportError:
22+
raise ImportError(
23+
'Could not import langchain-aws python package. '
24+
'Please install it with `pip install langchain-aws`.',
25+
)
26+
27+
import boto3
28+
from botocore import UNSIGNED
29+
from botocore.config import Config
30+
1431

1532
class SingleStoreEmbeddings(OpenAIEmbeddings):
1633

1734
def __init__(self, model_name: str, **kwargs: Any):
1835
inference_api_manger = (
19-
get_workspace_manager().organizations.current.inference_apis
36+
manage_workspaces().organizations.current.inference_apis
2037
)
2138
info = inference_api_manger.get(model_name=model_name)
2239
super().__init__(
@@ -25,3 +42,84 @@ def __init__(self, model_name: str, **kwargs: Any):
2542
model=model_name,
2643
**kwargs,
2744
)
45+
46+
47+
def SingleStoreEmbeddingsFactory(
48+
model_name: str,
49+
api_key: Optional[str] = None,
50+
http_client: Optional[httpx.Client] = None,
51+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
52+
**kwargs: Any,
53+
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
54+
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
55+
"""
56+
inference_api_manager = (
57+
manage_workspaces().organizations.current.inference_apis
58+
)
59+
info = inference_api_manager.get(model_name=model_name)
60+
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
61+
token = api_key if api_key is not None else token_env
62+
63+
if info.hosting_platform == 'Amazon':
64+
# Instantiate Bedrock client
65+
cfg_kwargs = {
66+
'signature_version': UNSIGNED,
67+
'retries': {'max_attempts': 1, 'mode': 'standard'},
68+
}
69+
if http_client is not None and http_client.timeout is not None:
70+
cfg_kwargs['read_timeout'] = http_client.timeout
71+
cfg_kwargs['connect_timeout'] = http_client.timeout
72+
73+
cfg = Config(**cfg_kwargs)
74+
client = boto3.client(
75+
'bedrock-runtime',
76+
endpoint_url=info.connection_url,
77+
region_name='us-east-1',
78+
aws_access_key_id='placeholder',
79+
aws_secret_access_key='placeholder',
80+
config=cfg,
81+
)
82+
83+
def _inject_headers(request: Any, **_ignored: Any) -> None:
84+
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
85+
if obo_token_getter is not None:
86+
obo_val = obo_token_getter()
87+
if obo_val:
88+
request.headers['X-S2-OBO'] = obo_val
89+
if token:
90+
request.headers['Authorization'] = f'Bearer {token}'
91+
request.headers.pop('X-Amz-Date', None)
92+
request.headers.pop('X-Amz-Security-Token', None)
93+
94+
emitter = client._endpoint._event_emitter
95+
emitter.register_first(
96+
'before-send.bedrock-runtime.InvokeModel',
97+
_inject_headers,
98+
)
99+
emitter.register_first(
100+
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
101+
_inject_headers,
102+
)
103+
104+
return BedrockEmbeddings(
105+
model_id=model_name,
106+
endpoint_url=info.connection_url,
107+
region_name='us-east-1',
108+
aws_access_key_id='placeholder',
109+
aws_secret_access_key='placeholder',
110+
client=client,
111+
**kwargs,
112+
)
113+
114+
# OpenAI / Azure OpenAI path
115+
openai_kwargs = dict(
116+
base_url=info.connection_url,
117+
api_key=token,
118+
model=model_name,
119+
)
120+
if http_client is not None:
121+
openai_kwargs['http_client'] = http_client
122+
return OpenAIEmbeddings(
123+
**openai_kwargs,
124+
**kwargs,
125+
)

singlestoredb/management/inference_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class InferenceAPIInfo(object):
2323
name: str
2424
connection_url: str
2525
project_id: str
26+
hosting_platform: str
2627

2728
def __init__(
2829
self,
@@ -31,12 +32,14 @@ def __init__(
3132
name: str,
3233
connection_url: str,
3334
project_id: str,
35+
hosting_platform: str,
3436
):
3537
self.service_id = service_id
3638
self.connection_url = connection_url
3739
self.model_name = model_name
3840
self.name = name
3941
self.project_id = project_id
42+
self.hosting_platform = hosting_platform
4043

4144
@classmethod
4245
def from_dict(
@@ -62,6 +65,7 @@ def from_dict(
6265
model_name=obj['modelName'],
6366
name=obj['name'],
6467
connection_url=obj['connectionURL'],
68+
hosting_platform=obj['hostingPlatform'],
6569
)
6670
return out
6771

0 commit comments

Comments
 (0)