|
11 | 11 | class OpenAIChatDecoder(DecoderBase): |
12 | 12 | def __init__(self, name: str, base_url=None, **kwargs) -> None: |
13 | 13 | super().__init__(name, **kwargs) |
14 | | - self.client = openai.OpenAI( |
15 | | - api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=base_url |
16 | | - ) |
17 | | - |
18 | | - # def codegen( |
19 | | - # self, prompts: List[str], do_sample: bool = True, num_samples: int = 200 |
20 | | - # ) -> List[str]: |
21 | | - # if do_sample: |
22 | | - # assert self.temperature > 0, "Temperature must be positive for sampling" |
23 | | - # all_outputs = [] |
24 | | - # for prompt in tqdm(prompts): |
25 | | - # outputs = [] |
26 | | - # message = make_raw_chat_prompt( |
27 | | - # task_prompt=prompt, |
28 | | - # subset=self.subset, |
29 | | - # split=self.split, |
30 | | - # instruction_prefix=self.instruction_prefix, |
31 | | - # response_prefix=self.response_prefix, |
32 | | - # tokenizer=None, |
33 | | - # ) |
34 | | - # ret = make_auto_request( |
35 | | - # self.client, |
36 | | - # message=message, |
37 | | - # model=self.name, |
38 | | - # max_tokens=self.max_new_tokens, |
39 | | - # temperature=self.temperature, |
40 | | - # n=num_samples, |
41 | | - # ) |
42 | | - # for item in ret.choices: |
43 | | - # outputs.append(item.message.content) |
44 | | - # all_outputs.append(outputs) |
45 | | - # return all_outputs |
46 | | - |
47 | | - # def is_direct_completion(self) -> bool: |
48 | | - # return False |
| 14 | + self.base_url = base_url |
49 | 15 |
|
50 | 16 | def codegen( |
51 | 17 | self, prompts: List[str], do_sample: bool = True, num_samples: int = 200 |
|
0 commit comments