Skip to content

Commit e6016a8

Browse files
Merge pull request #924 from andreswagner/watsonx
feature(dspy): Added support for IBM Watsonx.ai
2 parents b5c561f + cbd17dc commit e6016a8

File tree

4 files changed

+139
-0
lines changed

4 files changed

+139
-0
lines changed

dsp/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
from .pyserini import *
2323
from .sbert import *
2424
from .sentence_vectorizer import *
25+
from .watsonx import *

dsp/modules/lm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def inspect_history(self, n: int = 1, skip: int = 0):
5555
printed.append((prompt, x["response"].text))
5656
elif provider == "mistral":
5757
printed.append((prompt, x['response'].choices))
58+
elif provider == "ibm":
59+
printed.append((prompt, x))
5860
else:
5961
printed.append((prompt, x["response"]["choices"]))
6062

@@ -84,6 +86,8 @@ def inspect_history(self, n: int = 1, skip: int = 0):
8486
text = choices[0].parts[0].text
8587
elif provider == "mistral":
8688
text = choices[0].message.content
89+
elif provider == "ibm":
90+
text = choices
8791
else:
8892
text = choices[0]["text"]
8993
printing_value += self.print_green(text, end="")

dsp/modules/watsonx.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import Any
2+
3+
from dsp.modules.lm import LM
4+
5+
ibm_watsonx_ai_api_error = False
6+
7+
try:
8+
import ibm_watsonx_ai # noqa: F401
9+
from ibm_watsonx_ai.foundation_models import Model # type: ignore
10+
11+
except ImportError:
12+
ibm_watsonx_ai_api_error = Exception
13+
14+
15+
class Watsonx(LM):
16+
"""Wrapper around Watsonx AI's API.
17+
18+
The constructor initializes the base class LM to support prompting requests to Watsonx models.
19+
This requires the following parameters:
20+
Args:
21+
model (str): the type of model to use from IBM Watsonx AI.
22+
credentials ([dict]): credentials to Watson Machine Learning instance.
23+
project_id (str): ID of the Watson Studio project.
24+
**kwargs: Additional arguments to pass to the API provider. This is initialized with default values for relevant
25+
text generation parameters needed for communicating with Watsonx API, such as:
26+
- decoding_method
27+
- max_new_tokens
28+
- min_new_tokens
29+
- stop_sequences
30+
- repetition_penalty
31+
"""
32+
33+
def __init__(self, model, credentials, project_id, **kwargs):
34+
"""Parameters
35+
36+
model : str
37+
Which pre-trained model from Watsonx.ai to use?
38+
Choices are [
39+
`mistralai/mixtral-8x7b-instruct-v01`,
40+
`ibm/granite-13b-instruct-v2`,
41+
`meta-llama/llama-3-70b-instruct`]
42+
credentials : [dict]
43+
Credentials to Watson Machine Learning instance.
44+
project_id : str
45+
ID of the Watson Studio project.
46+
**kwargs: dict
47+
Additional arguments to pass to the API provider.
48+
"""
49+
self.model = model
50+
self.credentials = credentials
51+
self.project_id = project_id
52+
self.provider = "ibm"
53+
self.model_type = "instruct"
54+
self.kwargs = {
55+
"temperature": 0,
56+
"decoding_method": "greedy",
57+
"max_new_tokens": 150,
58+
"min_new_tokens": 0,
59+
"stop_sequences": [],
60+
"repetition_penalty": 1,
61+
"num_generations": 1,
62+
**kwargs,
63+
}
64+
65+
self.client = Model(
66+
model_id=self.model,
67+
params=self.kwargs,
68+
credentials=self.credentials,
69+
project_id=self.project_id,
70+
)
71+
72+
self.history: list[dict[str, Any]] = []
73+
74+
def basic_request(self, prompt: str, **kwargs) -> Any:
75+
raw_kwargs = kwargs
76+
kwargs = {**self.kwargs, **kwargs}
77+
78+
response = self.client.generate(prompt, params={**kwargs})
79+
80+
history = {
81+
"prompt": prompt,
82+
"response": response,
83+
"kwargs": kwargs,
84+
"raw_kwargs": raw_kwargs,
85+
}
86+
self.history.append(history)
87+
88+
return response
89+
90+
def request(self, prompt: str, **kwargs) -> Any:
91+
# Handles the specific prompting for each supported model and the retrieval of completions from IBM Watsonx AI
92+
93+
if self.model == "mistralai/mixtral-8x7b-instruct-v01":
94+
prompt = "<s>[INST]" + prompt + "</INST>"
95+
elif self.model == "meta-llama/llama-3-70b-instruct":
96+
prompt = (
97+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>"
98+
+ prompt
99+
+ "<|eot_id|><|start_header_id|>user<|end_header_id|>"
100+
)
101+
102+
return self.basic_request(prompt, **kwargs)
103+
104+
def __call__(
105+
self,
106+
prompt: str,
107+
only_completed: bool = True,
108+
return_sorted: bool = False,
109+
**kwargs,
110+
) -> list[dict[str, Any]]:
111+
"""Retrieves completions from Watsonx.
112+
113+
Args:
114+
prompt (str): prompt to send to Watsonx
115+
only_completed (bool, optional): return only completed responses and ignores completion due to length.
116+
Defaults to True.
117+
return_sorted (bool, optional): sort the completion choices using the returned probabilities.
118+
Defaults to False.
119+
**kwargs: Additional arguments to pass
120+
121+
Returns:
122+
list[dict[str, Any]]: list of completion choices
123+
"""
124+
if only_completed is False:
125+
raise ValueError("only_completed is True for now")
126+
127+
if return_sorted:
128+
raise ValueError("return_sorted is False for now")
129+
130+
response = self.request(prompt, **kwargs)
131+
132+
return [result["generated_text"] for result in response["results"]]

dspy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,7 @@
4242
AWSAnthropic = dsp.AWSAnthropic
4343
AWSMeta = dsp.AWSMeta
4444

45+
Watsonx = dsp.Watsonx
46+
4547
configure = settings.configure
4648
context = settings.context

0 commit comments

Comments
 (0)