|
6 | 6 | type Provider as UsageProvider, |
7 | 7 | } from '@pydantic/genai-prices' |
8 | 8 | import * as logfire from '@pydantic/logfire-api' |
| 9 | +import { EventStreamCodec } from '@smithy/eventstream-codec' |
9 | 10 | import { createParser, type EventSourceMessage } from 'eventsource-parser' |
10 | 11 |
|
11 | 12 | import type { GatewayOptions } from '..' |
@@ -385,12 +386,14 @@ export class DefaultProviderProxy { |
385 | 386 |
|
386 | 387 | // IMPORTANT: Start consuming BOTH streams immediately to prevent tee() from buffering |
387 | 388 | // The tee() requires both streams to be consumed concurrently, otherwise it will buffer |
388 | | - |
389 | | - // Tee stream: one for client, one for processing |
390 | 389 | const [responseStream, processingStream] = response.body.tee() |
391 | 390 |
|
392 | | - // Parse SSE events from processing stream |
393 | | - const events = this.parseSSE(processingStream) |
| 391 | + let events: AsyncIterable<JsonData> |
| 392 | + if (responseHeaders.get('content-type')?.toLowerCase().startsWith('application/vnd.amazon.eventstream')) { |
| 393 | + events = this.parseAmazonEventStream(processingStream) |
| 394 | + } else { |
| 395 | + events = this.parseSSE(processingStream) |
| 396 | + } |
394 | 397 |
|
395 | 398 | // @ts-expect-error: TODO(Marcelo): Fix this type error. |
396 | 399 | const extractionPromise = this.processChunks(modelAPI, events) |
@@ -474,9 +477,42 @@ export class DefaultProviderProxy { |
474 | 477 | } |
475 | 478 | } |
476 | 479 |
|
| 480 | + protected async *parseAmazonEventStream(stream: ReadableStream<Uint8Array>): AsyncIterable<JsonData> { |
| 481 | + const encoder = new TextEncoder() |
| 482 | + const codec = new EventStreamCodec((str) => str, encoder.encode) |
| 483 | + const decoder = new TextDecoder() |
| 484 | + let buffer = new Uint8Array(0) |
| 485 | + |
| 486 | + for await (const chunk of stream) { |
| 487 | + // Append incoming chunk to buffer since messages can span multiple network chunks |
| 488 | + const combined = new Uint8Array(buffer.length + chunk.length) |
| 489 | + combined.set(buffer, 0) |
| 490 | + combined.set(chunk, buffer.length) |
| 491 | + buffer = combined |
| 492 | + |
| 493 | + // Extract complete messages from buffer (eventstream format: 4-byte length prefix + message data) |
| 494 | + while (buffer.length >= 4) { |
| 495 | + const messageLength = new DataView(buffer.buffer, buffer.byteOffset).getUint32(0, false) |
| 496 | + if (buffer.length < messageLength) break |
| 497 | + |
| 498 | + try { |
| 499 | + const message = codec.decode(buffer.subarray(0, messageLength)) |
| 500 | + if (message.body?.length > 0) { |
| 501 | + yield JSON.parse(decoder.decode(message.body)) |
| 502 | + } |
| 503 | + buffer = buffer.subarray(messageLength) |
| 504 | + } catch (error) { |
| 505 | + logfire.reportError('Error parsing Amazon EventStream', error as Error) |
| 506 | + break |
| 507 | + } |
| 508 | + } |
| 509 | + } |
| 510 | + } |
| 511 | + |
477 | 512 | protected isStreaming(responseHeaders: Headers, requestBodyData: JsonData): boolean { |
478 | 513 | return ( |
479 | 514 | responseHeaders.get('content-type')?.toLowerCase().startsWith('text/event-stream') || |
| 515 | + responseHeaders.get('content-type')?.toLowerCase().startsWith('application/vnd.amazon.eventstream') || |
480 | 516 | ('stream' in requestBodyData && requestBodyData.stream === true) |
481 | 517 | ) |
482 | 518 | } |
|
0 commit comments