Skip to content

Commit afa881c

Browse files
committed
Add support for Hugging Face Serverless Inference
1 parent 41770a7 commit afa881c

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import time
2+
3+
from huggingface_hub import InferenceClient
4+
from huggingface_hub.inference._generated.types import TextGenerationOutput
5+
6+
7+
def make_request(
8+
client: InferenceClient,
9+
message: str,
10+
model: str,
11+
temperature: float,
12+
n: int,
13+
max_new_tokens: int = 2048,
14+
) -> TextGenerationOutput:
15+
response = client.text_generation(
16+
model=model,
17+
prompt=message,
18+
do_sample=False,
19+
max_new_tokens=max_new_tokens,
20+
)
21+
22+
return response
23+
24+
25+
def make_auto_request(*args, **kwargs) -> TextGenerationOutput:
26+
ret = None
27+
while ret is None:
28+
try:
29+
ret = make_request(*args, **kwargs)
30+
except Exception as e:
31+
print("Unknown error. Waiting...")
32+
print(e)
33+
time.sleep(1)
34+
return ret

bigcodebench/provider/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ def make_model(
6868
tokenizer_name=tokenizer_name,
6969
tokenizer_legacy=tokenizer_legacy,
7070
)
71+
elif backend == "hf-inference":
72+
from bigcodebench.provider.hf_inference import HuggingFaceInferenceDecoder
73+
74+
return HuggingFaceInferenceDecoder(
75+
name=model,
76+
subset=subset,
77+
split=split,
78+
temperature=temperature,
79+
max_new_tokens=max_new_tokens,
80+
direct_completion=direct_completion,
81+
instruction_prefix=instruction_prefix,
82+
response_prefix=response_prefix,
83+
)
7184
elif backend == "openai":
7285
from bigcodebench.provider.openai import OpenAIChatDecoder
7386

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os
2+
from typing import List
3+
from tqdm import tqdm
4+
5+
from huggingface_hub import InferenceClient
6+
7+
from bigcodebench.provider.base import DecoderBase
8+
from bigcodebench.gen.util.hf_inference_request import make_auto_request
9+
from bigcodebench.provider.utility import make_raw_chat_prompt
10+
11+
12+
class HuggingFaceInferenceDecoder(DecoderBase):
13+
def __init__(self, name: str, **kwargs):
14+
super().__init__(name, **kwargs)
15+
self.client = InferenceClient(
16+
provider="hf-inference", api_key=os.getenv("HF_INFERENCE_API_KEY")
17+
)
18+
19+
def codegen(
20+
self, prompts: List[str], do_sample: bool = True, num_samples: int = 200
21+
) -> List[str]:
22+
if do_sample:
23+
assert self.temperature > 0, "Temperature must be positive for sampling"
24+
25+
all_outputs = []
26+
27+
for prompt in tqdm(prompts):
28+
outputs = []
29+
message = (
30+
prompt
31+
if self.is_direct_completion()
32+
else make_raw_chat_prompt(
33+
task_prompt=prompt,
34+
subset=self.subset,
35+
split=self.split,
36+
instruction_prefix=self.instruction_prefix,
37+
response_prefix=self.response_prefix,
38+
tokenizer=None,
39+
)
40+
)
41+
ret = make_auto_request(
42+
self.client,
43+
message=message,
44+
model=self.name,
45+
n=num_samples,
46+
temperature=self.temperature,
47+
max_new_tokens=self.max_new_tokens,
48+
)
49+
outputs.append(ret)
50+
all_outputs.append(outputs)
51+
return all_outputs
52+
53+
def is_direct_completion(self) -> bool:
54+
return self.direct_completion

0 commit comments

Comments
 (0)