|
4 | 4 | """ |
5 | 5 |
|
6 | 6 | import datetime |
| 7 | +import json |
7 | 8 | import logging |
8 | 9 | import warnings |
9 | 10 | from abc import ABC, abstractmethod |
|
14 | 15 | import openai |
15 | 16 | import pandas as pd |
16 | 17 | import pybars |
| 18 | +import requests |
17 | 19 | from tqdm import tqdm |
18 | 20 |
|
19 | 21 | from . import base_model_runner |
@@ -335,7 +337,7 @@ def _get_output(self, response: Dict[str, Any]) -> str: |
335 | 337 |
|
336 | 338 | def _get_cost_estimate(self, response: Dict[str, Any]) -> float: |
337 | 339 | """Estimates the cost from the response.""" |
338 | | - return -1 |
| 340 | + return 0 |
339 | 341 |
|
340 | 342 |
|
341 | 343 | class CohereGenerateModelRunner(LLModelRunner): |
@@ -409,7 +411,7 @@ def _get_output(self, response: Dict[str, Any]) -> str: |
409 | 411 |
|
410 | 412 | def _get_cost_estimate(self, response: Dict[str, Any]) -> float: |
411 | 413 | """Estimates the cost from the response.""" |
412 | | - return -1 |
| 414 | + return 0 |
413 | 415 |
|
414 | 416 |
|
415 | 417 | class OpenAIChatCompletionRunner(LLModelRunner): |
@@ -492,3 +494,81 @@ def _get_cost_estimate(self, response: Dict[str, Any]) -> None: |
492 | 494 | num_input_tokens * self.COST_PER_TOKEN[model]["input"] |
493 | 495 | + num_output_tokens * self.COST_PER_TOKEN[model]["output"] |
494 | 496 | ) |
| 497 | + |
| 498 | + |
| 499 | +class SelfHostedLLModelRunner(LLModelRunner): |
| 500 | + """Wraps a self-hosted LLM.""" |
| 501 | + |
| 502 | + def __init__( |
| 503 | + self, |
| 504 | + logger: Optional[logging.Logger] = None, |
| 505 | + **kwargs, |
| 506 | + ): |
| 507 | + super().__init__(logger, **kwargs) |
| 508 | + if kwargs.get("url") is None: |
| 509 | + raise ValueError( |
| 510 | + "URL must be provided. Please pass it as the keyword argument 'url'" |
| 511 | + ) |
| 512 | + if kwargs.get("api_key") is None: |
| 513 | + raise ValueError( |
| 514 | + "API key must be provided for self-hosted LLMs. " |
| 515 | + "Please pass it as the keyword argument 'api_key'" |
| 516 | + ) |
| 517 | + |
| 518 | + self.url = kwargs["url"] |
| 519 | + self.api_key = kwargs["api_key"] |
| 520 | + self._initialize_llm() |
| 521 | + |
| 522 | + def _initialize_llm(self): |
| 523 | + """Initializes the self-hosted LL model.""" |
| 524 | + # Check if API key is valid |
| 525 | + try: |
| 526 | + requests.get(self.url) |
| 527 | + except Exception as e: |
| 528 | + raise ValueError( |
| 529 | + "URL is invalid. Please pass a valid URL as the " |
| 530 | + f"keyword argument 'url' \n Error message: {e}" |
| 531 | + ) |
| 532 | + |
| 533 | + def _get_llm_input(self, injected_prompt: List[Dict[str, str]]) -> str: |
| 534 | + """Prepares the input for the self-hosted LLM.""" |
| 535 | + llm_input = "" |
| 536 | + for message in injected_prompt: |
| 537 | + if message["role"] == "system": |
| 538 | + llm_input += f"S: {message['content']} \n" |
| 539 | + elif message["role"] == "assistant": |
| 540 | + llm_input += f"A: {message['content']} \n" |
| 541 | + elif message["role"] == "user": |
| 542 | + llm_input += f"U: {message['content']} \n" |
| 543 | + else: |
| 544 | + raise ValueError( |
| 545 | + "Message role must be either 'system', 'assistant' or 'user'. " |
| 546 | + f"Got: {message['role']}" |
| 547 | + ) |
| 548 | + llm_input += "A:" |
| 549 | + return llm_input |
| 550 | + |
| 551 | + def _make_request(self, llm_input: str) -> Dict[str, Any]: |
| 552 | + """Make the request to the self-hosted LL model |
| 553 | + for a given input.""" |
| 554 | + headers = { |
| 555 | + "Authorization": f"Bearer {self.api_key}", |
| 556 | + "Content-Type": "application/json", |
| 557 | + } |
| 558 | + # TODO: use correct input key |
| 559 | + data = {"inputs": llm_input} |
| 560 | + response = requests.post(self.url, headers=headers, json=data) |
| 561 | + if response.status_code == 200: |
| 562 | + response_data = response.json()[0] |
| 563 | + return response_data |
| 564 | + else: |
| 565 | + raise ValueError(f"Request failed with status code {response.status_code}") |
| 566 | + |
| 567 | + def _get_output(self, response: Dict[str, Any]) -> str: |
| 568 | + """Gets the output from the response.""" |
| 569 | + # TODO: use correct output key |
| 570 | + return response["generated_text"] |
| 571 | + |
| 572 | + def _get_cost_estimate(self, response: Dict[str, Any]) -> float: |
| 573 | + """Estimates the cost from the response.""" |
| 574 | + return 0 |
0 commit comments