Skip to content

Commit b1f0dbf

Browse files
authored
Watsonx module
Wrapper around Watsonx.ai API
1 parent 6ef5e37 commit b1f0dbf

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

dsp/modules/watsonx.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import Any
2+
3+
from dsp.modules.lm import LM
4+
5+
ibm_watsonx_ai_api_error = False
6+
7+
try:
8+
import ibm_watsonx_ai # noqa: F401
9+
from ibm_watsonx_ai.foundation_models import Model # type: ignore
10+
11+
except ImportError:
12+
ibm_watsonx_ai_api_error = Exception
13+
14+
15+
class Watsonx(LM):
16+
"""Wrapper around Watsonx AI's API.
17+
18+
The constructor initializes the base class LM to support prompting requests to Watsonx models.
19+
This requires the following parameters:
20+
Args:
21+
model (str): the type of model to use from IBM Watsonx AI.
22+
credentials ([dict]): credentials to Watson Machine Learning instance.
23+
project_id (str): ID of the Watson Studio project.
24+
**kwargs: Additional arguments to pass to the API provider. This is initialized with default values for relevant
25+
text generation parameters needed for communicating with Watsonx API, such as:
26+
- decoding_method
27+
- max_new_tokens
28+
- min_new_tokens
29+
- stop_sequences
30+
- repetition_penalty
31+
"""
32+
33+
def __init__(self, model, credentials, project_id, **kwargs):
34+
"""Parameters
35+
36+
model : str
37+
Which pre-trained model from Watsonx.ai to use?
38+
Choices are [
39+
`mistralai/mixtral-8x7b-instruct-v01`,
40+
`ibm/granite-13b-instruct-v2`,
41+
`meta-llama/llama-3-70b-instruct`]
42+
credentials : [dict]
43+
Credentials to Watson Machine Learning instance.
44+
project_id : str
45+
ID of the Watson Studio project.
46+
**kwargs: dict
47+
Additional arguments to pass to the API provider.
48+
"""
49+
self.model = model
50+
self.credentials = credentials
51+
self.project_id = project_id
52+
self.provider = "ibm"
53+
self.model_type = "instruct"
54+
self.kwargs = {
55+
"temperature": 0,
56+
"decoding_method": "greedy",
57+
"max_new_tokens": 150,
58+
"min_new_tokens": 0,
59+
"stop_sequences": [],
60+
"repetition_penalty": 1,
61+
"num_generations": 1,
62+
**kwargs,
63+
}
64+
65+
self.client = Model(
66+
model_id=self.model,
67+
params=self.kwargs,
68+
credentials=self.credentials,
69+
project_id=self.project_id,
70+
)
71+
72+
self.history: list[dict[str, Any]] = []
73+
74+
def basic_request(self, prompt: str, **kwargs) -> Any:
75+
raw_kwargs = kwargs
76+
kwargs = {**self.kwargs, **kwargs}
77+
78+
response = self.client.generate(prompt, params={**kwargs})
79+
80+
history = {
81+
"prompt": prompt,
82+
"response": response,
83+
"kwargs": kwargs,
84+
"raw_kwargs": raw_kwargs,
85+
}
86+
self.history.append(history)
87+
88+
return response
89+
90+
def request(self, prompt: str, **kwargs) -> Any:
91+
# Handles the specific prompting for each supported model and the retrieval of completions from IBM Watsonx AI
92+
93+
if self.model == "mistralai/mixtral-8x7b-instruct-v01":
94+
prompt = "<s>[INST]" + prompt + "</INST>"
95+
elif self.model == "meta-llama/llama-3-70b-instruct":
96+
prompt = (
97+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>"
98+
+ prompt
99+
+ "<|eot_id|><|start_header_id|>user<|end_header_id|>"
100+
)
101+
102+
return self.basic_request(prompt, **kwargs)
103+
104+
def __call__(
105+
self,
106+
prompt: str,
107+
only_completed: bool = True,
108+
return_sorted: bool = False,
109+
**kwargs,
110+
) -> list[dict[str, Any]]:
111+
"""Retrieves completions from Watsonx.
112+
113+
Args:
114+
prompt (str): prompt to send to Watsonx
115+
only_completed (bool, optional): return only completed responses and ignores completion due to length.
116+
Defaults to True.
117+
return_sorted (bool, optional): sort the completion choices using the returned probabilities.
118+
Defaults to False.
119+
**kwargs: Additional arguments to pass
120+
121+
Returns:
122+
list[dict[str, Any]]: list of completion choices
123+
"""
124+
if only_completed is False:
125+
raise ValueError("only_completed is True for now")
126+
127+
if return_sorted:
128+
raise ValueError("return_sorted is False for now")
129+
130+
response = self.request(prompt, **kwargs)
131+
132+
return [result["generated_text"] for result in response["results"]]

0 commit comments

Comments
 (0)