1414import pytest
1515from aiohttp import ClientConnectorError , ClientOSError , ServerDisconnectedError
1616from docker .errors import NotFound
17- from loguru import logger
18- from test_model import TEST_CONFIGS
19- from text_generation import AsyncClient
20- from text_generation .types import Response
17+ import logging
18+ from huggingface_hub import AsyncInferenceClient , TextGenerationOutput
19+ import huggingface_hub
20+
21+ logging .basicConfig (
22+ level = logging .INFO ,
23+ format = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" ,
24+ stream = sys .stdout ,
25+ )
26+ logger = logging .getLogger (__file__ )
2127
2228# Use the latest image from the local docker build
2329DOCKER_IMAGE = os .getenv ("DOCKER_IMAGE" , "tgi-gaudi" )
2430DOCKER_VOLUME = os .getenv ("DOCKER_VOLUME" , None )
25- HF_TOKEN = os . getenv ( "HF_TOKEN" , None )
31+ HF_TOKEN = huggingface_hub . get_token ( )
2632
2733assert (
2834 HF_TOKEN is not None
4854 "cap_add" : ["sys_nice" ],
4955}
5056
51- logger .add (
52- sys .stderr ,
53- format = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" ,
54- level = "INFO" ,
55- )
56-
5757
5858def stream_container_logs (container , test_name ):
5959 """Stream container logs in a separate thread."""
@@ -69,9 +69,15 @@ def stream_container_logs(container, test_name):
6969 logger .error (f"Error streaming container logs: { str (e )} " )
7070
7171
72+ class TestClient (AsyncInferenceClient ):
73+ def __init__ (self , service_name : str , base_url : str ):
74+ super ().__init__ (model = base_url )
75+ self .service_name = service_name
76+
77+
7278class LauncherHandle :
73- def __init__ (self , port : int ):
74- self .client = AsyncClient ( f"http://localhost:{ port } " , timeout = 3600 )
79+ def __init__ (self , service_name : str , port : int ):
80+ self .client = TestClient ( service_name , f"http://localhost:{ port } " )
7581
7682 def _inner_health (self ):
7783 raise NotImplementedError
@@ -87,7 +93,7 @@ async def health(self, timeout: int = 60):
8793 raise RuntimeError ("Launcher crashed" )
8894
8995 try :
90- await self .client .generate ("test" )
96+ await self .client .text_generation ("test" , max_new_tokens = 1 )
9197 elapsed = time .time () - start_time
9298 logger .info (f"Health check passed after { elapsed :.1f} s" )
9399 return
@@ -111,7 +117,8 @@ async def health(self, timeout: int = 60):
111117
112118class ContainerLauncherHandle (LauncherHandle ):
113119 def __init__ (self , docker_client , container_name , port : int ):
114- super (ContainerLauncherHandle , self ).__init__ (port )
120+ service_name = container_name # Use container name as service name
121+ super (ContainerLauncherHandle , self ).__init__ (service_name , port )
115122 self .docker_client = docker_client
116123 self .container_name = container_name
117124
@@ -132,7 +139,8 @@ def _inner_health(self) -> bool:
132139
133140class ProcessLauncherHandle (LauncherHandle ):
134141 def __init__ (self , process , port : int ):
135- super (ProcessLauncherHandle , self ).__init__ (port )
142+ service_name = "process" # Use generic name for process launcher
143+ super (ProcessLauncherHandle , self ).__init__ (service_name , port )
136144 self .process = process
137145
138146 def _inner_health (self ) -> bool :
@@ -151,11 +159,13 @@ def data_volume():
151159
152160
153161@pytest .fixture (scope = "module" )
154- def launcher ( data_volume ):
162+ def gaudi_launcher ( ):
155163 @contextlib .contextmanager
156164 def docker_launcher (
157165 model_id : str ,
158166 test_name : str ,
167+ tgi_args : List [str ] = None ,
168+ env_config : dict = None ,
159169 ):
160170 logger .info (
161171 f"Starting docker launcher for model { model_id } and test { test_name } "
@@ -183,32 +193,40 @@ def get_free_port():
183193 )
184194 container .stop ()
185195 container .wait ()
196+ container .remove ()
197+ logger .info (f"Removed existing container { container_name } " )
186198 except NotFound :
187199 pass
188200 except Exception as e :
189201 logger .error (f"Error handling existing container: { str (e )} " )
190202
191- model_name = next (
192- name for name , cfg in TEST_CONFIGS .items () if cfg ["model_id" ] == model_id
193- )
194-
195- tgi_args = TEST_CONFIGS [model_name ]["args" ].copy ()
203+ if tgi_args is None :
204+ tgi_args = []
205+ else :
206+ tgi_args = tgi_args .copy ()
196207
197208 env = BASE_ENV .copy ()
198209
199210 # Add model_id to env
200211 env ["MODEL_ID" ] = model_id
201212
202- # Add env config that is definied in the fixture parameter
203- if " env_config" in TEST_CONFIGS [ model_name ] :
204- env .update (TEST_CONFIGS [ model_name ][ " env_config" ] .copy ())
213+ # Add env config that is defined in the fixture parameter
214+ if env_config is not None :
215+ env .update (env_config .copy ())
205216
206- volumes = [f"{ DOCKER_VOLUME } :/data" ]
217+ volumes = []
218+ if DOCKER_VOLUME :
219+ volumes = [f"{ DOCKER_VOLUME } :/data" ]
207220 logger .debug (f"Using volume { volumes } " )
208221
209222 try :
223+ logger .debug (f"Using command { tgi_args } " )
210224 logger .info (f"Creating container with name { container_name } " )
211225
226+ logger .debug (f"Using environment { env } " )
227+ logger .debug (f"Using volumes { volumes } " )
228+ logger .debug (f"HABANA_RUN_ARGS { HABANA_RUN_ARGS } " )
229+
212230 # Log equivalent docker run command for debugging, this is not actually executed
213231 container = client .containers .run (
214232 DOCKER_IMAGE ,
@@ -271,15 +289,16 @@ def get_free_port():
271289
272290
273291@pytest .fixture (scope = "module" )
274- def generate_load ():
292+ def gaudi_generate_load ():
275293 async def generate_load_inner (
276- client : AsyncClient , prompt : str , max_new_tokens : int , n : int
277- ) -> List [Response ]:
294+ client : AsyncInferenceClient , prompt : str , max_new_tokens : int , n : int
295+ ) -> List [TextGenerationOutput ]:
278296 try :
279297 futures = [
280- client .generate (
298+ client .text_generation (
281299 prompt ,
282300 max_new_tokens = max_new_tokens ,
301+ details = True ,
283302 decoder_input_details = True ,
284303 )
285304 for _ in range (n )
0 commit comments