|
27 | 27 | CompletionResponseStreamChunk, |
28 | 28 | Inference, |
29 | 29 | ListOpenAIChatCompletionResponse, |
30 | | - LogProbConfig, |
31 | 30 | Message, |
32 | 31 | OpenAIAssistantMessageParam, |
33 | 32 | OpenAIChatCompletion, |
|
42 | 41 | OpenAIMessageParam, |
43 | 42 | OpenAIResponseFormatParam, |
44 | 43 | Order, |
45 | | - ResponseFormat, |
46 | | - SamplingParams, |
47 | 44 | StopReason, |
48 | | - ToolChoice, |
49 | | - ToolConfig, |
50 | | - ToolDefinition, |
51 | 45 | ToolPromptFormat, |
52 | 46 | ) |
53 | 47 | from llama_stack.apis.models import Model, ModelType |
@@ -185,88 +179,6 @@ async def _get_model(self, model_id: str, expected_model_type: str) -> Model: |
185 | 179 | raise ModelTypeError(model_id, model.model_type, expected_model_type) |
186 | 180 | return model |
187 | 181 |
|
188 | | - async def chat_completion( |
189 | | - self, |
190 | | - model_id: str, |
191 | | - messages: list[Message], |
192 | | - sampling_params: SamplingParams | None = None, |
193 | | - response_format: ResponseFormat | None = None, |
194 | | - tools: list[ToolDefinition] | None = None, |
195 | | - tool_choice: ToolChoice | None = None, |
196 | | - tool_prompt_format: ToolPromptFormat | None = None, |
197 | | - stream: bool | None = False, |
198 | | - logprobs: LogProbConfig | None = None, |
199 | | - tool_config: ToolConfig | None = None, |
200 | | - ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: |
201 | | - logger.debug( |
202 | | - f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", |
203 | | - ) |
204 | | - if sampling_params is None: |
205 | | - sampling_params = SamplingParams() |
206 | | - model = await self._get_model(model_id, ModelType.llm) |
207 | | - if tool_config: |
208 | | - if tool_choice and tool_choice != tool_config.tool_choice: |
209 | | - raise ValueError("tool_choice and tool_config.tool_choice must match") |
210 | | - if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: |
211 | | - raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") |
212 | | - else: |
213 | | - params = {} |
214 | | - if tool_choice: |
215 | | - params["tool_choice"] = tool_choice |
216 | | - if tool_prompt_format: |
217 | | - params["tool_prompt_format"] = tool_prompt_format |
218 | | - tool_config = ToolConfig(**params) |
219 | | - |
220 | | - tools = tools or [] |
221 | | - if tool_config.tool_choice == ToolChoice.none: |
222 | | - tools = [] |
223 | | - elif tool_config.tool_choice == ToolChoice.auto: |
224 | | - pass |
225 | | - elif tool_config.tool_choice == ToolChoice.required: |
226 | | - pass |
227 | | - else: |
228 | | - # verify tool_choice is one of the tools |
229 | | - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] |
230 | | - if tool_config.tool_choice not in tool_names: |
231 | | - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") |
232 | | - |
233 | | - params = dict( |
234 | | - model_id=model_id, |
235 | | - messages=messages, |
236 | | - sampling_params=sampling_params, |
237 | | - tools=tools, |
238 | | - tool_choice=tool_choice, |
239 | | - tool_prompt_format=tool_prompt_format, |
240 | | - response_format=response_format, |
241 | | - stream=stream, |
242 | | - logprobs=logprobs, |
243 | | - tool_config=tool_config, |
244 | | - ) |
245 | | - provider = await self.routing_table.get_provider_impl(model_id) |
246 | | - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) |
247 | | - |
248 | | - if stream: |
249 | | - response_stream = await provider.chat_completion(**params) |
250 | | - return self.stream_tokens_and_compute_metrics( |
251 | | - response=response_stream, |
252 | | - prompt_tokens=prompt_tokens, |
253 | | - model=model, |
254 | | - tool_prompt_format=tool_config.tool_prompt_format, |
255 | | - ) |
256 | | - |
257 | | - response = await provider.chat_completion(**params) |
258 | | - metrics = await self.count_tokens_and_compute_metrics( |
259 | | - response=response, |
260 | | - prompt_tokens=prompt_tokens, |
261 | | - model=model, |
262 | | - tool_prompt_format=tool_config.tool_prompt_format, |
263 | | - ) |
264 | | - # these metrics will show up in the client response. |
265 | | - response.metrics = ( |
266 | | - metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics |
267 | | - ) |
268 | | - return response |
269 | | - |
270 | 182 | async def openai_completion( |
271 | 183 | self, |
272 | 184 | model: str, |
|
0 commit comments