From d9d6c65c6c8947afea0eacd7975b4ff8355c1a64 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Wed, 23 Apr 2025 13:35:37 +0100 Subject: [PATCH 01/21] WIP: update tests with new SDK --- pyproject.toml | 22 ++++-- tests/apis/gemini/test_gemini.py | 34 +++++++-- tests/apis/gemini/test_gemini_chat_input.py | 56 +++++++++++---- .../apis/gemini/test_gemini_history_input.py | 72 ++++++++++++++----- tests/apis/gemini/test_gemini_image_input.py | 33 +++++---- tests/apis/gemini/test_gemini_string_input.py | 38 +++++++--- 6 files changed, 192 insertions(+), 63 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4ea6d527..188915a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,8 @@ mkdocs-jupyter = { version = "^0.24.7", optional = true } cli-test-helpers = { version = "^4.0.0", optional = true } vertexai = { version ="^1.71.1", optional = true } google-cloud-aiplatform = { version = "^1.71.1", optional = true } -google-generativeai = { version = "^0.8.4", optional = true } -google-genai = { version = "^0.7.0", optional = true } +# google-generativeai = { version = "^0.8.4", optional = true } # TODO: deprecated - to be removed +google-genai = { version = "^1.11.0", optional = true } openai = { version = "^1.60.0", optional = true } pillow = { version = "^11.1.0", optional = true } ollama = { version = "^0.4.7", optional = true } @@ -66,7 +66,7 @@ all = [ "cli-test-helpers", "vertexai", "google-cloud-aiplatform", - "google-generativeai", + # "google-generativeai", "google-genai", "openai", "pillow", @@ -97,8 +97,20 @@ dev = [ "mkdocs-jupyter", "cli-test-helpers", ] -gemini = ["vertexai", "google-cloud-aiplatform", "google-generativeai", "google-genai", "pillow"] -vertexai = ["vertexai", "google-cloud-aiplatform", "google-generativeai", "google-genai", "pillow"] +gemini = [ + "vertexai", + "google-cloud-aiplatform", + # "google-generativeai", # TODO: deprecated - to be removed + "google-genai", + "pillow" +] +vertexai = [ + "vertexai", + "google-cloud-aiplatform", + "google-generativeai", + "google-genai", + "pillow" +] azure_openai = ["openai", "pillow"] openai = ["openai", "pillow"] ollama = ["ollama"] diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index 60846b3e..199eb613 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -2,12 +2,16 @@ import pytest import regex as re -from google.generativeai import GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory +from google.genai.client import AsyncClient +from google.genai.types import HarmBlockThreshold, HarmCategory from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings +# from google.generativeai import GenerativeModel +# from google.generativeai.types import HarmBlockThreshold, HarmCategory + + pytest_plugins = ("pytest_asyncio",) @@ -69,6 +73,7 @@ def prompt_dict_history_no_system(): HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, } + TYPE_ERROR_MSG = ( "if api == 'gemini', then the prompt must be a str, list[str], or " "list[dict[str,str]] where the dictionary contains the keys 'role' and " @@ -366,7 +371,11 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - assert isinstance(test_case[2], GenerativeModel) + # TODO: For now assume that the most sensible thing for the `_obtain_model_inputs` to return + # here is the AsyncClient instance. It may be that returning nothing is the sensible thing to do. + # in which case we should update `assert len(test_case) == 4` and update the indexes. + # assert isinstance(test_case[2], GenerativeModel) + assert isinstance(test_case[2], AsyncClient) assert test_case[2]._model_name == "models/gemini_model_name" assert test_case[2]._system_instruction is None assert isinstance(test_case[3], dict) @@ -385,7 +394,11 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - assert isinstance(test_case[2], GenerativeModel) + # TODO: For now assume that the most sensible thing for the `_obtain_model_inputs` tp return + # here is the AsyncClient instance. It may be that retuning nothing is the sensible thing to do. + # in which case we should update `assert len(test_case) == 4` and update the indexes. + # assert isinstance(test_case[2], GenerativeModel) + assert isinstance(test_case[2], AsyncClient) assert test_case[2]._model_name == "models/gemini_model_name" assert test_case[2]._system_instruction is None assert isinstance(test_case[3], dict) @@ -405,7 +418,12 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - assert isinstance(test_case[2], GenerativeModel) + + # TODO: For now assume that the most sensible thing for the `_obtain_model_inputs` tp return + # here is the AsyncClient instance. It may be that retuning nothing is the sensible thing to do. + # in which case we should update `assert len(test_case) == 4` and update the indexes. + # assert isinstance(test_case[2], GenerativeModel) + assert isinstance(test_case[2], AsyncClient) assert test_case[2]._model_name == "models/gemini_model_name" assert test_case[2]._system_instruction is not None assert isinstance(test_case[3], dict) @@ -472,7 +490,11 @@ async def test_gemini_obtain_model_inputs_safety_filters( assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - assert isinstance(test_case[2], GenerativeModel) + # TODO: For now assume that the most sensible thing for the `_obtain_model_inputs` tp return + # here is the AsyncClient instance. It may be that retuning nothing is the sensible thing to do. + # in which case we should update `assert len(test_case) == 4` and update the indexes. + # assert isinstance(test_case[2], GenerativeModel) + assert isinstance(test_case[2], AsyncClient) assert test_case[2]._model_name == "models/gemini_model_name" assert test_case[2]._system_instruction is None assert isinstance(test_case[3], dict) diff --git a/tests/apis/gemini/test_gemini_chat_input.py b/tests/apis/gemini/test_gemini_chat_input.py index b621c112..6139c4f8 100644 --- a/tests/apis/gemini/test_gemini_chat_input.py +++ b/tests/apis/gemini/test_gemini_chat_input.py @@ -2,7 +2,9 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from google.generativeai import GenerativeModel + +# from google.generativeai import GenerativeModel +from google.genai.chats import AsyncChats, Chat from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings @@ -34,8 +36,13 @@ async def test_gemini_query_chat_no_env_var( @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# @patch( +# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# ) +@patch.object( + Chat, + "send_message", + new_callable=CopyingAsyncMock, ) @patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) @@ -125,7 +132,12 @@ async def test_gemini_query_chat( @pytest.mark.asyncio -@patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) +# @patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) +@patch.object( + AsyncChats, + "create", + new_callable=AsyncMock, +) @patch( "prompto.apis.gemini.gemini.GeminiAPI._obtain_model_inputs", new_callable=AsyncMock ) @@ -165,8 +177,13 @@ async def test_gemini_query_history_check_chat_init( @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# @patch( +# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# ) +@patch.object( + Chat, + "send_message", + new_callable=CopyingAsyncMock, ) async def test_gemini_query_chat_index_error_1( mock_gemini_call, prompt_dict_chat, temporary_data_folders, monkeypatch, caplog @@ -212,8 +229,13 @@ async def test_gemini_query_chat_index_error_1( @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# @patch( +# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# ) +@patch.object( + Chat, + "send_message", + new_callable=CopyingAsyncMock, ) async def test_gemini_query_chat_error_1( mock_gemini_call, prompt_dict_chat, temporary_data_folders, monkeypatch, caplog @@ -251,8 +273,13 @@ async def test_gemini_query_chat_error_1( @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# @patch( +# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# ) +@patch.object( + Chat, + "send_message", + new_callable=CopyingAsyncMock, ) @patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) @@ -331,8 +358,13 @@ async def test_gemini_query_chat_index_error_2( @pytest.mark.asyncio -@patch( - "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# @patch( +# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock +# ) +@patch.object( + Chat, + "send_message", + new_callable=CopyingAsyncMock, ) @patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) diff --git a/tests/apis/gemini/test_gemini_history_input.py b/tests/apis/gemini/test_gemini_history_input.py index 5b0f8fe5..f555e2a7 100644 --- a/tests/apis/gemini/test_gemini_history_input.py +++ b/tests/apis/gemini/test_gemini_history_input.py @@ -2,16 +2,16 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from google.generativeai import GenerativeModel + +# from google.generativeai import GenerativeModel +from google.genai.chats import AsyncChats, Chat from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings -from .test_gemini import ( - DEFAULT_SAFETY_SETTINGS, - prompt_dict_history, - prompt_dict_history_no_system, -) +from .test_gemini import prompt_dict_history # nopa: F401 +from .test_gemini import prompt_dict_history_no_system # nopa: F401 +from .test_gemini import DEFAULT_SAFETY_SETTINGS pytest_plugins = ("pytest_asyncio",) @@ -37,7 +37,12 @@ async def test_gemini_query_history_no_env_var( @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + Chat, + "send_message", + new_callable=AsyncMock, +) @patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_history( @@ -55,9 +60,9 @@ async def test_gemini_query_history( monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) - # mock the response from the API + # Mock the response from the API # NOTE: The actual response from the API is a - # google.generativeai.types.AsyncGenerateContentResponse object + # google.genai.types.GenerateContentResponse object # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function mock_gemini_call.return_value = "response Messages object" @@ -99,7 +104,12 @@ async def test_gemini_query_history( @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + Chat, + "send_message", + new_callable=AsyncMock, +) async def test_gemini_query_history_error( mock_gemini_call, prompt_dict_history, temporary_data_folders, monkeypatch, caplog ): @@ -135,7 +145,12 @@ async def test_gemini_query_history_error( @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + Chat, + "send_message", + new_callable=AsyncMock, +) async def test_gemini_query_history_index_error( mock_gemini_call, prompt_dict_history, temporary_data_folders, monkeypatch, caplog ): @@ -179,7 +194,12 @@ async def test_gemini_query_history_index_error( @pytest.mark.asyncio -@patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) +# @patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) +@patch.object( + AsyncChats, + "create", + new_callable=AsyncMock, +) @patch( "prompto.apis.gemini.gemini.GeminiAPI._obtain_model_inputs", new_callable=AsyncMock ) @@ -221,7 +241,12 @@ async def test_gemini_query_history_check_chat_init( @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + Chat, + "send_message", + new_callable=AsyncMock, +) @patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_history_no_system( @@ -287,7 +312,12 @@ async def test_gemini_query_history_no_system( @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + Chat, + "send_message", + new_callable=AsyncMock, +) async def test_gemini_query_history_error_no_system( mock_gemini_call, prompt_dict_history_no_system, @@ -330,7 +360,12 @@ async def test_gemini_query_history_error_no_system( @pytest.mark.asyncio -@patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) +@patch.object( + Chat, + "send_message", + new_callable=AsyncMock, +) async def test_gemini_query_history_index_error_no_system( mock_gemini_call, prompt_dict_history_no_system, @@ -383,7 +418,12 @@ async def test_gemini_query_history_index_error_no_system( @pytest.mark.asyncio -@patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) +# @patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) +@patch.object( + AsyncChats, + "create", + new_callable=AsyncMock, +) @patch( "prompto.apis.gemini.gemini.GeminiAPI._obtain_model_inputs", new_callable=AsyncMock ) diff --git a/tests/apis/gemini/test_gemini_image_input.py b/tests/apis/gemini/test_gemini_image_input.py index d25d1d09..989f5cbc 100644 --- a/tests/apis/gemini/test_gemini_image_input.py +++ b/tests/apis/gemini/test_gemini_image_input.py @@ -1,4 +1,5 @@ import os +from unittest.mock import patch import pytest from PIL import Image @@ -55,6 +56,7 @@ def test_parse_parts_value_video_not_uploaded(): assert "not uploaded" in str(excinfo.value) +# @patch("google.genai.client.files.get") def test_parse_parts_value_video_uploaded(monkeypatch): part = { "type": "video", @@ -65,23 +67,24 @@ def test_parse_parts_value_video_uploaded(monkeypatch): # parameter for the parse_parts_value function media_folder = "media" - # Mock the google.generativeai get_file function - # We don't want to call the real get_file function and it would be tricky to - # assert the binary data returned by the function. + # Mock the `google.genai.Client().files.get`` function + # The real `get` function returns the binary contents of the file + # which would be tricky to assert. # Instead, we will just return the uploaded_filename - def mock_get_file(name): - return name + mock_get_file_no_op = lambda name: name # Replace the original get_file function with the mock # ***It is important that the import statement used here is exactly the same as # the one in the gemini_utils.py file*** - import google.generativeai as genai - - monkeypatch.setattr(genai, "get_file", mock_get_file) - - # Assert that the mock function was called with the expected argument - assert genai.get_file("check mocked function") == "check mocked function" - - expected_result = "file/123456" - actual_result = parse_parts_value(part, media_folder) - assert actual_result == expected_result + import google.genai as genai + + with monkeypatch.context() as m: + # Mock the get_file function + client = genai.Client(api_key="DUMMY") + m.setattr(client.files, "get", mock_get_file_no_op) + # Assert that the mock function was called with the expected argument + assert client.files.get(name="check mocked function") == "check mocked function" + + expected_result = "file/123456" + actual_result = parse_parts_value(part, media_folder) + assert actual_result == expected_result diff --git a/tests/apis/gemini/test_gemini_string_input.py b/tests/apis/gemini/test_gemini_string_input.py index 0554f50c..26f9a05a 100644 --- a/tests/apis/gemini/test_gemini_string_input.py +++ b/tests/apis/gemini/test_gemini_string_input.py @@ -2,6 +2,8 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from google.genai.client import AsyncClient +from google.genai.models import AsyncModels from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings @@ -32,8 +34,13 @@ async def test_gemini_query_string_no_env_var( @pytest.mark.asyncio -@patch( - "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock +# @patch( +# "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock +# ) +@patch.object( + AsyncModels, + "generate_content", + new_callable=AsyncMock, ) @patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) @@ -52,11 +59,14 @@ async def test_gemini_query_string( monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) - # mock the response from the API - # NOTE: The actual response from the API is a - # google.generativeai.types.AsyncGenerateContentResponse object + # Mock the response from the API + # NOTE: The actual response from the API is a (probably) + # google.genai.types.GenerateContentResponse object # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function + # TODO: Check if there is a difference in the return type of + # `google.genai.client.aio.models.generate_content`` and + # `google.genai.client.models.generate_content` mock_gemini_call.return_value = "response Messages object" # mock the process_response function @@ -96,8 +106,13 @@ async def test_gemini_query_string( @pytest.mark.asyncio -@patch( - "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock +# @patch( +# "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock +# ) +@patch.object( + AsyncModels, + "generate_content", + new_callable=AsyncMock, ) async def test_gemini_query_string__index_error( mock_gemini_call, prompt_dict_string, temporary_data_folders, monkeypatch, caplog @@ -142,8 +157,13 @@ async def test_gemini_query_string__index_error( @pytest.mark.asyncio -@patch( - "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock +# @patch( +# "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock +# ) +@patch.object( + AsyncModels, + "generate_content", + new_callable=AsyncMock, ) async def test_gemini_query_string_error( mock_gemini_call, prompt_dict_string, temporary_data_folders, monkeypatch, caplog From 5de42f70ae333a28c4de72b075c3141797d28f86 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Thu, 24 Apr 2025 16:47:56 +0100 Subject: [PATCH 02/21] More updates to tests with new SDK --- tests/apis/gemini/test_gemini.py | 111 +++++++++++++++++++------------ 1 file changed, 68 insertions(+), 43 deletions(-) diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index 199eb613..65e5f047 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -2,8 +2,13 @@ import pytest import regex as re -from google.genai.client import AsyncClient -from google.genai.types import HarmBlockThreshold, HarmCategory +from google.genai.client import AsyncClient, Client +from google.genai.types import ( + GenerateContentConfig, + HarmBlockThreshold, + HarmCategory, + SafetySetting, +) from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings @@ -66,12 +71,24 @@ def prompt_dict_history_no_system(): } -DEFAULT_SAFETY_SETTINGS = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, -} +DEFAULT_SAFETY_SETTINGS = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), +] TYPE_ERROR_MSG = ( @@ -371,15 +388,14 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - # TODO: For now assume that the most sensible thing for the `_obtain_model_inputs` to return - # here is the AsyncClient instance. It may be that returning nothing is the sensible thing to do. - # in which case we should update `assert len(test_case) == 4` and update the indexes. - # assert isinstance(test_case[2], GenerativeModel) - assert isinstance(test_case[2], AsyncClient) - assert test_case[2]._model_name == "models/gemini_model_name" - assert test_case[2]._system_instruction is None - assert isinstance(test_case[3], dict) - assert test_case[4] == {"temperature": 1, "max_output_tokens": 100} + assert isinstance(test_case[2], Client) + assert isinstance(test_case[2].aio, AsyncClient) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].system_instruction is None + assert test_case[3].temperature == 1 + assert test_case[3].max_output_tokens == 100 + assert test_case[3].safety_settings == DEFAULT_SAFETY_SETTINGS + assert test_case[4] is None # test for case where no parameters in prompt_dict test_case = await gemini_api._obtain_model_inputs( @@ -398,11 +414,14 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): # here is the AsyncClient instance. It may be that retuning nothing is the sensible thing to do. # in which case we should update `assert len(test_case) == 4` and update the indexes. # assert isinstance(test_case[2], GenerativeModel) - assert isinstance(test_case[2], AsyncClient) - assert test_case[2]._model_name == "models/gemini_model_name" - assert test_case[2]._system_instruction is None - assert isinstance(test_case[3], dict) - assert test_case[4] == {} + assert isinstance(test_case[2], Client) + assert isinstance(test_case[2].aio, AsyncClient) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].system_instruction is None + assert test_case[3].temperature is None + assert test_case[3].max_output_tokens is None + assert test_case[3].safety_settings == DEFAULT_SAFETY_SETTINGS + assert test_case[4] is None # test for case where system_instruction is provided test_case = await gemini_api._obtain_model_inputs( @@ -418,16 +437,12 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - - # TODO: For now assume that the most sensible thing for the `_obtain_model_inputs` tp return - # here is the AsyncClient instance. It may be that retuning nothing is the sensible thing to do. - # in which case we should update `assert len(test_case) == 4` and update the indexes. - # assert isinstance(test_case[2], GenerativeModel) - assert isinstance(test_case[2], AsyncClient) - assert test_case[2]._model_name == "models/gemini_model_name" - assert test_case[2]._system_instruction is not None - assert isinstance(test_case[3], dict) - assert test_case[4] == {} + assert isinstance(test_case[2], Client) + assert isinstance(test_case[2].aio, AsyncClient) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].system_instruction is not None + assert test_case[3].safety_settings == DEFAULT_SAFETY_SETTINGS + assert test_case[4] is None # test error catching when parameters are not a dictionary with pytest.raises( @@ -473,9 +488,15 @@ async def test_gemini_obtain_model_inputs_safety_filters( monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) - valid_safety_filter_choices = ["none", "few", "default", "some", "most"] + valid_safety_filter_choices = { + "none": HarmBlockThreshold.BLOCK_NONE, + "few": HarmBlockThreshold.BLOCK_ONLY_HIGH, + "default": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "some": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "most": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + } - for safety_filter in valid_safety_filter_choices: + for safety_filter, expected_threshold in valid_safety_filter_choices.items(): test_case = await gemini_api._obtain_model_inputs( { "id": "gemini_id", @@ -490,15 +511,19 @@ async def test_gemini_obtain_model_inputs_safety_filters( assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - # TODO: For now assume that the most sensible thing for the `_obtain_model_inputs` tp return - # here is the AsyncClient instance. It may be that retuning nothing is the sensible thing to do. - # in which case we should update `assert len(test_case) == 4` and update the indexes. - # assert isinstance(test_case[2], GenerativeModel) - assert isinstance(test_case[2], AsyncClient) - assert test_case[2]._model_name == "models/gemini_model_name" - assert test_case[2]._system_instruction is None - assert isinstance(test_case[3], dict) - assert test_case[4] == {"temperature": 1, "max_output_tokens": 100} + assert isinstance(test_case[2], Client) + assert isinstance(test_case[2].aio, AsyncClient) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].system_instruction is None + assert test_case[3].temperature == 1 + assert test_case[3].max_output_tokens == 100 + assert test_case[4] is None + + # Assert that the safety settings contain the expected threshold for all categories + assert all( + safety_set.threshold == expected_threshold + for safety_set in test_case[3].safety_settings + ) # test error if safety filter is not recognised with pytest.raises( From 283a56447e1ffa9f6a255946e37c142cb4a870c9 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Thu, 24 Apr 2025 16:49:12 +0100 Subject: [PATCH 03/21] Update GeminiAPI class string & chat methods with new SDK --- src/prompto/apis/gemini/gemini.py | 216 ++++++++++++------ tests/apis/gemini/test_gemini_chat_input.py | 159 +++++++++---- tests/apis/gemini/test_gemini_string_input.py | 34 ++- 3 files changed, 288 insertions(+), 121 deletions(-) diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index b345f5f6..5b355e09 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -1,9 +1,17 @@ import logging from typing import Any -import google.generativeai as genai -from google.generativeai import GenerativeModel -from google.generativeai.types import GenerationConfig, HarmBlockThreshold, HarmCategory +from google.genai import Client + +# import google.generativeai as genai +# from google.generativeai import GenerativeModel +# from google.generativeai.types import GenerationConfig, HarmBlockThreshold, HarmCategory +from google.genai.types import ( + GenerateContentConfig, + HarmBlockThreshold, + HarmCategory, + SafetySetting, +) from prompto.apis.base import AsyncAPI from prompto.apis.gemini.gemini_utils import ( @@ -54,6 +62,8 @@ class GeminiAPI(AsyncAPI): The path to the log file """ + _clients: dict[str, Client] = {} + def __init__( self, settings: Settings, @@ -177,9 +187,37 @@ def check_prompt_dict(prompt_dict: dict) -> list[Exception]: return issues + def _get_client(self, model_name) -> Client: + """ + Method to get the client for the Gemini API. A separate client is created for each model name, to allow for + model-specific API keys to be used. + + The client is created only once per model name and stored in the clients dictionary. + If the client is already created, it is returned from the dictionary. + + Parameters + ---------- + model_name : str + The name of the model to use + + Returns + ------- + Client + A client for the Gemini API + """ + # If the Client does not exist, create it + if model_name not in self._clients: + api_key = get_environment_variable( + env_variable=API_KEY_VAR_NAME, model_name=model_name + ) + self._clients[model_name] = Client(api_key=api_key) + + # Return the client for the model name + return self._clients[model_name] + async def _obtain_model_inputs( self, prompt_dict: dict, system_instruction: str | None = None - ) -> tuple[str, str, GenerativeModel, dict, dict, list | None]: + ) -> tuple[str, str, Client, GenerateContentConfig, list | None]: """ Async method to obtain the model inputs from the prompt dictionary. @@ -193,26 +231,19 @@ async def _obtain_model_inputs( Returns ------- - tuple[str, str, dict, dict, list | None] - A tuple containing the prompt, model name, GenerativeModel instance, - safety settings, the generation config, and list of multimedia parts - (if passed) to use for querying the model + tuple[str, str, Client, GenerateContentConfig, list | None] + A tuple containing: + - the prompt, + - model name, + - Client instance, + - GenerateContentConfig instance (which incorporates the safety settings), + - (optional) list of multimedia parts (if passed) to use for querying the model or None """ prompt = prompt_dict["prompt"] # obtain model name model_name = prompt_dict["model_name"] - api_key = get_environment_variable( - env_variable=API_KEY_VAR_NAME, model_name=model_name - ) - - # configure the API key - genai.configure(api_key=api_key) - - # create the model instance - model = GenerativeModel( - model_name=model_name, system_instruction=system_instruction - ) + client = self._get_client(model_name) # define safety settings safety_filter = prompt_dict.get("safety_filter", None) @@ -221,33 +252,82 @@ async def _obtain_model_inputs( # explicitly set the safety settings if safety_filter == "none": - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - } + SafetySetting() + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + ] elif safety_filter == "few": - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + ] elif safety_filter in ["default", "some"]: - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + ] elif safety_filter == "most": - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ] else: raise ValueError( f"safety_filter '{safety_filter}' not recognised. Must be one of: " @@ -255,15 +335,21 @@ async def _obtain_model_inputs( ) # get parameters dict (if any) - generation_config = prompt_dict.get("parameters", None) - if generation_config is None: - generation_config = {} - if type(generation_config) is not dict: + generation_config_params = prompt_dict.get("parameters", None) + if generation_config_params is None: + generation_config_params = {} + if type(generation_config_params) is not dict: raise TypeError( - f"parameters must be a dictionary, not {type(generation_config)}" + f"parameters must be a dictionary, not {type(generation_config_params)}" ) - return prompt, model_name, model, safety_settings, generation_config + gen_content_config = GenerateContentConfig( + **generation_config_params, + safety_settings=safety_settings, + system_instruction=system_instruction, + ) + + return prompt, model_name, client, gen_content_config, None async def _query_string(self, prompt_dict: dict, index: int | str): """ @@ -271,18 +357,17 @@ async def _query_string(self, prompt_dict: dict, index: int | str): (prompt_dict["prompt"] is a string), i.e. single-turn completion or chat. """ - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) try: - response = await model.generate_content_async( + response = await client.aio.models.generate_content( + model=model_name, contents=prompt, generation_config=generation_config, - safety_settings=safety_settings, - stream=False, ) response_text = process_response(response) safety_attributes = process_safety_attributes(response) @@ -352,24 +437,27 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): (prompt_dict["prompt"] is a list of strings to sequentially send to the model), i.e. multi-turn chat with history. """ - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) - chat = model.start_chat(history=[]) + # chat = client.start_chat(history=[]) + chat = client.aio.chats.create( + model=model_name, + config=generation_config, + history=[], + ) response_list = [] safety_attributes_list = [] try: for message_index, message in enumerate(prompt): # send the messages sequentially # run the predict method in a separate thread using run_in_executor - response = await chat.send_message_async( - content=message, - generation_config=generation_config, - safety_settings=safety_settings, - stream=False, + response = await chat.send_message( + message=message, + config=generation_config, ) response_text = process_response(response) safety_attributes = process_safety_attributes(response) @@ -455,13 +543,13 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: i.e. multi-turn chat with history. """ if prompt_dict["prompt"][0]["role"] == "system": - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, safety_settings, generation_config = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=prompt_dict["prompt"][0]["parts"], ) ) - chat = model.start_chat( + chat = client.start_chat( history=[ convert_dict_to_input( content_dict=x, media_folder=self.settings.media_folder @@ -470,12 +558,12 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: ] ) else: - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, safety_settings, generation_config = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) - chat = model.start_chat( + chat = client.start_chat( history=[ convert_dict_to_input( content_dict=x, media_folder=self.settings.media_folder diff --git a/tests/apis/gemini/test_gemini_chat_input.py b/tests/apis/gemini/test_gemini_chat_input.py index 6139c4f8..d16d05c5 100644 --- a/tests/apis/gemini/test_gemini_chat_input.py +++ b/tests/apis/gemini/test_gemini_chat_input.py @@ -4,7 +4,8 @@ import pytest # from google.generativeai import GenerativeModel -from google.genai.chats import AsyncChats, Chat +from google.genai.chats import AsyncChat, AsyncChats +from google.genai.types import GenerateContentConfig from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings @@ -40,7 +41,7 @@ async def test_gemini_query_chat_no_env_var( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -63,7 +64,7 @@ async def test_gemini_query_chat( # mock the response from the API # NOTE: The actual response from the API is a - # google.generativeai.types.AsyncGenerateContentResponse object + # google.genai.types.GenerateContentResponse object # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function gemini_api_sequence_responses = [ @@ -88,16 +89,20 @@ async def test_gemini_query_chat( assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) mock_gemini_call.assert_awaited_with( - content=prompt_dict_chat["prompt"][1], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][1], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) assert mock_process_response.call_count == 2 @@ -143,7 +148,7 @@ async def test_gemini_query_chat( ) async def test_gemini_query_history_check_chat_init( mock_obtain_model_inputs, - mock_start_chat, + mock_chat_create, prompt_dict_chat, temporary_data_folders, monkeypatch, @@ -158,14 +163,16 @@ async def test_gemini_query_history_check_chat_init( mock_obtain_model_inputs.return_value = ( prompt_dict_chat["prompt"], prompt_dict_chat["model_name"], - GenerativeModel( - model_name=prompt_dict_chat["model_name"], system_instruction=None + gemini_api._get_client("gemini_model_name"), + GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, ), - DEFAULT_SAFETY_SETTINGS, prompt_dict_chat["parameters"], ) - # error will be raised as we've mocked the start_chat method + # error will be raised as we've mocked the Chats.create method # which leads to an error when the method is called on the mocked object with pytest.raises(Exception): await gemini_api._query_chat(prompt_dict_chat, index=0) @@ -173,7 +180,15 @@ async def test_gemini_query_history_check_chat_init( mock_obtain_model_inputs.assert_called_once_with( prompt_dict=prompt_dict_chat, system_instruction=None ) - mock_start_chat.assert_called_once_with(history=[]) + mock_chat_create.assert_called_once_with( + model="gemini_model_name", + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), + history=[], + ) @pytest.mark.asyncio @@ -181,7 +196,7 @@ async def test_gemini_query_history_check_chat_init( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -208,11 +223,19 @@ async def test_gemini_query_chat_index_error_1( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_any_await( + # content=prompt_dict_chat["prompt"][0], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( @@ -233,7 +256,7 @@ async def test_gemini_query_chat_index_error_1( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -255,11 +278,19 @@ async def test_gemini_query_chat_error_1( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_any_await( + # content=prompt_dict_chat["prompt"][0], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( @@ -277,7 +308,7 @@ async def test_gemini_query_chat_error_1( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -320,17 +351,36 @@ async def test_gemini_query_chat_index_error_2( assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 + # mock_gemini_call.assert_any_await( + # content=prompt_dict_chat["prompt"][0], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) + + # mock_gemini_call.assert_awaited_with( + # content=prompt_dict_chat["prompt"][1], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_with( - content=prompt_dict_chat["prompt"][1], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][1], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) mock_process_response.assert_called_once_with(gemini_api_sequence_responses[0]) @@ -362,7 +412,7 @@ async def test_gemini_query_chat_index_error_2( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -399,17 +449,36 @@ async def test_gemini_query_chat_error_2( assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 + # mock_gemini_call.assert_any_await( + # content=prompt_dict_chat["prompt"][0], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) + + # mock_gemini_call.assert_awaited_with( + # content=prompt_dict_chat["prompt"][1], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_with( - content=prompt_dict_chat["prompt"][1], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][1], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) mock_process_response.assert_called_once_with(gemini_api_sequence_responses[0]) diff --git a/tests/apis/gemini/test_gemini_string_input.py b/tests/apis/gemini/test_gemini_string_input.py index 26f9a05a..6aa6b962 100644 --- a/tests/apis/gemini/test_gemini_string_input.py +++ b/tests/apis/gemini/test_gemini_string_input.py @@ -4,6 +4,7 @@ import pytest from google.genai.client import AsyncClient from google.genai.models import AsyncModels +from google.genai.types import GenerateContentConfig from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings @@ -84,10 +85,13 @@ async def test_gemini_query_string( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=prompt_dict_string["prompt"], - generation_config=prompt_dict_string["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + model="gemini_model_name", + contents="test prompt", + generation_config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) mock_process_response.assert_called_once_with(mock_gemini_call.return_value) @@ -138,10 +142,13 @@ async def test_gemini_query_string__index_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=prompt_dict_string["prompt"], - generation_config=prompt_dict_string["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + model="gemini_model_name", + contents="test prompt", + generation_config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( @@ -184,10 +191,13 @@ async def test_gemini_query_string_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=prompt_dict_string["prompt"], - generation_config=prompt_dict_string["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + model="gemini_model_name", + contents="test prompt", + generation_config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( From ef9664535f7513f060fd26d75da872a4978ac29f Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Thu, 24 Apr 2025 16:49:12 +0100 Subject: [PATCH 04/21] Update GeminiAPI class string & chat methods with new SDK --- src/prompto/apis/gemini/gemini.py | 282 ++++++++++++------ src/prompto/apis/gemini/gemini_utils.py | 46 ++- tests/apis/gemini/test_gemini.py | 4 - tests/apis/gemini/test_gemini_chat_input.py | 162 +++++++--- .../apis/gemini/test_gemini_history_input.py | 181 +++++++---- tests/apis/gemini/test_gemini_image_input.py | 47 ++- tests/apis/gemini/test_gemini_string_input.py | 34 ++- tests/apis/gemini/test_gemini_utils.py | 39 +++ 8 files changed, 560 insertions(+), 235 deletions(-) create mode 100644 tests/apis/gemini/test_gemini_utils.py diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index b345f5f6..4819713b 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -1,14 +1,23 @@ import logging from typing import Any -import google.generativeai as genai -from google.generativeai import GenerativeModel -from google.generativeai.types import GenerationConfig, HarmBlockThreshold, HarmCategory +from google.genai import Client + +# import google.generativeai as genai +# from google.generativeai import GenerativeModel +# from google.generativeai.types import GenerationConfig, HarmBlockThreshold, HarmCategory +from google.genai.types import ( + GenerateContentConfig, + HarmBlockThreshold, + HarmCategory, + SafetySetting, +) from prompto.apis.base import AsyncAPI from prompto.apis.gemini.gemini_utils import ( - convert_dict_to_input, + convert_history_dict_to_content, gemini_chat_roles, + parse_parts, process_response, process_safety_attributes, ) @@ -54,6 +63,8 @@ class GeminiAPI(AsyncAPI): The path to the log file """ + _clients: dict[str, Client] = {} + def __init__( self, settings: Settings, @@ -177,9 +188,37 @@ def check_prompt_dict(prompt_dict: dict) -> list[Exception]: return issues + def _get_client(self, model_name) -> Client: + """ + Method to get the client for the Gemini API. A separate client is created for each model name, to allow for + model-specific API keys to be used. + + The client is created only once per model name and stored in the clients dictionary. + If the client is already created, it is returned from the dictionary. + + Parameters + ---------- + model_name : str + The name of the model to use + + Returns + ------- + Client + A client for the Gemini API + """ + # If the Client does not exist, create it + if model_name not in self._clients: + api_key = get_environment_variable( + env_variable=API_KEY_VAR_NAME, model_name=model_name + ) + self._clients[model_name] = Client(api_key=api_key) + + # Return the client for the model name + return self._clients[model_name] + async def _obtain_model_inputs( self, prompt_dict: dict, system_instruction: str | None = None - ) -> tuple[str, str, GenerativeModel, dict, dict, list | None]: + ) -> tuple[str, str, Client, GenerateContentConfig, list | None]: """ Async method to obtain the model inputs from the prompt dictionary. @@ -193,26 +232,19 @@ async def _obtain_model_inputs( Returns ------- - tuple[str, str, dict, dict, list | None] - A tuple containing the prompt, model name, GenerativeModel instance, - safety settings, the generation config, and list of multimedia parts - (if passed) to use for querying the model + tuple[str, str, Client, GenerateContentConfig, list | None] + A tuple containing: + - the prompt, + - model name, + - Client instance, + - GenerateContentConfig instance (which incorporates the safety settings), + - (optional) list of multimedia parts (if passed) to use for querying the model or None """ prompt = prompt_dict["prompt"] # obtain model name model_name = prompt_dict["model_name"] - api_key = get_environment_variable( - env_variable=API_KEY_VAR_NAME, model_name=model_name - ) - - # configure the API key - genai.configure(api_key=api_key) - - # create the model instance - model = GenerativeModel( - model_name=model_name, system_instruction=system_instruction - ) + client = self._get_client(model_name) # define safety settings safety_filter = prompt_dict.get("safety_filter", None) @@ -221,33 +253,81 @@ async def _obtain_model_inputs( # explicitly set the safety settings if safety_filter == "none": - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_NONE, + ), + ] elif safety_filter == "few": - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_ONLY_HIGH, + ), + ] elif safety_filter in ["default", "some"]: - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + ] elif safety_filter == "most": - safety_settings = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - } + safety_settings = [ + SafetySetting( + category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + SafetySetting( + category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ), + ] else: raise ValueError( f"safety_filter '{safety_filter}' not recognised. Must be one of: " @@ -255,15 +335,21 @@ async def _obtain_model_inputs( ) # get parameters dict (if any) - generation_config = prompt_dict.get("parameters", None) - if generation_config is None: - generation_config = {} - if type(generation_config) is not dict: + generation_config_params = prompt_dict.get("parameters", None) + if generation_config_params is None: + generation_config_params = {} + if type(generation_config_params) is not dict: raise TypeError( - f"parameters must be a dictionary, not {type(generation_config)}" + f"parameters must be a dictionary, not {type(generation_config_params)}" ) - return prompt, model_name, model, safety_settings, generation_config + gen_content_config = GenerateContentConfig( + **generation_config_params, + safety_settings=safety_settings, + system_instruction=system_instruction, + ) + + return prompt, model_name, client, gen_content_config, None async def _query_string(self, prompt_dict: dict, index: int | str): """ @@ -271,18 +357,17 @@ async def _query_string(self, prompt_dict: dict, index: int | str): (prompt_dict["prompt"] is a string), i.e. single-turn completion or chat. """ - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) try: - response = await model.generate_content_async( + response = await client.aio.models.generate_content( + model=model_name, contents=prompt, - generation_config=generation_config, - safety_settings=safety_settings, - stream=False, + config=generation_config, ) response_text = process_response(response) safety_attributes = process_safety_attributes(response) @@ -352,24 +437,27 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): (prompt_dict["prompt"] is a list of strings to sequentially send to the model), i.e. multi-turn chat with history. """ - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) - chat = model.start_chat(history=[]) + # chat = client.start_chat(history=[]) + chat = client.aio.chats.create( + model=model_name, + config=generation_config, + history=[], + ) response_list = [] safety_attributes_list = [] try: for message_index, message in enumerate(prompt): # send the messages sequentially # run the predict method in a separate thread using run_in_executor - response = await chat.send_message_async( - content=message, - generation_config=generation_config, - safety_settings=safety_settings, - stream=False, + response = await chat.send_message( + message=message, + config=generation_config, ) response_text = process_response(response) safety_attributes = process_safety_attributes(response) @@ -455,43 +543,63 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: i.e. multi-turn chat with history. """ if prompt_dict["prompt"][0]["role"] == "system": - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=prompt_dict["prompt"][0]["parts"], ) ) - chat = model.start_chat( - history=[ - convert_dict_to_input( - content_dict=x, media_folder=self.settings.media_folder - ) - for x in prompt[1:-1] - ] - ) + # Used to skip the system message in the prompt history + first_user_idx = 1 else: - prompt, model_name, model, safety_settings, generation_config = ( + prompt, model_name, client, generation_config, _ = ( await self._obtain_model_inputs( prompt_dict=prompt_dict, system_instruction=None ) ) - chat = model.start_chat( - history=[ - convert_dict_to_input( - content_dict=x, media_folder=self.settings.media_folder - ) - for x in prompt[:-1] - ] - ) + first_user_idx = 0 + + chat = client.aio.chats.create( + model=model_name, + config=generation_config, + history=[ + convert_history_dict_to_content( + content_dict=x, + media_folder=self.settings.media_folder, + client=client, + ) + for x in prompt[first_user_idx:-1] + ], + ) try: - response = await chat.send_message_async( - content=convert_dict_to_input( - content_dict=prompt[-1], media_folder=self.settings.media_folder - ), - generation_config=generation_config, - safety_settings=safety_settings, - stream=False, + # No need to send the generation_config again, as it is no different + # from the one used to create the chat + last_msg = prompt[-1] + print(f"whole prompt: {prompt}") + print(f"last_msg: {last_msg}") + # msg_to_send = convert_dict_to_input( + # content_dict=prompt[-1], media_folder=self.settings.media_folder + # ) + + msg_to_send = parse_parts( + prompt[-1]["parts"], + media_folder=self.settings.media_folder, + client=client, + ) + + assert ( + len(msg_to_send) == 1 + ), "Only one message is allowed in the last message" + msg_to_send = msg_to_send[0] + + print(f"msg_to_send: {msg_to_send}") + + response = await chat.send_message( + # message=convert_dict_to_input( + # content_dict=prompt[-1], media_folder=self.settings.media_folder + # ), + message=msg_to_send ) response_text = process_response(response) diff --git a/src/prompto/apis/gemini/gemini_utils.py b/src/prompto/apis/gemini/gemini_utils.py index c9a3c17a..0acb102f 100644 --- a/src/prompto/apis/gemini/gemini_utils.py +++ b/src/prompto/apis/gemini/gemini_utils.py @@ -1,12 +1,13 @@ import os -import google.generativeai as genai import PIL.Image +from google.genai import types +from google.genai.client import Client gemini_chat_roles = set(["user", "model"]) -def parse_parts_value(part: dict | str, media_folder: str) -> any: +def parse_parts_value(part: dict | str, media_folder: str, client: Client) -> any: """ Parse part dictionary and create a dictionary input for Gemini API. If part is a string, a dictionary to represent a text object is returned. @@ -27,7 +28,9 @@ def parse_parts_value(part: dict | str, media_folder: str) -> any: Multimedia data object """ if isinstance(part, str): - return part + # return part + print(f"Part is a string: {part}") + return types.Part.from_text(text=part) # read multimedia type media_type = part.get("type") @@ -54,14 +57,17 @@ def parse_parts_value(part: dict | str, media_folder: str) -> any: ) else: try: - return genai.get_file(name=uploaded_filename) + # return genai.get_file(name=uploaded_filename) + return client.aio.files.get(name=uploaded_filename) except Exception as err: raise ValueError( f"Failed to get file: {media} due to error: {type(err).__name__} - {err}" ) -def parse_parts(parts: list[dict | str] | dict | str, media_folder: str) -> list[any]: +def parse_parts( + parts: list[dict | str] | dict | str, media_folder: str, client: Client +) -> list[any]: """ Parse parts data and create a list of multimedia data objects. If parts is a single dictionary, a list with a single multimedia data object is returned. @@ -83,10 +89,14 @@ def parse_parts(parts: list[dict | str] | dict | str, media_folder: str) -> list if isinstance(parts, dict) or isinstance(parts, str): parts = [parts] - return [parse_parts_value(p, media_folder=media_folder) for p in parts] + return [ + parse_parts_value(p, media_folder=media_folder, client=client) for p in parts + ] -def convert_dict_to_input(content_dict: dict, media_folder: str) -> dict: +def convert_history_dict_to_content( + content_dict: dict, media_folder: str, client: Client +) -> types.Content: """ Convert dictionary to an input that can be used by the Gemini API. The output is a dictionary with keys "role" and "parts". @@ -112,13 +122,27 @@ def convert_dict_to_input(content_dict: dict, media_folder: str) -> dict: if "parts" not in content_dict: raise KeyError("parts key is missing in content dictionary") - return { - "role": content_dict["role"], - "parts": parse_parts( + # return parse_parts( + # content_dict["parts"], + # media_folder=media_folder, + # ) + + return types.Content( + role=content_dict["role"], + parts=parse_parts( content_dict["parts"], media_folder=media_folder, + client=client, ), - } + ) + + # return { + # "role": content_dict["role"], + # "parts": parse_parts( + # content_dict["parts"], + # media_folder=media_folder, + # ) + # } def process_response(response: dict) -> str: diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index 65e5f047..3d5454ac 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -410,10 +410,6 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): assert len(test_case) == 5 assert test_case[0] == "test prompt" assert test_case[1] == "gemini_model_name" - # TODO: For now assume that the most sensible thing for the `_obtain_model_inputs` tp return - # here is the AsyncClient instance. It may be that retuning nothing is the sensible thing to do. - # in which case we should update `assert len(test_case) == 4` and update the indexes. - # assert isinstance(test_case[2], GenerativeModel) assert isinstance(test_case[2], Client) assert isinstance(test_case[2].aio, AsyncClient) assert isinstance(test_case[3], GenerateContentConfig) diff --git a/tests/apis/gemini/test_gemini_chat_input.py b/tests/apis/gemini/test_gemini_chat_input.py index 6139c4f8..76664da7 100644 --- a/tests/apis/gemini/test_gemini_chat_input.py +++ b/tests/apis/gemini/test_gemini_chat_input.py @@ -4,7 +4,8 @@ import pytest # from google.generativeai import GenerativeModel -from google.genai.chats import AsyncChats, Chat +from google.genai.chats import AsyncChat, AsyncChats +from google.genai.types import GenerateContentConfig from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings @@ -15,6 +16,9 @@ pytest_plugins = ("pytest_asyncio",) +# TODO: FIX THIS. This test passes when executed alone, but fails when executed with all tests +# This is probably due to the environment variable being monkeypatched somewhere without being +# reset / properly scoped. @pytest.mark.asyncio async def test_gemini_query_chat_no_env_var( prompt_dict_chat, temporary_data_folders, caplog @@ -40,7 +44,7 @@ async def test_gemini_query_chat_no_env_var( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -63,7 +67,7 @@ async def test_gemini_query_chat( # mock the response from the API # NOTE: The actual response from the API is a - # google.generativeai.types.AsyncGenerateContentResponse object + # google.genai.types.GenerateContentResponse object # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function gemini_api_sequence_responses = [ @@ -88,16 +92,20 @@ async def test_gemini_query_chat( assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) mock_gemini_call.assert_awaited_with( - content=prompt_dict_chat["prompt"][1], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][1], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) assert mock_process_response.call_count == 2 @@ -143,7 +151,7 @@ async def test_gemini_query_chat( ) async def test_gemini_query_history_check_chat_init( mock_obtain_model_inputs, - mock_start_chat, + mock_chat_create, prompt_dict_chat, temporary_data_folders, monkeypatch, @@ -158,14 +166,16 @@ async def test_gemini_query_history_check_chat_init( mock_obtain_model_inputs.return_value = ( prompt_dict_chat["prompt"], prompt_dict_chat["model_name"], - GenerativeModel( - model_name=prompt_dict_chat["model_name"], system_instruction=None + gemini_api._get_client("gemini_model_name"), + GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, ), - DEFAULT_SAFETY_SETTINGS, prompt_dict_chat["parameters"], ) - # error will be raised as we've mocked the start_chat method + # error will be raised as we've mocked the Chats.create method # which leads to an error when the method is called on the mocked object with pytest.raises(Exception): await gemini_api._query_chat(prompt_dict_chat, index=0) @@ -173,7 +183,15 @@ async def test_gemini_query_history_check_chat_init( mock_obtain_model_inputs.assert_called_once_with( prompt_dict=prompt_dict_chat, system_instruction=None ) - mock_start_chat.assert_called_once_with(history=[]) + mock_chat_create.assert_called_once_with( + model="gemini_model_name", + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), + history=[], + ) @pytest.mark.asyncio @@ -181,7 +199,7 @@ async def test_gemini_query_history_check_chat_init( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -208,11 +226,19 @@ async def test_gemini_query_chat_index_error_1( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_any_await( + # content=prompt_dict_chat["prompt"][0], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( @@ -233,7 +259,7 @@ async def test_gemini_query_chat_index_error_1( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -255,11 +281,19 @@ async def test_gemini_query_chat_error_1( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_any_await( + # content=prompt_dict_chat["prompt"][0], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( @@ -277,7 +311,7 @@ async def test_gemini_query_chat_error_1( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -320,17 +354,36 @@ async def test_gemini_query_chat_index_error_2( assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 + # mock_gemini_call.assert_any_await( + # content=prompt_dict_chat["prompt"][0], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) + + # mock_gemini_call.assert_awaited_with( + # content=prompt_dict_chat["prompt"][1], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_with( - content=prompt_dict_chat["prompt"][1], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][1], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) mock_process_response.assert_called_once_with(gemini_api_sequence_responses[0]) @@ -362,7 +415,7 @@ async def test_gemini_query_chat_index_error_2( # "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock # ) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=CopyingAsyncMock, ) @@ -399,17 +452,36 @@ async def test_gemini_query_chat_error_2( assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 + # mock_gemini_call.assert_any_await( + # content=prompt_dict_chat["prompt"][0], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_any_await( - content=prompt_dict_chat["prompt"][0], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][0], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) + + # mock_gemini_call.assert_awaited_with( + # content=prompt_dict_chat["prompt"][1], + # generation_config=prompt_dict_chat["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_with( - content=prompt_dict_chat["prompt"][1], - generation_config=prompt_dict_chat["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=prompt_dict_chat["prompt"][1], + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) mock_process_response.assert_called_once_with(gemini_api_sequence_responses[0]) diff --git a/tests/apis/gemini/test_gemini_history_input.py b/tests/apis/gemini/test_gemini_history_input.py index f555e2a7..748af1d2 100644 --- a/tests/apis/gemini/test_gemini_history_input.py +++ b/tests/apis/gemini/test_gemini_history_input.py @@ -4,7 +4,8 @@ import pytest # from google.generativeai import GenerativeModel -from google.genai.chats import AsyncChats, Chat +from google.genai.chats import AsyncChat, AsyncChats, Chat +from google.genai.types import Content, GenerateContentConfig, Part from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings @@ -39,7 +40,7 @@ async def test_gemini_query_history_no_env_var( @pytest.mark.asyncio # @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=AsyncMock, ) @@ -81,11 +82,15 @@ async def test_gemini_query_history( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_awaited_once_with( + # content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, + # generation_config=prompt_dict_history["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_once_with( - content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, - generation_config=prompt_dict_history["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history["prompt"][1]["parts"]), ) mock_process_response.assert_called_once_with(mock_gemini_call.return_value) @@ -106,7 +111,7 @@ async def test_gemini_query_history( @pytest.mark.asyncio # @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=AsyncMock, ) @@ -128,11 +133,15 @@ async def test_gemini_query_history_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_awaited_once_with( + # content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, + # generation_config=prompt_dict_history["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_once_with( - content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, - generation_config=prompt_dict_history["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history["prompt"][1]["parts"]), ) expected_log_message = ( @@ -147,7 +156,7 @@ async def test_gemini_query_history_error( @pytest.mark.asyncio # @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=AsyncMock, ) @@ -174,11 +183,15 @@ async def test_gemini_query_history_index_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_awaited_once_with( + # content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, + # generation_config=prompt_dict_history["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_once_with( - content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, - generation_config=prompt_dict_history["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history["prompt"][1]["parts"]), ) expected_log_message = ( @@ -217,14 +230,30 @@ async def test_gemini_query_history_check_chat_init( monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) + # mock_obtain_model_inputs.return_value = ( + # prompt_dict_history["prompt"], + # prompt_dict_history["model_name"], + # GenerativeModel( + # model_name=prompt_dict_history["model_name"], + # system_instruction=prompt_dict_history["prompt"][0]["parts"], + # ), + # DEFAULT_SAFETY_SETTINGS, + # prompt_dict_history["parameters"], + # ) + + mock_generate_content_config = ( + GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), + ) + mock_obtain_model_inputs.return_value = ( prompt_dict_history["prompt"], prompt_dict_history["model_name"], - GenerativeModel( - model_name=prompt_dict_history["model_name"], - system_instruction=prompt_dict_history["prompt"][0]["parts"], - ), - DEFAULT_SAFETY_SETTINGS, + gemini_api._get_client("gemini_model_name"), + mock_generate_content_config, prompt_dict_history["parameters"], ) @@ -237,13 +266,15 @@ async def test_gemini_query_history_check_chat_init( prompt_dict=prompt_dict_history, system_instruction=prompt_dict_history["prompt"][0]["parts"], ) - mock_start_chat.assert_called_once_with(history=[]) + mock_start_chat.assert_called_once_with( + model="gemini_model_name", config=mock_generate_content_config, history=[] + ) @pytest.mark.asyncio # @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=AsyncMock, ) @@ -286,14 +317,18 @@ async def test_gemini_query_history_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_awaited_once_with( + # content={ + # "role": "user", + # "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], + # }, + # generation_config=prompt_dict_history_no_system["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_once_with( - content={ - "role": "user", - "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], - }, - generation_config=prompt_dict_history_no_system["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]) ) mock_process_response.assert_called_once_with(mock_gemini_call.return_value) @@ -314,7 +349,7 @@ async def test_gemini_query_history_no_system( @pytest.mark.asyncio # @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=AsyncMock, ) @@ -340,14 +375,18 @@ async def test_gemini_query_history_error_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_awaited_once_with( + # content={ + # "role": "user", + # "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], + # }, + # generation_config=prompt_dict_history_no_system["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_once_with( - content={ - "role": "user", - "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], - }, - generation_config=prompt_dict_history_no_system["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]), ) expected_log_message = ( @@ -362,7 +401,7 @@ async def test_gemini_query_history_error_no_system( @pytest.mark.asyncio # @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( - Chat, + AsyncChat, "send_message", new_callable=AsyncMock, ) @@ -395,14 +434,18 @@ async def test_gemini_query_history_index_error_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() + # mock_gemini_call.assert_awaited_once_with( + # content={ + # "role": "user", + # "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], + # }, + # generation_config=prompt_dict_history_no_system["parameters"], + # safety_settings=DEFAULT_SAFETY_SETTINGS, + # stream=False, + # ) + mock_gemini_call.assert_awaited_once_with( - content={ - "role": "user", - "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], - }, - generation_config=prompt_dict_history_no_system["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]), ) expected_log_message = ( @@ -441,14 +484,30 @@ async def test_gemini_query_history_no_system_check_chat_init( monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) + # mock_obtain_model_inputs.return_value = ( + # prompt_dict_history_no_system["prompt"], + # prompt_dict_history_no_system["model_name"], + # GenerativeModel( + # model_name=prompt_dict_history_no_system["model_name"], + # system_instruction=None, + # ), + # DEFAULT_SAFETY_SETTINGS, + # prompt_dict_history_no_system["parameters"], + # ) + + mock_generate_content_config = ( + GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), + ) + mock_obtain_model_inputs.return_value = ( prompt_dict_history_no_system["prompt"], prompt_dict_history_no_system["model_name"], - GenerativeModel( - model_name=prompt_dict_history_no_system["model_name"], - system_instruction=None, - ), - DEFAULT_SAFETY_SETTINGS, + gemini_api._get_client("gemini_model_name"), + mock_generate_content_config, prompt_dict_history_no_system["parameters"], ) @@ -461,14 +520,16 @@ async def test_gemini_query_history_no_system_check_chat_init( prompt_dict=prompt_dict_history_no_system, system_instruction=None ) mock_start_chat.assert_called_once_with( + model="gemini_model_name", + config=mock_generate_content_config, history=[ - { - "role": "user", - "parts": [prompt_dict_history_no_system["prompt"][0]["parts"]], - }, - { - "role": "model", - "parts": [prompt_dict_history_no_system["prompt"][1]["parts"]], - }, - ] + Content( + role="user", + parts=[Part(text=prompt_dict_history_no_system["prompt"][0]["parts"])], + ), + Content( + role="model", + parts=[Part(text=prompt_dict_history_no_system["prompt"][1]["parts"])], + ), + ], ) diff --git a/tests/apis/gemini/test_gemini_image_input.py b/tests/apis/gemini/test_gemini_image_input.py index 989f5cbc..e7c0adec 100644 --- a/tests/apis/gemini/test_gemini_image_input.py +++ b/tests/apis/gemini/test_gemini_image_input.py @@ -2,6 +2,8 @@ from unittest.mock import patch import pytest +from google.genai import Client +from google.genai.types import Part from PIL import Image from prompto.apis.gemini.gemini_utils import parse_parts_value @@ -10,15 +12,27 @@ def test_parse_parts_value_text(): part = "text" media_folder = "media" - result = parse_parts_value(part, media_folder) - assert result == part + # There is no uploaded media in the prompt_dict_chat, hence the + # client is not required, and we can pass `None`. + mock_client = None + actual_result = parse_parts_value(part, media_folder, mock_client) + expected_result = Part(text="text") + assert actual_result == expected_result def test_parse_parts_value_image(): + # This is a string, which happens to use a keyword "image", + # but it is not a key within dictionary. + # This test simply asserts that the string is handled correctly + # as a string. part = "image" media_folder = "media" - result = parse_parts_value(part, media_folder) - assert result == part + # There is no uploaded media in the prompt_dict_chat, hence the + # client is not required, and we can pass `None`. + mock_client = None + actual_result = parse_parts_value(part, media_folder, mock_client) + expected_result = Part(text="image") + assert actual_result == expected_result def test_parse_parts_value_image_dict(tmp_path): @@ -28,6 +42,7 @@ def test_parse_parts_value_image_dict(tmp_path): # The image should be loaded from the local file. part = {"type": "image", "media": "pantani_giro.jpg"} media_folder = tmp_path / "media" + mock_client = None media_folder.mkdir(parents=True, exist_ok=True) @@ -36,7 +51,7 @@ def test_parse_parts_value_image_dict(tmp_path): image = Image.new("RGB", (100, 100), color="red") image.save(image_path) - actual_result = parse_parts_value(part, str(media_folder)) + actual_result = parse_parts_value(part, str(media_folder), mock_client) # Assert the result assert actual_result.mode == "RGB" @@ -47,10 +62,11 @@ def test_parse_parts_value_image_dict(tmp_path): def test_parse_parts_value_video_not_uploaded(): part = {"type": "video", "media": "pantani_giro.mp4"} media_folder = "media" + mock_client = None # Because the video is not uploaded, we expect a ValueError with pytest.raises(ValueError) as excinfo: - parse_parts_value(part, media_folder) + parse_parts_value(part, media_folder, mock_client) print(excinfo) assert "not uploaded" in str(excinfo.value) @@ -73,18 +89,17 @@ def test_parse_parts_value_video_uploaded(monkeypatch): # Instead, we will just return the uploaded_filename mock_get_file_no_op = lambda name: name - # Replace the original get_file function with the mock - # ***It is important that the import statement used here is exactly the same as - # the one in the gemini_utils.py file*** - import google.genai as genai - + # Replace the original `get` function with the mock with monkeypatch.context() as m: - # Mock the get_file function - client = genai.Client(api_key="DUMMY") - m.setattr(client.files, "get", mock_get_file_no_op) + # Mock the get function + client = Client(api_key="DUMMY") + m.setattr(client.aio.files, "get", mock_get_file_no_op) # Assert that the mock function was called with the expected argument - assert client.files.get(name="check mocked function") == "check mocked function" + assert ( + client.aio.files.get(name="check mocked function") + == "check mocked function" + ) expected_result = "file/123456" - actual_result = parse_parts_value(part, media_folder) + actual_result = parse_parts_value(part, media_folder, client) assert actual_result == expected_result diff --git a/tests/apis/gemini/test_gemini_string_input.py b/tests/apis/gemini/test_gemini_string_input.py index 26f9a05a..98660b32 100644 --- a/tests/apis/gemini/test_gemini_string_input.py +++ b/tests/apis/gemini/test_gemini_string_input.py @@ -4,6 +4,7 @@ import pytest from google.genai.client import AsyncClient from google.genai.models import AsyncModels +from google.genai.types import GenerateContentConfig from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings @@ -84,10 +85,13 @@ async def test_gemini_query_string( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=prompt_dict_string["prompt"], - generation_config=prompt_dict_string["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + model="gemini_model_name", + contents="test prompt", + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) mock_process_response.assert_called_once_with(mock_gemini_call.return_value) @@ -138,10 +142,13 @@ async def test_gemini_query_string__index_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=prompt_dict_string["prompt"], - generation_config=prompt_dict_string["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + model="gemini_model_name", + contents="test prompt", + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( @@ -184,10 +191,13 @@ async def test_gemini_query_string_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() mock_gemini_call.assert_awaited_once_with( - contents=prompt_dict_string["prompt"], - generation_config=prompt_dict_string["parameters"], - safety_settings=DEFAULT_SAFETY_SETTINGS, - stream=False, + model="gemini_model_name", + contents="test prompt", + config=GenerateContentConfig( + temperature=1.0, + max_output_tokens=100, + safety_settings=DEFAULT_SAFETY_SETTINGS, + ), ) expected_log_message = ( diff --git a/tests/apis/gemini/test_gemini_utils.py b/tests/apis/gemini/test_gemini_utils.py new file mode 100644 index 00000000..b453c623 --- /dev/null +++ b/tests/apis/gemini/test_gemini_utils.py @@ -0,0 +1,39 @@ +import pytest +from google.genai.types import Content, Part + +from prompto.apis.gemini.gemini_utils import convert_history_dict_to_content + +from .test_gemini import prompt_dict_chat, prompt_dict_history + + +@pytest.mark.xfail(reason="Test not implemented") +def test_process_response(): + pytest.fail("Test not implemented") + + +@pytest.mark.xfail(reason="Test not implemented") +def test_process_safety_attributes(): + pytest.fail("Test not implemented") + + +def test_convert_history_dict_to_content(prompt_dict_history): + + media_folder = "media_folder" + mock_client = None + + expected_result_list = [ + Content(parts=[Part(text="test system prompt")], role="system"), + Content(parts=[Part(text="user message")], role="user"), + ] + + # There is no uploaded media in the prompt_dict_chat, hence the + # client is not required, and we can pass `None`. + prompt_list = prompt_dict_history["prompt"] + + for content_dict, expected_result in zip(prompt_list, expected_result_list): + actual_result = convert_history_dict_to_content( + content_dict, media_folder, mock_client + ) + assert ( + actual_result == expected_result + ), f"Expected {expected_result}, but got {actual_result}" From 32cf50de59725a8508a5882b79670cfc7d44b871 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Mon, 28 Apr 2025 16:33:25 +0100 Subject: [PATCH 05/21] Update gemini_media.py --- src/prompto/apis/gemini/gemini_media.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prompto/apis/gemini/gemini_media.py b/src/prompto/apis/gemini/gemini_media.py index 3ebecb6f..44971563 100644 --- a/src/prompto/apis/gemini/gemini_media.py +++ b/src/prompto/apis/gemini/gemini_media.py @@ -5,7 +5,7 @@ import os import time -import google.generativeai as genai +import google.genai import tqdm from dotenv import load_dotenv From c80273f6267500a263f3685a6ec9b837bd420910 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Tue, 29 Apr 2025 15:29:35 +0100 Subject: [PATCH 06/21] fix bug in '_get_client' method --- src/prompto/apis/gemini/gemini.py | 16 +- tests/apis/gemini/test_gemini.py | 152 +++++++++--------- tests/apis/gemini/test_gemini_chat_input.py | 46 +++++- tests/apis/gemini/test_gemini_string_input.py | 16 +- 4 files changed, 145 insertions(+), 85 deletions(-) diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index 4819713b..9163f116 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -1,4 +1,5 @@ import logging +import os from typing import Any from google.genai import Client @@ -63,8 +64,6 @@ class GeminiAPI(AsyncAPI): The path to the log file """ - _clients: dict[str, Client] = {} - def __init__( self, settings: Settings, @@ -73,6 +72,7 @@ def __init__( **kwargs: Any, ): super().__init__(settings=settings, log_file=log_file, *args, **kwargs) + self._clients: dict[str, Client] = {} @staticmethod def check_environment_variables() -> list[Exception]: @@ -207,12 +207,24 @@ def _get_client(self, model_name) -> Client: A client for the Gemini API """ # If the Client does not exist, create it + # already_created = True + # api_key = "NO VALUE HAS BEEN SET YET" + # print(f"GeminiAPI: {model_name=}") + # print(f"GeminiAPI: {self._clients=}") + # print(f"GeminiAPI: {self._clients.get(model_name, "not found")=}") if model_name not in self._clients: + # already_created = False api_key = get_environment_variable( env_variable=API_KEY_VAR_NAME, model_name=model_name ) + # print(f"Creating client for {model_name} with {api_key=}") self._clients[model_name] = Client(api_key=api_key) + # print(f"Client for {model_name} already created: {already_created}") + # print(f"{api_key=}") + # for env_var_name, env_var_val in os.environ.items(): + # print(f"{env_var_name}={env_var_val}") + # Return the client for the model name return self._clients[model_name] diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index 3d5454ac..80206a5b 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -278,93 +278,97 @@ def test_gemini_check_prompt_dict(temporary_data_folders, monkeypatch): raise test_case[0] # set the GEMINI_API_KEY environment variable - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") - # error if the model-specific environment variable is not set - test_case = GeminiAPI.check_prompt_dict( - { - "api": "gemini", - "model_name": "gemini_model_name", - "prompt": "test prompt", - } - ) - assert len(test_case) == 1 - with pytest.raises( - Warning, - match=re.escape( - "Environment variable 'GEMINI_API_KEY_gemini_model_name' is not set" - ), - ): - raise test_case[0] - - # unset the GEMINI_API_KEY environment variable and - # set the model-specific environment variable - monkeypatch.delenv("GEMINI_API_KEY") - monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") - test_case = GeminiAPI.check_prompt_dict( - { - "api": "gemini", - "model_name": "gemini_model_name", - "prompt": "test prompt", - } - ) - assert len(test_case) == 1 - with pytest.raises( - Warning, match=re.escape("Environment variable 'GEMINI_API_KEY' is not set") - ): - raise test_case[0] - - # full passes - # set both environment variables - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") - assert ( - GeminiAPI.check_prompt_dict( + with monkeypatch.context() as m1: + m1.setenv("GEMINI_API_KEY", "DUMMY") + # error if the model-specific environment variable is not set + test_case = GeminiAPI.check_prompt_dict( { "api": "gemini", "model_name": "gemini_model_name", "prompt": "test prompt", } ) - == [] - ) - assert ( - GeminiAPI.check_prompt_dict( + assert len(test_case) == 1 + with pytest.raises( + Warning, + match=re.escape( + "Environment variable 'GEMINI_API_KEY_gemini_model_name' is not set" + ), + ): + raise test_case[0] + + with monkeypatch.context() as m2: + # unset the GEMINI_API_KEY environment variable and + # set the model-specific environment variable + m2.delenv("GEMINI_API_KEY", raising=False) + m2.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") + test_case = GeminiAPI.check_prompt_dict( { "api": "gemini", "model_name": "gemini_model_name", - "prompt": ["prompt 1", "prompt 2"], + "prompt": "test prompt", } ) - == [] - ) - assert ( - GeminiAPI.check_prompt_dict( - { - "api": "gemini", - "model_name": "gemini_model_name", - "prompt": [ - {"role": "system", "parts": "system prompt"}, - {"role": "user", "parts": "user message 1"}, - {"role": "model", "parts": "model message"}, - {"role": "user", "parts": "user message 2"}, - ], - } + assert len(test_case) == 1 + with pytest.raises( + Warning, match=re.escape("Environment variable 'GEMINI_API_KEY' is not set") + ): + raise test_case[0] + + # full passes + # set both environment variables + with monkeypatch.context() as m3: + m3.setenv("GEMINI_API_KEY", "DUMMY") + m3.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") + assert ( + GeminiAPI.check_prompt_dict( + { + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": "test prompt", + } + ) + == [] ) - == [] - ) - assert ( - GeminiAPI.check_prompt_dict( - { - "api": "gemini", - "model_name": "gemini_model_name", - "prompt": [ - {"role": "user", "parts": "user message 1"}, - {"role": "model", "parts": "model message"}, - {"role": "user", "parts": "user message 2"}, - ], - } + assert ( + GeminiAPI.check_prompt_dict( + { + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": ["prompt 1", "prompt 2"], + } + ) + == [] + ) + assert ( + GeminiAPI.check_prompt_dict( + { + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": [ + {"role": "system", "parts": "system prompt"}, + {"role": "user", "parts": "user message 1"}, + {"role": "model", "parts": "model message"}, + {"role": "user", "parts": "user message 2"}, + ], + } + ) + == [] + ) + assert ( + GeminiAPI.check_prompt_dict( + { + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": [ + {"role": "user", "parts": "user message 1"}, + {"role": "model", "parts": "model message"}, + {"role": "user", "parts": "user message 2"}, + ], + } + ) + == [] ) - == [] - ) @pytest.mark.asyncio diff --git a/tests/apis/gemini/test_gemini_chat_input.py b/tests/apis/gemini/test_gemini_chat_input.py index 76664da7..e339d8b0 100644 --- a/tests/apis/gemini/test_gemini_chat_input.py +++ b/tests/apis/gemini/test_gemini_chat_input.py @@ -1,4 +1,6 @@ import logging +from copy import deepcopy +from importlib import reload from unittest.mock import AsyncMock, Mock, patch import pytest @@ -7,7 +9,10 @@ from google.genai.chats import AsyncChat, AsyncChats from google.genai.types import GenerateContentConfig +import prompto.utils from prompto.apis.gemini import GeminiAPI + +# import prompto.apis.gemini as prompto_gemini from prompto.settings import Settings from ...conftest import CopyingAsyncMock @@ -21,8 +26,31 @@ # reset / properly scoped. @pytest.mark.asyncio async def test_gemini_query_chat_no_env_var( - prompt_dict_chat, temporary_data_folders, caplog + prompt_dict_chat, temporary_data_folders, caplog, monkeypatch ): + # with monkeypatch.context() as m: + # reload(prompto.utils.os) + # import os + # # reload(prompto.apis.gemini.gemini.os) + # # reload(prompto.apis.gemini.gemini_utils.os) + + # if "GEMINI_API_KEY" in os.environ: + # m.delenv("GEMINI_API_KEY", raising=False) + # m.delitem(os.environ, "GEMINI_API_KEY") + # print("GEMINI_API_KEY deleted from os.environ with monkeypatch") + # else: + # print("GEMINI_API_KEY not in os.environ") + + # if "GEMINI_API_KEY_gemini_model_name" in os.environ: + # m.delenv("GEMINI_API_KEY_gemini_model_name", raising=False) + # m.delitem(os.environ, "GEMINI_API_KEY_gemini_model_name") + # print("GEMINI_API_KEY_gemini_model_name deleted from os.environ with monkeypatch") + # else: + # print("GEMINI_API_KEY_gemini_model_name not in os.environ") + + # monkeypatch.delenv("GEMINI_API_KEY", raising=False) + # monkeypatch.delenv("GEMINI_API_KEY_gemini_model_name", raising=False) + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" @@ -59,10 +87,10 @@ async def test_gemini_query_chat( monkeypatch, caplog, ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) # mock the response from the API @@ -157,10 +185,11 @@ async def test_gemini_query_history_check_chat_init( monkeypatch, caplog, ): + monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) mock_obtain_model_inputs.return_value = ( @@ -206,10 +235,11 @@ async def test_gemini_query_history_check_chat_init( async def test_gemini_query_chat_index_error_1( mock_gemini_call, prompt_dict_chat, temporary_data_folders, monkeypatch, caplog ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) # mock index error response from the API @@ -266,10 +296,10 @@ async def test_gemini_query_chat_index_error_1( async def test_gemini_query_chat_error_1( mock_gemini_call, prompt_dict_chat, temporary_data_folders, monkeypatch, caplog ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) # mock error response from the API @@ -326,10 +356,11 @@ async def test_gemini_query_chat_index_error_2( monkeypatch, caplog, ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) # mock error response from the API from second response @@ -430,10 +461,11 @@ async def test_gemini_query_chat_error_2( monkeypatch, caplog, ): + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) # mock error response from the API from second response diff --git a/tests/apis/gemini/test_gemini_string_input.py b/tests/apis/gemini/test_gemini_string_input.py index 98660b32..82aeedb1 100644 --- a/tests/apis/gemini/test_gemini_string_input.py +++ b/tests/apis/gemini/test_gemini_string_input.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from google.genai.chats import AsyncChat from google.genai.client import AsyncClient from google.genai.models import AsyncModels from google.genai.types import GenerateContentConfig @@ -15,13 +16,23 @@ @pytest.mark.asyncio +# @patch( +# AsyncChat, +# "send_message", +# new_callable=AsyncMock, +# ) async def test_gemini_query_string_no_env_var( - prompt_dict_string, temporary_data_folders, caplog + prompt_dict_string, temporary_data_folders, caplog, monkeypatch ): + + # monkeypatch.delenv("GEMINI_API_KEY", raising=False) + # monkeypatch.delenv("GEMINI_API_KEY_gemini_model_name", raising=False) + caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - gemini_api = GeminiAPI(settings=settings, log_file=log_file) + + # mock_send_message.return_value = "response Messages object" # raise error if no environment variable is set with pytest.raises( @@ -31,6 +42,7 @@ async def test_gemini_query_string_no_env_var( "environment variable is set." ), ): + gemini_api = GeminiAPI(settings=settings, log_file=log_file) await gemini_api._query_string(prompt_dict_string, index=0) From 60ec15be40c6135588d8b93af87e2e8085886f31 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Tue, 29 Apr 2025 17:24:53 +0100 Subject: [PATCH 07/21] fix bug in '_get_client' method --- src/prompto/apis/gemini/gemini.py | 12 --- tests/apis/gemini/test_gemini.py | 4 - tests/apis/gemini/test_gemini_chat_input.py | 85 ------------------- tests/apis/gemini/test_gemini_string_input.py | 23 ----- 4 files changed, 124 deletions(-) diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index 9163f116..6506c003 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -207,24 +207,12 @@ def _get_client(self, model_name) -> Client: A client for the Gemini API """ # If the Client does not exist, create it - # already_created = True - # api_key = "NO VALUE HAS BEEN SET YET" - # print(f"GeminiAPI: {model_name=}") - # print(f"GeminiAPI: {self._clients=}") - # print(f"GeminiAPI: {self._clients.get(model_name, "not found")=}") if model_name not in self._clients: - # already_created = False api_key = get_environment_variable( env_variable=API_KEY_VAR_NAME, model_name=model_name ) - # print(f"Creating client for {model_name} with {api_key=}") self._clients[model_name] = Client(api_key=api_key) - # print(f"Client for {model_name} already created: {already_created}") - # print(f"{api_key=}") - # for env_var_name, env_var_val in os.environ.items(): - # print(f"{env_var_name}={env_var_val}") - # Return the client for the model name return self._clients[model_name] diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index 80206a5b..8117ee1a 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -13,10 +13,6 @@ from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings -# from google.generativeai import GenerativeModel -# from google.generativeai.types import HarmBlockThreshold, HarmCategory - - pytest_plugins = ("pytest_asyncio",) diff --git a/tests/apis/gemini/test_gemini_chat_input.py b/tests/apis/gemini/test_gemini_chat_input.py index e339d8b0..7c9a363c 100644 --- a/tests/apis/gemini/test_gemini_chat_input.py +++ b/tests/apis/gemini/test_gemini_chat_input.py @@ -1,18 +1,12 @@ import logging -from copy import deepcopy -from importlib import reload from unittest.mock import AsyncMock, Mock, patch import pytest - -# from google.generativeai import GenerativeModel from google.genai.chats import AsyncChat, AsyncChats from google.genai.types import GenerateContentConfig import prompto.utils from prompto.apis.gemini import GeminiAPI - -# import prompto.apis.gemini as prompto_gemini from prompto.settings import Settings from ...conftest import CopyingAsyncMock @@ -21,35 +15,10 @@ pytest_plugins = ("pytest_asyncio",) -# TODO: FIX THIS. This test passes when executed alone, but fails when executed with all tests -# This is probably due to the environment variable being monkeypatched somewhere without being -# reset / properly scoped. @pytest.mark.asyncio async def test_gemini_query_chat_no_env_var( prompt_dict_chat, temporary_data_folders, caplog, monkeypatch ): - # with monkeypatch.context() as m: - # reload(prompto.utils.os) - # import os - # # reload(prompto.apis.gemini.gemini.os) - # # reload(prompto.apis.gemini.gemini_utils.os) - - # if "GEMINI_API_KEY" in os.environ: - # m.delenv("GEMINI_API_KEY", raising=False) - # m.delitem(os.environ, "GEMINI_API_KEY") - # print("GEMINI_API_KEY deleted from os.environ with monkeypatch") - # else: - # print("GEMINI_API_KEY not in os.environ") - - # if "GEMINI_API_KEY_gemini_model_name" in os.environ: - # m.delenv("GEMINI_API_KEY_gemini_model_name", raising=False) - # m.delitem(os.environ, "GEMINI_API_KEY_gemini_model_name") - # print("GEMINI_API_KEY_gemini_model_name deleted from os.environ with monkeypatch") - # else: - # print("GEMINI_API_KEY_gemini_model_name not in os.environ") - - # monkeypatch.delenv("GEMINI_API_KEY", raising=False) - # monkeypatch.delenv("GEMINI_API_KEY_gemini_model_name", raising=False) caplog.set_level(logging.INFO) settings = Settings(data_folder="data") @@ -68,9 +37,6 @@ async def test_gemini_query_chat_no_env_var( @pytest.mark.asyncio -# @patch( -# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock -# ) @patch.object( AsyncChat, "send_message", @@ -168,7 +134,6 @@ async def test_gemini_query_chat( @pytest.mark.asyncio -# @patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) @patch.object( AsyncChats, "create", @@ -224,9 +189,6 @@ async def test_gemini_query_history_check_chat_init( @pytest.mark.asyncio -# @patch( -# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock -# ) @patch.object( AsyncChat, "send_message", @@ -256,12 +218,6 @@ async def test_gemini_query_chat_index_error_1( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() - # mock_gemini_call.assert_any_await( - # content=prompt_dict_chat["prompt"][0], - # generation_config=prompt_dict_chat["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_any_await( message=prompt_dict_chat["prompt"][0], config=GenerateContentConfig( @@ -285,9 +241,6 @@ async def test_gemini_query_chat_index_error_1( @pytest.mark.asyncio -# @patch( -# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock -# ) @patch.object( AsyncChat, "send_message", @@ -311,12 +264,6 @@ async def test_gemini_query_chat_error_1( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() - # mock_gemini_call.assert_any_await( - # content=prompt_dict_chat["prompt"][0], - # generation_config=prompt_dict_chat["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_any_await( message=prompt_dict_chat["prompt"][0], config=GenerateContentConfig( @@ -337,9 +284,6 @@ async def test_gemini_query_chat_error_1( @pytest.mark.asyncio -# @patch( -# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock -# ) @patch.object( AsyncChat, "send_message", @@ -385,12 +329,6 @@ async def test_gemini_query_chat_index_error_2( assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 - # mock_gemini_call.assert_any_await( - # content=prompt_dict_chat["prompt"][0], - # generation_config=prompt_dict_chat["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_any_await( message=prompt_dict_chat["prompt"][0], @@ -401,13 +339,6 @@ async def test_gemini_query_chat_index_error_2( ), ) - # mock_gemini_call.assert_awaited_with( - # content=prompt_dict_chat["prompt"][1], - # generation_config=prompt_dict_chat["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) - mock_gemini_call.assert_awaited_with( message=prompt_dict_chat["prompt"][1], config=GenerateContentConfig( @@ -442,9 +373,6 @@ async def test_gemini_query_chat_index_error_2( @pytest.mark.asyncio -# @patch( -# "google.generativeai.ChatSession.send_message_async", new_callable=CopyingAsyncMock -# ) @patch.object( AsyncChat, "send_message", @@ -484,12 +412,6 @@ async def test_gemini_query_chat_error_2( assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 - # mock_gemini_call.assert_any_await( - # content=prompt_dict_chat["prompt"][0], - # generation_config=prompt_dict_chat["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_any_await( message=prompt_dict_chat["prompt"][0], @@ -500,13 +422,6 @@ async def test_gemini_query_chat_error_2( ), ) - # mock_gemini_call.assert_awaited_with( - # content=prompt_dict_chat["prompt"][1], - # generation_config=prompt_dict_chat["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) - mock_gemini_call.assert_awaited_with( message=prompt_dict_chat["prompt"][1], config=GenerateContentConfig( diff --git a/tests/apis/gemini/test_gemini_string_input.py b/tests/apis/gemini/test_gemini_string_input.py index 82aeedb1..da5a8500 100644 --- a/tests/apis/gemini/test_gemini_string_input.py +++ b/tests/apis/gemini/test_gemini_string_input.py @@ -16,24 +16,13 @@ @pytest.mark.asyncio -# @patch( -# AsyncChat, -# "send_message", -# new_callable=AsyncMock, -# ) async def test_gemini_query_string_no_env_var( prompt_dict_string, temporary_data_folders, caplog, monkeypatch ): - - # monkeypatch.delenv("GEMINI_API_KEY", raising=False) - # monkeypatch.delenv("GEMINI_API_KEY_gemini_model_name", raising=False) - caplog.set_level(logging.INFO) settings = Settings(data_folder="data") log_file = "log.txt" - # mock_send_message.return_value = "response Messages object" - # raise error if no environment variable is set with pytest.raises( KeyError, @@ -47,9 +36,6 @@ async def test_gemini_query_string_no_env_var( @pytest.mark.asyncio -# @patch( -# "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock -# ) @patch.object( AsyncModels, "generate_content", @@ -77,9 +63,6 @@ async def test_gemini_query_string( # google.genai.types.GenerateContentResponse object # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function - # TODO: Check if there is a difference in the return type of - # `google.genai.client.aio.models.generate_content`` and - # `google.genai.client.models.generate_content` mock_gemini_call.return_value = "response Messages object" # mock the process_response function @@ -122,9 +105,6 @@ async def test_gemini_query_string( @pytest.mark.asyncio -# @patch( -# "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock -# ) @patch.object( AsyncModels, "generate_content", @@ -176,9 +156,6 @@ async def test_gemini_query_string__index_error( @pytest.mark.asyncio -# @patch( -# "google.generativeai.GenerativeModel.generate_content_async", new_callable=AsyncMock -# ) @patch.object( AsyncModels, "generate_content", From d561e070e32811736b7aa795e5c45df22e6ee817 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Wed, 30 Apr 2025 12:09:22 +0100 Subject: [PATCH 08/21] linting and copilot review changes --- src/prompto/apis/gemini/gemini.py | 17 +---------------- src/prompto/apis/gemini/gemini_media.py | 17 +++++++---------- src/prompto/apis/gemini/gemini_utils.py | 2 -- 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index 6506c003..938380e7 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -1,5 +1,4 @@ import logging -import os from typing import Any from google.genai import Client @@ -575,13 +574,6 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: try: # No need to send the generation_config again, as it is no different # from the one used to create the chat - last_msg = prompt[-1] - print(f"whole prompt: {prompt}") - print(f"last_msg: {last_msg}") - # msg_to_send = convert_dict_to_input( - # content_dict=prompt[-1], media_folder=self.settings.media_folder - # ) - msg_to_send = parse_parts( prompt[-1]["parts"], media_folder=self.settings.media_folder, @@ -593,14 +585,7 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: ), "Only one message is allowed in the last message" msg_to_send = msg_to_send[0] - print(f"msg_to_send: {msg_to_send}") - - response = await chat.send_message( - # message=convert_dict_to_input( - # content_dict=prompt[-1], media_folder=self.settings.media_folder - # ), - message=msg_to_send - ) + response = await chat.send_message(message=msg_to_send) response_text = process_response(response) safety_attributes = process_safety_attributes(response) diff --git a/src/prompto/apis/gemini/gemini_media.py b/src/prompto/apis/gemini/gemini_media.py index 44971563..e4c51f99 100644 --- a/src/prompto/apis/gemini/gemini_media.py +++ b/src/prompto/apis/gemini/gemini_media.py @@ -1,13 +1,11 @@ import asyncio import base64 -import json import logging import os -import time -import google.genai import tqdm from dotenv import load_dotenv +from google import genai from prompto.utils import compute_sha256_base64 @@ -81,9 +79,10 @@ async def upload_single_file(local_file_path, already_uploaded_files): f"Failure uploaded file '{file_obj.name}'. Error: {file_obj.error_message}" ) raise ValueError(err_msg) - # logger.info( - # f"Uploaded file '{file_obj.name}' with hash '{local_hash}' to Gemini API" - # ) + + logger.info( + f"Uploaded file '{file_obj.name}' with hash '{local_hash}' to Gemini API" + ) already_uploaded_files[local_hash] = file_obj.name return file_obj.name, local_file_path @@ -158,11 +157,9 @@ def upload_media_files(files_to_upload: set[str]): async def upload_media_files_async(files_to_upload: set[str]): - start_time = time.time() - logger.info(f"Start retrieving previously uploaded files ") + logger.info("Start retrieving previously uploaded files") uploaded_files = _get_previously_uploaded_files() - next_time = time.time() - logger.info(f"Retrieved list of previously uploaded files") + logger.info("Retrieved list of previously uploaded files") # Upload files asynchronously tasks = [] diff --git a/src/prompto/apis/gemini/gemini_utils.py b/src/prompto/apis/gemini/gemini_utils.py index 0acb102f..089b46d3 100644 --- a/src/prompto/apis/gemini/gemini_utils.py +++ b/src/prompto/apis/gemini/gemini_utils.py @@ -28,8 +28,6 @@ def parse_parts_value(part: dict | str, media_folder: str, client: Client) -> an Multimedia data object """ if isinstance(part, str): - # return part - print(f"Part is a string: {part}") return types.Part.from_text(text=part) # read multimedia type From 7a2af561fad833a18681c231d692ad5e78b06032 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Thu, 1 May 2025 15:09:38 +0100 Subject: [PATCH 09/21] WIP: migrate gemini media functions --- src/prompto/apis/gemini/gemini_media.py | 173 ++++++++++--- src/prompto/upload_media.py | 33 ++- tests/apis/gemini/test_gemini_media.py | 327 ++++++++++++++++++++++++ tests/apis/gemini/test_gemini_utils.py | 2 +- 4 files changed, 496 insertions(+), 39 deletions(-) create mode 100644 tests/apis/gemini/test_gemini_media.py diff --git a/src/prompto/apis/gemini/gemini_media.py b/src/prompto/apis/gemini/gemini_media.py index e4c51f99..814ae624 100644 --- a/src/prompto/apis/gemini/gemini_media.py +++ b/src/prompto/apis/gemini/gemini_media.py @@ -2,11 +2,15 @@ import base64 import logging import os +import tempfile +from time import sleep import tqdm from dotenv import load_dotenv from google import genai +from prompto.apis.gemini.gemini import GeminiAPI +from prompto.settings import Settings from prompto.utils import compute_sha256_base64 # initialise logging @@ -23,23 +27,33 @@ def remote_file_hash_base64(remote_file): Convert a remote file's SHA256 hash (stored as a hex-encoded UTF-8 bytes object) to a base64-encoded string. """ - hex_str = remote_file.sha256_hash.decode("utf-8") + # hex_str = remote_file.sha256_hash.decode("utf-8") + hex_str = remote_file.sha256_hash raw_bytes = bytes.fromhex(hex_str) return base64.b64encode(raw_bytes).decode("utf-8") -async def wait_for_processing(file_obj, poll_interval=1): +async def wait_for_processing(file_obj, client: genai.Client, poll_interval=1): """ Poll until the file is no longer in the 'PROCESSING' state. Returns the updated file object. """ + # print(f"File {file_obj.name} is in state {file_obj.state.name}") + while file_obj.state.name == "PROCESSING": await asyncio.sleep(poll_interval) - file_obj = genai.get_file(file_obj.name) + # We need to re-fetch the file object to get the updated state. + file_obj = client.files.get(name=file_obj.name) + # print(f"File {file_obj.name} is in state {file_obj.state.name}") + # print(f"{file_obj.error=}") + # print(f"{file_obj.update_time=}") + # print(f"{file_obj.create_time=}") return file_obj -async def upload_single_file(local_file_path, already_uploaded_files): +async def _upload_single_file( + local_file_path, already_uploaded_files, client: genai.Client +): """ Upload the file at 'file_path' if it hasn't been uploaded yet. If a file with the same SHA256 (base64-encoded) hash exists, returns its name. @@ -61,6 +75,8 @@ async def upload_single_file(local_file_path, already_uploaded_files): results later.) """ local_hash = compute_sha256_base64(local_file_path) + print(f"local_file_path: {local_file_path}") + print(f"local_hash: {local_hash}") if local_hash in already_uploaded_files: logger.info( @@ -71,8 +87,10 @@ async def upload_single_file(local_file_path, already_uploaded_files): # Upload the file if it hasn't been found. # Use asyncio.to_thread to run the blocking upload_file function in a separate thread. logger.info(f"Uploading {local_file_path} to Gemini API") - file_obj = await asyncio.to_thread(genai.upload_file, local_file_path) - file_obj = await wait_for_processing(file_obj) + + # file_obj = await asyncio.to_thread(genai.upload_file, local_file_path) + file_obj = await client.aio.files.upload(file=local_file_path) + file_obj = await wait_for_processing(file_obj, client=client) if file_obj.state.name == "FAILED": err_msg = ( @@ -87,31 +105,50 @@ async def upload_single_file(local_file_path, already_uploaded_files): return file_obj.name, local_file_path -def _init_genai(): - load_dotenv(dotenv_path=".env") - # TODO: check if this can be refactored to a common function - GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") - if GEMINI_API_KEY is None: - raise ValueError("GEMINI_API_KEY is not set") +# def _init_genai(): +# load_dotenv(dotenv_path=".env") +# # TODO: check if this can be refactored to a common function +# GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") +# if GEMINI_API_KEY is None: +# raise ValueError("GEMINI_API_KEY is not set") - genai.configure(api_key=GEMINI_API_KEY) +# genai.configure(api_key=GEMINI_API_KEY) -def _get_previously_uploaded_files(): +async def _get_previously_uploaded_files(client: genai.Client): + raw_files = await client.aio.files.list() uploaded_files = { - remote_file_hash_base64(remote_file): remote_file.name - for remote_file in genai.list_files() + remote_file.sha256_hash: remote_file.name for remote_file in raw_files } logger.info(f"Found {len(uploaded_files)} files already uploaded at Gemini API") return uploaded_files -def list_uploaded_files(): +def list_uploaded_files(settings: Settings): """ List all previously uploaded files to the Gemini API. """ - _init_genai() - uploaded_files = _get_previously_uploaded_files() + # _init_genai() + + # Settings are not used in this function, but we need to + # create a dummy settings object to pass to the GeminiAPI + # TODO: + # Also, we don't need a directory, but Settings constructor + # insists on creating these directories locally. + # A better solution would be to create an option in the + # Settings constructor to not create the directories. + # But for now we'll just pass it a temporary directory. + # with tempfile.TemporaryDirectory() as temp_dir: + # data_folder = os.path.join(temp_dir, "data") + # os.makedirs(data_folder, exist_ok=True) + # dummy_settings = Settings(data_folder=data_folder) + + genmini_api = GeminiAPI(settings=settings, log_file=None) + # TODO: We need a model name, because our API caters for different API keys + # for different models. Maybe our API to complicated.... + default_model_name = "default" + client = genmini_api._get_client(default_model_name) + uploaded_files = asyncio.run(_get_previously_uploaded_files(client)) for file_hash, file_name in uploaded_files.items(): msg = f"File Name: {file_name}, File Hash: {file_hash}" @@ -119,26 +156,92 @@ def list_uploaded_files(): logger.info("All uploaded files listed.") -def delete_uploaded_files(): +def delete_uploaded_files(settings: Settings): """ Delete all previously uploaded files from the Gemini API. """ - _init_genai() - uploaded_files = _get_previously_uploaded_files() - return asyncio.run(_delete_uploaded_files_async(uploaded_files)) + # _init_genai() + + # with tempfile.TemporaryDirectory() as temp_dir: + # data_folder = os.path.join(temp_dir, "data") + # os.makedirs(data_folder, exist_ok=True) + # dummy_settings = Settings(data_folder=data_folder) + + genmini_api = GeminiAPI(settings=settings, log_file=None) + # TODO: We need a model name, because our API caters for different API keys + # for different models. Maybe our API to complicated.... + default_model_name = "default" + client = genmini_api._get_client(default_model_name) + + # uploaded_files = asyncio.run(_get_previously_uploaded_files(client)) + for remote_file in client.files.list(): + # file_name = file_name.name + client.files.delete(name=remote_file.name) + # _delete_single_uploaded_file(file_name, client) + # return asyncio.run(_delete_uploaded_files_async(uploaded_files, client)) + logger.info("All uploaded files deleted.") -async def _delete_uploaded_files_async(uploaded_files): - tasks = [] +def _delete_single_uploaded_file(file_name: str, client: genai.Client): + """ + Delete a single uploaded file from the Gemini API. + """ + print(f"Deleting file {file_name}") + file = client.files.get(name=file_name) + client.files.delete(name=file_name) + # indx = 0 + + # The delete function is non-blocking (even the sync version) + # and returns immediately. So we need to poll the file object + # to see if it is still exists. + # The only reliable way to check if the file is deleted is to + # try and get it again and see if it raises an error. + while True: + # We need to re-fetch the file object to get the updated state. + try: + file = client.files.get(name=file.name) + print(f"File {file.name} is in state {file.state.name}") + print(f"{file.error=}") + print(f"{file.update_time=}") + print(f"{file.create_time=}") + print(f"{indx=}") + # client.files.delete(name=file_name) + indx += 1 + # if indx > 10: + # break + except genai.errors.ClientError as e: + # print(f"ClientError: {e}" + print(f"File {file.name} deleted") + break + # if file.state.name == "PROCESSING": + sleep(1) + + +async def _delete_uploaded_files_async(uploaded_files, client: genai.Client): + tasks_set = set() for file_name in uploaded_files.values(): logger.info(f"Preparing to delete file: {file_name}") - tasks.append(asyncio.to_thread(genai.delete_file, file_name)) + # tasks.append(asyncio.to_thread(genai.delete_file, file_name)) - await tqdm.asyncio.tqdm.gather(*tasks) + # file = await client.aio.files.get(name=file_name) + # # task = client.aio.files.delete(name=file_name) + # task = asyncio.to_thread(client.aio.files.delete(name=file_name)) + tasks_set.add(_delete_single_uploaded_file(file_name, client)) + + # await tqdm.asyncio.tqdm.gather(*tasks) + await asyncio.gather(*tasks_set, return_exceptions=True) logger.info("All uploaded files deleted.") + # async with asyncio.TaskGroup() as tg: + # for file_name in uploaded_files.values(): + # logger.info(f"Preparing to delete file: {file_name}") + # # tasks.append(asyncio.to_thread(genai.delete_file, file_name)) + # tg.create_task(client.aio.files.delete(name=file_name)) + + # logger.info("All uploaded files deleted.") + -def upload_media_files(files_to_upload: set[str]): +def upload_media_files(files_to_upload: set[str], settings: Settings): """ Upload media files to the Gemini API. @@ -152,20 +255,24 @@ def upload_media_files(files_to_upload: set[str]): dict[str, str] Dictionary mapping local file paths to their corresponding uploaded filenames. """ - _init_genai() - return asyncio.run(upload_media_files_async(files_to_upload)) + # _init_genai() + return asyncio.run(upload_media_files_async(files_to_upload, settings)) -async def upload_media_files_async(files_to_upload: set[str]): +async def upload_media_files_async(files_to_upload: set[str], settings: Settings): logger.info("Start retrieving previously uploaded files") - uploaded_files = _get_previously_uploaded_files() + gemini_api = GeminiAPI(settings=settings, log_file=None) + client = gemini_api._get_client("default") + + uploaded_files = await _get_previously_uploaded_files(client) + logger.info("Retrieved list of previously uploaded files") # Upload files asynchronously tasks = [] for file_path in files_to_upload: logger.info(f"checking if {file_path} needs to be uploaded") - tasks.append(upload_single_file(file_path, uploaded_files)) + tasks.append(_upload_single_file(file_path, uploaded_files, client)) remote_local_pairs = await tqdm.asyncio.tqdm.gather(*tasks) diff --git a/src/prompto/upload_media.py b/src/prompto/upload_media.py index 4621c1cf..03b4c68b 100644 --- a/src/prompto/upload_media.py +++ b/src/prompto/upload_media.py @@ -7,6 +7,7 @@ import prompto.apis.gemini.gemini_media as gemini_media from prompto.apis import ASYNC_APIS from prompto.scripts.run_experiment import load_env_file +from prompto.settings import Settings # initialise logging logger = logging.getLogger(__name__) @@ -224,12 +225,14 @@ def upload_media_parse_args(): def do_delete_existing_files(args): - gemini_media.delete_uploaded_files() + settings = _create_settings() + gemini_media.delete_uploaded_files(settings) return def do_list_uploaded_files(args): - gemini_media.list_uploaded_files() + settings = _create_settings() + gemini_media.list_uploaded_files(settings) return @@ -249,10 +252,29 @@ def _do_upload_media_from_args(args): args : argparse.Namespace """ _resolve_output_file_location(args) - asyncio.run(do_upload_media(args.file, args.media_folder, args.output_file)) + do_upload_media(args.file, args.media_folder, args.output_file) -async def do_upload_media(input_file, media_folder, output_file): +def _create_settings(): + """ + Create a dummy settings object for the Gemini API. + This is used to create a client object for the API. + For now, we just create a temporary directory for the data folder. + """ + # A better solution would be to create an option in the + # Settings constructor to not create the directories. + # But for now we'll just pass it a temporary directory. + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + data_folder = os.path.join(temp_dir, "data") + os.makedirs(data_folder, exist_ok=True) + dummy_settings = Settings(data_folder=data_folder) + + return dummy_settings + + +def do_upload_media(input_file, media_folder, output_file): """ Upload media files to the relevant API. The media files are uploaded and the experiment file is updated with the uploaded filenames. @@ -276,7 +298,8 @@ async def do_upload_media(input_file, media_folder, output_file): # If in future we support other bulk upload to other APIs, we will need to # refactor here - uploaded_files = await gemini_media.upload_media_files_async(files_to_upload) + settings = _create_settings() + uploaded_files = gemini_media.upload_media_files(files_to_upload, settings) update_experiment_file( prompt_dict_list, diff --git a/tests/apis/gemini/test_gemini_media.py b/tests/apis/gemini/test_gemini_media.py new file mode 100644 index 00000000..c818bbda --- /dev/null +++ b/tests/apis/gemini/test_gemini_media.py @@ -0,0 +1,327 @@ +import logging +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from google import genai +from google.genai.types import FileState + +from prompto.apis.gemini.gemini import GeminiAPI +from prompto.apis.gemini.gemini_media import ( + delete_uploaded_files, + list_uploaded_files, + remote_file_hash_base64, + upload_media_files, + wait_for_processing, +) +from prompto.upload_media import _create_settings + + +def test_remote_file_hash_base64(): + + # The example hashes are for the strings "hash1", "hash2", "hash3" eg: + # >>> "hash1".encode("utf-8").hex() + # '6861736831' + test_cases = [ + ( + Mock(dummy_name="file1", sha256_hash="6861736831"), + "aGFzaDE=", + ), + ( + Mock(dummy_name="file2", sha256_hash="6861736832"), + "aGFzaDI=", + ), + ( + Mock(dummy_name="file3", sha256_hash="6861736833"), + "aGFzaDM=", + ), + ] + + # We need to juggle with the mock names, because we can't set them + # directly in the constructor. See these docs for more details: + # https://docs.python.org/3/library/unittest.mock.html#mock-names-and-the-name-attribute + for mock_file, expected_hash in test_cases: + mock_file.configure_mock(name=mock_file.dummy_name) + mock_file.__str__ = Mock(return_value=mock_file.dummy_name) + + actual_hash = remote_file_hash_base64(mock_file) + assert ( + actual_hash == expected_hash + ), f"Expected {expected_hash}, but got {actual_hash}" + + +@pytest.mark.asyncio +@patch.object( + genai.files.Files, + "get", + new_callable=Mock, +) +async def test_wait_for_processing(mock_file_get, monkeypatch): + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + dummy_settings = _create_settings() + gemini_api = GeminiAPI(settings=dummy_settings, log_file=None) + client = gemini_api._get_client("default") + + # These mocks represent the same file, but at different states/points in time + starting_file = Mock( + dummy_name="file1", state=FileState.PROCESSING, sha256_hash="aGFzaDE=" + ) + + side_effects = [ + Mock(name="file1", state=FileState.PROCESSING, sha256_hash="aGFzaDE="), + Mock(name="file1", state=FileState.PROCESSING, sha256_hash="aGFzaDE="), + Mock(name="file1", state=FileState.ACTIVE, sha256_hash="aGFzaDE="), + # We should never to this, but including it to differentiate + # between the function completing because it picked up on the + # previous file==ACTIVE (correct), or if the function completed + # because it ran out of side effects (incorrect) + Mock(name="file1", state=FileState.PROCESSING, sha256_hash="aGFzaDE="), + ] + + mock_file_get.side_effect = side_effects + + # Call the function to test + await wait_for_processing(starting_file, client, poll_interval=0) + + # Check that the `get` method was called exactly 3 times + assert mock_file_get.call_count == 3 + + +@patch.object( + genai.files.AsyncFiles, + "list", + new_callable=AsyncMock, +) +@patch( + "prompto.apis.gemini.gemini_media.compute_sha256_base64", + new_callable=MagicMock, +) +def test_upload_media_files_already_uploaded( + mock_compute_sha256_base64, mock_list_files, monkeypatch, caplog +): + caplog.set_level(logging.INFO) + + uploaded_file = Mock(dummy_name="remote_uploaded/file1", sha256_hash="aGFzaDE=") + uploaded_file.configure_mock(name=uploaded_file.dummy_name) + uploaded_file.__str__ = Mock(return_value=uploaded_file.dummy_name) + + local_file_path = "dummy_local_path/file1.txt" + expected_log_msg = ( + "File 'dummy_local_path/file1.txt' already uploaded as 'remote_uploaded/file1'" + ) + + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + + mock_compute_sha256_base64.return_value = "aGFzaDE=" + # return_value is a list of a single mock remote file + mock_list_files.return_value = [uploaded_file] + dummy_settings = None + + # Pass a list of local file paths to the function + actual_uploads = upload_media_files([local_file_path], dummy_settings) + + # actual_uploads is a dict of local and remote file names + assert local_file_path in actual_uploads + assert actual_uploads[local_file_path] == "remote_uploaded/file1" + + # Check the log message + assert expected_log_msg in caplog.text + + +# def test_upload_media_files(): +# pytest.fail("Test not implemented") + + +# @pytest.mark.asyncio +@patch( + "prompto.apis.gemini.gemini_media._get_previously_uploaded_files", + new_callable=AsyncMock, +) +@patch.object( + genai.files.AsyncFiles, + "upload", + new_callable=AsyncMock, +) +@patch( + "prompto.apis.gemini.gemini_media.compute_sha256_base64", + new_callable=MagicMock, +) +def test_upload_media_files_new_file( + mock_compute_sha256_base64, + mock_files_upload, + mock_previous_files, + monkeypatch, + caplog, +): + """ + Test the upload_media_files function when the file is not already uploaded, but there are already + other files uploaded.""" + caplog.set_level(logging.INFO) + + pre_uploaded_file = Mock( + dummy_name="remote_uploaded/file1", + sha256_hash=Mock(decode=lambda _: "6861736831"), + ) + pre_uploaded_file.configure_mock(name=pre_uploaded_file.dummy_name) + pre_uploaded_file.__str__ = Mock(return_value=pre_uploaded_file.dummy_name) + + previous_files_dict = { + "hash1": pre_uploaded_file, + } + + local_file_path = "dummy_local_path/file2.txt" + expected_log_msgs = [ + "Uploading dummy_local_path/file2.txt to Gemini API", + "Uploaded file 'remote_uploaded/file2' with hash 'hash2' to Gemini API", + ] + + new_file = Mock( + dummy_name="remote_uploaded/file2", + sha256_hash=Mock(decode=lambda _: "hash2"), + ) + new_file.configure_mock(name=new_file.dummy_name) + new_file.__str__ = Mock(return_value=new_file.name) + + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + + mock_compute_sha256_base64.return_value = "hash2" + mock_previous_files.return_value = previous_files_dict + mock_files_upload.return_value = new_file + + dummy_settings = None + actual_uploads = upload_media_files([local_file_path], dummy_settings) + + print(actual_uploads) + + # actual_uploads is a dict of local and remote file names + assert local_file_path in actual_uploads + assert actual_uploads[local_file_path] == "remote_uploaded/file2" + + # Check that the previously uploaded file is not in the actual_uploads dict + assert pre_uploaded_file.dummy_name not in actual_uploads + + # Check the log message + assert all(msg in caplog.text for msg in expected_log_msgs) + + +# def test__init_genai(): +# # Is this still required, or is it superseded by the Client object +# pytest.fail("Test not implemented") + + +@patch.object( + genai.files.AsyncFiles, + "list", + new_callable=AsyncMock, +) +def test_list_uploaded_files(mock_list_files, caplog, monkeypatch): + caplog.set_level(logging.INFO) + + # Case 1: No files uploaded + case_1 = { + "return_value": [], + "expected_log_msgs": ["Found 0 files already uploaded at Gemini API"], + } + + # Case 2: Three files uploaded + # The example hashes are for the strings "hash1", "hash2", "hash3" eg: + # >>> "hash1".encode("utf-8").hex() + # '6861736831' + case_2 = { + "return_value": [ + Mock(dummy_name="file1", sha256_hash="aGFzaDE="), + Mock(dummy_name="file2", sha256_hash="aGFzaDI="), + Mock(dummy_name="file3", sha256_hash="aGFzaDM="), + ], + "expected_log_msgs": [ + "Found 3 files already uploaded at Gemini API", + "File Name: file1, File Hash: aGFzaDE=", + "File Name: file2, File Hash: aGFzaDI=", + "File Name: file3, File Hash: aGFzaDM=", + ], + } + + expected_final_log_msg = "All uploaded files listed." + + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + + for case_dict in [case_1, case_2]: + + mocked_list_value = case_dict["return_value"] + + # We need to juggle with the mock names, because we can't set them + # directly in the constructor. See these docs for more details: + # https://docs.python.org/3/library/unittest.mock.html#mock-names-and-the-name-attribute + for mock_file in mocked_list_value: + mock_file.configure_mock(name=mock_file.dummy_name) + mock_file.__str__ = Mock(return_value=mock_file.dummy_name) + + mock_list_files.return_value = mocked_list_value + + expected_total_in_log_msg = case_dict["expected_log_msgs"] + + dummy_settings = _create_settings() + # Call the function to test + list_uploaded_files(dummy_settings) + + # There is no return value from the function, so we need to check the + # log messages + for msg in expected_total_in_log_msg: + assert msg in caplog.text + assert expected_final_log_msg in caplog.text + + +@patch.object( + genai.files.Files, + "delete", + new_callable=Mock, +) +@patch.object( + genai.files.Files, + "list", + new_callable=Mock, +) +def test_delete_uploaded_files(mock_list_files, mock_delete, caplog, monkeypatch): + + caplog.set_level(logging.INFO) + + # Case 1: No files uploaded + case_1 = [] + + # Case 2: Three files uploaded + # The example hashes are for the strings "hash1", "hash2", "hash3" eg: + # >>> "hash1".encode("utf-8").hex() + # '6861736831' + case_2 = [ + Mock(dummy_name="file1", sha256_hash="aGFzaDE="), + Mock(dummy_name="file2", sha256_hash="aGFzaDI="), + Mock(dummy_name="file3", sha256_hash="aGFzaDM="), + ] + expected_final_log_msg = "All uploaded files deleted." + + with monkeypatch.context() as m: + m.setenv("GEMINI_API_KEY", "DUMMY") + + for mocked_list_value in [case_1, case_2]: + + # We need to juggle with the mock names, because we can't set them + # directly in the constructor. See these docs for more details: + # https://docs.python.org/3/library/unittest.mock.html#mock-names-and-the-name-attribute + for mock_file in mocked_list_value: + mock_file.configure_mock(name=mock_file.dummy_name) + mock_file.__str__ = Mock(return_value=mock_file.dummy_name) + + mock_list_files.return_value = mocked_list_value + + dummy_settings = _create_settings() + # Call the function to test + delete_uploaded_files(dummy_settings) + + # There is no return value from the function, so we need to check the + # that the delete function was called the expected number of times + # Add 1 to force it to fail for now. + assert mock_delete.call_count == len(mocked_list_value) + assert expected_final_log_msg in caplog.text diff --git a/tests/apis/gemini/test_gemini_utils.py b/tests/apis/gemini/test_gemini_utils.py index b453c623..9005ed2b 100644 --- a/tests/apis/gemini/test_gemini_utils.py +++ b/tests/apis/gemini/test_gemini_utils.py @@ -3,7 +3,7 @@ from prompto.apis.gemini.gemini_utils import convert_history_dict_to_content -from .test_gemini import prompt_dict_chat, prompt_dict_history +from .test_gemini import prompt_dict_history @pytest.mark.xfail(reason="Test not implemented") From 6e66327f435ad87327e7100aab311cfdf7ecb7f3 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Tue, 6 May 2025 11:20:37 +0100 Subject: [PATCH 10/21] tweak dependency versions --- pyproject.toml | 47 ++++++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 188915a1..4728088d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,38 +13,47 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.11,<4.0" -tqdm = "^4.66.4" -python-dotenv = "^1.0.1" +# tqdm = "^4.66.4" # working with 4.67.1 +tqdm = "^4.67.1" # working with 4.67.1 +python-dotenv = "^1.0.1" # working with 1.1.0 pandas = "^2.2.3" -black = { version = "^24.3.0", optional = true } +black = { version = "^24.3.0", optional = true } # working with 24.10.0 isort = { version = "^5.13.2", optional = true } -pre-commit = { version = "^3.7.0", optional = true } -pytest = { version = "^8.1.1", optional = true } -pytest-asyncio = { version = "^0.23.6", optional = true } +pre-commit = { version = "^3.7.0", optional = true } # working with 3.8.0 +pytest = { version = "^8.1.1", optional = true } # working with 8.3.5 +pytest-asyncio = { version = "^0.23.6", optional = true } # working with 0.23.8 pytest-cov = { version = "^5.0.0", optional = true } -ipykernel = { version = "^6.29.4", optional = true } -mkdocs-material = { version = "^9.5.26", optional = true } -mkdocstrings-python = { version = "^1.10.3", optional = true } +ipykernel = { version = "^6.29.4", optional = true } # working with 6.29.5 +mkdocs-material = { version = "^9.5.26", optional = true } # working with 9.6.11 +mkdocstrings-python = { version = "^1.10.3", optional = true } # working with 1.16.10 mkdocs-gen-files = { version = "^0.5.0", optional = true } -mkdocs-literate-nav = { version = "^0.6.1", optional = true } +mkdocs-literate-nav = { version = "^0.6.1", optional = true } # working with 0.6.2 mkdocs-section-index = { version = "^0.3.9", optional = true } mkdocs-same-dir = { version = "^0.1.3", optional = true } -mkdocs-jupyter = { version = "^0.24.7", optional = true } -cli-test-helpers = { version = "^4.0.0", optional = true } +mkdocs-jupyter = { version = "^0.24.7", optional = true } # working with 0.24.8 +cli-test-helpers = { version = "^4.0.0", optional = true } # working with 4.1.0 vertexai = { version ="^1.71.1", optional = true } google-cloud-aiplatform = { version = "^1.71.1", optional = true } # google-generativeai = { version = "^0.8.4", optional = true } # TODO: deprecated - to be removed -google-genai = { version = "^1.11.0", optional = true } -openai = { version = "^1.60.0", optional = true } +google-genai = { version = "^1.11.0", optional = true } # working with 1.13.0 +openai = { version = "^1.60.0", optional = true } # working with 1.70.0 pillow = { version = "^11.1.0", optional = true } ollama = { version = "^0.4.7", optional = true } -huggingface-hub = { version = "^0.28.0", optional = true } +huggingface-hub = { version = "^0.28.0", optional = true } # working with 0.28.1 quart = { version = "^0.20.0", optional = true } -transformers = { version = "^4.48.1", optional = true } +# transformers = { version = "^4.48.1", optional = true } # working with 4.50.3 +transformers = { version = "^4.50.3", optional = true } # working with 4.50.3 torch = { version = "^2.6.0", optional = true } -accelerate = { version = "^1.3.0", optional = true } -aiohttp = { version = "^3.11.11", optional = true } +accelerate = { version = "^1.3.0", optional = true } # working with 1.6.0 +aiohttp = { version = "^3.11.11", optional = true } # working with 3.11.16 anthropic = { version = "^0.45.2", optional = true } +# In theory these could be derived from the dependencies above, but +# we want to be explicit about them, so that pip doesn't spend ages resolving +# them. +urllib3 = { version = "^2.3.0", optional = true } +shapely = { version = "^2.1.0", optional = true } + + [tool.poetry.extras] all = [ @@ -107,7 +116,7 @@ gemini = [ vertexai = [ "vertexai", "google-cloud-aiplatform", - "google-generativeai", + # "google-generativeai", # TODO: deprecated - to be removed "google-genai", "pillow" ] From 487716d84ebd016b57b273ff8bfdbd79713c61da Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Tue, 6 May 2025 11:32:32 +0100 Subject: [PATCH 11/21] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4728088d..acccdd82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,8 @@ anthropic = { version = "^0.45.2", optional = true } # In theory these could be derived from the dependencies above, but # we want to be explicit about them, so that pip doesn't spend ages resolving # them. -urllib3 = { version = "^2.3.0", optional = true } -shapely = { version = "^2.1.0", optional = true } +urllib3 = { version = "^2.3.0", optional = false } +shapely = { version = "^2.1.0", optional = false } From 9fa2ad149e5116edef16019857336145826cc600 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Tue, 6 May 2025 15:58:29 +0100 Subject: [PATCH 12/21] WIP: attempt tighter deps version restrictions --- pyproject.toml | 68 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index acccdd82..d5bee3d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,11 +24,11 @@ pytest = { version = "^8.1.1", optional = true } # working with 8.3.5 pytest-asyncio = { version = "^0.23.6", optional = true } # working with 0.23.8 pytest-cov = { version = "^5.0.0", optional = true } ipykernel = { version = "^6.29.4", optional = true } # working with 6.29.5 -mkdocs-material = { version = "^9.5.26", optional = true } # working with 9.6.11 +# mkdocs-material = { version = "^9.5.26", optional = true } # working with 9.6.11 mkdocstrings-python = { version = "^1.10.3", optional = true } # working with 1.16.10 mkdocs-gen-files = { version = "^0.5.0", optional = true } mkdocs-literate-nav = { version = "^0.6.1", optional = true } # working with 0.6.2 -mkdocs-section-index = { version = "^0.3.9", optional = true } +# mkdocs-section-index = { version = "^0.3.9", optional = true } mkdocs-same-dir = { version = "^0.1.3", optional = true } mkdocs-jupyter = { version = "^0.24.7", optional = true } # working with 0.24.8 cli-test-helpers = { version = "^4.0.0", optional = true } # working with 4.1.0 @@ -36,23 +36,75 @@ vertexai = { version ="^1.71.1", optional = true } google-cloud-aiplatform = { version = "^1.71.1", optional = true } # google-generativeai = { version = "^0.8.4", optional = true } # TODO: deprecated - to be removed google-genai = { version = "^1.11.0", optional = true } # working with 1.13.0 -openai = { version = "^1.60.0", optional = true } # working with 1.70.0 -pillow = { version = "^11.1.0", optional = true } -ollama = { version = "^0.4.7", optional = true } +# openai = { version = "^1.60.0", optional = true } # working with 1.70.0 +# pillow = { version = "^11.1.0", optional = true } +# ollama = { version = "^0.4.7", optional = true } huggingface-hub = { version = "^0.28.0", optional = true } # working with 0.28.1 quart = { version = "^0.20.0", optional = true } # transformers = { version = "^4.48.1", optional = true } # working with 4.50.3 transformers = { version = "^4.50.3", optional = true } # working with 4.50.3 -torch = { version = "^2.6.0", optional = true } +# torch = { version = "^2.6.0", optional = true } accelerate = { version = "^1.3.0", optional = true } # working with 1.6.0 -aiohttp = { version = "^3.11.11", optional = true } # working with 3.11.16 +# aiohttp = { version = "^3.11.11", optional = true } # working with 3.11.16 anthropic = { version = "^0.45.2", optional = true } # In theory these could be derived from the dependencies above, but # we want to be explicit about them, so that pip doesn't spend ages resolving # them. -urllib3 = { version = "^2.3.0", optional = false } +# urllib3 = { version = "^2.3.0", optional = false } shapely = { version = "^2.1.0", optional = false } +# Paste in from pip list --format=freeze + +aiohttp = { version = "^3.11.18", optional = false } +beautifulsoup4 = { version = "^4.13.4", optional = false } +certifi = { version = "^2025.4.26", optional = false } +charset-normalizer = { version = "^3.4.2", optional = false } +debugpy = { version = "^1.8.14", optional = false } +frozenlist = { version = "^1.6.0", optional = false } +google-api-core = { version = "^2.25.0rc0", optional = false } +google-auth = { version = "^2.40.0", optional = false } +google-api-python-client = { version = "^2.166.0", optional = false } +# google-auth = { version = "^2.38.0", optional = false } +google-auth-httplib2 = { version = "^0.2.0", optional = false } +google-generativeai = { version = "^0.8.4", optional = false } +googleapis-common-protos = { version = "^1.70.0", optional = false } +griffe = { version = "^1.7.3", optional = false } +h11 = { version = "^0.16.0", optional = false } +httpcore = { version = "^1.0.9", optional = false } +httplib2 = { version = "^0.22.0", optional = false } +identify = { version = "^2.6.10", optional = false } +ipython = { version = "^9.2.0", optional = false } +jsonschema-specifications = { version = "^2025.4.1", optional = false } +jupytext = { version = "^1.17.1", optional = false } +Markdown = { version = "^3.8", optional = false } +mkdocs-material = { version = "^9.6.12", optional = false } +mkdocs-section-index = { version = "^0.3.10", optional = false } +multidict = { version = "^6.4.3", optional = false } +mypy_extensions = { version = "^1.1.0", optional = false } +numpy = { version = "^2.2.5", optional = false } +ollama = { version = "^0.4.8", optional = false } +openai = { version = "^1.77.0", optional = false } +packaging = { version = "^25.0", optional = false } +pillow = { version = "^11.2.1", optional = false } +pydantic = { version = "^2.11.4", optional = false } +pydantic_core = { version = "^2.33.2", optional = false } +pymdown-extensions = { version = "^10.15", optional = false } +pyparsing = { version = "^3.2.3", optional = false } +pytest-random-order = { version = "^1.1.1", optional = false } +pyzmq = { version = "^26.4.0", optional = false } +rsa = { version = "^4.9.1", optional = false } +ruff = { version = "^0.11.6", optional = false } +setuptools = { version = "^80.3.1", optional = false } +snakeviz = { version = "^2.2.2", optional = false } +soupsieve = { version = "^2.7", optional = false } +sympy = { version = "^1.14.0", optional = false } +torch = { version = "^2.7.0", optional = false } +typing_extensions = { version = "^4.13.2", optional = false } +urllib3 = { version = "^2.4.0", optional = false } +# urllib3 = { version = "^2.3.0", optional = false } +virtualenv = { version = "^20.31.1", optional = false } +websockets = { version = "^15.0.1", optional = false } +yarl = { version = "^1.20.0", optional = false } [tool.poetry.extras] From 68b5de1c035ba578afb449a07cdc3310c793aefb Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Tue, 6 May 2025 17:02:01 +0100 Subject: [PATCH 13/21] Sorted dependencies in pyproject.toml --- pyproject.toml | 148 +++++++++++++++++++++++-------------------------- 1 file changed, 68 insertions(+), 80 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d5bee3d3..80c2df6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,98 +13,86 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.11,<4.0" -# tqdm = "^4.66.4" # working with 4.67.1 -tqdm = "^4.67.1" # working with 4.67.1 -python-dotenv = "^1.0.1" # working with 1.1.0 -pandas = "^2.2.3" + +# Sorted primary dependencies +accelerate = { version = "^1.3.0", optional = true } # working with 1.6.0 +aiohttp = { version = "^3.11.18", optional = false } +anthropic = { version = "^0.45.2", optional = true } black = { version = "^24.3.0", optional = true } # working with 24.10.0 -isort = { version = "^5.13.2", optional = true } -pre-commit = { version = "^3.7.0", optional = true } # working with 3.8.0 -pytest = { version = "^8.1.1", optional = true } # working with 8.3.5 -pytest-asyncio = { version = "^0.23.6", optional = true } # working with 0.23.8 -pytest-cov = { version = "^5.0.0", optional = true } -ipykernel = { version = "^6.29.4", optional = true } # working with 6.29.5 -# mkdocs-material = { version = "^9.5.26", optional = true } # working with 9.6.11 -mkdocstrings-python = { version = "^1.10.3", optional = true } # working with 1.16.10 -mkdocs-gen-files = { version = "^0.5.0", optional = true } -mkdocs-literate-nav = { version = "^0.6.1", optional = true } # working with 0.6.2 -# mkdocs-section-index = { version = "^0.3.9", optional = true } -mkdocs-same-dir = { version = "^0.1.3", optional = true } -mkdocs-jupyter = { version = "^0.24.7", optional = true } # working with 0.24.8 cli-test-helpers = { version = "^4.0.0", optional = true } # working with 4.1.0 -vertexai = { version ="^1.71.1", optional = true } google-cloud-aiplatform = { version = "^1.71.1", optional = true } -# google-generativeai = { version = "^0.8.4", optional = true } # TODO: deprecated - to be removed google-genai = { version = "^1.11.0", optional = true } # working with 1.13.0 -# openai = { version = "^1.60.0", optional = true } # working with 1.70.0 -# pillow = { version = "^11.1.0", optional = true } -# ollama = { version = "^0.4.7", optional = true } -huggingface-hub = { version = "^0.28.0", optional = true } # working with 0.28.1 -quart = { version = "^0.20.0", optional = true } -# transformers = { version = "^4.48.1", optional = true } # working with 4.50.3 -transformers = { version = "^4.50.3", optional = true } # working with 4.50.3 -# torch = { version = "^2.6.0", optional = true } -accelerate = { version = "^1.3.0", optional = true } # working with 1.6.0 -# aiohttp = { version = "^3.11.11", optional = true } # working with 3.11.16 -anthropic = { version = "^0.45.2", optional = true } -# In theory these could be derived from the dependencies above, but -# we want to be explicit about them, so that pip doesn't spend ages resolving -# them. -# urllib3 = { version = "^2.3.0", optional = false } -shapely = { version = "^2.1.0", optional = false } - -# Paste in from pip list --format=freeze - -aiohttp = { version = "^3.11.18", optional = false } -beautifulsoup4 = { version = "^4.13.4", optional = false } -certifi = { version = "^2025.4.26", optional = false } -charset-normalizer = { version = "^3.4.2", optional = false } -debugpy = { version = "^1.8.14", optional = false } -frozenlist = { version = "^1.6.0", optional = false } -google-api-core = { version = "^2.25.0rc0", optional = false } -google-auth = { version = "^2.40.0", optional = false } -google-api-python-client = { version = "^2.166.0", optional = false } -# google-auth = { version = "^2.38.0", optional = false } -google-auth-httplib2 = { version = "^0.2.0", optional = false } google-generativeai = { version = "^0.8.4", optional = false } -googleapis-common-protos = { version = "^1.70.0", optional = false } -griffe = { version = "^1.7.3", optional = false } -h11 = { version = "^0.16.0", optional = false } -httpcore = { version = "^1.0.9", optional = false } -httplib2 = { version = "^0.22.0", optional = false } -identify = { version = "^2.6.10", optional = false } -ipython = { version = "^9.2.0", optional = false } -jsonschema-specifications = { version = "^2025.4.1", optional = false } -jupytext = { version = "^1.17.1", optional = false } -Markdown = { version = "^3.8", optional = false } +huggingface-hub = { version = "^0.28.0", optional = true } # working with 0.28.1 +ipykernel = { version = "^6.29.4", optional = true } # working with 6.29.5 +isort = { version = "^5.13.2", optional = true } +mkdocs-gen-files = { version = "^0.5.0", optional = true } +mkdocs-jupyter = { version = "^0.24.7", optional = true } # working with 0.24.8 +mkdocs-literate-nav = { version = "^0.6.1", optional = true } # working with 0.6.2 mkdocs-material = { version = "^9.6.12", optional = false } +mkdocs-same-dir = { version = "^0.1.3", optional = true } mkdocs-section-index = { version = "^0.3.10", optional = false } -multidict = { version = "^6.4.3", optional = false } -mypy_extensions = { version = "^1.1.0", optional = false } -numpy = { version = "^2.2.5", optional = false } +mkdocstrings-python = { version = "^1.10.3", optional = true } # working with 1.16.10 ollama = { version = "^0.4.8", optional = false } openai = { version = "^1.77.0", optional = false } -packaging = { version = "^25.0", optional = false } +pandas = "^2.2.3" pillow = { version = "^11.2.1", optional = false } -pydantic = { version = "^2.11.4", optional = false } -pydantic_core = { version = "^2.33.2", optional = false } -pymdown-extensions = { version = "^10.15", optional = false } -pyparsing = { version = "^3.2.3", optional = false } -pytest-random-order = { version = "^1.1.1", optional = false } -pyzmq = { version = "^26.4.0", optional = false } -rsa = { version = "^4.9.1", optional = false } -ruff = { version = "^0.11.6", optional = false } -setuptools = { version = "^80.3.1", optional = false } -snakeviz = { version = "^2.2.2", optional = false } -soupsieve = { version = "^2.7", optional = false } -sympy = { version = "^1.14.0", optional = false } +pre-commit = { version = "^3.7.0", optional = true } # working with 3.8.0 +pytest = { version = "^8.1.1", optional = true } # working with 8.3.5 +pytest-asyncio = { version = "^0.23.6", optional = true } # working with 0.23.8 +pytest-cov = { version = "^5.0.0", optional = true } +python-dotenv = "^1.0.1" # working with 1.1.0 +quart = { version = "^0.20.0", optional = true } torch = { version = "^2.7.0", optional = false } -typing_extensions = { version = "^4.13.2", optional = false } +tqdm = "^4.67.1" # working with 4.67.1 +transformers = { version = "^4.50.3", optional = true } # working with 4.50.3 urllib3 = { version = "^2.4.0", optional = false } -# urllib3 = { version = "^2.3.0", optional = false } -virtualenv = { version = "^20.31.1", optional = false } -websockets = { version = "^15.0.1", optional = false } -yarl = { version = "^1.20.0", optional = false } +vertexai = { version ="^1.71.1", optional = true } + +# Sorted secondary dependencies +# These are not directly used in the code, but are required by the primary dependencies +# They are specified here, because the automatic dependency resolution of pip +# can be slow and sometimes times-out without resolving them all. +beautifulsoup4 = { version = "^4.13.4", optional = false } # Secondary +certifi = { version = "^2025.4.26", optional = false } # Secondary +charset-normalizer = { version = "^3.4.2", optional = false } # Secondary +debugpy = { version = "^1.8.14", optional = false } # Secondary +frozenlist = { version = "^1.6.0", optional = false } # Secondary +google-api-core = { version = "^2.25.0rc0", optional = false } # Secondary +google-api-python-client = { version = "^2.166.0", optional = false } # Secondary +google-auth = { version = "^2.40.0", optional = false } # Secondary +google-auth-httplib2 = { version = "^0.2.0", optional = false } # Secondary +googleapis-common-protos = { version = "^1.70.0", optional = false } # Secondary +griffe = { version = "^1.7.3", optional = false } # Secondary +h11 = { version = "^0.16.0", optional = false } # Secondary +httpcore = { version = "^1.0.9", optional = false } # Secondary +httplib2 = { version = "^0.22.0", optional = false } # Secondary +identify = { version = "^2.6.10", optional = false } # Secondary +ipython = { version = "^9.2.0", optional = false } # Secondary +jsonschema-specifications = { version = "^2025.4.1", optional = false } # Secondary +jupytext = { version = "^1.17.1", optional = false } # Secondary +Markdown = { version = "^3.8", optional = false } # Secondary +multidict = { version = "^6.4.3", optional = false } # Secondary +mypy_extensions = { version = "^1.1.0", optional = false } # Secondary +numpy = { version = "^2.2.5", optional = false } # Secondary +packaging = { version = "^25.0", optional = false } # Secondary +pydantic = { version = "^2.11.4", optional = false } # Secondary +pydantic_core = { version = "^2.33.2", optional = false } # Secondary +pymdown-extensions = { version = "^10.15", optional = false } # Secondary +pyparsing = { version = "^3.2.3", optional = false } # Secondary +pytest-random-order = { version = "^1.1.1", optional = false } # Secondary +pyzmq = { version = "^26.4.0", optional = false } # Secondary +rsa = { version = "^4.9.1", optional = false } # Secondary +ruff = { version = "^0.11.6", optional = false } # Secondary +setuptools = { version = "^80.3.1", optional = false } # Secondary +shapely = { version = "^2.1.0", optional = false } # Secondary +snakeviz = { version = "^2.2.2", optional = false } # Secondary +soupsieve = { version = "^2.7", optional = false } # Secondary +sympy = { version = "^1.14.0", optional = false } # Secondary +typing_extensions = { version = "^4.13.2", optional = false } # Secondary +virtualenv = { version = "^20.31.1", optional = false } # Secondary +websockets = { version = "^15.0.1", optional = false } # Secondary +yarl = { version = "^1.20.0", optional = false } # Secondary [tool.poetry.extras] From 0ef2373aba111810824099181a534f5580b98b3f Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Tue, 6 May 2025 19:58:35 +0100 Subject: [PATCH 14/21] WIP: tidy up and linting --- src/prompto/apis/gemini/gemini.py | 12 +- src/prompto/apis/gemini/gemini_media.py | 115 ++---------------- src/prompto/apis/gemini/gemini_utils.py | 14 --- src/prompto/upload_media.py | 4 +- tests/apis/gemini/test_gemini_chat_input.py | 4 +- .../apis/gemini/test_gemini_history_input.py | 77 +----------- tests/apis/gemini/test_gemini_image_input.py | 1 - tests/apis/gemini/test_gemini_media.py | 10 -- tests/apis/gemini/test_gemini_utils.py | 10 -- 9 files changed, 15 insertions(+), 232 deletions(-) diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index 938380e7..b55d213b 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -2,10 +2,6 @@ from typing import Any from google.genai import Client - -# import google.generativeai as genai -# from google.generativeai import GenerativeModel -# from google.generativeai.types import GenerationConfig, HarmBlockThreshold, HarmCategory from google.genai.types import ( GenerateContentConfig, HarmBlockThreshold, @@ -181,7 +177,7 @@ def check_prompt_dict(prompt_dict: dict) -> list[Exception]: # if generation_config is provided, check that it can create a valid GenerationConfig object if "parameters" in prompt_dict: try: - GenerationConfig(**prompt_dict["parameters"]) + GenerateContentConfig(**prompt_dict["parameters"]) except Exception as err: issues.append(Exception(f"Invalid generation_config parameter: {err}")) @@ -406,7 +402,7 @@ async def _query_string(self, prompt_dict: dict, index: int | str): safety_attributes = BLOCKED_SAFETY_ATTRIBUTES else: safety_attributes = process_safety_attributes(response) - except: + except Exception: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES prompt_dict["response"] = response_text @@ -508,7 +504,7 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): safety_attributes = BLOCKED_SAFETY_ATTRIBUTES else: safety_attributes = process_safety_attributes(response) - except: + except Exception: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES prompt_dict["response"] = response_text @@ -625,7 +621,7 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES else: safety_attributes = process_safety_attributes(response) - except: + except Exception: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES prompt_dict["response"] = response_text diff --git a/src/prompto/apis/gemini/gemini_media.py b/src/prompto/apis/gemini/gemini_media.py index 814ae624..7fd08964 100644 --- a/src/prompto/apis/gemini/gemini_media.py +++ b/src/prompto/apis/gemini/gemini_media.py @@ -1,12 +1,8 @@ import asyncio import base64 import logging -import os -import tempfile -from time import sleep import tqdm -from dotenv import load_dotenv from google import genai from prompto.apis.gemini.gemini import GeminiAPI @@ -27,7 +23,6 @@ def remote_file_hash_base64(remote_file): Convert a remote file's SHA256 hash (stored as a hex-encoded UTF-8 bytes object) to a base64-encoded string. """ - # hex_str = remote_file.sha256_hash.decode("utf-8") hex_str = remote_file.sha256_hash raw_bytes = bytes.fromhex(hex_str) return base64.b64encode(raw_bytes).decode("utf-8") @@ -38,16 +33,10 @@ async def wait_for_processing(file_obj, client: genai.Client, poll_interval=1): Poll until the file is no longer in the 'PROCESSING' state. Returns the updated file object. """ - # print(f"File {file_obj.name} is in state {file_obj.state.name}") - while file_obj.state.name == "PROCESSING": await asyncio.sleep(poll_interval) # We need to re-fetch the file object to get the updated state. file_obj = client.files.get(name=file_obj.name) - # print(f"File {file_obj.name} is in state {file_obj.state.name}") - # print(f"{file_obj.error=}") - # print(f"{file_obj.update_time=}") - # print(f"{file_obj.create_time=}") return file_obj @@ -85,10 +74,8 @@ async def _upload_single_file( return already_uploaded_files[local_hash], local_file_path # Upload the file if it hasn't been found. - # Use asyncio.to_thread to run the blocking upload_file function in a separate thread. logger.info(f"Uploading {local_file_path} to Gemini API") - # file_obj = await asyncio.to_thread(genai.upload_file, local_file_path) file_obj = await client.aio.files.upload(file=local_file_path) file_obj = await wait_for_processing(file_obj, client=client) @@ -105,16 +92,6 @@ async def _upload_single_file( return file_obj.name, local_file_path -# def _init_genai(): -# load_dotenv(dotenv_path=".env") -# # TODO: check if this can be refactored to a common function -# GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") -# if GEMINI_API_KEY is None: -# raise ValueError("GEMINI_API_KEY is not set") - -# genai.configure(api_key=GEMINI_API_KEY) - - async def _get_previously_uploaded_files(client: genai.Client): raw_files = await client.aio.files.list() uploaded_files = { @@ -128,24 +105,9 @@ def list_uploaded_files(settings: Settings): """ List all previously uploaded files to the Gemini API. """ - # _init_genai() - - # Settings are not used in this function, but we need to - # create a dummy settings object to pass to the GeminiAPI - # TODO: - # Also, we don't need a directory, but Settings constructor - # insists on creating these directories locally. - # A better solution would be to create an option in the - # Settings constructor to not create the directories. - # But for now we'll just pass it a temporary directory. - # with tempfile.TemporaryDirectory() as temp_dir: - # data_folder = os.path.join(temp_dir, "data") - # os.makedirs(data_folder, exist_ok=True) - # dummy_settings = Settings(data_folder=data_folder) - genmini_api = GeminiAPI(settings=settings, log_file=None) # TODO: We need a model name, because our API caters for different API keys - # for different models. Maybe our API to complicated.... + # for different models. Maybe our API is too complicated.... default_model_name = "default" client = genmini_api._get_client(default_model_name) uploaded_files = asyncio.run(_get_previously_uploaded_files(client)) @@ -160,86 +122,23 @@ def delete_uploaded_files(settings: Settings): """ Delete all previously uploaded files from the Gemini API. """ - # _init_genai() - - # with tempfile.TemporaryDirectory() as temp_dir: - # data_folder = os.path.join(temp_dir, "data") - # os.makedirs(data_folder, exist_ok=True) - # dummy_settings = Settings(data_folder=data_folder) - genmini_api = GeminiAPI(settings=settings, log_file=None) # TODO: We need a model name, because our API caters for different API keys # for different models. Maybe our API to complicated.... default_model_name = "default" client = genmini_api._get_client(default_model_name) - # uploaded_files = asyncio.run(_get_previously_uploaded_files(client)) + # This just using the synchronous API. Using the async API did not + # seem reliable. In particular `client.aio.files.delete()` did not appear + # to always actually deleting the files (even after repeatedly polling the file) + # This is not an important function in prompto and delete action is reasonably + # quick, so we can live with this simple solution. + # `` for remote_file in client.files.list(): - # file_name = file_name.name client.files.delete(name=remote_file.name) - # _delete_single_uploaded_file(file_name, client) - # return asyncio.run(_delete_uploaded_files_async(uploaded_files, client)) - logger.info("All uploaded files deleted.") - -def _delete_single_uploaded_file(file_name: str, client: genai.Client): - """ - Delete a single uploaded file from the Gemini API. - """ - print(f"Deleting file {file_name}") - file = client.files.get(name=file_name) - client.files.delete(name=file_name) - # indx = 0 - - # The delete function is non-blocking (even the sync version) - # and returns immediately. So we need to poll the file object - # to see if it is still exists. - # The only reliable way to check if the file is deleted is to - # try and get it again and see if it raises an error. - while True: - # We need to re-fetch the file object to get the updated state. - try: - file = client.files.get(name=file.name) - print(f"File {file.name} is in state {file.state.name}") - print(f"{file.error=}") - print(f"{file.update_time=}") - print(f"{file.create_time=}") - print(f"{indx=}") - # client.files.delete(name=file_name) - indx += 1 - # if indx > 10: - # break - except genai.errors.ClientError as e: - # print(f"ClientError: {e}" - print(f"File {file.name} deleted") - break - # if file.state.name == "PROCESSING": - sleep(1) - - -async def _delete_uploaded_files_async(uploaded_files, client: genai.Client): - tasks_set = set() - for file_name in uploaded_files.values(): - logger.info(f"Preparing to delete file: {file_name}") - # tasks.append(asyncio.to_thread(genai.delete_file, file_name)) - - # file = await client.aio.files.get(name=file_name) - # # task = client.aio.files.delete(name=file_name) - # task = asyncio.to_thread(client.aio.files.delete(name=file_name)) - tasks_set.add(_delete_single_uploaded_file(file_name, client)) - - # await tqdm.asyncio.tqdm.gather(*tasks) - await asyncio.gather(*tasks_set, return_exceptions=True) logger.info("All uploaded files deleted.") - # async with asyncio.TaskGroup() as tg: - # for file_name in uploaded_files.values(): - # logger.info(f"Preparing to delete file: {file_name}") - # # tasks.append(asyncio.to_thread(genai.delete_file, file_name)) - # tg.create_task(client.aio.files.delete(name=file_name)) - - # logger.info("All uploaded files deleted.") - def upload_media_files(files_to_upload: set[str], settings: Settings): """ diff --git a/src/prompto/apis/gemini/gemini_utils.py b/src/prompto/apis/gemini/gemini_utils.py index 089b46d3..09389ea3 100644 --- a/src/prompto/apis/gemini/gemini_utils.py +++ b/src/prompto/apis/gemini/gemini_utils.py @@ -55,7 +55,6 @@ def parse_parts_value(part: dict | str, media_folder: str, client: Client) -> an ) else: try: - # return genai.get_file(name=uploaded_filename) return client.aio.files.get(name=uploaded_filename) except Exception as err: raise ValueError( @@ -120,11 +119,6 @@ def convert_history_dict_to_content( if "parts" not in content_dict: raise KeyError("parts key is missing in content dictionary") - # return parse_parts( - # content_dict["parts"], - # media_folder=media_folder, - # ) - return types.Content( role=content_dict["role"], parts=parse_parts( @@ -134,14 +128,6 @@ def convert_history_dict_to_content( ), ) - # return { - # "role": content_dict["role"], - # "parts": parse_parts( - # content_dict["parts"], - # media_folder=media_folder, - # ) - # } - def process_response(response: dict) -> str: """ diff --git a/src/prompto/upload_media.py b/src/prompto/upload_media.py index 03b4c68b..e4e137cf 100644 --- a/src/prompto/upload_media.py +++ b/src/prompto/upload_media.py @@ -1,11 +1,9 @@ import argparse -import asyncio import json import logging import os import prompto.apis.gemini.gemini_media as gemini_media -from prompto.apis import ASYNC_APIS from prompto.scripts.run_experiment import load_env_file from prompto.settings import Settings @@ -261,7 +259,7 @@ def _create_settings(): This is used to create a client object for the API. For now, we just create a temporary directory for the data folder. """ - # A better solution would be to create an option in the + # TODO: A better solution would be to create an option in the # Settings constructor to not create the directories. # But for now we'll just pass it a temporary directory. import tempfile diff --git a/tests/apis/gemini/test_gemini_chat_input.py b/tests/apis/gemini/test_gemini_chat_input.py index 7c9a363c..6000c78b 100644 --- a/tests/apis/gemini/test_gemini_chat_input.py +++ b/tests/apis/gemini/test_gemini_chat_input.py @@ -59,9 +59,9 @@ async def test_gemini_query_chat( log_file = "log.txt" gemini_api = GeminiAPI(settings=settings, log_file=log_file) - # mock the response from the API + # Mock the response from the API # NOTE: The actual response from the API is a - # google.genai.types.GenerateContentResponse object + # `google.genai.types.GenerateContentResponse` object # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function gemini_api_sequence_responses = [ diff --git a/tests/apis/gemini/test_gemini_history_input.py b/tests/apis/gemini/test_gemini_history_input.py index 748af1d2..d5db7bb7 100644 --- a/tests/apis/gemini/test_gemini_history_input.py +++ b/tests/apis/gemini/test_gemini_history_input.py @@ -38,7 +38,6 @@ async def test_gemini_query_history_no_env_var( @pytest.mark.asyncio -# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( AsyncChat, "send_message", @@ -63,7 +62,7 @@ async def test_gemini_query_history( # Mock the response from the API # NOTE: The actual response from the API is a - # google.genai.types.GenerateContentResponse object + # `google.genai.types.GenerateContentResponse`` object # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function mock_gemini_call.return_value = "response Messages object" @@ -82,12 +81,6 @@ async def test_gemini_query_history( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() - # mock_gemini_call.assert_awaited_once_with( - # content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, - # generation_config=prompt_dict_history["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_awaited_once_with( message=Part(text=prompt_dict_history["prompt"][1]["parts"]), @@ -109,7 +102,6 @@ async def test_gemini_query_history( @pytest.mark.asyncio -# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( AsyncChat, "send_message", @@ -133,12 +125,6 @@ async def test_gemini_query_history_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() - # mock_gemini_call.assert_awaited_once_with( - # content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, - # generation_config=prompt_dict_history["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_awaited_once_with( message=Part(text=prompt_dict_history["prompt"][1]["parts"]), @@ -154,7 +140,6 @@ async def test_gemini_query_history_error( @pytest.mark.asyncio -# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( AsyncChat, "send_message", @@ -183,12 +168,6 @@ async def test_gemini_query_history_index_error( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() - # mock_gemini_call.assert_awaited_once_with( - # content={"role": "user", "parts": [prompt_dict_history["prompt"][1]["parts"]]}, - # generation_config=prompt_dict_history["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_awaited_once_with( message=Part(text=prompt_dict_history["prompt"][1]["parts"]), @@ -207,7 +186,6 @@ async def test_gemini_query_history_index_error( @pytest.mark.asyncio -# @patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) @patch.object( AsyncChats, "create", @@ -230,17 +208,6 @@ async def test_gemini_query_history_check_chat_init( monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) - # mock_obtain_model_inputs.return_value = ( - # prompt_dict_history["prompt"], - # prompt_dict_history["model_name"], - # GenerativeModel( - # model_name=prompt_dict_history["model_name"], - # system_instruction=prompt_dict_history["prompt"][0]["parts"], - # ), - # DEFAULT_SAFETY_SETTINGS, - # prompt_dict_history["parameters"], - # ) - mock_generate_content_config = ( GenerateContentConfig( temperature=1.0, @@ -272,7 +239,6 @@ async def test_gemini_query_history_check_chat_init( @pytest.mark.asyncio -# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( AsyncChat, "send_message", @@ -317,15 +283,6 @@ async def test_gemini_query_history_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() - # mock_gemini_call.assert_awaited_once_with( - # content={ - # "role": "user", - # "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], - # }, - # generation_config=prompt_dict_history_no_system["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_awaited_once_with( message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]) @@ -347,7 +304,6 @@ async def test_gemini_query_history_no_system( @pytest.mark.asyncio -# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( AsyncChat, "send_message", @@ -375,15 +331,6 @@ async def test_gemini_query_history_error_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() - # mock_gemini_call.assert_awaited_once_with( - # content={ - # "role": "user", - # "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], - # }, - # generation_config=prompt_dict_history_no_system["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_awaited_once_with( message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]), @@ -399,7 +346,6 @@ async def test_gemini_query_history_error_no_system( @pytest.mark.asyncio -# @patch("google.generativeai.ChatSession.send_message_async", new_callable=AsyncMock) @patch.object( AsyncChat, "send_message", @@ -434,15 +380,6 @@ async def test_gemini_query_history_index_error_no_system( mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() - # mock_gemini_call.assert_awaited_once_with( - # content={ - # "role": "user", - # "parts": [prompt_dict_history_no_system["prompt"][2]["parts"]], - # }, - # generation_config=prompt_dict_history_no_system["parameters"], - # safety_settings=DEFAULT_SAFETY_SETTINGS, - # stream=False, - # ) mock_gemini_call.assert_awaited_once_with( message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]), @@ -461,7 +398,6 @@ async def test_gemini_query_history_index_error_no_system( @pytest.mark.asyncio -# @patch("google.generativeai.GenerativeModel.start_chat", new_callable=Mock) @patch.object( AsyncChats, "create", @@ -484,17 +420,6 @@ async def test_gemini_query_history_no_system_check_chat_init( monkeypatch.setenv("GEMINI_API_KEY_gemini_model_name", "DUMMY") gemini_api = GeminiAPI(settings=settings, log_file=log_file) - # mock_obtain_model_inputs.return_value = ( - # prompt_dict_history_no_system["prompt"], - # prompt_dict_history_no_system["model_name"], - # GenerativeModel( - # model_name=prompt_dict_history_no_system["model_name"], - # system_instruction=None, - # ), - # DEFAULT_SAFETY_SETTINGS, - # prompt_dict_history_no_system["parameters"], - # ) - mock_generate_content_config = ( GenerateContentConfig( temperature=1.0, diff --git a/tests/apis/gemini/test_gemini_image_input.py b/tests/apis/gemini/test_gemini_image_input.py index e7c0adec..33d01d35 100644 --- a/tests/apis/gemini/test_gemini_image_input.py +++ b/tests/apis/gemini/test_gemini_image_input.py @@ -72,7 +72,6 @@ def test_parse_parts_value_video_not_uploaded(): assert "not uploaded" in str(excinfo.value) -# @patch("google.genai.client.files.get") def test_parse_parts_value_video_uploaded(monkeypatch): part = { "type": "video", diff --git a/tests/apis/gemini/test_gemini_media.py b/tests/apis/gemini/test_gemini_media.py index c818bbda..ada2d2c8 100644 --- a/tests/apis/gemini/test_gemini_media.py +++ b/tests/apis/gemini/test_gemini_media.py @@ -129,11 +129,6 @@ def test_upload_media_files_already_uploaded( assert expected_log_msg in caplog.text -# def test_upload_media_files(): -# pytest.fail("Test not implemented") - - -# @pytest.mark.asyncio @patch( "prompto.apis.gemini.gemini_media._get_previously_uploaded_files", new_callable=AsyncMock, @@ -206,11 +201,6 @@ def test_upload_media_files_new_file( assert all(msg in caplog.text for msg in expected_log_msgs) -# def test__init_genai(): -# # Is this still required, or is it superseded by the Client object -# pytest.fail("Test not implemented") - - @patch.object( genai.files.AsyncFiles, "list", diff --git a/tests/apis/gemini/test_gemini_utils.py b/tests/apis/gemini/test_gemini_utils.py index 9005ed2b..6e98fe5e 100644 --- a/tests/apis/gemini/test_gemini_utils.py +++ b/tests/apis/gemini/test_gemini_utils.py @@ -6,16 +6,6 @@ from .test_gemini import prompt_dict_history -@pytest.mark.xfail(reason="Test not implemented") -def test_process_response(): - pytest.fail("Test not implemented") - - -@pytest.mark.xfail(reason="Test not implemented") -def test_process_safety_attributes(): - pytest.fail("Test not implemented") - - def test_convert_history_dict_to_content(prompt_dict_history): media_folder = "media_folder" From 6c975bce87f9655f3dc661f3fb7ee5ee24ac2843 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Wed, 7 May 2025 10:23:47 +0100 Subject: [PATCH 15/21] Tweak to comment --- tests/apis/gemini/test_gemini_string_input.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/apis/gemini/test_gemini_string_input.py b/tests/apis/gemini/test_gemini_string_input.py index da5a8500..08d1af75 100644 --- a/tests/apis/gemini/test_gemini_string_input.py +++ b/tests/apis/gemini/test_gemini_string_input.py @@ -60,7 +60,7 @@ async def test_gemini_query_string( # Mock the response from the API # NOTE: The actual response from the API is a (probably) - # google.genai.types.GenerateContentResponse object + # google.genai.types.GenerateContentResponse object (or a promise of it), # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function mock_gemini_call.return_value = "response Messages object" From 5c5f4c4c1880bb6e560377fc15744bbe5ca7f83c Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Thu, 15 May 2025 15:14:43 +0100 Subject: [PATCH 16/21] failing test to demonstrate the thinking params --- src/prompto/apis/gemini/gemini.py | 17 ++++++++ tests/apis/gemini/test_gemini.py | 69 +++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index b55d213b..e5a1a702 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -7,6 +7,7 @@ HarmBlockThreshold, HarmCategory, SafetySetting, + ThinkingConfig, ) from prompto.apis.base import AsyncAPI @@ -338,10 +339,26 @@ async def _obtain_model_inputs( f"parameters must be a dictionary, not {type(generation_config_params)}" ) + # Derive the required ThinkingConfig from the parameters + # TBC - how do we get these values from the prompt_dict? + + # Placeholder values for now + include_thoughts = False + thinking_budget = 9999 + + if include_thoughts is None and thinking_budget is None: + thinking_config = None + else: + thinking_config = ThinkingConfig( + include_thoughts=include_thoughts, + thinking_budget=thinking_budget, + ) + gen_content_config = GenerateContentConfig( **generation_config_params, safety_settings=safety_settings, system_instruction=system_instruction, + thinking_config=thinking_config, ) return prompt, model_name, client, gen_content_config, None diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index 8117ee1a..b6263b18 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -8,6 +8,7 @@ HarmBlockThreshold, HarmCategory, SafetySetting, + ThinkingConfig, ) from prompto.apis.gemini import GeminiAPI @@ -475,6 +476,74 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): ) +@pytest.mark.xfail( + reason="This cannot work until we agree on the schema for the thinking config parameters" +) +@pytest.mark.asyncio +async def test_gemini_obtain_model_inputs_thinking_config( + temporary_data_folders, monkeypatch +): + settings = Settings(data_folder="data") + log_file = "log.txt" + monkeypatch.setenv("GEMINI_API_KEY", "DUMMY") + gemini_api = GeminiAPI(settings=settings, log_file=log_file) + + # These test cases are very similar to the test above, so we will not repeat assertions for all + # attributes - only those that are relevant to the thinking config + + # Case 1: test with *NO* thinking config parameters provided + test_case = await gemini_api._obtain_model_inputs( + { + "id": "gemini_id", + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": "test prompt", + "parameters": {"temperature": 1, "max_output_tokens": 100}, + } + ) + assert isinstance(test_case, tuple) + assert isinstance(test_case[3], GenerateContentConfig) + assert test_case[3].thinking_config is None + + # Case 2: test with thinking config parameters provided + # There are two possible ways we could provide the thinking config parameters. + # We need to select from one of these to options: + dummy_prompt_dicts = [ + # Either within the parameters dictionary + { + "id": "gemini_id", + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": "test prompt", + "parameters": { + "temperature": 1, + "max_output_tokens": 100, + "thinking_budget": 1234, + "include_thoughts": True, + }, + }, + # OR as top-level keys within the prompt dictionary + { + "id": "gemini_id", + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": "test prompt", + "thinking_budget": 1234, + "include_thoughts": True, + "parameters": {"temperature": 1, "max_output_tokens": 100}, + }, + ] + + for dummy_prompt_dict in dummy_prompt_dicts: + test_case = await gemini_api._obtain_model_inputs(dummy_prompt_dict) + + assert isinstance(test_case, tuple) + assert isinstance(test_case[3], GenerateContentConfig) + assert isinstance(test_case[3].thinking_config, ThinkingConfig) + assert test_case[3].thinking_config.thinking_budget == 1234 + assert test_case[3].thinking_config.include_thoughts is True + + @pytest.mark.asyncio async def test_gemini_obtain_model_inputs_safety_filters( temporary_data_folders, monkeypatch From 2e001dfa8e6fa51cc9ba77ba4004da60b8ba702d Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Thu, 15 May 2025 15:54:16 +0100 Subject: [PATCH 17/21] read thinking params from general params dict --- src/prompto/apis/gemini/gemini.py | 22 ++++++++++--- tests/apis/gemini/test_gemini.py | 52 +++++++++++++++---------------- 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index e5a1a702..99c90985 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -340,15 +340,27 @@ async def _obtain_model_inputs( ) # Derive the required ThinkingConfig from the parameters - # TBC - how do we get these values from the prompt_dict? - - # Placeholder values for now - include_thoughts = False - thinking_budget = 9999 + # `pop` removes the key from the dictionary + include_thoughts = generation_config_params.pop("include_thoughts", None) + thinking_budget = generation_config_params.pop("thinking_budget", None) if include_thoughts is None and thinking_budget is None: thinking_config = None else: + if not isinstance(include_thoughts, bool | None): + err_msg = "If include_thoughts is set, it must be a boolean" + raise ValueError(err_msg) + + try: + assert isinstance(thinking_budget, int | None) + if thinking_budget is not None: + assert isinstance(thinking_budget, int) + assert 0 <= thinking_budget <= 24576 + + except AssertionError as ae: + err_msg = "if thinking_budget is set, it must be an integer between 0 and 24576" + raise ValueError(err_msg) from ae + thinking_config = ThinkingConfig( include_thoughts=include_thoughts, thinking_budget=thinking_budget, diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index b6263b18..75725817 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -476,9 +476,6 @@ async def test_gemini_obtain_model_inputs(temporary_data_folders, monkeypatch): ) -@pytest.mark.xfail( - reason="This cannot work until we agree on the schema for the thinking config parameters" -) @pytest.mark.asyncio async def test_gemini_obtain_model_inputs_thinking_config( temporary_data_folders, monkeypatch @@ -506,10 +503,7 @@ async def test_gemini_obtain_model_inputs_thinking_config( assert test_case[3].thinking_config is None # Case 2: test with thinking config parameters provided - # There are two possible ways we could provide the thinking config parameters. - # We need to select from one of these to options: - dummy_prompt_dicts = [ - # Either within the parameters dictionary + test_case = await gemini_api._obtain_model_inputs( { "id": "gemini_id", "api": "gemini", @@ -521,27 +515,33 @@ async def test_gemini_obtain_model_inputs_thinking_config( "thinking_budget": 1234, "include_thoughts": True, }, - }, - # OR as top-level keys within the prompt dictionary - { - "id": "gemini_id", - "api": "gemini", - "model_name": "gemini_model_name", - "prompt": "test prompt", - "thinking_budget": 1234, - "include_thoughts": True, - "parameters": {"temperature": 1, "max_output_tokens": 100}, - }, - ] + } + ) - for dummy_prompt_dict in dummy_prompt_dicts: - test_case = await gemini_api._obtain_model_inputs(dummy_prompt_dict) + assert isinstance(test_case, tuple) + assert isinstance(test_case[3], GenerateContentConfig) + assert isinstance(test_case[3].thinking_config, ThinkingConfig) + assert test_case[3].thinking_config.thinking_budget == 1234 + assert test_case[3].thinking_config.include_thoughts is True - assert isinstance(test_case, tuple) - assert isinstance(test_case[3], GenerateContentConfig) - assert isinstance(test_case[3].thinking_config, ThinkingConfig) - assert test_case[3].thinking_config.thinking_budget == 1234 - assert test_case[3].thinking_config.include_thoughts is True + # Case 3: test with invalid thinking config parameters provided + + with pytest.raises(ValueError): + # thinking_budget is out of range + test_case = await gemini_api._obtain_model_inputs( + { + "id": "gemini_id", + "api": "gemini", + "model_name": "gemini_model_name", + "prompt": "test prompt", + "parameters": { + "temperature": 1, + "max_output_tokens": 100, + "thinking_budget": 12345678901234567890, # out of range + "include_thoughts": True, + }, + } + ) @pytest.mark.asyncio From cac04f420af26e335528f48beac0920546c0246e Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Fri, 16 May 2025 13:18:10 +0100 Subject: [PATCH 18/21] Adds failing test case --- tests/apis/gemini/test_gemini_utils.py | 114 ++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 2 deletions(-) diff --git a/tests/apis/gemini/test_gemini_utils.py b/tests/apis/gemini/test_gemini_utils.py index 6e98fe5e..9706a016 100644 --- a/tests/apis/gemini/test_gemini_utils.py +++ b/tests/apis/gemini/test_gemini_utils.py @@ -1,7 +1,10 @@ import pytest -from google.genai.types import Content, Part +from google.genai.types import Candidate, Content, FinishReason, Part -from prompto.apis.gemini.gemini_utils import convert_history_dict_to_content +from prompto.apis.gemini.gemini_utils import ( + convert_history_dict_to_content, + process_response, +) from .test_gemini import prompt_dict_history @@ -27,3 +30,110 @@ def test_convert_history_dict_to_content(prompt_dict_history): assert ( actual_result == expected_result ), f"Expected {expected_result}, but got {actual_result}" + + +def test_process_response(): + """ + Test the process_response function. + """ + # Some example responses, with and without thinking included. + non_thinking_candidates = [ + Candidate( + content=Content( + parts=[ + Part( + video_metadata=None, + thought=None, + inline_data=None, + code_execution_result=None, + executable_code=None, + file_data=None, + function_call=None, + function_response=None, + text="A spontaneous answer...", + ), + ], + role="model", + ), + citation_metadata=None, + finish_message=None, + token_count=None, + finish_reason=FinishReason.STOP, + avg_logprobs=None, + grounding_metadata=None, + index=0, + logprobs_result=None, + safety_ratings=None, + ) + ] + + thinking_candidates = [ + Candidate( + content=Content( + parts=[ + Part( + video_metadata=None, + thought=True, + inline_data=None, + code_execution_result=None, + executable_code=None, + file_data=None, + function_call=None, + function_response=None, + text="Some thinking...", + ), + Part( + video_metadata=None, + thought=None, + inline_data=None, + code_execution_result=None, + executable_code=None, + file_data=None, + function_call=None, + function_response=None, + text="A thought out answer...", + ), + ], + role="model", + ), + citation_metadata=None, + finish_message=None, + token_count=None, + finish_reason=FinishReason.STOP, + avg_logprobs=None, + grounding_metadata=None, + index=0, + logprobs_result=None, + safety_ratings=None, + ) + ] + + # Each test case is a tuple of (candidates, expected_answer, expected_thoughts) + test_cases = [ + (non_thinking_candidates, "A spontaneous answer...", [None]), + (thinking_candidates, "A thought out answer...", ["Some thinking..."]), + ] + + for candidates, expected_answer, expected_thoughts in test_cases: + + # Create a mock response - the candidates key is the only one used in the + # process_response function. + response = { + "candidates": candidates, + } + + # Call the function with the test case + actual_answer = process_response(response) + actual_thoughts = process_thoughts(response) + + # Assert the expected response + assert ( + actual_answer == expected_answer + ), f"Expected {expected_answer}, but got {actual_answer}" + + assert isinstance( + actual_thoughts, list + ), f"Expected a list, but got {type(actual_thoughts)}" + assert ( + actual_thoughts == expected_thoughts + ), f"Expected {expected_thoughts}, but got {actual_thoughts}" From 1c8764616f51c5db7e8904dfa1ec757ec993f7cf Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Fri, 16 May 2025 13:41:48 +0100 Subject: [PATCH 19/21] adds parsing function --- src/prompto/apis/gemini/gemini_utils.py | 43 +++++++++++++++++++++++-- tests/apis/gemini/test_gemini_utils.py | 17 +++++++--- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/src/prompto/apis/gemini/gemini_utils.py b/src/prompto/apis/gemini/gemini_utils.py index 09389ea3..c3cde417 100644 --- a/src/prompto/apis/gemini/gemini_utils.py +++ b/src/prompto/apis/gemini/gemini_utils.py @@ -143,8 +143,47 @@ def process_response(response: dict) -> str: str The processed response text as a string """ - response_text = response.candidates[0].content.parts[0].text - return response_text + answer, _ = _process_answer_and_thoughts(response) + return answer + + +def process_thoughts(response: dict) -> list[str]: + """ + Helper function to process the thoughts from Gemini API. + + Parameters + ---------- + response : dict + The response from the Gemini API as a dictionary + + Returns + ------- + list[str] + The thoughts as a list of strings + """ + _, thoughts = _process_answer_and_thoughts(response) + return thoughts + + +def _process_answer_and_thoughts(response: dict) -> tuple[str, list[str]]: + # This is a helper function that implements process_response and process_thoughts. + # + # These two functions are only trivially different. This implementation is potentially + # inefficient as it encourages the caller to call both functions. + # However the pattern of `process_response` is widely used throughout the codebase. So + # it is better to keep the same pattern. + thoughts = [] + answers = [] + for candidate in response.candidates: + for part in candidate.content.parts: + if part.thought: + thoughts.append(part.text) + else: + answers.append(part.text) + + assert len(answers) == 1, "There should be only one answer" + + return answers[0], thoughts def process_safety_attributes(response: dict) -> dict: diff --git a/tests/apis/gemini/test_gemini_utils.py b/tests/apis/gemini/test_gemini_utils.py index 9706a016..4ac7bd38 100644 --- a/tests/apis/gemini/test_gemini_utils.py +++ b/tests/apis/gemini/test_gemini_utils.py @@ -1,9 +1,16 @@ import pytest -from google.genai.types import Candidate, Content, FinishReason, Part +from google.genai.types import ( + Candidate, + Content, + FinishReason, + GenerateContentResponse, + Part, +) from prompto.apis.gemini.gemini_utils import ( convert_history_dict_to_content, process_response, + process_thoughts, ) from .test_gemini import prompt_dict_history @@ -110,7 +117,7 @@ def test_process_response(): # Each test case is a tuple of (candidates, expected_answer, expected_thoughts) test_cases = [ - (non_thinking_candidates, "A spontaneous answer...", [None]), + (non_thinking_candidates, "A spontaneous answer...", []), (thinking_candidates, "A thought out answer...", ["Some thinking..."]), ] @@ -118,9 +125,9 @@ def test_process_response(): # Create a mock response - the candidates key is the only one used in the # process_response function. - response = { - "candidates": candidates, - } + response = GenerateContentResponse( + candidates=candidates, + ) # Call the function with the test case actual_answer = process_response(response) From fd234411d79dffed84ab6a12047f74307b859a3c Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Fri, 16 May 2025 15:45:18 +0100 Subject: [PATCH 20/21] Add "thinking_text" key to output --- src/prompto/apis/gemini/gemini.py | 26 +++++- tests/apis/gemini/test_gemini.py | 88 ++++++++++++++++++ tests/apis/gemini/test_gemini_chat_input.py | 70 +++++++------- .../apis/gemini/test_gemini_history_input.py | 37 ++++---- tests/apis/gemini/test_gemini_string_input.py | 29 +++--- tests/apis/gemini/test_gemini_utils.py | 91 ++----------------- 6 files changed, 185 insertions(+), 156 deletions(-) diff --git a/src/prompto/apis/gemini/gemini.py b/src/prompto/apis/gemini/gemini.py index 99c90985..ba3918f9 100644 --- a/src/prompto/apis/gemini/gemini.py +++ b/src/prompto/apis/gemini/gemini.py @@ -17,6 +17,7 @@ parse_parts, process_response, process_safety_attributes, + process_thoughts, ) from prompto.settings import Settings from prompto.utils import ( @@ -47,6 +48,11 @@ "finish_reason": "block_reason: OTHER", } +# See https://ai.google.dev/gemini-api/docs/thinking#set-budget +# for more details +MIN_THINKING_BUDGET = 0 +MAX_THINKING_BUDGET = 24576 + class GeminiAPI(AsyncAPI): """ @@ -355,10 +361,12 @@ async def _obtain_model_inputs( assert isinstance(thinking_budget, int | None) if thinking_budget is not None: assert isinstance(thinking_budget, int) - assert 0 <= thinking_budget <= 24576 + # The ThinkingConfig constructor does not seem to check that the + # thinking_budget is in the valid range (0, 24576), so we do it here + assert MIN_THINKING_BUDGET <= thinking_budget <= MAX_THINKING_BUDGET except AssertionError as ae: - err_msg = "if thinking_budget is set, it must be an integer between 0 and 24576" + err_msg = f"if thinking_budget is set, it must be an integer between {MIN_THINKING_BUDGET} and {MAX_THINKING_BUDGET}" raise ValueError(err_msg) from ae thinking_config = ThinkingConfig( @@ -394,6 +402,7 @@ async def _query_string(self, prompt_dict: dict, index: int | str): config=generation_config, ) response_text = process_response(response) + thinking_text = process_thoughts(response) safety_attributes = process_safety_attributes(response) log_success_response_query( @@ -406,6 +415,7 @@ async def _query_string(self, prompt_dict: dict, index: int | str): prompt_dict["response"] = response_text prompt_dict["safety_attributes"] = safety_attributes + prompt_dict["thinking_text"] = thinking_text return prompt_dict except IndexError as err: error_as_string = ( @@ -426,6 +436,7 @@ async def _query_string(self, prompt_dict: dict, index: int | str): log_file=self.log_file, log_message=log_message, log=True ) response_text = "" + thinking_text = [] try: if len(response.candidates) == 0: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES @@ -436,6 +447,7 @@ async def _query_string(self, prompt_dict: dict, index: int | str): prompt_dict["response"] = response_text prompt_dict["safety_attributes"] = safety_attributes + prompt_dict["thinking_text"] = thinking_text return prompt_dict except Exception as err: @@ -475,6 +487,7 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): ) response_list = [] safety_attributes_list = [] + thinking_list = [] try: for message_index, message in enumerate(prompt): # send the messages sequentially @@ -484,10 +497,12 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): config=generation_config, ) response_text = process_response(response) + thinking_text = process_thoughts(response) safety_attributes = process_safety_attributes(response) response_list.append(response_text) safety_attributes_list.append(safety_attributes) + thinking_list.append(thinking_text) log_success_response_chat( index=index, @@ -504,6 +519,7 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): ) prompt_dict["response"] = response_list + prompt_dict["thinking_text"] = thinking_list prompt_dict["safety_attributes"] = safety_attributes_list return prompt_dict except IndexError as err: @@ -528,6 +544,7 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): log_file=self.log_file, log_message=log_message, log=True ) response_text = response_list + [""] + thinking_text = thinking_list + [[]] try: if len(response.candidates) == 0: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES @@ -537,6 +554,7 @@ async def _query_chat(self, prompt_dict: dict, index: int | str): safety_attributes = BLOCKED_SAFETY_ATTRIBUTES prompt_dict["response"] = response_text + prompt_dict["thinking_text"] = thinking_text prompt_dict["safety_attributes"] = safety_attributes return prompt_dict except Exception as err: @@ -613,6 +631,7 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: response = await chat.send_message(message=msg_to_send) response_text = process_response(response) + thinking_text = process_thoughts(response) safety_attributes = process_safety_attributes(response) log_success_response_query( @@ -624,6 +643,7 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: ) prompt_dict["response"] = response_text + prompt_dict["thinking_text"] = thinking_text prompt_dict["safety_attributes"] = safety_attributes return prompt_dict except IndexError as err: @@ -645,6 +665,7 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: log_file=self.log_file, log_message=log_message, log=True ) response_text = "" + thinking_text = [] try: if len(response.candidates) == 0: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES @@ -654,6 +675,7 @@ async def _query_history(self, prompt_dict: dict, index: int | str) -> dict: safety_attributes = BLOCKED_SAFETY_ATTRIBUTES prompt_dict["response"] = response_text + prompt_dict["thinking_text"] = thinking_text prompt_dict["safety_attributes"] = safety_attributes return prompt_dict except Exception as err: diff --git a/tests/apis/gemini/test_gemini.py b/tests/apis/gemini/test_gemini.py index 75725817..10c33d65 100644 --- a/tests/apis/gemini/test_gemini.py +++ b/tests/apis/gemini/test_gemini.py @@ -4,9 +4,14 @@ import regex as re from google.genai.client import AsyncClient, Client from google.genai.types import ( + Candidate, + Content, + FinishReason, GenerateContentConfig, + GenerateContentResponse, HarmBlockThreshold, HarmCategory, + Part, SafetySetting, ThinkingConfig, ) @@ -68,6 +73,89 @@ def prompt_dict_history_no_system(): } +@pytest.fixture +def non_thinking_response(): + + # Some example responses, with and without thinking included. + return GenerateContentResponse( + candidates=[ + Candidate( + content=Content( + parts=[ + Part( + video_metadata=None, + thought=None, + inline_data=None, + code_execution_result=None, + executable_code=None, + file_data=None, + function_call=None, + function_response=None, + text="A spontaneous answer", + ), + ], + role="model", + ), + citation_metadata=None, + finish_message=None, + token_count=None, + finish_reason=FinishReason.STOP, + avg_logprobs=None, + grounding_metadata=None, + index=0, + logprobs_result=None, + safety_ratings=None, + ) + ] + ) + + +@pytest.fixture +def thinking_response(): + return GenerateContentResponse( + candidates=[ + Candidate( + content=Content( + parts=[ + Part( + video_metadata=None, + thought=True, + inline_data=None, + code_execution_result=None, + executable_code=None, + file_data=None, + function_call=None, + function_response=None, + text="Some thinking", + ), + Part( + video_metadata=None, + thought=None, + inline_data=None, + code_execution_result=None, + executable_code=None, + file_data=None, + function_call=None, + function_response=None, + text="A thought out answer", + ), + ], + role="model", + ), + citation_metadata=None, + finish_message=None, + token_count=None, + finish_reason=FinishReason.STOP, + avg_logprobs=None, + grounding_metadata=None, + index=0, + logprobs_result=None, + safety_ratings=None, + ) + ] + ) + + DEFAULT_SAFETY_SETTINGS = [ SafetySetting( category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, diff --git a/tests/apis/gemini/test_gemini_chat_input.py b/tests/apis/gemini/test_gemini_chat_input.py index 6000c78b..b10a620c 100644 --- a/tests/apis/gemini/test_gemini_chat_input.py +++ b/tests/apis/gemini/test_gemini_chat_input.py @@ -10,7 +10,12 @@ from prompto.settings import Settings from ...conftest import CopyingAsyncMock -from .test_gemini import DEFAULT_SAFETY_SETTINGS, prompt_dict_chat +from .test_gemini import ( + DEFAULT_SAFETY_SETTINGS, + non_thinking_response, + prompt_dict_chat, + thinking_response, +) pytest_plugins = ("pytest_asyncio",) @@ -42,14 +47,14 @@ async def test_gemini_query_chat_no_env_var( "send_message", new_callable=CopyingAsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_chat( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_chat, temporary_data_folders, + non_thinking_response, + thinking_response, monkeypatch, caplog, ): @@ -60,19 +65,17 @@ async def test_gemini_query_chat( gemini_api = GeminiAPI(settings=settings, log_file=log_file) # Mock the response from the API - # NOTE: The actual response from the API is a - # `google.genai.types.GenerateContentResponse` object - # not a string value, but for the purpose of this test, we are using a string value - # and testing that this is the input to the process_response function gemini_api_sequence_responses = [ - "response Messages object 1", - "response Messages object 2", + thinking_response, + non_thinking_response, ] mock_gemini_call.side_effect = gemini_api_sequence_responses # mock the process_response function - process_response_sequence_responses = ["response text 1", "response text 2"] - mock_process_response.side_effect = process_response_sequence_responses + process_response_sequence_responses = [ + "A thought out answer", + "A spontaneous answer", + ] # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_chat.keys() @@ -80,8 +83,13 @@ async def test_gemini_query_chat( # call the _query_chat method prompt_dict = await gemini_api._query_chat(prompt_dict_chat, index=0) + expected_thinking_text = [["Some thinking"], []] + # assert that the response key is added to the prompt_dict assert "response" in prompt_dict.keys() + assert "thinking_text" in prompt_dict.keys() + assert isinstance(prompt_dict["thinking_text"], list) + assert prompt_dict["thinking_text"] == expected_thinking_text assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 @@ -102,10 +110,6 @@ async def test_gemini_query_chat( ), ) - assert mock_process_response.call_count == 2 - mock_process_response.assert_any_call(gemini_api_sequence_responses[0]) - mock_process_response.assert_called_with(gemini_api_sequence_responses[1]) - assert mock_process_safety_attr.call_count == 2 mock_process_safety_attr.assert_any_call(gemini_api_sequence_responses[0]) mock_process_safety_attr.assert_called_with(gemini_api_sequence_responses[1]) @@ -289,14 +293,13 @@ async def test_gemini_query_chat_error_1( "send_message", new_callable=CopyingAsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_chat_index_error_2( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_chat, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): @@ -309,23 +312,25 @@ async def test_gemini_query_chat_index_error_2( # mock error response from the API from second response gemini_api_sequence_responses = [ - "response Messages object 1", + non_thinking_response, IndexError("Test error"), ] mock_gemini_call.side_effect = gemini_api_sequence_responses - # mock the process_response function - mock_process_response.return_value = "response text 1" - # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_chat.keys() # call the _query_chat method prompt_dict = await gemini_api._query_chat(prompt_dict_chat, index=0) + expected_answer = "A spontaneous answer" + # assert that the response key is added to the prompt_dict assert "response" in prompt_dict.keys() + assert "thinking_text" in prompt_dict.keys() + assert isinstance(prompt_dict["thinking_text"], list) + assert prompt_dict["thinking_text"] == [[], []] assert mock_gemini_call.call_count == 2 assert mock_gemini_call.await_count == 2 @@ -348,14 +353,13 @@ async def test_gemini_query_chat_index_error_2( ), ) - mock_process_response.assert_called_once_with(gemini_api_sequence_responses[0]) - mock_process_safety_attr.assert_called_once_with(gemini_api_sequence_responses[0]) + mock_process_safety_attr.assert_any_call(gemini_api_sequence_responses[0]) expected_log_message_1 = ( f"Response received for model Gemini ({prompt_dict_chat['model_name']}) " "(i=0, id=gemini_id, message=1/2)\n" f"Prompt: {prompt_dict_chat['prompt'][0][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message_1 in caplog.text @@ -363,13 +367,13 @@ async def test_gemini_query_chat_index_error_2( f"Error with model Gemini ({prompt_dict_chat['model_name']}) " "(i=0, id=gemini_id, message=2/2)\n" f"Prompt: {prompt_dict_chat['prompt'][1][:50]}...\n" - f"Responses so far: {[mock_process_response.return_value]}...\n" + f"Responses so far: {[expected_answer]}...\n" "Error: Response is empty and blocked (IndexError - Test error)" ) assert expected_log_message_2 in caplog.text # assert that the response value is the first response value and an empty string - assert prompt_dict["response"] == [mock_process_response.return_value, ""] + assert prompt_dict["response"] == [expected_answer, ""] @pytest.mark.asyncio @@ -378,14 +382,13 @@ async def test_gemini_query_chat_index_error_2( "send_message", new_callable=CopyingAsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_chat_error_2( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_chat, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): @@ -395,17 +398,15 @@ async def test_gemini_query_chat_error_2( settings = Settings(data_folder="data") log_file = "log.txt" gemini_api = GeminiAPI(settings=settings, log_file=log_file) + expected_answer = "A spontaneous answer" # mock error response from the API from second response gemini_api_sequence_responses = [ - "response Messages object 1", + non_thinking_response, Exception("Test error"), ] mock_gemini_call.side_effect = gemini_api_sequence_responses - # mock the process_response function - mock_process_response.return_value = "response text 1" - # raise error if the API call fails with pytest.raises(Exception, match="Test error"): await gemini_api._query_chat(prompt_dict_chat, index=0) @@ -431,14 +432,13 @@ async def test_gemini_query_chat_error_2( ), ) - mock_process_response.assert_called_once_with(gemini_api_sequence_responses[0]) mock_process_safety_attr.assert_called_once_with(gemini_api_sequence_responses[0]) expected_log_message_1 = ( f"Response received for model Gemini ({prompt_dict_chat['model_name']}) " "(i=0, id=gemini_id, message=1/2)\n" f"Prompt: {prompt_dict_chat['prompt'][0][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message_1 in caplog.text @@ -446,7 +446,7 @@ async def test_gemini_query_chat_error_2( f"Error with model Gemini ({prompt_dict_chat['model_name']}) " "(i=0, id=gemini_id, message=2/2)\n" f"Prompt: {prompt_dict_chat['prompt'][1][:50]}...\n" - f"Responses so far: {[mock_process_response.return_value]}...\n" + f"Responses so far: {[expected_answer]}...\n" "Error: Exception - Test error" ) assert expected_log_message_2 in caplog.text diff --git a/tests/apis/gemini/test_gemini_history_input.py b/tests/apis/gemini/test_gemini_history_input.py index d5db7bb7..c422ce96 100644 --- a/tests/apis/gemini/test_gemini_history_input.py +++ b/tests/apis/gemini/test_gemini_history_input.py @@ -10,6 +10,7 @@ from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings +from .test_gemini import non_thinking_response # nopa: F401 from .test_gemini import prompt_dict_history # nopa: F401 from .test_gemini import prompt_dict_history_no_system # nopa: F401 from .test_gemini import DEFAULT_SAFETY_SETTINGS @@ -43,14 +44,13 @@ async def test_gemini_query_history_no_env_var( "send_message", new_callable=AsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_history( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_history, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): @@ -61,14 +61,7 @@ async def test_gemini_query_history( gemini_api = GeminiAPI(settings=settings, log_file=log_file) # Mock the response from the API - # NOTE: The actual response from the API is a - # `google.genai.types.GenerateContentResponse`` object - # not a string value, but for the purpose of this test, we are using a string value - # and testing that this is the input to the process_response function - mock_gemini_call.return_value = "response Messages object" - - # mock the process_response function - mock_process_response.return_value = "response text" + mock_gemini_call.return_value = non_thinking_response # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_history.keys() @@ -76,8 +69,14 @@ async def test_gemini_query_history( # call the _query_history method prompt_dict = await gemini_api._query_history(prompt_dict_history, index=0) + expected_answer = "A spontaneous answer" + # assert that the response key is added to the prompt_dict assert "response" in prompt_dict.keys() + # assert "thinking_text" is in prompt_dict + assert "thinking_text" in prompt_dict.keys() + assert isinstance(prompt_dict["thinking_text"], list) + assert prompt_dict["thinking_text"] == [] mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() @@ -86,17 +85,16 @@ async def test_gemini_query_history( message=Part(text=prompt_dict_history["prompt"][1]["parts"]), ) - mock_process_response.assert_called_once_with(mock_gemini_call.return_value) mock_process_safety_attr.assert_called_once_with(mock_gemini_call.return_value) # assert that the response value is the return value of the process_response function - assert prompt_dict["response"] == mock_process_response.return_value + assert prompt_dict["response"] == expected_answer expected_log_message = ( f"Response received for model Gemini ({prompt_dict_history['model_name']}) " "(i=0, id=gemini_id)\n" f"Prompt: {prompt_dict_history['prompt'][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message in caplog.text @@ -244,14 +242,13 @@ async def test_gemini_query_history_check_chat_init( "send_message", new_callable=AsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_history_no_system( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_history_no_system, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): @@ -265,10 +262,11 @@ async def test_gemini_query_history_no_system( # NOTE: The actual response from the API is a gemini.types.message.Message object # not a string value, but for the purpose of this test, we are using a string value # and testing that this is the input to the process_response function - mock_gemini_call.return_value = "response Messages object" + mock_gemini_call.return_value = non_thinking_response # mock the process_response function - mock_process_response.return_value = "response text" + # mock_process_response.return_value = "response text" + expected_answer = "A spontaneous answer" # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_history_no_system.keys() @@ -288,17 +286,16 @@ async def test_gemini_query_history_no_system( message=Part(text=prompt_dict_history_no_system["prompt"][2]["parts"]) ) - mock_process_response.assert_called_once_with(mock_gemini_call.return_value) mock_process_safety_attr.assert_called_once_with(mock_gemini_call.return_value) # assert that the response value is the return value of the process_response function - assert prompt_dict["response"] == mock_process_response.return_value + assert prompt_dict["response"] == expected_answer expected_log_message = ( f"Response received for model Gemini ({prompt_dict_history_no_system['model_name']}) " "(i=0, id=gemini_id)\n" f"Prompt: {prompt_dict_history_no_system['prompt'][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message in caplog.text diff --git a/tests/apis/gemini/test_gemini_string_input.py b/tests/apis/gemini/test_gemini_string_input.py index 08d1af75..2d2fbc72 100644 --- a/tests/apis/gemini/test_gemini_string_input.py +++ b/tests/apis/gemini/test_gemini_string_input.py @@ -10,7 +10,11 @@ from prompto.apis.gemini import GeminiAPI from prompto.settings import Settings -from .test_gemini import DEFAULT_SAFETY_SETTINGS, prompt_dict_string +from .test_gemini import ( + DEFAULT_SAFETY_SETTINGS, + non_thinking_response, + prompt_dict_string, +) pytest_plugins = ("pytest_asyncio",) @@ -41,14 +45,13 @@ async def test_gemini_query_string_no_env_var( "generate_content", new_callable=AsyncMock, ) -@patch("prompto.apis.gemini.gemini.process_response", new_callable=Mock) @patch("prompto.apis.gemini.gemini.process_safety_attributes", new_callable=Mock) async def test_gemini_query_string( mock_process_safety_attr, - mock_process_response, mock_gemini_call, prompt_dict_string, temporary_data_folders, + non_thinking_response, monkeypatch, caplog, ): @@ -59,14 +62,7 @@ async def test_gemini_query_string( gemini_api = GeminiAPI(settings=settings, log_file=log_file) # Mock the response from the API - # NOTE: The actual response from the API is a (probably) - # google.genai.types.GenerateContentResponse object (or a promise of it), - # not a string value, but for the purpose of this test, we are using a string value - # and testing that this is the input to the process_response function - mock_gemini_call.return_value = "response Messages object" - - # mock the process_response function - mock_process_response.return_value = "response text" + mock_gemini_call.return_value = non_thinking_response # make sure that the input prompt_dict does not have a response key assert "response" not in prompt_dict_string.keys() @@ -74,8 +70,14 @@ async def test_gemini_query_string( # call the _query_string method prompt_dict = await gemini_api._query_string(prompt_dict_string, index=0) + expected_answer = "A spontaneous answer" + # assert that the response key is added to the prompt_dict assert "response" in prompt_dict.keys() + # assert "thinking_text" is in prompt_dict + assert "thinking_text" in prompt_dict.keys() + assert isinstance(prompt_dict["thinking_text"], list) + assert prompt_dict["thinking_text"] == [] mock_gemini_call.assert_called_once() mock_gemini_call.assert_awaited_once() @@ -89,17 +91,16 @@ async def test_gemini_query_string( ), ) - mock_process_response.assert_called_once_with(mock_gemini_call.return_value) mock_process_safety_attr.assert_called_once_with(mock_gemini_call.return_value) # assert that the response value is the return value of the process_response function - assert prompt_dict["response"] == mock_process_response.return_value + assert prompt_dict["response"] == expected_answer expected_log_message = ( f"Response received for model Gemini ({prompt_dict_string['model_name']}) " "(i=0, id=gemini_id)\n" f"Prompt: {prompt_dict_string['prompt'][:50]}...\n" - f"Response: {mock_process_response.return_value[:50]}...\n" + f"Response: {expected_answer[:50]}...\n" ) assert expected_log_message in caplog.text diff --git a/tests/apis/gemini/test_gemini_utils.py b/tests/apis/gemini/test_gemini_utils.py index 4ac7bd38..4d81059b 100644 --- a/tests/apis/gemini/test_gemini_utils.py +++ b/tests/apis/gemini/test_gemini_utils.py @@ -13,7 +13,7 @@ process_thoughts, ) -from .test_gemini import prompt_dict_history +from .test_gemini import non_thinking_response, prompt_dict_history, thinking_response def test_convert_history_dict_to_content(prompt_dict_history): @@ -39,96 +39,17 @@ def test_convert_history_dict_to_content(prompt_dict_history): ), f"Expected {expected_result}, but got {actual_result}" -def test_process_response(): +def test_process_response(thinking_response, non_thinking_response): """ Test the process_response function. """ - # Some example responses, with and without thinking included. - non_thinking_candidates = [ - Candidate( - content=Content( - parts=[ - Part( - video_metadata=None, - thought=None, - inline_data=None, - code_execution_result=None, - executable_code=None, - file_data=None, - function_call=None, - function_response=None, - text="A spontaneous answer...", - ), - ], - role="model", - ), - citation_metadata=None, - finish_message=None, - token_count=None, - finish_reason=FinishReason.STOP, - avg_logprobs=None, - grounding_metadata=None, - index=0, - logprobs_result=None, - safety_ratings=None, - ) - ] - - thinking_candidates = [ - Candidate( - content=Content( - parts=[ - Part( - video_metadata=None, - thought=True, - inline_data=None, - code_execution_result=None, - executable_code=None, - file_data=None, - function_call=None, - function_response=None, - text="Some thinking...", - ), - Part( - video_metadata=None, - thought=None, - inline_data=None, - code_execution_result=None, - executable_code=None, - file_data=None, - function_call=None, - function_response=None, - text="A thought out answer...", - ), - ], - role="model", - ), - citation_metadata=None, - finish_message=None, - token_count=None, - finish_reason=FinishReason.STOP, - avg_logprobs=None, - grounding_metadata=None, - index=0, - logprobs_result=None, - safety_ratings=None, - ) - ] - - # Each test case is a tuple of (candidates, expected_answer, expected_thoughts) + # Each test case is a tuple of (response, expected_answer, expected_thoughts) test_cases = [ - (non_thinking_candidates, "A spontaneous answer...", []), - (thinking_candidates, "A thought out answer...", ["Some thinking..."]), + (non_thinking_response, "A spontaneous answer", []), + (thinking_response, "A thought out answer", ["Some thinking"]), ] - for candidates, expected_answer, expected_thoughts in test_cases: - - # Create a mock response - the candidates key is the only one used in the - # process_response function. - response = GenerateContentResponse( - candidates=candidates, - ) - + for response, expected_answer, expected_thoughts in test_cases: # Call the function with the test case actual_answer = process_response(response) actual_thoughts = process_thoughts(response) From 4511cf06e15f0afb7773d922edc383e3d904c218 Mon Sep 17 00:00:00 2001 From: Andy Smith Date: Fri, 16 May 2025 15:53:45 +0100 Subject: [PATCH 21/21] Minor tweaks from copilot code review --- src/prompto/apis/gemini/gemini_media.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/prompto/apis/gemini/gemini_media.py b/src/prompto/apis/gemini/gemini_media.py index 7fd08964..8b0f9ddd 100644 --- a/src/prompto/apis/gemini/gemini_media.py +++ b/src/prompto/apis/gemini/gemini_media.py @@ -64,8 +64,6 @@ async def _upload_single_file( results later.) """ local_hash = compute_sha256_base64(local_file_path) - print(f"local_file_path: {local_file_path}") - print(f"local_hash: {local_hash}") if local_hash in already_uploaded_files: logger.info( @@ -105,11 +103,11 @@ def list_uploaded_files(settings: Settings): """ List all previously uploaded files to the Gemini API. """ - genmini_api = GeminiAPI(settings=settings, log_file=None) + gemini_api = GeminiAPI(settings=settings, log_file=None) # TODO: We need a model name, because our API caters for different API keys # for different models. Maybe our API is too complicated.... default_model_name = "default" - client = genmini_api._get_client(default_model_name) + client = gemini_api._get_client(default_model_name) uploaded_files = asyncio.run(_get_previously_uploaded_files(client)) for file_hash, file_name in uploaded_files.items(): @@ -122,11 +120,11 @@ def delete_uploaded_files(settings: Settings): """ Delete all previously uploaded files from the Gemini API. """ - genmini_api = GeminiAPI(settings=settings, log_file=None) + gemini_api = GeminiAPI(settings=settings, log_file=None) # TODO: We need a model name, because our API caters for different API keys # for different models. Maybe our API to complicated.... default_model_name = "default" - client = genmini_api._get_client(default_model_name) + client = gemini_api._get_client(default_model_name) # This just using the synchronous API. Using the async API did not # seem reliable. In particular `client.aio.files.delete()` did not appear