55import sys
66import urllib .parse
77from hashlib import sha256
8- from os import PathLike
98from pathlib import Path
109from typing import (
1110 Any ,
1211 Callable ,
1312 Dict ,
13+ Generic ,
1414 List ,
1515 Literal ,
1616 Mapping ,
6969)
7070
7171T = TypeVar ('T' )
72+ CT = TypeVar ('CT' , httpx .Client , httpx .AsyncClient )
73+ AnyCallable = Callable [..., Any ]
7274
7375
74- class BaseClient :
76+ class BaseClient ( Generic [ CT ]) :
7577 def __init__ (
7678 self ,
77- client ,
79+ client : Type [ CT ] ,
7880 host : Optional [str ] = None ,
7981 * ,
8082 follow_redirects : bool = True ,
@@ -90,7 +92,7 @@ def __init__(
9092 `kwargs` are passed to the httpx client.
9193 """
9294
93- self ._client = client (
95+ self ._client : CT = client (
9496 base_url = _parse_host (host or os .getenv ('OLLAMA_HOST' )),
9597 follow_redirects = follow_redirects ,
9698 timeout = timeout ,
@@ -111,7 +113,7 @@ def __init__(
111113CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
112114
113115
114- class Client (BaseClient ):
116+ class Client (BaseClient [ httpx . Client ] ):
115117 def __init__ (self , host : Optional [str ] = None , ** kwargs ) -> None :
116118 super ().__init__ (httpx .Client , host , ** kwargs )
117119
@@ -139,19 +141,10 @@ def _request(
139141 self ,
140142 cls : Type [T ],
141143 * args ,
142- stream : Literal [True ] = True ,
144+ stream : Literal [True ],
143145 ** kwargs ,
144146 ) -> Iterator [T ]: ...
145147
146- @overload
147- def _request (
148- self ,
149- cls : Type [T ],
150- * args ,
151- stream : bool = False ,
152- ** kwargs ,
153- ) -> Union [T , Iterator [T ]]: ...
154-
155148 def _request (
156149 self ,
157150 cls : Type [T ],
@@ -189,7 +182,7 @@ def generate(
189182 system : str = '' ,
190183 template : str = '' ,
191184 context : Optional [Sequence [int ]] = None ,
192- stream : Literal [False ] = False ,
185+ stream : Literal [False ],
193186 think : Optional [bool ] = None ,
194187 raw : bool = False ,
195188 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -272,7 +265,7 @@ def chat(
272265 model : str = '' ,
273266 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
274267 * ,
275- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
268+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
276269 stream : Literal [False ] = False ,
277270 think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
278271 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -286,8 +279,8 @@ def chat(
286279 model : str = '' ,
287280 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
288281 * ,
289- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
290- stream : Literal [True ] = True ,
282+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
283+ stream : Literal [True ],
291284 think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
292285 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
293286 options : Optional [Union [Mapping [str , Any ], Options ]] = None ,
@@ -299,7 +292,7 @@ def chat(
299292 model : str = '' ,
300293 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
301294 * ,
302- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
295+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
303296 stream : bool = False ,
304297 think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
305298 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -414,7 +407,7 @@ def pull(
414407 model : str ,
415408 * ,
416409 insecure : bool = False ,
417- stream : Literal [True ] = True ,
410+ stream : Literal [True ],
418411 ) -> Iterator [ProgressResponse ]: ...
419412
420413 def pull (
@@ -447,7 +440,7 @@ def push(
447440 model : str ,
448441 * ,
449442 insecure : bool = False ,
450- stream : Literal [False ] = False ,
443+ stream : Literal [False ],
451444 ) -> ProgressResponse : ...
452445
453446 @overload
@@ -497,7 +490,7 @@ def create(
497490 parameters : Optional [Union [Mapping [str , Any ], Options ]] = None ,
498491 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
499492 * ,
500- stream : Literal [False ] = False ,
493+ stream : Literal [False ],
501494 ) -> ProgressResponse : ...
502495
503496 @overload
@@ -623,7 +616,7 @@ def ps(self) -> ProcessResponse:
623616 )
624617
625618
626- class AsyncClient (BaseClient ):
619+ class AsyncClient (BaseClient [ httpx . AsyncClient ] ):
627620 def __init__ (self , host : Optional [str ] = None , ** kwargs ) -> None :
628621 super ().__init__ (httpx .AsyncClient , host , ** kwargs )
629622
@@ -783,7 +776,7 @@ async def chat(
783776 model : str = '' ,
784777 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
785778 * ,
786- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
779+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
787780 stream : Literal [False ] = False ,
788781 think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
789782 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -797,7 +790,7 @@ async def chat(
797790 model : str = '' ,
798791 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
799792 * ,
800- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
793+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
801794 stream : Literal [True ] = True ,
802795 think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
803796 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -810,7 +803,7 @@ async def chat(
810803 model : str = '' ,
811804 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
812805 * ,
813- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
806+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
814807 stream : bool = False ,
815808 think : Optional [Union [bool , Literal ['low' , 'medium' , 'high' ]]] = None ,
816809 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -1155,21 +1148,11 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
11551148 )
11561149
11571150
1158- def _copy_tools (tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ) -> Iterator [Tool ]:
1151+ def _copy_tools (tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ) -> Iterator [Tool ]:
11591152 for unprocessed_tool in tools or []:
11601153 yield convert_function_to_tool (unprocessed_tool ) if callable (unprocessed_tool ) else Tool .model_validate (unprocessed_tool )
11611154
11621155
1163- def _as_path (s : Optional [Union [str , PathLike ]]) -> Union [Path , None ]:
1164- if isinstance (s , (str , Path )):
1165- try :
1166- if (p := Path (s )).exists ():
1167- return p
1168- except Exception :
1169- ...
1170- return None
1171-
1172-
11731156def _parse_host (host : Optional [str ]) -> str :
11741157 """
11751158 >>> _parse_host(None)
0 commit comments