1818
1919import static com .google .common .truth .Truth .assertThat ;
2020import static com .google .common .truth .Truth .assertWithMessage ;
21-
21+ import static org .mockito .ArgumentMatchers .any ;
22+ import static org .mockito .ArgumentMatchers .anyString ;
23+ import static org .mockito .Mockito .RETURNS_SELF ;
24+ import static org .mockito .Mockito .mock ;
25+ import static org .mockito .Mockito .mockStatic ;
26+ import static org .mockito .Mockito .times ;
27+ import static org .mockito .Mockito .verify ;
28+ import static org .mockito .Mockito .when ;
29+
30+ import com .google .genai .Client ;
31+ import com .google .genai .Models ;
32+ import com .google .genai .types .GenerateContentConfig ;
33+ import com .google .genai .types .GenerateContentResponse ;
2234import java .io .ByteArrayOutputStream ;
2335import java .io .IOException ;
2436import java .io .PrintStream ;
37+ import java .lang .reflect .Field ;
2538import org .junit .After ;
2639import org .junit .Before ;
2740import org .junit .BeforeClass ;
2841import org .junit .Test ;
2942import org .junit .runner .RunWith ;
3043import org .junit .runners .JUnit4 ;
44+ import org .mockito .MockedStatic ;
45+
3146
3247@ RunWith (JUnit4 .class )
3348public class ToolsIT {
@@ -105,4 +120,42 @@ public void testToolsGoogleSearchWithText() {
105120 assertThat (response ).isNotEmpty ();
106121 }
107122
123+ @ Test
124+ public void testToolsVaisWithText () throws NoSuchFieldException , IllegalAccessException {
125+ String response = "The process for making an appointment to renew your driver's license"
126+ + " varies depending on your location." ;
127+
128+ String datastore =
129+ String .format (
130+ "projects/%s/locations/global/collections/default_collection/"
131+ + "dataStores/grounding-test-datastore" ,
132+ PROJECT_ID );
133+
134+ Client .Builder mockedBuilder = mock (Client .Builder .class , RETURNS_SELF );
135+ Client mockedClient = mock (Client .class );
136+ Models mockedModels = mock (Models .class );
137+ GenerateContentResponse mockedResponse = mock (GenerateContentResponse .class );
138+
139+ try (MockedStatic <Client > mockedStatic = mockStatic (Client .class )) {
140+ mockedStatic .when (Client ::builder ).thenReturn (mockedBuilder );
141+ when (mockedBuilder .build ()).thenReturn (mockedClient );
142+
143+ // Using reflection because 'models' is a final field and cannot be mockable directly
144+ Field field = Client .class .getDeclaredField ("models" );
145+ field .setAccessible (true );
146+ field .set (mockedClient , mockedModels );
147+
148+ when (mockedClient .models .generateContent (
149+ anyString (), anyString (), any (GenerateContentConfig .class )))
150+ .thenReturn (mockedResponse );
151+ when (mockedResponse .text ()).thenReturn (response );
152+
153+ String generatedResponse = ToolsVaisWithText .generateContent (GEMINI_FLASH , datastore );
154+
155+ verify (mockedClient .models , times (1 ))
156+ .generateContent (anyString (), anyString (), any (GenerateContentConfig .class ));
157+ assertThat (generatedResponse ).isNotEmpty ();
158+ assertThat (response ).isEqualTo (generatedResponse );
159+ }
160+ }
108161}
0 commit comments