1919using System . IO ;
2020using System . Linq ;
2121using System . Net . Http ;
22+ using System . Runtime . CompilerServices ;
2223using System . Text ;
24+ using System . Threading ;
2325using System . Threading . Tasks ;
2426using Google . MiniJSON ;
2527using Firebase . VertexAI . Internal ;
@@ -81,94 +83,102 @@ internal GenerativeModel(FirebaseApp firebaseApp,
8183 /// <summary>
8284 /// Generates new content from input `ModelContent` given to the model as a prompt.
8385 /// </summary>
84- /// <param name="content">The input(s) given to the model as a prompt.</param>
86+ /// <param name="content">The input given to the model as a prompt.</param>
87+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
8588 /// <returns>The generated content response from the model.</returns>
8689 /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
8790 public Task < GenerateContentResponse > GenerateContentAsync (
88- params ModelContent [ ] content ) {
89- return GenerateContentAsync ( ( IEnumerable < ModelContent > ) content ) ;
91+ ModelContent content , CancellationToken cancellationToken = default ) {
92+ return GenerateContentAsync ( new [ ] { content } , cancellationToken ) ;
9093 }
9194 /// <summary>
9295 /// Generates new content from input text given to the model as a prompt.
9396 /// </summary>
9497 /// <param name="text">The text given to the model as a prompt.</param>
98+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
9599 /// <returns>The generated content response from the model.</returns>
96100 /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
97101 public Task < GenerateContentResponse > GenerateContentAsync (
98- string text ) {
99- return GenerateContentAsync ( new ModelContent [ ] { ModelContent . Text ( text ) } ) ;
102+ string text , CancellationToken cancellationToken = default ) {
103+ return GenerateContentAsync ( new [ ] { ModelContent . Text ( text ) } , cancellationToken ) ;
100104 }
101105 /// <summary>
102106 /// Generates new content from input `ModelContent` given to the model as a prompt.
103107 /// </summary>
104- /// <param name="content">The input(s) given to the model as a prompt.</param>
108+ /// <param name="content">The input given to the model as a prompt.</param>
109+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
105110 /// <returns>The generated content response from the model.</returns>
106111 /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
107112 public Task < GenerateContentResponse > GenerateContentAsync (
108- IEnumerable < ModelContent > content ) {
109- return GenerateContentAsyncInternal ( content ) ;
113+ IEnumerable < ModelContent > content , CancellationToken cancellationToken = default ) {
114+ return GenerateContentAsyncInternal ( content , cancellationToken ) ;
110115 }
111116
112117 /// <summary>
113118 /// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
114119 /// </summary>
115- /// <param name="content">The input(s) given to the model as a prompt.</param>
120+ /// <param name="content">The input given to the model as a prompt.</param>
121+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
116122 /// <returns>A stream of generated content responses from the model.</returns>
117123 /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
118124 public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
119- params ModelContent [ ] content ) {
120- return GenerateContentStreamAsync ( ( IEnumerable < ModelContent > ) content ) ;
125+ ModelContent content , CancellationToken cancellationToken = default ) {
126+ return GenerateContentStreamAsync ( new [ ] { content } , cancellationToken ) ;
121127 }
122128 /// <summary>
123129 /// Generates new content as a stream from input text given to the model as a prompt.
124130 /// </summary>
125131 /// <param name="text">The text given to the model as a prompt.</param>
132+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
126133 /// <returns>A stream of generated content responses from the model.</returns>
127134 /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
128135 public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
129- string text ) {
130- return GenerateContentStreamAsync ( new ModelContent [ ] { ModelContent . Text ( text ) } ) ;
136+ string text , CancellationToken cancellationToken = default ) {
137+ return GenerateContentStreamAsync ( new [ ] { ModelContent . Text ( text ) } , cancellationToken ) ;
131138 }
132139 /// <summary>
133140 /// Generates new content as a stream from input `ModelContent` given to the model as a prompt.
134141 /// </summary>
135- /// <param name="content">The input(s) given to the model as a prompt.</param>
142+ /// <param name="content">The input given to the model as a prompt.</param>
143+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
136144 /// <returns>A stream of generated content responses from the model.</returns>
137145 /// <exception cref="VertexAIException">Thrown when an error occurs during content generation.</exception>
138146 public IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsync (
139- IEnumerable < ModelContent > content ) {
140- return GenerateContentStreamAsyncInternal ( content ) ;
147+ IEnumerable < ModelContent > content , CancellationToken cancellationToken = default ) {
148+ return GenerateContentStreamAsyncInternal ( content , cancellationToken ) ;
141149 }
142150
143151 /// <summary>
144152 /// Counts the number of tokens in a prompt using the model's tokenizer.
145153 /// </summary>
146- /// <param name="content">The input(s) given to the model as a prompt.</param>
154+ /// <param name="content">The input given to the model as a prompt.</param>
147155 /// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
148156 /// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
149157 public Task < CountTokensResponse > CountTokensAsync (
150- params ModelContent [ ] content ) {
151- return CountTokensAsync ( ( IEnumerable < ModelContent > ) content ) ;
158+ ModelContent content , CancellationToken cancellationToken = default ) {
159+ return CountTokensAsync ( new [ ] { content } , cancellationToken ) ;
152160 }
153161 /// <summary>
154162 /// Counts the number of tokens in a prompt using the model's tokenizer.
155163 /// </summary>
156164 /// <param name="text">The text input given to the model as a prompt.</param>
165+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
157166 /// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
158167 /// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
159168 public Task < CountTokensResponse > CountTokensAsync (
160- string text ) {
161- return CountTokensAsync ( new ModelContent [ ] { ModelContent . Text ( text ) } ) ;
169+ string text , CancellationToken cancellationToken = default ) {
170+ return CountTokensAsync ( new [ ] { ModelContent . Text ( text ) } , cancellationToken ) ;
162171 }
163172 /// <summary>
164173 /// Counts the number of tokens in a prompt using the model's tokenizer.
165174 /// </summary>
166- /// <param name="content">The input(s) given to the model as a prompt.</param>
175+ /// <param name="content">The input given to the model as a prompt.</param>
176+ /// <param name="cancellationToken">An optional token to cancel the operation.</param>
167177 /// <returns>The `CountTokensResponse` of running the model's tokenizer on the input.</returns>
168178 /// <exception cref="VertexAIException">Thrown when an error occurs during the request.</exception>
169179 public Task < CountTokensResponse > CountTokensAsync (
170- IEnumerable < ModelContent > content ) {
171- return CountTokensAsyncInternal ( content ) ;
180+ IEnumerable < ModelContent > content , CancellationToken cancellationToken = default ) {
181+ return CountTokensAsyncInternal ( content , cancellationToken ) ;
172182 }
173183
174184 /// <summary>
@@ -188,7 +198,8 @@ public Chat StartChat(IEnumerable<ModelContent> history) {
188198#endregion
189199
190200 private async Task < GenerateContentResponse > GenerateContentAsyncInternal (
191- IEnumerable < ModelContent > content ) {
201+ IEnumerable < ModelContent > content ,
202+ CancellationToken cancellationToken ) {
192203 HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":generateContent" ) ;
193204
194205 // Set the request headers
@@ -204,7 +215,7 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
204215
205216 HttpResponseMessage response ;
206217 try {
207- response = await _httpClient . SendAsync ( request ) ;
218+ response = await _httpClient . SendAsync ( request , cancellationToken ) ;
208219 response . EnsureSuccessStatusCode ( ) ;
209220 } catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
210221 throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
@@ -223,7 +234,8 @@ private async Task<GenerateContentResponse> GenerateContentAsyncInternal(
223234 }
224235
225236 private async IAsyncEnumerable < GenerateContentResponse > GenerateContentStreamAsyncInternal (
226- IEnumerable < ModelContent > content ) {
237+ IEnumerable < ModelContent > content ,
238+ [ EnumeratorCancellation ] CancellationToken cancellationToken ) {
227239 HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":streamGenerateContent?alt=sse" ) ;
228240
229241 // Set the request headers
@@ -239,7 +251,7 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
239251
240252 HttpResponseMessage response ;
241253 try {
242- response = await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead ) ;
254+ response = await _httpClient . SendAsync ( request , HttpCompletionOption . ResponseHeadersRead , cancellationToken ) ;
243255 response . EnsureSuccessStatusCode ( ) ;
244256 } catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
245257 throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
@@ -266,7 +278,8 @@ private async IAsyncEnumerable<GenerateContentResponse> GenerateContentStreamAsy
266278 }
267279
268280 private async Task < CountTokensResponse > CountTokensAsyncInternal (
269- IEnumerable < ModelContent > content ) {
281+ IEnumerable < ModelContent > content ,
282+ CancellationToken cancellationToken ) {
270283 HttpRequestMessage request = new ( HttpMethod . Post , GetURL ( ) + ":countTokens" ) ;
271284
272285 // Set the request headers
@@ -282,7 +295,7 @@ private async Task<CountTokensResponse> CountTokensAsyncInternal(
282295
283296 HttpResponseMessage response ;
284297 try {
285- response = await _httpClient . SendAsync ( request ) ;
298+ response = await _httpClient . SendAsync ( request , cancellationToken ) ;
286299 response . EnsureSuccessStatusCode ( ) ;
287300 } catch ( TaskCanceledException e ) when ( e . InnerException is TimeoutException ) {
288301 throw new VertexAIRequestTimeoutException ( "Request timed out." , e ) ;
0 commit comments