Skip to content

Commit 2fa8af8

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Add support for Google's GenAI models (e.g., gemini-pro)
1 parent b85376e commit 2fa8af8

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

openlayer/model_runners/ll_model_runners.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pandas as pd
1616
import pybars
1717
import requests
18+
from google import generativeai
1819
from tqdm import tqdm
1920

2021
from .. import constants
@@ -502,6 +503,84 @@ def _get_cost_estimate(self, response: Dict[str, Any]) -> None:
502503
)
503504

504505

506+
class GoogleGenAIModelRunner(LLModelRunner):
507+
"""Wraps Google's Gen AI models."""
508+
509+
def __init__(
510+
self,
511+
logger: Optional[logging.Logger] = None,
512+
**kwargs,
513+
):
514+
super().__init__(logger, **kwargs)
515+
if kwargs.get("google_api_key") is None:
516+
raise openlayer_exceptions.OpenlayerMissingLlmApiKey(
517+
"Please pass your Google API key generated with "
518+
"https://makersuite.google.com/ as the keyword argument"
519+
" 'google_api_key'"
520+
)
521+
self.google_api_key = kwargs["google_api_key"]
522+
523+
self._initialize_llm()
524+
525+
self.cost: List[float] = []
526+
527+
def _initialize_llm(self):
528+
"""Initializes the OpenAI chat completion model."""
529+
if self.model_config.get("model") is None:
530+
warnings.warn("No model specified. Defaulting to model 'gemini-pro'.")
531+
if self.model_config.get("model_parameters") is None:
532+
warnings.warn("No model parameters specified. Using default parameters.")
533+
# Check if API key is valid
534+
try:
535+
generativeai.configure(api_key=self.google_api_key)
536+
self.model = generativeai.GenerativeModel(
537+
self.model_config.get("model", "gemini-pro")
538+
)
539+
except Exception as e:
540+
raise openlayer_exceptions.OpenlayerInvalidLlmApiKey(
541+
"Please pass your Google API key generated with "
542+
"https://makersuite.google.com/ as the keyword argument"
543+
f" 'google_api_key' \n Error message: {e}"
544+
) from e
545+
546+
def _get_llm_input(
547+
self, injected_prompt: List[Dict[str, str]]
548+
) -> List[Dict[str, str]]:
549+
"""Prepares the input for Google's model."""
550+
llm_input = ""
551+
for message in injected_prompt:
552+
if message["role"] == "system":
553+
llm_input += f"S: {message['content']} \n"
554+
elif message["role"] == "assistant":
555+
llm_input += f"A: {message['content']} \n"
556+
elif message["role"] == "user":
557+
llm_input += f"U: {message['content']} \n"
558+
else:
559+
raise ValueError(
560+
"Message role must be either 'system', 'assistant' or 'user'. "
561+
f"Got: {message['role']}"
562+
)
563+
llm_input += "A:"
564+
return llm_input
565+
566+
def _make_request(self, llm_input: List[Dict[str, str]]) -> Dict[str, Any]:
567+
"""Make the request to Google's model
568+
for a given input."""
569+
response = self.model.generate_content(
570+
contents=llm_input,
571+
**self.model_config.get("model_parameters", {}),
572+
)
573+
return response
574+
575+
def _get_output(self, response: Dict[str, Any]) -> str:
576+
"""Gets the output from the response."""
577+
return response.text
578+
579+
def _get_cost_estimate(self, response: Dict[str, Any]) -> None:
580+
"""Estimates the cost from the response."""
581+
return 0
582+
583+
505584
class SelfHostedLLModelRunner(LLModelRunner):
506585
"""Wraps a self-hosted LLM."""
507586

openlayer/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class ModelRunnerFactory:
111111
"OpenAI": ll_model_runners.OpenAIChatCompletionRunner,
112112
"SelfHosted": ll_model_runners.SelfHostedLLModelRunner,
113113
"HuggingFace": ll_model_runners.HuggingFaceModelRunner,
114+
"Google": ll_model_runners.GoogleGenAIModelRunner,
114115
}
115116
_MODEL_RUNNERS = {
116117
tasks.TaskType.TabularClassification.value: traditional_ml_model_runners.ClassificationModelRunner,
@@ -176,8 +177,8 @@ def _create_ll_model_runner(
176177
raise exceptions.OpenlayerUnsupportedLlmProvider(
177178
provider=model_provider,
178179
message="\nCurrently, the supported providers are: 'OpenAI', 'Cohere',"
179-
" 'Anthropic', 'SelfHosted', 'HuggingFace'. Reach out if you'd like us"
180-
" to support your use case.",
180+
" 'Anthropic', 'SelfHosted', 'HuggingFace', and 'Google'."
181+
" Reach out if you'd like us to support your use case.",
181182
)
182183

183184
model_runner_class = ModelRunnerFactory._LL_MODEL_RUNNERS[task_type.value][

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ packages =
4141
install_requires =
4242
anthropic
4343
cohere
44+
google-generativeai
4445
marshmallow
4546
marshmallow_oneofschema
4647
openai>=1.0.0

0 commit comments

Comments
 (0)