Skip to content

Commit 4a8a89b

Browse files
committed
Initial commit
1 parent 656b7de commit 4a8a89b

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

dsp/modules/ibm_watsonx.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from typing import Any
2+
3+
from dsp.modules.lm import LM
4+
5+
try:
6+
from ibm_watsonx_ai.foundation_models import Model
7+
8+
except ImportError:
9+
ERRORS = ImportError
10+
11+
12+
class Watsonx(LM):
13+
"""Wrapper around Watsonx AI's API.
14+
15+
The constructor initializes the base class LM to support prompting requests to Watsonx models.
16+
This requires the following parameters:
17+
Args:
18+
model (str): the type of model to use from IBM Watsonx AI.
19+
credentials ([dict]): credentials to Watson Machine Learning instance.
20+
project_id (str): ID of the Watson Studio project.
21+
**kwargs: Additional arguments to pass to the API provider. This is initialized with default values for relevant
22+
text generation parameters needed for communicating with Watsonx API, such as:
23+
- decoding_method
24+
- max_new_tokens
25+
- min_new_tokens
26+
- stop_sequences
27+
- repetition_penalty
28+
"""
29+
30+
def __init__(self, model, credentials, project_id, **kwargs):
31+
"""Parameters
32+
33+
model : str
34+
Which pre-trained model from Watsonx.ai to use?
35+
Choices are [
36+
`mistralai/mixtral-8x7b-instruct-v01`,
37+
`ibm/granite-13b-instruct-v2`,
38+
`meta-llama/llama-3-70b-instruct`]
39+
credentials : [dict]
40+
Credentials to Watson Machine Learning instance.
41+
project_id : str
42+
ID of the Watson Studio project.
43+
**kwargs: dict
44+
Additional arguments to pass to the API provider.
45+
"""
46+
self.model = model
47+
self.credentials = credentials
48+
self.project_id = project_id
49+
self.provider = "ibm"
50+
self.model_type = "instruct"
51+
self.kwargs = {
52+
"temperature": 0,
53+
"decoding_method": "greedy",
54+
"max_new_tokens": 150,
55+
"min_new_tokens": 0,
56+
"stop_sequences": [],
57+
"repetition_penalty": 1,
58+
"num_generations": 1,
59+
**kwargs,
60+
}
61+
62+
self.client = Model(
63+
model_id=self.model,
64+
params=self.kwargs,
65+
credentials=self.credentials,
66+
project_id=self.project_id,
67+
)
68+
69+
self.history: list[dict[str, Any]] = []
70+
71+
def basic_request(self, prompt: str, **kwargs) -> Any:
72+
raw_kwargs = kwargs
73+
kwargs = {**self.kwargs, **kwargs}
74+
75+
response = self.client.generate(prompt, params={**kwargs})
76+
77+
history = {
78+
"prompt": prompt,
79+
"response": response,
80+
"kwargs": kwargs,
81+
"raw_kwargs": raw_kwargs,
82+
}
83+
self.history.append(history)
84+
85+
return response
86+
87+
def request(self, prompt: str, **kwargs) -> Any:
88+
# Handles the specific prompting for each supported model and the retrieval of completions from IBM Watsonx AI
89+
90+
if self.model == "mistralai/mixtral-8x7b-instruct-v01":
91+
prompt = "<s>[INST]" + prompt + "</INST>"
92+
elif self.model == "meta-llama/llama-3-70b-instruct":
93+
prompt = (
94+
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>"
95+
+ prompt
96+
+ "<|eot_id|><|start_header_id|>user<|end_header_id|>"
97+
)
98+
99+
return self.basic_request(prompt, **kwargs)
100+
101+
def __call__(
102+
self,
103+
prompt: str,
104+
only_completed: bool = True,
105+
return_sorted: bool = False,
106+
**kwargs,
107+
) -> list[dict[str, Any]]:
108+
"""Retrieves completions from Watsonx.
109+
110+
Args:
111+
prompt (str): prompt to send to Watsonx
112+
only_completed (bool, optional): return only completed responses and ignores completion due to length.
113+
Defaults to True.
114+
return_sorted (bool, optional): sort the completion choices using the returned probabilities.
115+
Defaults to False.
116+
**kwargs: Additional arguments to pass
117+
Returns:
118+
list[dict[str, Any]]: list of completion choices
119+
"""
120+
if only_completed is False:
121+
raise ValueError("only_completed is True for now")
122+
123+
if return_sorted:
124+
raise ValueError("return_sorted is False for now")
125+
126+
response = self.request(prompt, **kwargs)
127+
128+
return [result["generated_text"] for result in response["results"]]

0 commit comments

Comments
 (0)