diff --git a/pyproject.toml b/pyproject.toml index 4ea6d52..80c2df6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,38 +13,87 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.11,<4.0" -tqdm = "^4.66.4" -python-dotenv = "^1.0.1" -pandas = "^2.2.3" -black = { version = "^24.3.0", optional = true } + +# Sorted primary dependencies +accelerate = { version = "^1.3.0", optional = true } # working with 1.6.0 +aiohttp = { version = "^3.11.18", optional = false } +anthropic = { version = "^0.45.2", optional = true } +black = { version = "^24.3.0", optional = true } # working with 24.10.0 +cli-test-helpers = { version = "^4.0.0", optional = true } # working with 4.1.0 +google-cloud-aiplatform = { version = "^1.71.1", optional = true } +google-genai = { version = "^1.11.0", optional = true } # working with 1.13.0 +google-generativeai = { version = "^0.8.4", optional = false } +huggingface-hub = { version = "^0.28.0", optional = true } # working with 0.28.1 +ipykernel = { version = "^6.29.4", optional = true } # working with 6.29.5 isort = { version = "^5.13.2", optional = true } -pre-commit = { version = "^3.7.0", optional = true } -pytest = { version = "^8.1.1", optional = true } -pytest-asyncio = { version = "^0.23.6", optional = true } -pytest-cov = { version = "^5.0.0", optional = true } -ipykernel = { version = "^6.29.4", optional = true } -mkdocs-material = { version = "^9.5.26", optional = true } -mkdocstrings-python = { version = "^1.10.3", optional = true } mkdocs-gen-files = { version = "^0.5.0", optional = true } -mkdocs-literate-nav = { version = "^0.6.1", optional = true } -mkdocs-section-index = { version = "^0.3.9", optional = true } +mkdocs-jupyter = { version = "^0.24.7", optional = true } # working with 0.24.8 +mkdocs-literate-nav = { version = "^0.6.1", optional = true } # working with 0.6.2 +mkdocs-material = { version = "^9.6.12", optional = false } mkdocs-same-dir = { version = "^0.1.3", optional = true } -mkdocs-jupyter = { version = "^0.24.7", optional = true } -cli-test-helpers = { version = "^4.0.0", optional = true } -vertexai = { version ="^1.71.1", optional = true } -google-cloud-aiplatform = { version = "^1.71.1", optional = true } -google-generativeai = { version = "^0.8.4", optional = true } -google-genai = { version = "^0.7.0", optional = true } -openai = { version = "^1.60.0", optional = true } -pillow = { version = "^11.1.0", optional = true } -ollama = { version = "^0.4.7", optional = true } -huggingface-hub = { version = "^0.28.0", optional = true } +mkdocs-section-index = { version = "^0.3.10", optional = false } +mkdocstrings-python = { version = "^1.10.3", optional = true } # working with 1.16.10 +ollama = { version = "^0.4.8", optional = false } +openai = { version = "^1.77.0", optional = false } +pandas = "^2.2.3" +pillow = { version = "^11.2.1", optional = false } +pre-commit = { version = "^3.7.0", optional = true } # working with 3.8.0 +pytest = { version = "^8.1.1", optional = true } # working with 8.3.5 +pytest-asyncio = { version = "^0.23.6", optional = true } # working with 0.23.8 +pytest-cov = { version = "^5.0.0", optional = true } +python-dotenv = "^1.0.1" # working with 1.1.0 quart = { version = "^0.20.0", optional = true } -transformers = { version = "^4.48.1", optional = true } -torch = { version = "^2.6.0", optional = true } -accelerate = { version = "^1.3.0", optional = true } -aiohttp = { version = "^3.11.11", optional = true } -anthropic = { version = "^0.45.2", optional = true } +torch = { version = "^2.7.0", optional = false } +tqdm = "^4.67.1" # working with 4.67.1 +transformers = { version = "^4.50.3", optional = true } # working with 4.50.3 +urllib3 = { version = "^2.4.0", optional = false } +vertexai = { version ="^1.71.1", optional = true } + +# Sorted secondary dependencies +# These are not directly used in the code, but are required by the primary dependencies +# They are specified here, because the automatic dependency resolution of pip +# can be slow and sometimes times-out without resolving them all. +beautifulsoup4 = { version = "^4.13.4", optional = false } # Secondary +certifi = { version = "^2025.4.26", optional = false } # Secondary +charset-normalizer = { version = "^3.4.2", optional = false } # Secondary +debugpy = { version = "^1.8.14", optional = false } # Secondary +frozenlist = { version = "^1.6.0", optional = false } # Secondary +google-api-core = { version = "^2.25.0rc0", optional = false } # Secondary +google-api-python-client = { version = "^2.166.0", optional = false } # Secondary +google-auth = { version = "^2.40.0", optional = false } # Secondary +google-auth-httplib2 = { version = "^0.2.0", optional = false } # Secondary +googleapis-common-protos = { version = "^1.70.0", optional = false } # Secondary +griffe = { version = "^1.7.3", optional = false } # Secondary +h11 = { version = "^0.16.0", optional = false } # Secondary +httpcore = { version = "^1.0.9", optional = false } # Secondary +httplib2 = { version = "^0.22.0", optional = false } # Secondary +identify = { version = "^2.6.10", optional = false } # Secondary +ipython = { version = "^9.2.0", optional = false } # Secondary +jsonschema-specifications = { version = "^2025.4.1", optional = false } # Secondary +jupytext = { version = "^1.17.1", optional = false } # Secondary +Markdown = { version = "^3.8", optional = false } # Secondary +multidict = { version = "^6.4.3", optional = false } # Secondary +mypy_extensions = { version = "^1.1.0", optional = false } # Secondary +numpy = { version = "^2.2.5", optional = false } # Secondary +packaging = { version = "^25.0", optional = false } # Secondary +pydantic = { version = "^2.11.4", optional = false } # Secondary +pydantic_core = { version = "^2.33.2", optional = false } # Secondary +pymdown-extensions = { version = "^10.15", optional = false } # Secondary +pyparsing = { version = "^3.2.3", optional = false } # Secondary +pytest-random-order = { version = "^1.1.1", optional = false } # Secondary +pyzmq = { version = "^26.4.0", optional = false } # Secondary +rsa = { version = "^4.9.1", optional = false } # Secondary +ruff = { version = "^0.11.6", optional = false } # Secondary +setuptools = { version = "^80.3.1", optional = false } # Secondary +shapely = { version = "^2.1.0", optional = false } # Secondary +snakeviz = { version = "^2.2.2", optional = false } # Secondary +soupsieve = { version = "^2.7", optional = false } # Secondary +sympy = { version = "^1.14.0", optional = false } # Secondary +typing_extensions = { version = "^4.13.2", optional = false } # Secondary +virtualenv = { version = "^20.31.1", optional = false } # Secondary +websockets = { version = "^15.0.1", optional = false } # Secondary +yarl = { version = "^1.20.0", optional = false } # Secondary + [tool.poetry.extras] all = [ @@ -66,7 +115,7 @@ all = [ "cli-test-helpers", "vertexai", "google-cloud-aiplatform", - "google-generativeai", + # "google-generativeai", "google-genai", "openai", "pillow", @@ -97,8 +146,20 @@ dev = [ "mkdocs-jupyter", "cli-test-helpers", ] -gemini = ["vertexai", "google-cloud-aiplatform", "google-generativeai", "google-genai", "pillow"] -vertexai = ["vertexai", "google-cloud-aiplatform", "google-generativeai", "google-genai", "pillow"] +gemini = [ + "vertexai", + "google-cloud-aiplatform", + # "google-generativeai", # TODO: deprecated - to be removed + "google-genai", + "pillow" +] +vertexai = [ + "vertexai", + "google-cloud-aiplatform", + # "google-generativeai", # TODO: deprecated - to be removed + "google-genai", + "pillow" +] azure_openai = ["openai", "pillow"] openai = ["openai", "pillow"] ollama = ["ollama"] diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index b345f5f..ba3918f 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -1,16 +1,23 @@ import logging from typing import Any -import google.generativeai as genai -from google.generativeai import GenerativeModel -from google.generativeai.types import GenerationConfig, HarmBlockThreshold, HarmCategory +from google.genai import Client +from google.genai.types import ( + GenerateContentConfig, + HarmBlockThreshold, + HarmCategory, + SafetySetting, + ThinkingConfig, +) from prompto.apis.base import AsyncAPI from prompto.apis.gemini.gemini_utils import ( - convert_dict_to_input, + convert_history_dict_to_content, gemini_chat_roles, + parse_parts, process_response, process_safety_attributes, + process_thoughts, ) from prompto.settings import Settings from prompto.utils import ( @@ -41,6 +48,11 @@ "finish_reason": "block_reason: OTHER", } +# See https://ai.google.dev/gemini-api/docs/thinking#set-budget +# for more details +MIN_THINKING_BUDGET = 0 +MAX_THINKING_BUDGET = 24576 + class GeminiAPI(AsyncAPI): """ @@ -62,6 +74,7 @@ def __init__( **kwargs: Any, ): super().__init__(settings=settings, log_file=log_file, *args, **kwargs) + self._clients: dict[str, Client] = {} @staticmethod def check_environment_variables() -> list[Exception]: @@ -171,15 +184,43 @@ def check_prompt_dict(prompt_dict: dict) -> list[Exception]: # if generation_config is provided, check that it can create a valid GenerationConfig object if "parameters" in prompt_dict: try: - GenerationConfig(**prompt_dict["parameters"]) + GenerateContentConfig(**prompt_dict["parameters"]) except Exception as err: issues.append(Exception(f"Invalid generation_config parameter: {err}")) return issues + def _get_client(self, model_name) -> Client: + """ + Method to get the client for the Gemini API. A separate client is created for each model name, to allow for + model-specific API keys to be used. + + The client is created only once per model name and stored in the clients dictionary. + If the client is already created, it is returned from the dictionary. + + Parameters + ---------- + model_name : str + The name of the model to use + + Returns + ------- + Client + A client for the Gemini API + """ + # If the Client does not exist, create it + if model_name not in self._clients: + api_key = get_environment_variable( + env_variable=API_KEY_VAR_NAME, model_name=model_name + ) + self._clients[model_name] = Client(api_key=api_key) + + # Return the client for the model name + return self._clients[model_name] + async def _obtain_model_inputs( self, prompt_dict: dict, system_instruction: str | None = None - ) -> tuple[str, str, GenerativeModel, dict, dict, list | None]: + ) -> tuple[str, str, Client, GenerateContentConfig, list | None]: """ Async method to obtain the model inputs from the prompt dictionary. @@ -193,26 +234,19 @@ async def _obtain_model_inputs( Returns ------- - tuple[str, str, dict, dict, list | None] - A tuple containing the prompt, model name, GenerativeModel instance, - safety settings, the generation config, and list of multimedia parts - (if passed) to use for querying the model + tuple[str, str, Client, GenerateContentConfig, list | None] + A tuple containing: + - the prompt, + - model name, + - Client instance, + - GenerateContentConfig instance (which incorporates the safety settings), + - (optional) list of multimedia parts (if passed) to use for querying the model or None """ prompt = prompt_dict["prompt"] # obtain model name model_name = prompt_dict["model_name"] - api_key = get_environment_variable( - env_variable=API_KEY_VAR_NAME, model_name=model_name - ) - - # configure the API key - genai.configure(api_key=api_key) - - # create the model instance - model = GenerativeModel( - model_name=model_name, system_instruction=system_instruction - ) + client = self._get_client(model_name) # define safety settings safety_filter = prompt_dict.get("safety_filter", None) @@ -221,33 +255,81 @@ async def _obtain_model_inputs( # explicitly set the safety settings if safety_filter == "none": - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + ] elif safety_filter == "few": - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + ] elif safety_filter in ["default", "some"]: - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + ] elif safety_filter == "most": - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ] else: raise ValueError( f"safety_filter '{safety_filter}' not recognised. Must be one of: " @@ -255,15 +337,51 @@ async def _obtain_model_inputs( ) # get parameters dict (if any) - generation_config = prompt_dict.get("parameters", None) - if generation_config is None: - generation_config = {} - if type(generation_config) is not dict: + generation_config_params = prompt_dict.get("parameters", None) + if generation_config_params is None: + generation_config_params = {} + if type(generation_config_params) is not dict: raise TypeError( - f"parameters must be a dictionary, not {type(generation_config)}" + f"parameters must be a dictionary, not {type(generation_config_params)}" + ) + + # Derive the required ThinkingConfig from the parameters + # `pop` removes the key from the dictionary + include_thoughts = generation_config_params.pop("include_thoughts", None) + thinking_budget = generation_config_params.pop("thinking_budget", None) + + if include_thoughts is None and thinking_budget is None: + thinking_config = None + else: + if not isinstance(include_thoughts, bool | None): + err_msg = "If include_thoughts is set, it must be a boolean" + raise ValueError(err_msg) + + try: + assert isinstance(thinking_budget, int | None) + if thinking_budget is not None: + assert isinstance(thinking_budget, int) + # The ThinkingConfig constructor does not seem to check that the + # thinking_budget is in the valid range (0, 24576), so we do it here + assert MIN_THINKING_BUDGET <= thinking_budget <= MAX_THINKING_BUDGET + + except AssertionError as ae: + err_msg = f"if thinking_budget is set, it must be an integer between {MIN_THINKING_BUDGET} and {MAX_THINKING_BUDGET}" + raise ValueError(err_msg) from ae + + thinking_config = ThinkingConfig( + include_thoughts=include_thoughts, + thinking_budget=thinking_budget, ) - return prompt, model_name, model, safety_settings, generation_config + gen_content_config = GenerateContentConfig( + **generation_config_params, + safety_settings=safety_settings, + system_instruction=system_instruction, + thinking_config=thinking_config, + ) + + return prompt, model_name, client, gen_content_config, None async def _query_string(self, prompt_dict: dict, index: int | str): """ @@ -271,20 +389,20 @@ async def _query_string(self, prompt_dict: dict, index: int | str): (prompt_dict["prompt"] is a string), i.e. single-turn completion or chat. """ - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) try: - response = await model.generate_content_async( + response = await client.aio.models.generate_content( + model=model_name, contents=prompt, - generation_config=generation_config, - safety_settings=safety_settings, - stream=False, + config=generation_config, ) response_text = process_response(response) + thinking_text = process_thoughts(response) safety_attributes = process_safety_attributes(response) log_success_response_query( @@ -297,6 +415,7 @@ async def _query_string(self, prompt_dict: dict, index: int | str): prompt_dict["response"] = response_text prompt_dict["safety_attributes"] = safety_attributes + prompt_dict["thinking_text"] = thinking_text return prompt_dict except IndexError as err: error_as_string = ( @@ -317,16 +436,18 @@ async def _query_string(self, prompt_dict: dict, index: int | str): log_file=self.log_file, log_message=log_message, log=True ) response_text = "" + thinking_text = [] try: if len(response.candidates) == 0: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES else: safety_attributes = process_safety_attributes(response) - except: + except Exception: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES prompt_dict["response"] = response_text prompt_dict["safety_attributes"] = safety_attributes + prompt_dict["thinking_text"] = thinking_text return prompt_dict except Exception as err: @@ -352,30 +473,36 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): (prompt_dict["prompt"] is a list of strings to sequentially send to the model), i.e. multi-turn chat with history. """ - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) - chat = model.start_chat(history=[]) + # chat = client.start_chat(history=[]) + chat = client.aio.chats.create( + model=model_name, + config=generation_config, + history=[], + ) response_list = [] safety_attributes_list = [] + thinking_list = [] try: for message_index, message in enumerate(prompt): # send the messages sequentially # run the predict method in a separate thread using run_in_executor - response = await chat.send_message_async( - content=message, - generation_config=generation_config, - safety_settings=safety_settings, - stream=False, + response = await chat.send_message( + message=message, + config=generation_config, ) response_text = process_response(response) + thinking_text = process_thoughts(response) safety_attributes = process_safety_attributes(response) response_list.append(response_text) safety_attributes_list.append(safety_attributes) + thinking_list.append(thinking_text) log_success_response_chat( index=index, @@ -392,6 +519,7 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): ) prompt_dict["response"] = response_list + prompt_dict["thinking_text"] = thinking_list prompt_dict["safety_attributes"] = safety_attributes_list return prompt_dict except IndexError as err: @@ -416,15 +544,17 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): log_file=self.log_file, log_message=log_message, log=True ) response_text = response_list + [""] + thinking_text = thinking_list + [[]] try: if len(response.candidates) == 0: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES else: safety_attributes = process_safety_attributes(response) - except: + except Exception: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES prompt_dict["response"] = response_text + prompt_dict["thinking_text"] = thinking_text prompt_dict["safety_attributes"] = safety_attributes return prompt_dict except Exception as err: @@ -455,46 +585,53 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: i.e. multi-turn chat with history. """ if prompt_dict["prompt"][0]["role"] == "system": - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=prompt_dict["prompt"][0]["parts"], ) ) - chat = model.start_chat( - history=[ - convert_dict_to_input( - content_dict=x, media_folder=self.settings.media_folder - ) - for x in prompt[1:-1] - ] - ) + # Used to skip the system message in the prompt history + first_user_idx = 1 else: - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) - chat = model.start_chat( - history=[ - convert_dict_to_input( - content_dict=x, media_folder=self.settings.media_folder - ) - for x in prompt[:-1] - ] - ) + first_user_idx = 0 + + chat = client.aio.chats.create( + model=model_name, + config=generation_config, + history=[ + convert_history_dict_to_content( + content_dict=x, + media_folder=self.settings.media_folder, + client=client, + ) + for x in prompt[first_user_idx:-1] + ], + ) try: - response = await chat.send_message_async( - content=convert_dict_to_input( - content_dict=prompt[-1], media_folder=self.settings.media_folder - ), - generation_config=generation_config, - safety_settings=safety_settings, - stream=False, + # No need to send the generation_config again, as it is no different + # from the one used to create the chat + msg_to_send = parse_parts( + prompt[-1]["parts"], + media_folder=self.settings.media_folder, + client=client, ) + assert ( + len(msg_to_send) == 1 + ), "Only one message is allowed in the last message" + msg_to_send = msg_to_send[0] + + response = await chat.send_message(message=msg_to_send) + response_text = process_response(response) + thinking_text = process_thoughts(response) safety_attributes = process_safety_attributes(response) log_success_response_query( @@ -506,6 +643,7 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: ) prompt_dict["response"] = response_text + prompt_dict["thinking_text"] = thinking_text prompt_dict["safety_attributes"] = safety_attributes return prompt_dict except IndexError as err: @@ -527,15 +665,17 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: log_file=self.log_file, log_message=log_message, log=True ) response_text = "" + thinking_text = [] try: if len(response.candidates) == 0: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES else: safety_attributes = process_safety_attributes(response) - except: + except Exception: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES prompt_dict["response"] = response_text + prompt_dict["thinking_text"] = thinking_text prompt_dict["safety_attributes"] = safety_attributes return prompt_dict except Exception as err: diff --git a/src/prompto/apis/gemini/gemini_media.py b/src/prompto/apis/gemini/gemini_media.py index 3ebecb6..8b0f9dd 100644 --- a/src/prompto/apis/gemini/gemini_media.py +++ b/src/prompto/apis/gemini/gemini_media.py @@ -1,14 +1,12 @@ import asyncio import base64 -import json import logging -import os -import time -import google.generativeai as genai import tqdm -from dotenv import load_dotenv +from google import genai +from prompto.apis.gemini.gemini import GeminiAPI +from prompto.settings import Settings from prompto.utils import compute_sha256_base64 # initialise logging @@ -25,23 +23,26 @@ def remote_file_hash_base64(remote_file): Convert a remote file's SHA256 hash (stored as a hex-encoded UTF-8 bytes object) to a base64-encoded string. """ - hex_str = remote_file.sha256_hash.decode("utf-8") + hex_str = remote_file.sha256_hash raw_bytes = bytes.fromhex(hex_str) return base64.b64encode(raw_bytes).decode("utf-8") -async def wait_for_processing(file_obj, poll_interval=1): +async def wait_for_processing(file_obj, client: genai.Client, poll_interval=1): """ Poll until the file is no longer in the 'PROCESSING' state. Returns the updated file object. """ while file_obj.state.name == "PROCESSING": await asyncio.sleep(poll_interval) - file_obj = genai.get_file(file_obj.name) + # We need to re-fetch the file object to get the updated state. + file_obj = client.files.get(name=file_obj.name) return file_obj -async def upload_single_file(local_file_path, already_uploaded_files): +async def _upload_single_file( + local_file_path, already_uploaded_files, client: genai.Client +): """ Upload the file at 'file_path' if it hasn't been uploaded yet. If a file with the same SHA256 (base64-encoded) hash exists, returns its name. @@ -71,48 +72,43 @@ async def upload_single_file(local_file_path, already_uploaded_files): return already_uploaded_files[local_hash], local_file_path # Upload the file if it hasn't been found. - # Use asyncio.to_thread to run the blocking upload_file function in a separate thread. logger.info(f"Uploading {local_file_path} to Gemini API") - file_obj = await asyncio.to_thread(genai.upload_file, local_file_path) - file_obj = await wait_for_processing(file_obj) + + file_obj = await client.aio.files.upload(file=local_file_path) + file_obj = await wait_for_processing(file_obj, client=client) if file_obj.state.name == "FAILED": err_msg = ( f"Failure uploaded file '{file_obj.name}'. Error: {file_obj.error_message}" ) raise ValueError(err_msg) - # logger.info( - # f"Uploaded file '{file_obj.name}' with hash '{local_hash}' to Gemini API" - # ) + + logger.info( + f"Uploaded file '{file_obj.name}' with hash '{local_hash}' to Gemini API" + ) already_uploaded_files[local_hash] = file_obj.name return file_obj.name, local_file_path -def _init_genai(): - load_dotenv(dotenv_path=".env") - # TODO: check if this can be refactored to a common function - GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") - if GEMINI_API_KEY is None: - raise ValueError("GEMINI_API_KEY is not set") - - genai.configure(api_key=GEMINI_API_KEY) - - -def _get_previously_uploaded_files(): +async def _get_previously_uploaded_files(client: genai.Client): + raw_files = await client.aio.files.list() uploaded_files = { - remote_file_hash_base64(remote_file): remote_file.name - for remote_file in genai.list_files() + remote_file.sha256_hash: remote_file.name for remote_file in raw_files } logger.info(f"Found {len(uploaded_files)} files already uploaded at Gemini API") return uploaded_files -def list_uploaded_files(): +def list_uploaded_files(settings: Settings): """ List all previously uploaded files to the Gemini API. """ - _init_genai() - uploaded_files = _get_previously_uploaded_files() + gemini_api = GeminiAPI(settings=settings, log_file=None) + # TODO: We need a model name, because our API caters for different API keys + # for different models. Maybe our API is too complicated.... + default_model_name = "default" + client = gemini_api._get_client(default_model_name) + uploaded_files = asyncio.run(_get_previously_uploaded_files(client)) for file_hash, file_name in uploaded_files.items(): msg = f"File Name: {file_name}, File Hash: {file_hash}" @@ -120,26 +116,29 @@ def list_uploaded_files(): logger.info("All uploaded files listed.") -def delete_uploaded_files(): +def delete_uploaded_files(settings: Settings): """ Delete all previously uploaded files from the Gemini API. """ - _init_genai() - uploaded_files = _get_previously_uploaded_files() - return asyncio.run(_delete_uploaded_files_async(uploaded_files)) + gemini_api = GeminiAPI(settings=settings, log_file=None) + # TODO: We need a model name, because our API caters for different API keys + # for different models. Maybe our API to complicated.... + default_model_name = "default" + client = gemini_api._get_client(default_model_name) + + # This just using the synchronous API. Using the async API did not + # seem reliable. In particular `client.aio.files.delete()` did not appear + # to always actually deleting the files (even after repeatedly polling the file) + # This is not an important function in prompto and delete action is reasonably + # quick, so we can live with this simple solution. + # `` + for remote_file in client.files.list(): + client.files.delete(name=remote_file.name) - -async def _delete_uploaded_files_async(uploaded_files): - tasks = [] - for file_name in uploaded_files.values(): - logger.info(f"Preparing to delete file: {file_name}") - tasks.append(asyncio.to_thread(genai.delete_file, file_name)) - - await tqdm.asyncio.tqdm.gather(*tasks) logger.info("All uploaded files deleted.") -def upload_media_files(files_to_upload: set[str]): +def upload_media_files(files_to_upload: set[str], settings: Settings): """ Upload media files to the Gemini API. @@ -153,22 +152,24 @@ def upload_media_files(files_to_upload: set[str]): dict[str, str] Dictionary mapping local file paths to their corresponding uploaded filenames. """ - _init_genai() - return asyncio.run(upload_media_files_async(files_to_upload)) + # _init_genai() + return asyncio.run(upload_media_files_async(files_to_upload, settings)) + + +async def upload_media_files_async(files_to_upload: set[str], settings: Settings): + logger.info("Start retrieving previously uploaded files") + gemini_api = GeminiAPI(settings=settings, log_file=None) + client = gemini_api._get_client("default") + uploaded_files = await _get_previously_uploaded_files(client) -async def upload_media_files_async(files_to_upload: set[str]): - start_time = time.time() - logger.info(f"Start retrieving previously uploaded files ") - uploaded_files = _get_previously_uploaded_files() - next_time = time.time() - logger.info(f"Retrieved list of previously uploaded files") + logger.info("Retrieved list of previously uploaded files") # Upload files asynchronously tasks = [] for file_path in files_to_upload: logger.info(f"checking if {file_path} needs to be uploaded") - tasks.append(upload_single_file(file_path, uploaded_files)) + tasks.append(_upload_single_file(file_path, uploaded_files, client)) remote_local_pairs = await tqdm.asyncio.tqdm.gather(*tasks) diff --git a/src/prompto/apis/gemini/gemini_utils.py b/src/prompto/apis/gemini/gemini_utils.py index c9a3c17..c3cde41 100644 --- a/src/prompto/apis/gemini/gemini_utils.py +++ b/src/prompto/apis/gemini/gemini_utils.py @@ -1,12 +1,13 @@ import os -import google.generativeai as genai import PIL.Image +from google.genai import types +from google.genai.client import Client gemini_chat_roles = set(["user", "model"]) -def parse_parts_value(part: dict | str, media_folder: str) -> any: +def parse_parts_value(part: dict | str, media_folder: str, client: Client) -> any: """ Parse part dictionary and create a dictionary input for Gemini API. If part is a string, a dictionary to represent a text object is returned. @@ -27,7 +28,7 @@ def parse_parts_value(part: dict | str, media_folder: str) -> any: Multimedia data object """ if isinstance(part, str): - return part + return types.Part.from_text(text=part) # read multimedia type media_type = part.get("type") @@ -54,14 +55,16 @@ def parse_parts_value(part: dict | str, media_folder: str) -> any: ) else: try: - return genai.get_file(name=uploaded_filename) + return client.aio.files.get(name=uploaded_filename) except Exception as err: raise ValueError( f"Failed to get file: {media} due to error: {type(err).__name__} - {err}" ) -def parse_parts(parts: list[dict | str] | dict | str, media_folder: str) -> list[any]: +def parse_parts( + parts: list[dict | str] | dict | str, media_folder: str, client: Client +) -> list[any]: """ Parse parts data and create a list of multimedia data objects. If parts is a single dictionary, a list with a single multimedia data object is returned. @@ -83,10 +86,14 @@ def parse_parts(parts: list[dict | str] | dict | str, media_folder: str) -> list if isinstance(parts, dict) or isinstance(parts, str): parts = [parts] - return [parse_parts_value(p, media_folder=media_folder) for p in parts] + return [ + parse_parts_value(p, media_folder=media_folder, client=client) for p in parts + ] -def convert_dict_to_input(content_dict: dict, media_folder: str) -> dict: +def convert_history_dict_to_content( + content_dict: dict, media_folder: str, client: Client +) -> types.Content: """ Convert dictionary to an input that can be used by the Gemini API. The output is a dictionary with keys "role" and "parts". @@ -112,13 +119,14 @@ def convert_dict_to_input(content_dict: dict, media_folder: str) -> dict: if "parts" not in content_dict: raise KeyError("parts key is missing in content dictionary") - return { - "role": content_dict["role"], - "parts": parse_parts( + return types.Content( + role=content_dict["role"], + parts=parse_parts( content_dict["parts"], media_folder=media_folder, + client=client, ), - } + ) def process_response(response: dict) -> str: @@ -135,8 +143,47 @@ def process_response(response: dict) -> str: str The processed response text as a string """ - response_text = response.candidates[0].content.parts[0].text - return response_text + answer, _ = _process_answer_and_thoughts(response) + return answer + + +def process_thoughts(response: dict) -> list[str]: + """ + Helper function to process the thoughts from Gemini API. + + Parameters + ---------- + response : dict + The response from the Gemini API as a dictionary + + Returns + ------- + list[str] + The thoughts as a list of strings + """ + _, thoughts = _process_answer_and_thoughts(response) + return thoughts + + +def _process_answer_and_thoughts(response: dict) -> tuple[str, list[str]]: + # This is a helper function that implements process_response and process_thoughts. + # + # These two functions are only trivially different. This implementation is potentially + # inefficient as it encourages the caller to call both functions. + # However the pattern of `process_response` is widely used throughout the codebase. So + # it is better to keep the same pattern. + thoughts = [] + answers = [] + for candidate in response.candidates: + for part in candidate.content.parts: + if part.thought: + thoughts.append(part.text) + else: + answers.append(part.text) + + assert len(answers) == 1, "There should be only one answer" + + return answers[0], thoughts def process_safety_attributes(response: dict) -> dict: diff --git a/src/prompto/upload_media.py b/src/prompto/upload_media.py index 4621c1c..e4e137c 100644 --- a/src/prompto/upload_media.py +++ b/src/prompto/upload_media.py @@ -1,12 +1,11 @@ import argparse -import asyncio import json import logging import os import prompto.apis.gemini.gemini_media as gemini_media -from prompto.apis import ASYNC_APIS from prompto.scripts.run_experiment import load_env_file +from prompto.settings import Settings # initialise logging logger = logging.getLogger(__name__) @@ -224,12 +223,14 @@ def upload_media_parse_args(): def do_delete_existing_files(args): - gemini_media.delete_uploaded_files() + settings = _create_settings() + gemini_media.delete_uploaded_files(settings) return def do_list_uploaded_files(args): - gemini_media.list_uploaded_files() + settings = _create_settings() + gemini_media.list_uploaded_files(settings) return @@ -249,10 +250,29 @@ def _do_upload_media_from_args(args): args : argparse.Namespace """ _resolve_output_file_location(args) - asyncio.run(do_upload_media(args.file, args.media_folder, args.output_file)) + do_upload_media(args.file, args.media_folder, args.output_file) -async def do_upload_media(input_file, media_folder, output_file): +def _create_settings(): + """ + Create a dummy settings object for the Gemini API. + This is used to create a client object for the API. + For now, we just create a temporary directory for the data folder. + """ + # TODO: A better solution would be to create an option in the + # Settings constructor to not create the directories. + # But for now we'll just pass it a temporary directory. + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + data_folder = os.path.join(temp_dir, "data") + os.makedirs(data_folder, exist_ok=True) + dummy_settings = Settings(data_folder=data_folder) + + return dummy_settings + + +def do_upload_media(input_file, media_folder, output_file): """ Upload media files to the relevant API. The media files are uploaded and the experiment file is updated with the uploaded filenames. @@ -276,7 +296,8 @@ async def do_upload_media(input_file, media_folder, output_file): # If in future we support other bulk upload to other APIs, we will need to # refactor here - uploaded_files = await gemini_media.upload_media_files_async(files_to_upload) + settings = _create_settings() + uploaded_files = gemini_media.upload_media_files(files_to_upload, settings) update_experiment_file( prompt_dict_list, diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index 60846b3..10c33d6 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -2,8 +2,19 @@ import pytest import regex as re -from google.generativeai import GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory +from google.genai.client import AsyncClient, Client +from google.genai.types import ( + Candidate, + Content, + FinishReason, + GenerateContentConfig, + GenerateContentResponse, + HarmBlockThreshold, + HarmCategory, + Part, + SafetySetting, + ThinkingConfig, +) from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings @@ -62,12 +73,108 @@ def prompt_dict_history_no_system(): } -DEFAULT_SAFETY_SETTINGS = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, -} +@pytest.fixture +def non_thinking_response(): + + # Some example responses, with and without thinking included. + return GenerateContentResponse( + candidates=[ + Candidate( + content=Content( + parts=[ + Part( + video_metadata=None, + thought=None, + inline_data=None, + code_execution_result=None, + executable_code=None, + file_data=None, + function_call=None, + function_response=None, + text="A spontaneous answer", + ), + ], + role="model", + ), + citation_metadata=None, + finish_message=None, + token_count=None, + finish_reason=FinishReason.STOP, + avg_logprobs=None, + grounding_metadata=None, + index=0, + logprobs_result=None, + safety_ratings=None, + ) + ] + ) + + +@pytest.fixture +def thinking_response(): + return GenerateContentResponse( + candidates=[ + Candidate( + content=Content( + parts=[ + Part( + video_metadata=None, + thought=True, + inline_data=None, + code_execution_result=None, + executable_code=None, + file_data=None, + function_call=None, + function_response=None, + text="Some thinking", + ), + Part( + video_metadata=None, + thought=None, + inline_data=None, + code_execution_result=None, + executable_code=None, + file_data=None, + function_call=None, + function_response=None, + text="A thought out answer", + ), + ], + role="model", + ), + citation_metadata=None, + finish_message=None, + token_count=None, + finish_reason=FinishReason.STOP, + avg_logprobs=None, + grounding_metadata=None, + index=0, + logprobs_result=None, + safety_ratings=None, + ) + ] + ) + + +DEFAULT_SAFETY_SETTINGS = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), +] + TYPE_ERROR_MSG = ( "if api == 'gemini', then the prompt must be a str, list[str], or " @@ -256,93 +363,97 @@ def test_gemini_check_prompt_dict(temporary_data_folders, monkeypatch): raise test_case[0] # set the GEMINI_API_KEY environment variable - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") - # error if the model-specific environment variable is not set - test_case = GeminiAPI.check_prompt_dict( - { - "api": "gemini", - "model_name": "gemini_model_name", - "prompt": "test prompt", - } - ) - assert len(test_case) == 1 - with pytest.raises( - Warning, - match=re.escape( - "Environment variable 'GEMINI_API_KEY_gemini_model_name' is not set" - ), - ): - raise test_case[0] - - # unset the GEMINI_API_KEY environment variable and - # set the model-specific environment variable - monkeypatch.delenv("GEMINI_API_KEY") - monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") - test_case = GeminiAPI.check_prompt_dict( - { - "api": "gemini", - "model_name": "gemini_model_name", - "prompt": "test prompt", - } - ) - assert len(test_case) == 1 - with pytest.raises( - Warning, match=re.escape("Environment variable 'GEMINI_API_KEY' is not set") - ): - raise test_case[0] - - # full passes - # set both environment variables - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") - assert ( - GeminiAPI.check_prompt_dict( + with monkeypatch.context() as m1: + m1.setenv("GEMINI_API_KEY", "DUMMY") + # error if the model-specific environment variable is not set + test_case = GeminiAPI.check_prompt_dict( { "api": "gemini", "model_name": "gemini_model_name", "prompt": "test prompt", } ) - == [] - ) - assert ( - GeminiAPI.check_prompt_dict( + assert len(test_case) == 1 + with pytest.raises( + Warning, + match=re.escape( + "Environment variable 'GEMINI_API_KEY_gemini_model_name' is not set" + ), + ): + raise test_case[0] + + with monkeypatch.context() as m2: + # unset the GEMINI_API_KEY environment variable and + # set the model-specific environment variable + m2.delenv("GEMINI_API_KEY", raising=False) + m2.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") + test_case = GeminiAPI.check_prompt_dict( { "api": "gemini", "model_name": "gemini_model_name", - "prompt": ["prompt 1", "prompt 2"], + "prompt": "test prompt", } ) - == [] - ) - assert ( - GeminiAPI.check_prompt_dict( - { - "api": "gemini", - "model_name": "gemini_model_name", - "prompt": [ - {"role": "system", "parts": "system prompt"}, - {"role": "user", "parts": "user message 1"}, - {"role": "model", "parts": "model message"}, - {"role": "user", "parts": "user message 2"}, - ], - } + assert len(test_case) == 1 + with pytest.raises( + Warning, match=re.escape("Environment variable 'GEMINI_API_KEY' is not set") + ): + raise test_case[0] + + # full passes + # set both environment variables + with monkeypatch.context() as m3: + m3.setenv("GEMINI_API_KEY", "DUMMY") + m3.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") + assert ( + GeminiAPI.check_prompt_dict( + { + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": "test prompt", + } + ) + == [] ) - == [] - ) - assert ( - GeminiAPI.check_prompt_dict( - { - "api": "gemini", - "model_name": "gemini_model_name", - "prompt": [ - {"role": "user", "parts": "user message 1"}, - {"role": "model", "parts": "model message"}, - {"role": "user", "parts": "user message 2"}, - ], - } + assert ( + GeminiAPI.check_prompt_dict( + { + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": ["prompt 1", "prompt 2"], + } + ) + == [] + ) + assert ( + GeminiAPI.check_prompt_dict( + { + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": [ + {"role": "system", "parts": "system prompt"}, + {"role": "user", "parts": "user message 1"}, + {"role": "model", "parts": "model message"}, + {"role": "user", "parts": "user message 2"}, + ], + } + ) + == [] + ) + assert ( + GeminiAPI.check_prompt_dict( + { + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": [ + {"role": "user", "parts": "user message 1"}, + {"role": "model", "parts": "model message"}, + {"role": "user", "parts": "user message 2"}, + ], + } + ) + == [] ) - == [] - ) @pytest.mark.asyncio @@ -366,11 +477,14 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - assert isinstance(test_case[2], GenerativeModel) - assert test_case[2]._model_name == "models/gemini_model_name" - assert test_case[2]._system_instruction is None - assert isinstance(test_case[3], dict) - assert test_case[4] == {"temperature": 1, "max_output_tokens": 100} + assert isinstance(test_case[2], Client) + assert isinstance(test_case[2].aio, AsyncClient) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].system_instruction is None + assert test_case[3].temperature == 1 + assert test_case[3].max_output_tokens == 100 + assert test_case[3].safety_settings == DEFAULT_SAFETY_SETTINGS + assert test_case[4] is None # test for case where no parameters in prompt_dict test_case = await gemini_api._obtain_model_inputs( @@ -385,11 +499,14 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - assert isinstance(test_case[2], GenerativeModel) - assert test_case[2]._model_name == "models/gemini_model_name" - assert test_case[2]._system_instruction is None - assert isinstance(test_case[3], dict) - assert test_case[4] == {} + assert isinstance(test_case[2], Client) + assert isinstance(test_case[2].aio, AsyncClient) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].system_instruction is None + assert test_case[3].temperature is None + assert test_case[3].max_output_tokens is None + assert test_case[3].safety_settings == DEFAULT_SAFETY_SETTINGS + assert test_case[4] is None # test for case where system_instruction is provided test_case = await gemini_api._obtain_model_inputs( @@ -405,11 +522,12 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - assert isinstance(test_case[2], GenerativeModel) - assert test_case[2]._model_name == "models/gemini_model_name" - assert test_case[2]._system_instruction is not None - assert isinstance(test_case[3], dict) - assert test_case[4] == {} + assert isinstance(test_case[2], Client) + assert isinstance(test_case[2].aio, AsyncClient) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].system_instruction is not None + assert test_case[3].safety_settings == DEFAULT_SAFETY_SETTINGS + assert test_case[4] is None # test error catching when parameters are not a dictionary with pytest.raises( @@ -446,6 +564,74 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): ) +@pytest.mark.asyncio +async def test_gemini_obtain_model_inputs_thinking_config( + temporary_data_folders, monkeypatch +): + settings = Settings(data_folder="data") + log_file = "log.txt" + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") + gemini_api = GeminiAPI(settings=settings, log_file=log_file) + + # These test cases are very similar to the test above, so we will not repeat assertions for all + # attributes - only those that are relevant to the thinking config + + # Case 1: test with *NO* thinking config parameters provided + test_case = await gemini_api._obtain_model_inputs( + { + "id": "gemini_id", + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": "test prompt", + "parameters": {"temperature": 1, "max_output_tokens": 100}, + } + ) + assert isinstance(test_case, tuple) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].thinking_config is None + + # Case 2: test with thinking config parameters provided + test_case = await gemini_api._obtain_model_inputs( + { + "id": "gemini_id", + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": "test prompt", + "parameters": { + "temperature": 1, + "max_output_tokens": 100, + "thinking_budget": 1234, + "include_thoughts": True, + }, + } + ) + + assert isinstance(test_case, tuple) + assert isinstance(test_case[3], GenerateContentConfig) + assert isinstance(test_case[3].thinking_config, ThinkingConfig) + assert test_case[3].thinking_config.thinking_budget == 1234 + assert test_case[3].thinking_config.include_thoughts is True + + # Case 3: test with invalid thinking config parameters provided + + with pytest.raises(ValueError): + # thinking_budget is out of range + test_case = await gemini_api._obtain_model_inputs( + { + "id": "gemini_id", + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": "test prompt", + "parameters": { + "temperature": 1, + "max_output_tokens": 100, + "thinking_budget": 12345678901234567890, # out of range + "include_thoughts": True, + }, + } + ) + + @pytest.mark.asyncio async def test_gemini_obtain_model_inputs_safety_filters( temporary_data_folders, monkeypatch @@ -455,9 +641,15 @@ async def test_gemini_obtain_model_inputs_safety_filters( monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) - valid_safety_filter_choices = ["none", "few", "default", "some", "most"] + valid_safety_filter_choices = { + "none": HarmBlockThreshold.BLOCK_NONE, + "few": HarmBlockThreshold.BLOCK_ONLY_HIGH, + "default": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "some": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "most": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + } - for safety_filter in valid_safety_filter_choices: + for safety_filter, expected_threshold in valid_safety_filter_choices.items(): test_case = await gemini_api._obtain_model_inputs( { "id": "gemini_id", @@ -472,11 +664,19 @@ async def test_gemini_obtain_model_inputs_safety_filters( assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - assert isinstance(test_case[2], GenerativeModel) - assert test_case[2]._model_name == "models/gemini_model_name" - assert test_case[2]._system_instruction is None - assert isinstance(test_case[3], dict) - assert test_case[4] == {"temperature": 1, "max_output_tokens": 100} + assert isinstance(test_case[2], Client) + assert isinstance(test_case[2].aio, AsyncClient) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].system_instruction is None + assert test_case[3].temperature == 1 + assert test_case[3].max_output_tokens == 100 + assert test_case[4] is None + + # Assert that the safety settings contain the expected threshold for all categories + assert all( + safety_set.threshold == expected_threshold + for safety_set in test_case[3].safety_settings + ) # test error if safety filter is not recognised with pytest.raises( diff --git a/tests/apis/gemini/test_gemini_chat_input.py b/tests/apis/gemini/test_gemini_chat_input.py index b621c11..b10a620 100644 --- a/tests/apis/gemini/test_gemini_chat_input.py +++ b/tests/apis/gemini/test_gemini_chat_input.py @@ -2,21 +2,29 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from google.generativeai import GenerativeModel +from google.genai.chats import AsyncChat, AsyncChats +from google.genai.types import GenerateContentConfig +import prompto.utils from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings from ...conftest import CopyingAsyncMock -from .test_gemini import DEFAULT_SAFETY_SETTINGS, prompt_dict_chat +from .test_gemini import ( + DEFAULT_SAFETY_SETTINGS, + non_thinking_response, + prompt_dict_chat, + thinking_response, +) pytest_plugins = ("pytest_asyncio",) @pytest.mark.asyncio async def test_gemini_query_chat_no_env_var( - prompt_dict_chat, temporary_data_folders, caplog + prompt_dict_chat, temporary_data_folders, caplog, monkeypatch ): + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" @@ -34,40 +42,40 @@ async def test_gemini_query_chat_no_env_var( @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +@patch.object( + AsyncChat, + "send_message", + new_callable=CopyingAsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_chat( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_chat, temporary_data_folders, + non_thinking_response, + thinking_response, monkeypatch, caplog, ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) - # mock the response from the API - # NOTE: The actual response from the API is a - # google.generativeai.types.AsyncGenerateContentResponse object - # not a string value, but for the purpose of this test, we are using a string value - # and testing that this is the input to the process_response function + # Mock the response from the API gemini_api_sequence_responses = [ - "response Messages object 1", - "response Messages object 2", + thinking_response, + non_thinking_response, ] mock_gemini_call.side_effect = gemini_api_sequence_responses # mock the process_response function - process_response_sequence_responses = ["response text 1", "response text 2"] - mock_process_response.side_effect = process_response_sequence_responses + process_response_sequence_responses = [ + "A thought out answer", + "A spontaneous answer", + ] # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_chat.keys() @@ -75,28 +83,33 @@ async def test_gemini_query_chat( # call the _query_chat method prompt_dict = await gemini_api._query_chat(prompt_dict_chat, index=0) + expected_thinking_text = [["Some thinking"], []] + # assert that the response key is added to the prompt_dict assert "response" in prompt_dict.keys() + assert "thinking_text" in prompt_dict.keys() + assert isinstance(prompt_dict["thinking_text"], list) + assert prompt_dict["thinking_text"] == expected_thinking_text assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) mock_gemini_call.assert_awaited_with( - content=prompt_dict_chat["prompt"][1], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][1], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) - assert mock_process_response.call_count == 2 - mock_process_response.assert_any_call(gemini_api_sequence_responses[0]) - mock_process_response.assert_called_with(gemini_api_sequence_responses[1]) - assert mock_process_safety_attr.call_count == 2 mock_process_safety_attr.assert_any_call(gemini_api_sequence_responses[0]) mock_process_safety_attr.assert_called_with(gemini_api_sequence_responses[1]) @@ -125,35 +138,42 @@ async def test_gemini_query_chat( @pytest.mark.asyncio -@patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) +@patch.object( + AsyncChats, + "create", + new_callable=AsyncMock, +) @patch( "prompto.apis.gemini.gemini.GeminiAPI._obtain_model_inputs", new_callable=AsyncMock ) async def test_gemini_query_history_check_chat_init( mock_obtain_model_inputs, - mock_start_chat, + mock_chat_create, prompt_dict_chat, temporary_data_folders, monkeypatch, caplog, ): + monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) mock_obtain_model_inputs.return_value = ( prompt_dict_chat["prompt"], prompt_dict_chat["model_name"], - GenerativeModel( - model_name=prompt_dict_chat["model_name"], system_instruction=None + gemini_api._get_client("gemini_model_name"), + GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, ), - DEFAULT_SAFETY_SETTINGS, prompt_dict_chat["parameters"], ) - # error will be raised as we've mocked the start_chat method + # error will be raised as we've mocked the Chats.create method # which leads to an error when the method is called on the mocked object with pytest.raises(Exception): await gemini_api._query_chat(prompt_dict_chat, index=0) @@ -161,20 +181,31 @@ async def test_gemini_query_history_check_chat_init( mock_obtain_model_inputs.assert_called_once_with( prompt_dict=prompt_dict_chat, system_instruction=None ) - mock_start_chat.assert_called_once_with(history=[]) + mock_chat_create.assert_called_once_with( + model="gemini_model_name", + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), + history=[], + ) @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +@patch.object( + AsyncChat, + "send_message", + new_callable=CopyingAsyncMock, ) async def test_gemini_query_chat_index_error_1( mock_gemini_call, prompt_dict_chat, temporary_data_folders, monkeypatch, caplog ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) # mock index error response from the API @@ -192,10 +223,12 @@ async def test_gemini_query_chat_index_error_1( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( @@ -212,16 +245,18 @@ async def test_gemini_query_chat_index_error_1( @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +@patch.object( + AsyncChat, + "send_message", + new_callable=CopyingAsyncMock, ) async def test_gemini_query_chat_error_1( mock_gemini_call, prompt_dict_chat, temporary_data_folders, monkeypatch, caplog ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) # mock error response from the API @@ -234,10 +269,12 @@ async def test_gemini_query_chat_error_1( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( @@ -251,69 +288,78 @@ async def test_gemini_query_chat_error_1( @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +@patch.object( + AsyncChat, + "send_message", + new_callable=CopyingAsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_chat_index_error_2( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_chat, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) # mock error response from the API from second response gemini_api_sequence_responses = [ - "response Messages object 1", + non_thinking_response, IndexError("Test error"), ] mock_gemini_call.side_effect = gemini_api_sequence_responses - # mock the process_response function - mock_process_response.return_value = "response text 1" - # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_chat.keys() # call the _query_chat method prompt_dict = await gemini_api._query_chat(prompt_dict_chat, index=0) + expected_answer = "A spontaneous answer" + # assert that the response key is added to the prompt_dict assert "response" in prompt_dict.keys() + assert "thinking_text" in prompt_dict.keys() + assert isinstance(prompt_dict["thinking_text"], list) + assert prompt_dict["thinking_text"] == [[], []] assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 + mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) + mock_gemini_call.assert_awaited_with( - content=prompt_dict_chat["prompt"][1], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][1], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) - mock_process_response.assert_called_once_with(gemini_api_sequence_responses[0]) - mock_process_safety_attr.assert_called_once_with(gemini_api_sequence_responses[0]) + mock_process_safety_attr.assert_any_call(gemini_api_sequence_responses[0]) expected_log_message_1 = ( f"Response received for model Gemini ({prompt_dict_chat['model_name']}) " "(i=0, id=gemini_id, message=1/2)\n" f"Prompt: {prompt_dict_chat['prompt'][0][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message_1 in caplog.text @@ -321,73 +367,78 @@ async def test_gemini_query_chat_index_error_2( f"Error with model Gemini ({prompt_dict_chat['model_name']}) " "(i=0, id=gemini_id, message=2/2)\n" f"Prompt: {prompt_dict_chat['prompt'][1][:50]}...\n" - f"Responses so far: {[mock_process_response.return_value]}...\n" + f"Responses so far: {[expected_answer]}...\n" "Error: Response is empty and blocked (IndexError - Test error)" ) assert expected_log_message_2 in caplog.text # assert that the response value is the first response value and an empty string - assert prompt_dict["response"] == [mock_process_response.return_value, ""] + assert prompt_dict["response"] == [expected_answer, ""] @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +@patch.object( + AsyncChat, + "send_message", + new_callable=CopyingAsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_chat_error_2( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_chat, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) + expected_answer = "A spontaneous answer" # mock error response from the API from second response gemini_api_sequence_responses = [ - "response Messages object 1", + non_thinking_response, Exception("Test error"), ] mock_gemini_call.side_effect = gemini_api_sequence_responses - # mock the process_response function - mock_process_response.return_value = "response text 1" - # raise error if the API call fails with pytest.raises(Exception, match="Test error"): await gemini_api._query_chat(prompt_dict_chat, index=0) assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 + mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) + mock_gemini_call.assert_awaited_with( - content=prompt_dict_chat["prompt"][1], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][1], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) - mock_process_response.assert_called_once_with(gemini_api_sequence_responses[0]) mock_process_safety_attr.assert_called_once_with(gemini_api_sequence_responses[0]) expected_log_message_1 = ( f"Response received for model Gemini ({prompt_dict_chat['model_name']}) " "(i=0, id=gemini_id, message=1/2)\n" f"Prompt: {prompt_dict_chat['prompt'][0][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message_1 in caplog.text @@ -395,7 +446,7 @@ async def test_gemini_query_chat_error_2( f"Error with model Gemini ({prompt_dict_chat['model_name']}) " "(i=0, id=gemini_id, message=2/2)\n" f"Prompt: {prompt_dict_chat['prompt'][1][:50]}...\n" - f"Responses so far: {[mock_process_response.return_value]}...\n" + f"Responses so far: {[expected_answer]}...\n" "Error: Exception - Test error" ) assert expected_log_message_2 in caplog.text diff --git a/tests/apis/gemini/test_gemini_history_input.py b/tests/apis/gemini/test_gemini_history_input.py index 5b0f8fe..c422ce9 100644 --- a/tests/apis/gemini/test_gemini_history_input.py +++ b/tests/apis/gemini/test_gemini_history_input.py @@ -2,16 +2,18 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from google.generativeai import GenerativeModel + +# from google.generativeai import GenerativeModel +from google.genai.chats import AsyncChat, AsyncChats, Chat +from google.genai.types import Content, GenerateContentConfig, Part from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings -from .test_gemini import ( - DEFAULT_SAFETY_SETTINGS, - prompt_dict_history, - prompt_dict_history_no_system, -) +from .test_gemini import non_thinking_response # nopa: F401 +from .test_gemini import prompt_dict_history # nopa: F401 +from .test_gemini import prompt_dict_history_no_system # nopa: F401 +from .test_gemini import DEFAULT_SAFETY_SETTINGS pytest_plugins = ("pytest_asyncio",) @@ -37,15 +39,18 @@ async def test_gemini_query_history_no_env_var( @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) +@patch.object( + AsyncChat, + "send_message", + new_callable=AsyncMock, +) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_history( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_history, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): @@ -55,15 +60,8 @@ async def test_gemini_query_history( monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) - # mock the response from the API - # NOTE: The actual response from the API is a - # google.generativeai.types.AsyncGenerateContentResponse object - # not a string value, but for the purpose of this test, we are using a string value - # and testing that this is the input to the process_response function - mock_gemini_call.return_value = "response Messages object" - - # mock the process_response function - mock_process_response.return_value = "response text" + # Mock the response from the API + mock_gemini_call.return_value = non_thinking_response # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_history.keys() @@ -71,35 +69,42 @@ async def test_gemini_query_history( # call the _query_history method prompt_dict = await gemini_api._query_history(prompt_dict_history, index=0) + expected_answer = "A spontaneous answer" + # assert that the response key is added to the prompt_dict assert "response" in prompt_dict.keys() + # assert "thinking_text" is in prompt_dict + assert "thinking_text" in prompt_dict.keys() + assert isinstance(prompt_dict["thinking_text"], list) + assert prompt_dict["thinking_text"] == [] mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + mock_gemini_call.assert_awaited_once_with( - content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, - generation_config=prompt_dict_history["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history["prompt"][1]["parts"]), ) - mock_process_response.assert_called_once_with(mock_gemini_call.return_value) mock_process_safety_attr.assert_called_once_with(mock_gemini_call.return_value) # assert that the response value is the return value of the process_response function - assert prompt_dict["response"] == mock_process_response.return_value + assert prompt_dict["response"] == expected_answer expected_log_message = ( f"Response received for model Gemini ({prompt_dict_history['model_name']}) " "(i=0, id=gemini_id)\n" f"Prompt: {prompt_dict_history['prompt'][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message in caplog.text @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + AsyncChat, + "send_message", + new_callable=AsyncMock, +) async def test_gemini_query_history_error( mock_gemini_call, prompt_dict_history, temporary_data_folders, monkeypatch, caplog ): @@ -118,11 +123,9 @@ async def test_gemini_query_history_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + mock_gemini_call.assert_awaited_once_with( - content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, - generation_config=prompt_dict_history["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history["prompt"][1]["parts"]), ) expected_log_message = ( @@ -135,7 +138,11 @@ async def test_gemini_query_history_error( @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + AsyncChat, + "send_message", + new_callable=AsyncMock, +) async def test_gemini_query_history_index_error( mock_gemini_call, prompt_dict_history, temporary_data_folders, monkeypatch, caplog ): @@ -159,11 +166,9 @@ async def test_gemini_query_history_index_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + mock_gemini_call.assert_awaited_once_with( - content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, - generation_config=prompt_dict_history["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history["prompt"][1]["parts"]), ) expected_log_message = ( @@ -179,7 +184,11 @@ async def test_gemini_query_history_index_error( @pytest.mark.asyncio -@patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) +@patch.object( + AsyncChats, + "create", + new_callable=AsyncMock, +) @patch( "prompto.apis.gemini.gemini.GeminiAPI._obtain_model_inputs", new_callable=AsyncMock ) @@ -197,14 +206,19 @@ async def test_gemini_query_history_check_chat_init( monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) + mock_generate_content_config = ( + GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), + ) + mock_obtain_model_inputs.return_value = ( prompt_dict_history["prompt"], prompt_dict_history["model_name"], - GenerativeModel( - model_name=prompt_dict_history["model_name"], - system_instruction=prompt_dict_history["prompt"][0]["parts"], - ), - DEFAULT_SAFETY_SETTINGS, + gemini_api._get_client("gemini_model_name"), + mock_generate_content_config, prompt_dict_history["parameters"], ) @@ -217,19 +231,24 @@ async def test_gemini_query_history_check_chat_init( prompt_dict=prompt_dict_history, system_instruction=prompt_dict_history["prompt"][0]["parts"], ) - mock_start_chat.assert_called_once_with(history=[]) + mock_start_chat.assert_called_once_with( + model="gemini_model_name", config=mock_generate_content_config, history=[] + ) @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) +@patch.object( + AsyncChat, + "send_message", + new_callable=AsyncMock, +) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_history_no_system( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_history_no_system, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): @@ -243,10 +262,11 @@ async def test_gemini_query_history_no_system( # NOTE: The actual response from the API is a gemini.types.message.Message object # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function - mock_gemini_call.return_value = "response Messages object" + mock_gemini_call.return_value = non_thinking_response # mock the process_response function - mock_process_response.return_value = "response text" + # mock_process_response.return_value = "response text" + expected_answer = "A spontaneous answer" # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_history_no_system.keys() @@ -261,33 +281,31 @@ async def test_gemini_query_history_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + mock_gemini_call.assert_awaited_once_with( - content={ - "role": "user", - "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], - }, - generation_config=prompt_dict_history_no_system["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]) ) - mock_process_response.assert_called_once_with(mock_gemini_call.return_value) mock_process_safety_attr.assert_called_once_with(mock_gemini_call.return_value) # assert that the response value is the return value of the process_response function - assert prompt_dict["response"] == mock_process_response.return_value + assert prompt_dict["response"] == expected_answer expected_log_message = ( f"Response received for model Gemini ({prompt_dict_history_no_system['model_name']}) " "(i=0, id=gemini_id)\n" f"Prompt: {prompt_dict_history_no_system['prompt'][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message in caplog.text @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + AsyncChat, + "send_message", + new_callable=AsyncMock, +) async def test_gemini_query_history_error_no_system( mock_gemini_call, prompt_dict_history_no_system, @@ -310,14 +328,9 @@ async def test_gemini_query_history_error_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + mock_gemini_call.assert_awaited_once_with( - content={ - "role": "user", - "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], - }, - generation_config=prompt_dict_history_no_system["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]), ) expected_log_message = ( @@ -330,7 +343,11 @@ async def test_gemini_query_history_error_no_system( @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + AsyncChat, + "send_message", + new_callable=AsyncMock, +) async def test_gemini_query_history_index_error_no_system( mock_gemini_call, prompt_dict_history_no_system, @@ -360,14 +377,9 @@ async def test_gemini_query_history_index_error_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + mock_gemini_call.assert_awaited_once_with( - content={ - "role": "user", - "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], - }, - generation_config=prompt_dict_history_no_system["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]), ) expected_log_message = ( @@ -383,7 +395,11 @@ async def test_gemini_query_history_index_error_no_system( @pytest.mark.asyncio -@patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) +@patch.object( + AsyncChats, + "create", + new_callable=AsyncMock, +) @patch( "prompto.apis.gemini.gemini.GeminiAPI._obtain_model_inputs", new_callable=AsyncMock ) @@ -401,14 +417,19 @@ async def test_gemini_query_history_no_system_check_chat_init( monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) + mock_generate_content_config = ( + GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), + ) + mock_obtain_model_inputs.return_value = ( prompt_dict_history_no_system["prompt"], prompt_dict_history_no_system["model_name"], - GenerativeModel( - model_name=prompt_dict_history_no_system["model_name"], - system_instruction=None, - ), - DEFAULT_SAFETY_SETTINGS, + gemini_api._get_client("gemini_model_name"), + mock_generate_content_config, prompt_dict_history_no_system["parameters"], ) @@ -421,14 +442,16 @@ async def test_gemini_query_history_no_system_check_chat_init( prompt_dict=prompt_dict_history_no_system, system_instruction=None ) mock_start_chat.assert_called_once_with( + model="gemini_model_name", + config=mock_generate_content_config, history=[ - { - "role": "user", - "parts": [prompt_dict_history_no_system["prompt"][0]["parts"]], - }, - { - "role": "model", - "parts": [prompt_dict_history_no_system["prompt"][1]["parts"]], - }, - ] + Content( + role="user", + parts=[Part(text=prompt_dict_history_no_system["prompt"][0]["parts"])], + ), + Content( + role="model", + parts=[Part(text=prompt_dict_history_no_system["prompt"][1]["parts"])], + ), + ], ) diff --git a/tests/apis/gemini/test_gemini_image_input.py b/tests/apis/gemini/test_gemini_image_input.py index d25d1d0..33d01d3 100644 --- a/tests/apis/gemini/test_gemini_image_input.py +++ b/tests/apis/gemini/test_gemini_image_input.py @@ -1,6 +1,9 @@ import os +from unittest.mock import patch import pytest +from google.genai import Client +from google.genai.types import Part from PIL import Image from prompto.apis.gemini.gemini_utils import parse_parts_value @@ -9,15 +12,27 @@ def test_parse_parts_value_text(): part = "text" media_folder = "media" - result = parse_parts_value(part, media_folder) - assert result == part + # There is no uploaded media in the prompt_dict_chat, hence the + # client is not required, and we can pass `None`. + mock_client = None + actual_result = parse_parts_value(part, media_folder, mock_client) + expected_result = Part(text="text") + assert actual_result == expected_result def test_parse_parts_value_image(): + # This is a string, which happens to use a keyword "image", + # but it is not a key within dictionary. + # This test simply asserts that the string is handled correctly + # as a string. part = "image" media_folder = "media" - result = parse_parts_value(part, media_folder) - assert result == part + # There is no uploaded media in the prompt_dict_chat, hence the + # client is not required, and we can pass `None`. + mock_client = None + actual_result = parse_parts_value(part, media_folder, mock_client) + expected_result = Part(text="image") + assert actual_result == expected_result def test_parse_parts_value_image_dict(tmp_path): @@ -27,6 +42,7 @@ def test_parse_parts_value_image_dict(tmp_path): # The image should be loaded from the local file. part = {"type": "image", "media": "pantani_giro.jpg"} media_folder = tmp_path / "media" + mock_client = None media_folder.mkdir(parents=True, exist_ok=True) @@ -35,7 +51,7 @@ def test_parse_parts_value_image_dict(tmp_path): image = Image.new("RGB", (100, 100), color="red") image.save(image_path) - actual_result = parse_parts_value(part, str(media_folder)) + actual_result = parse_parts_value(part, str(media_folder), mock_client) # Assert the result assert actual_result.mode == "RGB" @@ -46,10 +62,11 @@ def test_parse_parts_value_image_dict(tmp_path): def test_parse_parts_value_video_not_uploaded(): part = {"type": "video", "media": "pantani_giro.mp4"} media_folder = "media" + mock_client = None # Because the video is not uploaded, we expect a ValueError with pytest.raises(ValueError) as excinfo: - parse_parts_value(part, media_folder) + parse_parts_value(part, media_folder, mock_client) print(excinfo) assert "not uploaded" in str(excinfo.value) @@ -65,23 +82,23 @@ def test_parse_parts_value_video_uploaded(monkeypatch): # parameter for the parse_parts_value function media_folder = "media" - # Mock the google.generativeai get_file function - # We don't want to call the real get_file function and it would be tricky to - # assert the binary data returned by the function. + # Mock the `google.genai.Client().files.get`` function + # The real `get` function returns the binary contents of the file + # which would be tricky to assert. # Instead, we will just return the uploaded_filename - def mock_get_file(name): - return name - - # Replace the original get_file function with the mock - # ***It is important that the import statement used here is exactly the same as - # the one in the gemini_utils.py file*** - import google.generativeai as genai - - monkeypatch.setattr(genai, "get_file", mock_get_file) - - # Assert that the mock function was called with the expected argument - assert genai.get_file("check mocked function") == "check mocked function" - - expected_result = "file/123456" - actual_result = parse_parts_value(part, media_folder) - assert actual_result == expected_result + mock_get_file_no_op = lambda name: name + + # Replace the original `get` function with the mock + with monkeypatch.context() as m: + # Mock the get function + client = Client(api_key="DUMMY") + m.setattr(client.aio.files, "get", mock_get_file_no_op) + # Assert that the mock function was called with the expected argument + assert ( + client.aio.files.get(name="check mocked function") + == "check mocked function" + ) + + expected_result = "file/123456" + actual_result = parse_parts_value(part, media_folder, client) + assert actual_result == expected_result diff --git a/tests/apis/gemini/test_gemini_media.py b/tests/apis/gemini/test_gemini_media.py new file mode 100644 index 0000000..ada2d2c --- /dev/null +++ b/tests/apis/gemini/test_gemini_media.py @@ -0,0 +1,317 @@ +import logging +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from google import genai +from google.genai.types import FileState + +from prompto.apis.gemini.gemini import GeminiAPI +from prompto.apis.gemini.gemini_media import ( + delete_uploaded_files, + list_uploaded_files, + remote_file_hash_base64, + upload_media_files, + wait_for_processing, +) +from prompto.upload_media import _create_settings + + +def test_remote_file_hash_base64(): + + # The example hashes are for the strings "hash1", "hash2", "hash3" eg: + # >>> "hash1".encode("utf-8").hex() + # '6861736831' + test_cases = [ + ( + Mock(dummy_name="file1", sha256_hash="6861736831"), + "aGFzaDE=", + ), + ( + Mock(dummy_name="file2", sha256_hash="6861736832"), + "aGFzaDI=", + ), + ( + Mock(dummy_name="file3", sha256_hash="6861736833"), + "aGFzaDM=", + ), + ] + + # We need to juggle with the mock names, because we can't set them + # directly in the constructor. See these docs for more details: + # https://docs.python.org/3/library/unittest.mock.html#mock-names-and-the-name-attribute + for mock_file, expected_hash in test_cases: + mock_file.configure_mock(name=mock_file.dummy_name) + mock_file.__str__ = Mock(return_value=mock_file.dummy_name) + + actual_hash = remote_file_hash_base64(mock_file) + assert ( + actual_hash == expected_hash + ), f"Expected {expected_hash}, but got {actual_hash}" + + +@pytest.mark.asyncio +@patch.object( + genai.files.Files, + "get", + new_callable=Mock, +) +async def test_wait_for_processing(mock_file_get, monkeypatch): + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + dummy_settings = _create_settings() + gemini_api = GeminiAPI(settings=dummy_settings, log_file=None) + client = gemini_api._get_client("default") + + # These mocks represent the same file, but at different states/points in time + starting_file = Mock( + dummy_name="file1", state=FileState.PROCESSING, sha256_hash="aGFzaDE=" + ) + + side_effects = [ + Mock(name="file1", state=FileState.PROCESSING, sha256_hash="aGFzaDE="), + Mock(name="file1", state=FileState.PROCESSING, sha256_hash="aGFzaDE="), + Mock(name="file1", state=FileState.ACTIVE, sha256_hash="aGFzaDE="), + # We should never to this, but including it to differentiate + # between the function completing because it picked up on the + # previous file==ACTIVE (correct), or if the function completed + # because it ran out of side effects (incorrect) + Mock(name="file1", state=FileState.PROCESSING, sha256_hash="aGFzaDE="), + ] + + mock_file_get.side_effect = side_effects + + # Call the function to test + await wait_for_processing(starting_file, client, poll_interval=0) + + # Check that the `get` method was called exactly 3 times + assert mock_file_get.call_count == 3 + + +@patch.object( + genai.files.AsyncFiles, + "list", + new_callable=AsyncMock, +) +@patch( + "prompto.apis.gemini.gemini_media.compute_sha256_base64", + new_callable=MagicMock, +) +def test_upload_media_files_already_uploaded( + mock_compute_sha256_base64, mock_list_files, monkeypatch, caplog +): + caplog.set_level(logging.INFO) + + uploaded_file = Mock(dummy_name="remote_uploaded/file1", sha256_hash="aGFzaDE=") + uploaded_file.configure_mock(name=uploaded_file.dummy_name) + uploaded_file.__str__ = Mock(return_value=uploaded_file.dummy_name) + + local_file_path = "dummy_local_path/file1.txt" + expected_log_msg = ( + "File 'dummy_local_path/file1.txt' already uploaded as 'remote_uploaded/file1'" + ) + + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + + mock_compute_sha256_base64.return_value = "aGFzaDE=" + # return_value is a list of a single mock remote file + mock_list_files.return_value = [uploaded_file] + dummy_settings = None + + # Pass a list of local file paths to the function + actual_uploads = upload_media_files([local_file_path], dummy_settings) + + # actual_uploads is a dict of local and remote file names + assert local_file_path in actual_uploads + assert actual_uploads[local_file_path] == "remote_uploaded/file1" + + # Check the log message + assert expected_log_msg in caplog.text + + +@patch( + "prompto.apis.gemini.gemini_media._get_previously_uploaded_files", + new_callable=AsyncMock, +) +@patch.object( + genai.files.AsyncFiles, + "upload", + new_callable=AsyncMock, +) +@patch( + "prompto.apis.gemini.gemini_media.compute_sha256_base64", + new_callable=MagicMock, +) +def test_upload_media_files_new_file( + mock_compute_sha256_base64, + mock_files_upload, + mock_previous_files, + monkeypatch, + caplog, +): + """ + Test the upload_media_files function when the file is not already uploaded, but there are already + other files uploaded.""" + caplog.set_level(logging.INFO) + + pre_uploaded_file = Mock( + dummy_name="remote_uploaded/file1", + sha256_hash=Mock(decode=lambda _: "6861736831"), + ) + pre_uploaded_file.configure_mock(name=pre_uploaded_file.dummy_name) + pre_uploaded_file.__str__ = Mock(return_value=pre_uploaded_file.dummy_name) + + previous_files_dict = { + "hash1": pre_uploaded_file, + } + + local_file_path = "dummy_local_path/file2.txt" + expected_log_msgs = [ + "Uploading dummy_local_path/file2.txt to Gemini API", + "Uploaded file 'remote_uploaded/file2' with hash 'hash2' to Gemini API", + ] + + new_file = Mock( + dummy_name="remote_uploaded/file2", + sha256_hash=Mock(decode=lambda _: "hash2"), + ) + new_file.configure_mock(name=new_file.dummy_name) + new_file.__str__ = Mock(return_value=new_file.name) + + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + + mock_compute_sha256_base64.return_value = "hash2" + mock_previous_files.return_value = previous_files_dict + mock_files_upload.return_value = new_file + + dummy_settings = None + actual_uploads = upload_media_files([local_file_path], dummy_settings) + + print(actual_uploads) + + # actual_uploads is a dict of local and remote file names + assert local_file_path in actual_uploads + assert actual_uploads[local_file_path] == "remote_uploaded/file2" + + # Check that the previously uploaded file is not in the actual_uploads dict + assert pre_uploaded_file.dummy_name not in actual_uploads + + # Check the log message + assert all(msg in caplog.text for msg in expected_log_msgs) + + +@patch.object( + genai.files.AsyncFiles, + "list", + new_callable=AsyncMock, +) +def test_list_uploaded_files(mock_list_files, caplog, monkeypatch): + caplog.set_level(logging.INFO) + + # Case 1: No files uploaded + case_1 = { + "return_value": [], + "expected_log_msgs": ["Found 0 files already uploaded at Gemini API"], + } + + # Case 2: Three files uploaded + # The example hashes are for the strings "hash1", "hash2", "hash3" eg: + # >>> "hash1".encode("utf-8").hex() + # '6861736831' + case_2 = { + "return_value": [ + Mock(dummy_name="file1", sha256_hash="aGFzaDE="), + Mock(dummy_name="file2", sha256_hash="aGFzaDI="), + Mock(dummy_name="file3", sha256_hash="aGFzaDM="), + ], + "expected_log_msgs": [ + "Found 3 files already uploaded at Gemini API", + "File Name: file1, File Hash: aGFzaDE=", + "File Name: file2, File Hash: aGFzaDI=", + "File Name: file3, File Hash: aGFzaDM=", + ], + } + + expected_final_log_msg = "All uploaded files listed." + + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + + for case_dict in [case_1, case_2]: + + mocked_list_value = case_dict["return_value"] + + # We need to juggle with the mock names, because we can't set them + # directly in the constructor. See these docs for more details: + # https://docs.python.org/3/library/unittest.mock.html#mock-names-and-the-name-attribute + for mock_file in mocked_list_value: + mock_file.configure_mock(name=mock_file.dummy_name) + mock_file.__str__ = Mock(return_value=mock_file.dummy_name) + + mock_list_files.return_value = mocked_list_value + + expected_total_in_log_msg = case_dict["expected_log_msgs"] + + dummy_settings = _create_settings() + # Call the function to test + list_uploaded_files(dummy_settings) + + # There is no return value from the function, so we need to check the + # log messages + for msg in expected_total_in_log_msg: + assert msg in caplog.text + assert expected_final_log_msg in caplog.text + + +@patch.object( + genai.files.Files, + "delete", + new_callable=Mock, +) +@patch.object( + genai.files.Files, + "list", + new_callable=Mock, +) +def test_delete_uploaded_files(mock_list_files, mock_delete, caplog, monkeypatch): + + caplog.set_level(logging.INFO) + + # Case 1: No files uploaded + case_1 = [] + + # Case 2: Three files uploaded + # The example hashes are for the strings "hash1", "hash2", "hash3" eg: + # >>> "hash1".encode("utf-8").hex() + # '6861736831' + case_2 = [ + Mock(dummy_name="file1", sha256_hash="aGFzaDE="), + Mock(dummy_name="file2", sha256_hash="aGFzaDI="), + Mock(dummy_name="file3", sha256_hash="aGFzaDM="), + ] + expected_final_log_msg = "All uploaded files deleted." + + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + + for mocked_list_value in [case_1, case_2]: + + # We need to juggle with the mock names, because we can't set them + # directly in the constructor. See these docs for more details: + # https://docs.python.org/3/library/unittest.mock.html#mock-names-and-the-name-attribute + for mock_file in mocked_list_value: + mock_file.configure_mock(name=mock_file.dummy_name) + mock_file.__str__ = Mock(return_value=mock_file.dummy_name) + + mock_list_files.return_value = mocked_list_value + + dummy_settings = _create_settings() + # Call the function to test + delete_uploaded_files(dummy_settings) + + # There is no return value from the function, so we need to check the + # that the delete function was called the expected number of times + # Add 1 to force it to fail for now. + assert mock_delete.call_count == len(mocked_list_value) + assert expected_final_log_msg in caplog.text diff --git a/tests/apis/gemini/test_gemini_string_input.py b/tests/apis/gemini/test_gemini_string_input.py index 0554f50..2d2fbc7 100644 --- a/tests/apis/gemini/test_gemini_string_input.py +++ b/tests/apis/gemini/test_gemini_string_input.py @@ -2,23 +2,30 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from google.genai.chats import AsyncChat +from google.genai.client import AsyncClient +from google.genai.models import AsyncModels +from google.genai.types import GenerateContentConfig from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings -from .test_gemini import DEFAULT_SAFETY_SETTINGS, prompt_dict_string +from .test_gemini import ( + DEFAULT_SAFETY_SETTINGS, + non_thinking_response, + prompt_dict_string, +) pytest_plugins = ("pytest_asyncio",) @pytest.mark.asyncio async def test_gemini_query_string_no_env_var( - prompt_dict_string, temporary_data_folders, caplog + prompt_dict_string, temporary_data_folders, caplog, monkeypatch ): caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - gemini_api = GeminiAPI(settings=settings, log_file=log_file) # raise error if no environment variable is set with pytest.raises( @@ -28,21 +35,23 @@ async def test_gemini_query_string_no_env_var( "environment variable is set." ), ): + gemini_api = GeminiAPI(settings=settings, log_file=log_file) await gemini_api._query_string(prompt_dict_string, index=0) @pytest.mark.asyncio -@patch( - "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock +@patch.object( + AsyncModels, + "generate_content", + new_callable=AsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_string( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_string, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): @@ -52,15 +61,8 @@ async def test_gemini_query_string( monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) - # mock the response from the API - # NOTE: The actual response from the API is a - # google.generativeai.types.AsyncGenerateContentResponse object - # not a string value, but for the purpose of this test, we are using a string value - # and testing that this is the input to the process_response function - mock_gemini_call.return_value = "response Messages object" - - # mock the process_response function - mock_process_response.return_value = "response text" + # Mock the response from the API + mock_gemini_call.return_value = non_thinking_response # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_string.keys() @@ -68,36 +70,46 @@ async def test_gemini_query_string( # call the _query_string method prompt_dict = await gemini_api._query_string(prompt_dict_string, index=0) + expected_answer = "A spontaneous answer" + # assert that the response key is added to the prompt_dict assert "response" in prompt_dict.keys() + # assert "thinking_text" is in prompt_dict + assert "thinking_text" in prompt_dict.keys() + assert isinstance(prompt_dict["thinking_text"], list) + assert prompt_dict["thinking_text"] == [] mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=prompt_dict_string["prompt"], - generation_config=prompt_dict_string["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + model="gemini_model_name", + contents="test prompt", + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) - mock_process_response.assert_called_once_with(mock_gemini_call.return_value) mock_process_safety_attr.assert_called_once_with(mock_gemini_call.return_value) # assert that the response value is the return value of the process_response function - assert prompt_dict["response"] == mock_process_response.return_value + assert prompt_dict["response"] == expected_answer expected_log_message = ( f"Response received for model Gemini ({prompt_dict_string['model_name']}) " "(i=0, id=gemini_id)\n" f"Prompt: {prompt_dict_string['prompt'][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message in caplog.text @pytest.mark.asyncio -@patch( - "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock +@patch.object( + AsyncModels, + "generate_content", + new_callable=AsyncMock, ) async def test_gemini_query_string__index_error( mock_gemini_call, prompt_dict_string, temporary_data_folders, monkeypatch, caplog @@ -123,10 +135,13 @@ async def test_gemini_query_string__index_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=prompt_dict_string["prompt"], - generation_config=prompt_dict_string["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + model="gemini_model_name", + contents="test prompt", + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( @@ -142,8 +157,10 @@ async def test_gemini_query_string__index_error( @pytest.mark.asyncio -@patch( - "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock +@patch.object( + AsyncModels, + "generate_content", + new_callable=AsyncMock, ) async def test_gemini_query_string_error( mock_gemini_call, prompt_dict_string, temporary_data_folders, monkeypatch, caplog @@ -164,10 +181,13 @@ async def test_gemini_query_string_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=prompt_dict_string["prompt"], - generation_config=prompt_dict_string["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + model="gemini_model_name", + contents="test prompt", + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( diff --git a/tests/apis/gemini/test_gemini_utils.py b/tests/apis/gemini/test_gemini_utils.py new file mode 100644 index 0000000..4d81059 --- /dev/null +++ b/tests/apis/gemini/test_gemini_utils.py @@ -0,0 +1,67 @@ +import pytest +from google.genai.types import ( + Candidate, + Content, + FinishReason, + GenerateContentResponse, + Part, +) + +from prompto.apis.gemini.gemini_utils import ( + convert_history_dict_to_content, + process_response, + process_thoughts, +) + +from .test_gemini import non_thinking_response, prompt_dict_history, thinking_response + + +def test_convert_history_dict_to_content(prompt_dict_history): + + media_folder = "media_folder" + mock_client = None + + expected_result_list = [ + Content(parts=[Part(text="test system prompt")], role="system"), + Content(parts=[Part(text="user message")], role="user"), + ] + + # There is no uploaded media in the prompt_dict_chat, hence the + # client is not required, and we can pass `None`. + prompt_list = prompt_dict_history["prompt"] + + for content_dict, expected_result in zip(prompt_list, expected_result_list): + actual_result = convert_history_dict_to_content( + content_dict, media_folder, mock_client + ) + assert ( + actual_result == expected_result + ), f"Expected {expected_result}, but got {actual_result}" + + +def test_process_response(thinking_response, non_thinking_response): + """ + Test the process_response function. + """ + # Each test case is a tuple of (response, expected_answer, expected_thoughts) + test_cases = [ + (non_thinking_response, "A spontaneous answer", []), + (thinking_response, "A thought out answer", ["Some thinking"]), + ] + + for response, expected_answer, expected_thoughts in test_cases: + # Call the function with the test case + actual_answer = process_response(response) + actual_thoughts = process_thoughts(response) + + # Assert the expected response + assert ( + actual_answer == expected_answer + ), f"Expected {expected_answer}, but got {actual_answer}" + + assert isinstance( + actual_thoughts, list + ), f"Expected a list, but got {type(actual_thoughts)}" + assert ( + actual_thoughts == expected_thoughts + ), f"Expected {expected_thoughts}, but got {actual_thoughts}"