@@ -91,7 +91,7 @@ public Task<GenerateContentResponse> GenerateContentAsync(
9191 /// <summary>
9292 /// Generates new content from input text given to the model as a prompt.
9393 /// </summary>
94- /// <param name="content ">The text given to the model as a prompt.</param>
94+ /// <param name="text ">The text given to the model as a prompt.</param>
9595 /// <returns>The generated content response from the model.</returns>
9696 /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
9797 public Task < GenerateContentResponse > GenerateContentAsync (
@@ -122,7 +122,7 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
122122 /// <summary>
123123 /// Generates new content as a stream from input text given to the model as a prompt.
124124 /// </summary>
125- /// <param name="content ">The text given to the model as a prompt.</param>
125+ /// <param name="text ">The text given to the model as a prompt.</param>
126126 /// <returns>A stream of generated content responses from the model.</returns>
127127 /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
128128 public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
@@ -140,14 +140,32 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsync(
140140 return GenerateContentStreamAsyncInternal ( content ) ;
141141 }
142142
143+ /// <summary>
144+ /// Counts the number of tokens in a prompt using the model's tokenizer.
145+ /// </summary>
146+ /// <param name="content">The input(s) given to the model as a prompt.</param>
147+ /// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
148+ /// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
143149 public Task < CountTokensResponse > CountTokensAsync (
144150 params ModelContent [ ] content ) {
145151 return CountTokensAsync ( ( IEnumerable < ModelContent > ) content ) ;
146152 }
153+ /// <summary>
154+ /// Counts the number of tokens in a prompt using the model's tokenizer.
155+ /// </summary>
156+ /// <param name="text">The text input given to the model as a prompt.</param>
157+ /// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
158+ /// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
147159 public Task < CountTokensResponse > CountTokensAsync (
148160 string text ) {
149161 return CountTokensAsync ( new ModelContent [ ] { ModelContent . Text ( text ) } ) ;
150162 }
163+ /// <summary>
164+ /// Counts the number of tokens in a prompt using the model's tokenizer.
165+ /// </summary>
166+ /// <param name="content">The input(s) given to the model as a prompt.</param>
167+ /// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
168+ /// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
151169 public Task < CountTokensResponse > CountTokensAsync (
152170 IEnumerable < ModelContent > content ) {
153171 return CountTokensAsyncInternal ( content ) ;
@@ -184,12 +202,16 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
184202 UnityEngine . Debug . Log ( "Request:\n " + bodyJson ) ;
185203#endif
186204
187- HttpResponseMessage response = await _httpClient . SendAsync ( request ) ;
188- // TODO: Convert any timeout exception into a VertexAI equivalent
189- // TODO: Convert any HttpRequestExceptions, see:
190- // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
191- // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
192- response . EnsureSuccessStatusCode ( ) ;
205+ HttpResponseMessage response ;
206+ try {
207+ response = await _httpClient . SendAsync ( request ) ;
208+ response . EnsureSuccessStatusCode ( ) ;
209+ } catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
210+ throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
211+ } catch ( HttpRequestException e ) {
212+ // TODO: Convert to a more precise exception when possible.
213+ throw new VertexAIException ( "HTTP request failed." , e ) ;
214+ }
193215
194216 string result = await response . Content . ReadAsStringAsync ( ) ;
195217
@@ -215,13 +237,16 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
215237 UnityEngine . Debug . Log ( "Request:\n " + bodyJson ) ;
216238#endif
217239
218- HttpResponseMessage response =
219- await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead ) ;
220- // TODO: Convert any timeout exception into a VertexAI equivalent
221- // TODO: Convert any HttpRequestExceptions, see:
222- // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpclient.sendasync?view=net-9.0
223- // https://learn.microsoft.com/en-us/dotnet/api/system.net.http.httpresponsemessage.ensuresuccessstatuscode?view=net-9.0
224- response . EnsureSuccessStatusCode ( ) ;
240+ HttpResponseMessage response ;
241+ try {
242+ response = await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead ) ;
243+ response . EnsureSuccessStatusCode ( ) ;
244+ } catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
245+ throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
246+ } catch ( HttpRequestException e ) {
247+ // TODO: Convert to a more precise exception when possible.
248+ throw new VertexAIException ( "HTTP request failed." , e ) ;
249+ }
225250
226251 // We are expecting a Stream as the response, so handle that.
227252 using var stream = await response . Content . ReadAsStreamAsync ( ) ;
@@ -242,9 +267,37 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
242267
243268 private async Task < CountTokensResponse > CountTokensAsyncInternal (
244269 IEnumerable < ModelContent > content ) {
245- // TODO: Implementation
246- await Task . CompletedTask ;
247- throw new NotImplementedException ( ) ;
270+ HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":countTokens" ) ;
271+
272+ // Set the request headers
273+ SetRequestHeaders ( request ) ;
274+
275+ // Set the content
276+ string bodyJson = MakeCountTokensRequest ( content ) ;
277+ request . Content = new StringContent ( bodyJson , Encoding . UTF8 , "application/json" ) ;
278+
279+ #if FIREBASE_LOG_REST_CALLS
280+ UnityEngine . Debug . Log ( "CountTokensRequest:\n " + bodyJson ) ;
281+ #endif
282+
283+ HttpResponseMessage response ;
284+ try {
285+ response = await _httpClient . SendAsync ( request ) ;
286+ response . EnsureSuccessStatusCode ( ) ;
287+ } catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
288+ throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
289+ } catch ( HttpRequestException e ) {
290+ // TODO: Convert to a more precise exception when possible.
291+ throw new VertexAIException ( "HTTP request failed." , e ) ;
292+ }
293+
294+ string result = await response . Content . ReadAsStringAsync ( ) ;
295+
296+ #if FIREBASE_LOG_REST_CALLS
297+ UnityEngine . Debug . Log ( "CountTokensResponse:\n " + result ) ;
298+ #endif
299+
300+ return CountTokensResponse . FromJson ( result ) ;
248301 }
249302
250303 private string GetURL ( ) {
@@ -283,6 +336,25 @@ private string ModelContentsToJson(IEnumerable<ModelContent> contents) {
283336
284337 return Json . Serialize ( jsonDict ) ;
285338 }
339+
340+ // CountTokensRequest is a subset of the full info needed for GenerateContent
341+ private string MakeCountTokensRequest ( IEnumerable < ModelContent > contents ) {
342+ Dictionary < string , object > jsonDict = new ( ) {
343+ // Convert the Contents into a list of Json dictionaries
344+ [ "contents" ] = contents . Select ( c => c . ToJson ( ) ) . ToList ( )
345+ } ;
346+ if ( _generationConfig . HasValue ) {
347+ jsonDict [ "generationConfig" ] = _generationConfig ? . ToJson ( ) ;
348+ }
349+ if ( _tools != null && _tools . Length > 0 ) {
350+ jsonDict [ "tools" ] = _tools . Select ( t => t . ToJson ( ) ) . ToList ( ) ;
351+ }
352+ if ( _systemInstruction . HasValue ) {
353+ jsonDict [ "systemInstruction" ] = _systemInstruction ? . ToJson ( ) ;
354+ }
355+
356+ return Json . Serialize ( jsonDict ) ;
357+ }
286358}
287359
288360}
0 commit comments