Skip to content

Commit 6e39272

Browse files
committed
Add verbose attribute to LLM models for debugging purpose.
1 parent 0458ea3 commit 6e39272

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

ads/llm/langchain/plugins/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,26 @@ class BaseLLM(LLM):
4444
stop: Optional[List[str]] = None
4545
"""Stop words to use when generating. Model output is cut off at the first occurrence of any of these substrings."""
4646

47+
verbose: int = 0
48+
"""Verbose level for debugging purpose.
49+
The LLM implementation should print out debugging information base on the verbose level:
50+
0 - No debugging information
51+
1 - Print prompt and response(completion) from LLM
52+
2 - In addition to prompt and response(completion) from LLM, also print the parameters (payloads).
53+
"""
54+
55+
def _print_request(self, prompt, params):
56+
if self.verbose >= 1:
57+
print(f"LLM API Request:\n{prompt}")
58+
elif self.verbose == 2:
59+
print(f"LLM API Parameters:\n{params}")
60+
61+
def _print_response(self, completion, response):
62+
if self.verbose == 1:
63+
print(f"LLM API Completion:\n{completion}")
64+
elif self.verbose == 2:
65+
print(f"LLM API Response:\n{response}")
66+
4767

4868
class GenerativeAiClientModel(BaseModel):
4969
client: Any #: :meta private:

ads/llm/langchain/plugins/llm_gen_ai.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _call(
146146
"""
147147

148148
params = self._invocation_params(stop, **kwargs)
149+
self._print_request(prompt, params)
149150

150151
try:
151152
response = (
@@ -163,7 +164,9 @@ def _call(
163164
)
164165
raise
165166

166-
return self._process_response(response, params.get("num_generations", 1))
167+
completion = self._process_response(response, params.get("num_generations", 1))
168+
self._print_response(completion, response)
169+
return completion
167170

168171
def _process_response(self, response: Any, num_generations: int = 1) -> str:
169172
if self.task == Task.SUMMARY_TEXT:

ads/llm/langchain/plugins/llm_md.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ def _call(
9494
"""
9595
params = self._invocation_params(stop, **kwargs)
9696
body = self._construct_json_body(prompt, params)
97+
self._print_request(prompt, params)
9798
response = self.send_request(data=body, endpoint=self.endpoint)
98-
return self._process_response(response)
99+
completion = self._process_response(response)
100+
self._print_response(completion, response)
101+
return completion
99102

100103
def send_request(
101104
self,
@@ -134,9 +137,7 @@ def send_request(
134137
request_kwargs["headers"] = header
135138
request_kwargs["auth"] = self.auth.get("signer")
136139
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
137-
response = requests.post(
138-
endpoint, timeout=timeout, **request_kwargs, **kwargs
139-
)
140+
response = requests.post(endpoint, timeout=timeout, **request_kwargs, **kwargs)
140141

141142
try:
142143
response.raise_for_status()
@@ -205,7 +206,7 @@ def _construct_json_body(self, prompt, params):
205206
}
206207

207208
def _process_response(self, response_json: dict):
208-
return str(response_json.get("generated_text", response_json))
209+
return str(response_json.get("generated_text", response_json)) + "\n"
209210

210211

211212
class ModelDeploymentVLLM(ModelDeploymentLLM):

0 commit comments

Comments
 (0)