Skip to content

Commit c819e17

Browse files
authored
fix semaphore calls (#1012)
1 parent ba60756 commit c819e17

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/lighteval/models/endpoints/inference_providers_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(self, config: InferenceProvidersModelConfig) -> None:
116116
self.API_RETRY_SLEEP = 3
117117
self.API_RETRY_MULTIPLIER = 2
118118
self.pairwise_tokenization = False
119-
self.semaphore = asyncio.Semaphore(config.parallel_calls_count) # Limit concurrent API calls
119+
self.parallel_calls_count = config.parallel_calls_count
120120

121121
self.client = AsyncInferenceClient(
122122
provider=self.provider,
@@ -179,13 +179,16 @@ async def __call_api_parallel(
179179
):
180180
results = []
181181

182+
# Initialize semaphore for the current event loop
183+
semaphore = asyncio.Semaphore(self.parallel_calls_count)
184+
182185
num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples
183186
assert len(prompts) == len(num_sampless), (
184187
f"Length of prompts and max_new_tokenss should be the same but are {len(prompts)}, {len(num_sampless)}"
185188
)
186189

187190
async def bounded_api_call(prompt, num_samples):
188-
async with self.semaphore:
191+
async with semaphore:
189192
return await self.__call_api(prompt, num_samples)
190193

191194
tasks = [bounded_api_call(prompt, num_samples) for prompt, num_samples in zip(prompts, num_sampless)]

0 commit comments

Comments
 (0)