55import sys
66import urllib .parse
77from hashlib import sha256
8- from os import PathLike
8+
99from pathlib import Path
1010from typing import (
1111 Any ,
1212 Callable ,
1313 Dict ,
14+ Generic ,
1415 List ,
1516 Literal ,
1617 Mapping ,
2223 overload ,
2324)
2425
26+ AnyCallable = Callable [..., Any ]
27+
2528import anyio
2629from pydantic .json_schema import JsonSchemaValue
2730
6972)
7073
7174T = TypeVar ('T' )
75+ CT = TypeVar ('CT' , httpx .Client , httpx .AsyncClient )
7276
7377
74- class BaseClient :
78+ class BaseClient ( Generic [ CT ]) :
7579 def __init__ (
7680 self ,
77- client ,
81+ client : Type [ CT ] ,
7882 host : Optional [str ] = None ,
7983 * ,
8084 follow_redirects : bool = True ,
@@ -90,7 +94,7 @@ def __init__(
9094 `kwargs` are passed to the httpx client.
9195 """
9296
93- self ._client = client (
97+ self ._client : CT = client (
9498 base_url = _parse_host (host or os .getenv ('OLLAMA_HOST' )),
9599 follow_redirects = follow_redirects ,
96100 timeout = timeout ,
@@ -111,7 +115,7 @@ def __init__(
111115CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
112116
113117
114- class Client (BaseClient ):
118+ class Client (BaseClient [ httpx . Client ] ):
115119 def __init__ (self , host : Optional [str ] = None , ** kwargs ) -> None :
116120 super ().__init__ (httpx .Client , host , ** kwargs )
117121
@@ -139,19 +143,10 @@ def _request(
139143 self ,
140144 cls : Type [T ],
141145 * args ,
142- stream : Literal [True ] = True ,
146+ stream : Literal [True ],
143147 ** kwargs ,
144148 ) -> Iterator [T ]: ...
145149
146- @overload
147- def _request (
148- self ,
149- cls : Type [T ],
150- * args ,
151- stream : bool = False ,
152- ** kwargs ,
153- ) -> Union [T , Iterator [T ]]: ...
154-
155150 def _request (
156151 self ,
157152 cls : Type [T ],
@@ -189,7 +184,7 @@ def generate(
189184 system : str = '' ,
190185 template : str = '' ,
191186 context : Optional [Sequence [int ]] = None ,
192- stream : Literal [False ] = False ,
187+ stream : Literal [False ],
193188 think : Optional [bool ] = None ,
194189 raw : bool = False ,
195190 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -272,7 +267,7 @@ def chat(
272267 model : str = '' ,
273268 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
274269 * ,
275- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
270+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
276271 stream : Literal [False ] = False ,
277272 think : Optional [bool ] = None ,
278273 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -286,8 +281,8 @@ def chat(
286281 model : str = '' ,
287282 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
288283 * ,
289- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
290- stream : Literal [True ] = True ,
284+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
285+ stream : Literal [True ],
291286 think : Optional [bool ] = None ,
292287 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
293288 options : Optional [Union [Mapping [str , Any ], Options ]] = None ,
@@ -299,7 +294,7 @@ def chat(
299294 model : str = '' ,
300295 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
301296 * ,
302- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
297+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
303298 stream : bool = False ,
304299 think : Optional [bool ] = None ,
305300 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -414,7 +409,7 @@ def pull(
414409 model : str ,
415410 * ,
416411 insecure : bool = False ,
417- stream : Literal [True ] = True ,
412+ stream : Literal [True ],
418413 ) -> Iterator [ProgressResponse ]: ...
419414
420415 def pull (
@@ -447,7 +442,7 @@ def push(
447442 model : str ,
448443 * ,
449444 insecure : bool = False ,
450- stream : Literal [False ] = False ,
445+ stream : Literal [False ],
451446 ) -> ProgressResponse : ...
452447
453448 @overload
@@ -497,7 +492,7 @@ def create(
497492 parameters : Optional [Union [Mapping [str , Any ], Options ]] = None ,
498493 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
499494 * ,
500- stream : Literal [False ] = False ,
495+ stream : Literal [False ],
501496 ) -> ProgressResponse : ...
502497
503498 @overload
@@ -623,7 +618,7 @@ def ps(self) -> ProcessResponse:
623618 )
624619
625620
626- class AsyncClient (BaseClient ):
621+ class AsyncClient (BaseClient [ httpx . AsyncClient ] ):
627622 def __init__ (self , host : Optional [str ] = None , ** kwargs ) -> None :
628623 super ().__init__ (httpx .AsyncClient , host , ** kwargs )
629624
@@ -783,7 +778,7 @@ async def chat(
783778 model : str = '' ,
784779 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
785780 * ,
786- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
781+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
787782 stream : Literal [False ] = False ,
788783 think : Optional [bool ] = None ,
789784 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -797,7 +792,7 @@ async def chat(
797792 model : str = '' ,
798793 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
799794 * ,
800- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
795+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
801796 stream : Literal [True ] = True ,
802797 think : Optional [bool ] = None ,
803798 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -810,7 +805,7 @@ async def chat(
810805 model : str = '' ,
811806 messages : Optional [Sequence [Union [Mapping [str , Any ], Message ]]] = None ,
812807 * ,
813- tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ,
808+ tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ,
814809 stream : bool = False ,
815810 think : Optional [bool ] = None ,
816811 format : Optional [Union [Literal ['' , 'json' ], JsonSchemaValue ]] = None ,
@@ -1155,21 +1150,11 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
11551150 )
11561151
11571152
1158- def _copy_tools (tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , Callable ]]] = None ) -> Iterator [Tool ]:
1153+ def _copy_tools (tools : Optional [Sequence [Union [Mapping [str , Any ], Tool , AnyCallable ]]] = None ) -> Iterator [Tool ]:
11591154 for unprocessed_tool in tools or []:
11601155 yield convert_function_to_tool (unprocessed_tool ) if callable (unprocessed_tool ) else Tool .model_validate (unprocessed_tool )
11611156
11621157
1163- def _as_path (s : Optional [Union [str , PathLike ]]) -> Union [Path , None ]:
1164- if isinstance (s , (str , Path )):
1165- try :
1166- if (p := Path (s )).exists ():
1167- return p
1168- except Exception :
1169- ...
1170- return None
1171-
1172-
11731158def _parse_host (host : Optional [str ]) -> str :
11741159 """
11751160 >>> _parse_host(None)
0 commit comments