Skip to content

Commit 0d3b871

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-4812 Create model runner for self hosted LLMs
1 parent 447ff85 commit 0d3b871

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

openlayer/model_runners/ll_model_runners.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import datetime
7+
import json
78
import logging
89
import warnings
910
from abc import ABC, abstractmethod
@@ -14,6 +15,7 @@
1415
import openai
1516
import pandas as pd
1617
import pybars
18+
import requests
1719
from tqdm import tqdm
1820

1921
from . import base_model_runner
@@ -335,7 +337,7 @@ def _get_output(self, response: Dict[str, Any]) -> str:
335337

336338
def _get_cost_estimate(self, response: Dict[str, Any]) -> float:
337339
"""Estimates the cost from the response."""
338-
return -1
340+
return 0
339341

340342

341343
class CohereGenerateModelRunner(LLModelRunner):
@@ -409,7 +411,7 @@ def _get_output(self, response: Dict[str, Any]) -> str:
409411

410412
def _get_cost_estimate(self, response: Dict[str, Any]) -> float:
411413
"""Estimates the cost from the response."""
412-
return -1
414+
return 0
413415

414416

415417
class OpenAIChatCompletionRunner(LLModelRunner):
@@ -492,3 +494,81 @@ def _get_cost_estimate(self, response: Dict[str, Any]) -> None:
492494
num_input_tokens * self.COST_PER_TOKEN[model]["input"]
493495
+ num_output_tokens * self.COST_PER_TOKEN[model]["output"]
494496
)
497+
498+
499+
class SelfHostedLLModelRunner(LLModelRunner):
500+
"""Wraps a self-hosted LLM."""
501+
502+
def __init__(
503+
self,
504+
logger: Optional[logging.Logger] = None,
505+
**kwargs,
506+
):
507+
super().__init__(logger, **kwargs)
508+
if kwargs.get("url") is None:
509+
raise ValueError(
510+
"URL must be provided. Please pass it as the keyword argument 'url'"
511+
)
512+
if kwargs.get("api_key") is None:
513+
raise ValueError(
514+
"API key must be provided for self-hosted LLMs. "
515+
"Please pass it as the keyword argument 'api_key'"
516+
)
517+
518+
self.url = kwargs["url"]
519+
self.api_key = kwargs["api_key"]
520+
self._initialize_llm()
521+
522+
def _initialize_llm(self):
523+
"""Initializes the self-hosted LL model."""
524+
# Check if API key is valid
525+
try:
526+
requests.get(self.url)
527+
except Exception as e:
528+
raise ValueError(
529+
"URL is invalid. Please pass a valid URL as the "
530+
f"keyword argument 'url' \n Error message: {e}"
531+
)
532+
533+
def _get_llm_input(self, injected_prompt: List[Dict[str, str]]) -> str:
534+
"""Prepares the input for the self-hosted LLM."""
535+
llm_input = ""
536+
for message in injected_prompt:
537+
if message["role"] == "system":
538+
llm_input += f"S: {message['content']} \n"
539+
elif message["role"] == "assistant":
540+
llm_input += f"A: {message['content']} \n"
541+
elif message["role"] == "user":
542+
llm_input += f"U: {message['content']} \n"
543+
else:
544+
raise ValueError(
545+
"Message role must be either 'system', 'assistant' or 'user'. "
546+
f"Got: {message['role']}"
547+
)
548+
llm_input += "A:"
549+
return llm_input
550+
551+
def _make_request(self, llm_input: str) -> Dict[str, Any]:
552+
"""Make the request to the self-hosted LL model
553+
for a given input."""
554+
headers = {
555+
"Authorization": f"Bearer {self.api_key}",
556+
"Content-Type": "application/json",
557+
}
558+
# TODO: use correct input key
559+
data = {"inputs": llm_input}
560+
response = requests.post(self.url, headers=headers, json=data)
561+
if response.status_code == 200:
562+
response_data = response.json()[0]
563+
return response_data
564+
else:
565+
raise ValueError(f"Request failed with status code {response.status_code}")
566+
567+
def _get_output(self, response: Dict[str, Any]) -> str:
568+
"""Gets the output from the response."""
569+
# TODO: use correct output key
570+
return response["generated_text"]
571+
572+
def _get_cost_estimate(self, response: Dict[str, Any]) -> float:
573+
"""Estimates the cost from the response."""
574+
return 0

openlayer/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class ModelRunnerFactory:
109109
"Anthropic": ll_model_runners.AnthropicModelRunner,
110110
"Cohere": ll_model_runners.CohereGenerateModelRunner,
111111
"OpenAI": ll_model_runners.OpenAIChatCompletionRunner,
112+
"SelfHosted": ll_model_runners.SelfHostedLLModelRunner,
112113
}
113114
_MODEL_RUNNERS = {
114115
tasks.TaskType.TabularClassification.value: traditional_ml_model_runners.ClassificationModelRunner,

0 commit comments

Comments
 (0)