@@ -34,7 +34,7 @@ public class GenerativeModel {
3434 private readonly FirebaseApp _firebaseApp ;
3535
3636 // Various setting fields provided by the user.
37- private readonly string _location ;
37+ private readonly FirebaseAI . Backend _backend ;
3838 private readonly string _modelName ;
3939 private readonly GenerationConfig ? _generationConfig ;
4040 private readonly SafetySetting [ ] _safetySettings ;
@@ -52,7 +52,7 @@ public class GenerativeModel {
5252 /// Use `VertexAI.GetGenerativeModel` instead to ensure proper initialization and configuration of the `GenerativeModel`.
5353 /// </summary>
5454 internal GenerativeModel ( FirebaseApp firebaseApp ,
55- string location ,
55+ FirebaseAI . Backend backend ,
5656 string modelName ,
5757 GenerationConfig ? generationConfig = null ,
5858 SafetySetting [ ] safetySettings = null ,
@@ -61,7 +61,7 @@ internal GenerativeModel(FirebaseApp firebaseApp,
6161 ModelContent ? systemInstruction = null ,
6262 RequestOptions ? requestOptions = null ) {
6363 _firebaseApp = firebaseApp ;
64- _location = location ;
64+ _backend = backend ;
6565 _modelName = modelName ;
6666 _generationConfig = generationConfig ;
6767 _safetySettings = safetySettings ;
@@ -195,7 +195,7 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
195195 SetRequestHeaders ( request ) ;
196196
197197 // Set the content
198- string bodyJson = ModelContentsToJson ( content ) ;
198+ string bodyJson = MakeGenerateContentRequest ( content ) ;
199199 request . Content = new StringContent ( bodyJson , Encoding . UTF8 , "application/json" ) ;
200200
201201#if FIREBASE_LOG_REST_CALLS
@@ -219,7 +219,7 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
219219 UnityEngine . Debug . Log ( "Response:\n " + result ) ;
220220#endif
221221
222- return GenerateContentResponse . FromJson ( result ) ;
222+ return GenerateContentResponse . FromJson ( result , _backend . Provider ) ;
223223 }
224224
225225 private async IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsyncInternal (
@@ -230,7 +230,7 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
230230 SetRequestHeaders ( request ) ;
231231
232232 // Set the content
233- string bodyJson = ModelContentsToJson ( content ) ;
233+ string bodyJson = MakeGenerateContentRequest ( content ) ;
234234 request . Content = new StringContent ( bodyJson , Encoding . UTF8 , "application/json" ) ;
235235
236236#if FIREBASE_LOG_REST_CALLS
@@ -260,7 +260,7 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
260260 UnityEngine . Debug . Log ( "Streaming Response:\n " + line ) ;
261261#endif
262262
263- yield return GenerateContentResponse . FromJson ( line [ StreamPrefix . Length ..] ) ;
263+ yield return GenerateContentResponse . FromJson ( line [ StreamPrefix . Length ..] , _backend . Provider ) ;
264264 }
265265 }
266266 }
@@ -301,10 +301,18 @@ private async Task<CountTokensResponse> CountTokensAsyncInternal(
301301 }
302302
303303 private string GetURL ( ) {
304- return "https://firebasevertexai.googleapis.com/v1beta" +
305- "/projects/" + _firebaseApp . Options . ProjectId +
306- "/locations/" + _location +
307- "/publishers/google/models/" + _modelName ;
304+ if ( _backend . Provider == FirebaseAI . Backend . InternalProvider . VertexAI ) {
305+ return "https://firebasevertexai.googleapis.com/v1beta" +
306+ "/projects/" + _firebaseApp . Options . ProjectId +
307+ "/locations/" + _backend . Location +
308+ "/publishers/google/models/" + _modelName ;
309+ } else if ( _backend . Provider == FirebaseAI . Backend . InternalProvider . GoogleAI ) {
310+ return "https://firebasevertexai.googleapis.com/v1beta" +
311+ "/projects/" + _firebaseApp . Options . ProjectId +
312+ "/models/" + _modelName ;
313+ } else {
314+ throw new NotSupportedException ( $ "Missing support for backend: { _backend . Provider } ") ;
315+ }
308316 }
309317
310318 private void SetRequestHeaders ( HttpRequestMessage request ) {
@@ -313,7 +321,13 @@ private void SetRequestHeaders(HttpRequestMessage request) {
313321 request . Headers . Add ( "x-goog-api-client" , "genai-csharp/0.1.0" ) ;
314322 }
315323
316- private string ModelContentsToJson ( IEnumerable < ModelContent > contents ) {
324+ private string MakeGenerateContentRequest ( IEnumerable < ModelContent > contents ) {
325+ Dictionary < string , object > jsonDict = MakeGenerateContentRequestAsDictionary ( contents ) ;
326+ return Json . Serialize ( jsonDict ) ;
327+ }
328+
329+ private Dictionary < string , object > MakeGenerateContentRequestAsDictionary (
330+ IEnumerable < ModelContent > contents ) {
317331 Dictionary < string , object > jsonDict = new ( ) {
318332 // Convert the Contents into a list of Json dictionaries
319333 [ "contents" ] = contents . Select ( c => c . ToJson ( ) ) . ToList ( )
@@ -322,7 +336,7 @@ private string ModelContentsToJson(IEnumerable<ModelContent> contents) {
322336 jsonDict [ "generationConfig" ] = _generationConfig ? . ToJson ( ) ;
323337 }
324338 if ( _safetySettings != null && _safetySettings . Length > 0 ) {
325- jsonDict [ "safetySettings" ] = _safetySettings . Select ( s => s . ToJson ( ) ) . ToList ( ) ;
339+ jsonDict [ "safetySettings" ] = _safetySettings . Select ( s => s . ToJson ( _backend . Provider ) ) . ToList ( ) ;
326340 }
327341 if ( _tools != null && _tools . Length > 0 ) {
328342 jsonDict [ "tools" ] = _tools . Select ( t => t . ToJson ( ) ) . ToList ( ) ;
@@ -334,23 +348,38 @@ private string ModelContentsToJson(IEnumerable<ModelContent> contents) {
334348 jsonDict [ "systemInstruction" ] = _systemInstruction ? . ToJson ( ) ;
335349 }
336350
337- return Json . Serialize ( jsonDict ) ;
351+ return jsonDict ;
338352 }
339353
340354 // CountTokensRequest is a subset of the full info needed for GenerateContent
341355 private string MakeCountTokensRequest ( IEnumerable < ModelContent > contents ) {
342- Dictionary < string , object > jsonDict = new ( ) {
343- // Convert the Contents into a list of Json dictionaries
344- [ "contents" ] = contents . Select ( c => c . ToJson ( ) ) . ToList ( )
345- } ;
346- if ( _generationConfig . HasValue ) {
347- jsonDict [ "generationConfig" ] = _generationConfig ? . ToJson ( ) ;
348- }
349- if ( _tools != null && _tools . Length > 0 ) {
350- jsonDict [ "tools" ] = _tools . Select ( t => t . ToJson ( ) ) . ToList ( ) ;
351- }
352- if ( _systemInstruction . HasValue ) {
353- jsonDict [ "systemInstruction" ] = _systemInstruction ? . ToJson ( ) ;
356+ Dictionary < string , object > jsonDict ;
357+ switch ( _backend . Provider ) {
358+ case FirebaseAI . Backend . InternalProvider . GoogleAI :
359+ jsonDict = new ( ) {
360+ [ "generateContentRequest" ] = MakeGenerateContentRequestAsDictionary ( contents )
361+ } ;
362+ // GoogleAI wants the model name included as well.
363+ ( ( Dictionary < string , object > ) jsonDict [ "generateContentRequest" ] ) [ "model" ] =
364+ $ "models/{ _modelName } ";
365+ break ;
366+ case FirebaseAI . Backend . InternalProvider . VertexAI :
367+ jsonDict = new ( ) {
368+ // Convert the Contents into a list of Json dictionaries
369+ [ "contents" ] = contents . Select ( c => c . ToJson ( ) ) . ToList ( )
370+ } ;
371+ if ( _generationConfig . HasValue ) {
372+ jsonDict [ "generationConfig" ] = _generationConfig ? . ToJson ( ) ;
373+ }
374+ if ( _tools != null && _tools . Length > 0 ) {
375+ jsonDict [ "tools" ] = _tools . Select ( t => t . ToJson ( ) ) . ToList ( ) ;
376+ }
377+ if ( _systemInstruction . HasValue ) {
378+ jsonDict [ "systemInstruction" ] = _systemInstruction ? . ToJson ( ) ;
379+ }
380+ break ;
381+ default :
382+ throw new NotSupportedException ( $ "Missing support for backend: { _backend . Provider } ") ;
354383 }
355384
356385 return Json . Serialize ( jsonDict ) ;
0 commit comments