|
15 | 15 | import pandas as pd |
16 | 16 | import pybars |
17 | 17 | import requests |
| 18 | +from google import generativeai |
18 | 19 | from tqdm import tqdm |
19 | 20 |
|
20 | 21 | from .. import constants |
@@ -502,6 +503,84 @@ def _get_cost_estimate(self, response: Dict[str, Any]) -> None: |
502 | 503 | ) |
503 | 504 |
|
504 | 505 |
|
| 506 | +class GoogleGenAIModelRunner(LLModelRunner): |
| 507 | + """Wraps Google's Gen AI models.""" |
| 508 | + |
| 509 | + def __init__( |
| 510 | + self, |
| 511 | + logger: Optional[logging.Logger] = None, |
| 512 | + **kwargs, |
| 513 | + ): |
| 514 | + super().__init__(logger, **kwargs) |
| 515 | + if kwargs.get("google_api_key") is None: |
| 516 | + raise openlayer_exceptions.OpenlayerMissingLlmApiKey( |
| 517 | + "Please pass your Google API key generated with " |
| 518 | + "https://makersuite.google.com/ as the keyword argument" |
| 519 | + " 'google_api_key'" |
| 520 | + ) |
| 521 | + self.google_api_key = kwargs["google_api_key"] |
| 522 | + |
| 523 | + self._initialize_llm() |
| 524 | + |
| 525 | + self.cost: List[float] = [] |
| 526 | + |
| 527 | + def _initialize_llm(self): |
| 528 | + """Initializes the OpenAI chat completion model.""" |
| 529 | + if self.model_config.get("model") is None: |
| 530 | + warnings.warn("No model specified. Defaulting to model 'gemini-pro'.") |
| 531 | + if self.model_config.get("model_parameters") is None: |
| 532 | + warnings.warn("No model parameters specified. Using default parameters.") |
| 533 | + # Check if API key is valid |
| 534 | + try: |
| 535 | + generativeai.configure(api_key=self.google_api_key) |
| 536 | + self.model = generativeai.GenerativeModel( |
| 537 | + self.model_config.get("model", "gemini-pro") |
| 538 | + ) |
| 539 | + except Exception as e: |
| 540 | + raise openlayer_exceptions.OpenlayerInvalidLlmApiKey( |
| 541 | + "Please pass your Google API key generated with " |
| 542 | + "https://makersuite.google.com/ as the keyword argument" |
| 543 | + f" 'google_api_key' \n Error message: {e}" |
| 544 | + ) from e |
| 545 | + |
| 546 | + def _get_llm_input( |
| 547 | + self, injected_prompt: List[Dict[str, str]] |
| 548 | + ) -> List[Dict[str, str]]: |
| 549 | + """Prepares the input for Google's model.""" |
| 550 | + llm_input = "" |
| 551 | + for message in injected_prompt: |
| 552 | + if message["role"] == "system": |
| 553 | + llm_input += f"S: {message['content']} \n" |
| 554 | + elif message["role"] == "assistant": |
| 555 | + llm_input += f"A: {message['content']} \n" |
| 556 | + elif message["role"] == "user": |
| 557 | + llm_input += f"U: {message['content']} \n" |
| 558 | + else: |
| 559 | + raise ValueError( |
| 560 | + "Message role must be either 'system', 'assistant' or 'user'. " |
| 561 | + f"Got: {message['role']}" |
| 562 | + ) |
| 563 | + llm_input += "A:" |
| 564 | + return llm_input |
| 565 | + |
| 566 | + def _make_request(self, llm_input: List[Dict[str, str]]) -> Dict[str, Any]: |
| 567 | + """Make the request to Google's model |
| 568 | + for a given input.""" |
| 569 | + response = self.model.generate_content( |
| 570 | + contents=llm_input, |
| 571 | + **self.model_config.get("model_parameters", {}), |
| 572 | + ) |
| 573 | + return response |
| 574 | + |
| 575 | + def _get_output(self, response: Dict[str, Any]) -> str: |
| 576 | + """Gets the output from the response.""" |
| 577 | + return response.text |
| 578 | + |
| 579 | + def _get_cost_estimate(self, response: Dict[str, Any]) -> None: |
| 580 | + """Estimates the cost from the response.""" |
| 581 | + return 0 |
| 582 | + |
| 583 | + |
505 | 584 | class SelfHostedLLModelRunner(LLModelRunner): |
506 | 585 | """Wraps a self-hosted LLM.""" |
507 | 586 |
|
|
0 commit comments