55
66import json
77import logging
8- from typing import Any , Iterable , Optional , Union
8+ from typing import Any , Iterable , Optional , cast
99
1010from ollama import Client as OllamaClient
1111from typing_extensions import TypedDict , Unpack , override
1212
13- from ..types .content import ContentBlock , Message , Messages
14- from ..types .media import DocumentContent , ImageContent
13+ from ..types .content import ContentBlock , Messages
1514from ..types .models import Model
1615from ..types .streaming import StopReason , StreamEvent
1716from ..types .tools import ToolSpec
@@ -92,35 +91,31 @@ def get_config(self) -> OllamaConfig:
9291 """
9392 return self .config
9493
95- @override
96- def format_request (
97- self , messages : Messages , tool_specs : Optional [list [ToolSpec ]] = None , system_prompt : Optional [str ] = None
98- ) -> dict [str , Any ]:
99- """Format an Ollama chat streaming request.
94+ def _format_request_message_contents (self , role : str , content : ContentBlock ) -> list [dict [str , Any ]]:
95+ """Format Ollama compatible message contents.
96+
97+ Ollama doesn't support an array of contents, so we must flatten everything into separate message blocks.
10098
10199 Args:
102- messages: List of message objects to be processed by the model.
103- tool_specs: List of tool specifications to make available to the model.
104- system_prompt: System prompt to provide context to the model.
100+ role: E.g., user.
101+ content: Content block to format.
105102
106103 Returns:
107- An Ollama chat streaming request .
104+ Ollama formatted message contents .
108105
109106 Raises:
110- TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible
111- format.
107+ TypeError: If the content block type cannot be converted to an Ollama-compatible format.
112108 """
109+ if "text" in content :
110+ return [{"role" : role , "content" : content ["text" ]}]
113111
114- def format_message (message : Message , content : ContentBlock ) -> dict [str , Any ]:
115- if "text" in content :
116- return {"role" : message ["role" ], "content" : content ["text" ]}
112+ if "image" in content :
113+ return [{"role" : role , "images" : [content ["image" ]["source" ]["bytes" ]]}]
117114
118- if "image" in content :
119- return {"role" : message ["role" ], "images" : [content ["image" ]["source" ]["bytes" ]]}
120-
121- if "toolUse" in content :
122- return {
123- "role" : "assistant" ,
115+ if "toolUse" in content :
116+ return [
117+ {
118+ "role" : role ,
124119 "tool_calls" : [
125120 {
126121 "function" : {
@@ -130,45 +125,63 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]:
130125 }
131126 ],
132127 }
128+ ]
129+
130+ if "toolResult" in content :
131+ return [
132+ formatted_tool_result_content
133+ for tool_result_content in content ["toolResult" ]["content" ]
134+ for formatted_tool_result_content in self ._format_request_message_contents (
135+ "tool" ,
136+ (
137+ {"text" : json .dumps (tool_result_content ["json" ])}
138+ if "json" in tool_result_content
139+ else cast (ContentBlock , tool_result_content )
140+ ),
141+ )
142+ ]
133143
134- if "toolResult" in content :
135- result_content : Union [str , ImageContent , DocumentContent , Any ] = None
136- result_images = []
137- for tool_result_content in content ["toolResult" ]["content" ]:
138- if "text" in tool_result_content :
139- result_content = tool_result_content ["text" ]
140- elif "json" in tool_result_content :
141- result_content = tool_result_content ["json" ]
142- elif "image" in tool_result_content :
143- result_content = "see images"
144- result_images .append (tool_result_content ["image" ]["source" ]["bytes" ])
145- else :
146- result_content = content ["toolResult" ]["content" ]
144+ raise TypeError (f"content_type=<{ next (iter (content ))} > | unsupported type" )
147145
148- return {
149- "role" : "tool" ,
150- "content" : json .dumps (
151- {
152- "name" : content ["toolResult" ]["toolUseId" ],
153- "result" : result_content ,
154- "status" : content ["toolResult" ]["status" ],
155- }
156- ),
157- ** ({"images" : result_images } if result_images else {}),
158- }
146+ def _format_request_messages (self , messages : Messages , system_prompt : Optional [str ] = None ) -> list [dict [str , Any ]]:
147+ """Format an Ollama compatible messages array.
159148
160- raise TypeError (f"content_type=<{ next (iter (content ))} > | unsupported type" )
149+ Args:
150+ messages: List of message objects to be processed by the model.
151+ system_prompt: System prompt to provide context to the model.
161152
162- def format_messages () -> list [dict [str , Any ]]:
163- return [format_message (message , content ) for message in messages for content in message ["content" ]]
153+ Returns:
154+ An Ollama compatible messages array.
155+ """
156+ system_message = [{"role" : "system" , "content" : system_prompt }] if system_prompt else []
164157
165- formatted_messages = format_messages ()
158+ return system_message + [
159+ formatted_message
160+ for message in messages
161+ for content in message ["content" ]
162+ for formatted_message in self ._format_request_message_contents (message ["role" ], content )
163+ ]
166164
165+ @override
166+ def format_request (
167+ self , messages : Messages , tool_specs : Optional [list [ToolSpec ]] = None , system_prompt : Optional [str ] = None
168+ ) -> dict [str , Any ]:
169+ """Format an Ollama chat streaming request.
170+
171+ Args:
172+ messages: List of message objects to be processed by the model.
173+ tool_specs: List of tool specifications to make available to the model.
174+ system_prompt: System prompt to provide context to the model.
175+
176+ Returns:
177+ An Ollama chat streaming request.
178+
179+ Raises:
180+ TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible
181+ format.
182+ """
167183 return {
168- "messages" : [
169- * ([{"role" : "system" , "content" : system_prompt }] if system_prompt else []),
170- * formatted_messages ,
171- ],
184+ "messages" : self ._format_request_messages (messages , system_prompt ),
172185 "model" : self .config ["model_id" ],
173186 "options" : {
174187 ** (self .config .get ("options" ) or {}),
@@ -217,52 +230,54 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
217230 RuntimeError: If chunk_type is not recognized.
218231 This error should never be encountered as we control chunk_type in the stream method.
219232 """
220- if event ["chunk_type" ] == "message_start" :
221- return {"messageStart" : {"role" : "assistant" }}
222-
223- if event ["chunk_type" ] == "content_start" :
224- if event ["data_type" ] == "text" :
225- return {"contentBlockStart" : {"start" : {}}}
226-
227- tool_name = event ["data" ].function .name
228- return {"contentBlockStart" : {"start" : {"toolUse" : {"name" : tool_name , "toolUseId" : tool_name }}}}
229-
230- if event ["chunk_type" ] == "content_delta" :
231- if event ["data_type" ] == "text" :
232- return {"contentBlockDelta" : {"delta" : {"text" : event ["data" ]}}}
233-
234- tool_arguments = event ["data" ].function .arguments
235- return {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : json .dumps (tool_arguments )}}}}
236-
237- if event ["chunk_type" ] == "content_stop" :
238- return {"contentBlockStop" : {}}
239-
240- if event ["chunk_type" ] == "message_stop" :
241- reason : StopReason
242- if event ["data" ] == "tool_use" :
243- reason = "tool_use"
244- elif event ["data" ] == "length" :
245- reason = "max_tokens"
246- else :
247- reason = "end_turn"
248-
249- return {"messageStop" : {"stopReason" : reason }}
250-
251- if event ["chunk_type" ] == "metadata" :
252- return {
253- "metadata" : {
254- "usage" : {
255- "inputTokens" : event ["data" ].eval_count ,
256- "outputTokens" : event ["data" ].prompt_eval_count ,
257- "totalTokens" : event ["data" ].eval_count + event ["data" ].prompt_eval_count ,
258- },
259- "metrics" : {
260- "latencyMs" : event ["data" ].total_duration / 1e6 ,
233+ match event ["chunk_type" ]:
234+ case "message_start" :
235+ return {"messageStart" : {"role" : "assistant" }}
236+
237+ case "content_start" :
238+ if event ["data_type" ] == "text" :
239+ return {"contentBlockStart" : {"start" : {}}}
240+
241+ tool_name = event ["data" ].function .name
242+ return {"contentBlockStart" : {"start" : {"toolUse" : {"name" : tool_name , "toolUseId" : tool_name }}}}
243+
244+ case "content_delta" :
245+ if event ["data_type" ] == "text" :
246+ return {"contentBlockDelta" : {"delta" : {"text" : event ["data" ]}}}
247+
248+ tool_arguments = event ["data" ].function .arguments
249+ return {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : json .dumps (tool_arguments )}}}}
250+
251+ case "content_stop" :
252+ return {"contentBlockStop" : {}}
253+
254+ case "message_stop" :
255+ reason : StopReason
256+ if event ["data" ] == "tool_use" :
257+ reason = "tool_use"
258+ elif event ["data" ] == "length" :
259+ reason = "max_tokens"
260+ else :
261+ reason = "end_turn"
262+
263+ return {"messageStop" : {"stopReason" : reason }}
264+
265+ case "metadata" :
266+ return {
267+ "metadata" : {
268+ "usage" : {
269+ "inputTokens" : event ["data" ].eval_count ,
270+ "outputTokens" : event ["data" ].prompt_eval_count ,
271+ "totalTokens" : event ["data" ].eval_count + event ["data" ].prompt_eval_count ,
272+ },
273+ "metrics" : {
274+ "latencyMs" : event ["data" ].total_duration / 1e6 ,
275+ },
261276 },
262- },
263- }
277+ }
264278
265- raise RuntimeError (f"chunk_type=<{ event ['chunk_type' ]} | unknown type" )
279+ case _:
280+ raise RuntimeError (f"chunk_type=<{ event ['chunk_type' ]} | unknown type" )
266281
267282 @override
268283 def stream (self , request : dict [str , Any ]) -> Iterable [dict [str , Any ]]:
0 commit comments