Skip to content

Commit d60f1f2

Browse files
authored
feat(bedrock): support streaming (#128)
1 parent c035a2d commit d60f1f2

File tree

7 files changed

+521
-12
lines changed

7 files changed

+521
-12
lines changed

gateway/src/api/bedrock.ts

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
* @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Operations_Amazon_Bedrock_Runtime.html
33
*/
44

5-
import type { ConverseRequest, ConverseResponse } from '@aws-sdk/client-bedrock-runtime'
6-
import { BaseAPI } from './base'
5+
import type { ConverseRequest, ConverseResponse, ConverseStreamOutput } from '@aws-sdk/client-bedrock-runtime'
6+
import { BaseAPI, type ExtractedRequest, type ExtractedResponse, type ExtractorConfig } from './base'
77

8-
export class ConverseAPI extends BaseAPI<ConverseRequest, ConverseResponse> {
8+
export class ConverseAPI extends BaseAPI<ConverseRequest, ConverseResponse, ConverseStreamOutput> {
99
defaultBaseUrl = 'https://bedrock-runtime.us-east-1.amazonaws.com'
1010

1111
requestStopSequences = (requestBody: ConverseRequest): string[] | undefined => {
@@ -28,6 +28,38 @@ export class ConverseAPI extends BaseAPI<ConverseRequest, ConverseResponse> {
2828
responseId = (_responseBody: ConverseResponse): string | undefined => {
2929
return undefined
3030
}
31+
32+
// SafeExtractor implementation
33+
requestExtractors: ExtractorConfig<ConverseRequest, ExtractedRequest> = {
34+
requestModel: (requestBody: ConverseRequest) => {
35+
this.extractedRequest.requestModel = requestBody.modelId
36+
},
37+
maxTokens: (requestBody: ConverseRequest) => {
38+
this.extractedRequest.maxTokens = requestBody.inferenceConfig?.maxTokens
39+
},
40+
temperature: (requestBody: ConverseRequest) => {
41+
this.extractedRequest.temperature = requestBody.inferenceConfig?.temperature
42+
},
43+
topP: (requestBody: ConverseRequest) => {
44+
this.extractedRequest.topP = requestBody.inferenceConfig?.topP
45+
},
46+
stopSequences: (requestBody: ConverseRequest) => {
47+
this.extractedRequest.stopSequences = requestBody.inferenceConfig?.stopSequences
48+
},
49+
}
50+
51+
responseExtractors: ExtractorConfig<ConverseResponse, ExtractedResponse> = {}
52+
53+
chunkExtractors: ExtractorConfig<ConverseStreamOutput, ExtractedResponse> = {
54+
usage: (chunk: ConverseStreamOutput) => {
55+
if ('usage' in chunk) {
56+
this.extractedResponse.usage = this.extractUsage(chunk)
57+
}
58+
},
59+
responseModel: (_chunk: ConverseStreamOutput) => {
60+
this.extractedResponse.responseModel = this.requestModel
61+
},
62+
}
3163
}
3264

3365
// TODO(Marcelo): Add input/output messages extraction.

gateway/src/providers/bedrock.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ export class BedrockProvider extends DefaultProviderProxy {
2626
} catch (_error) {
2727
return { error: 'invalid request JSON' }
2828
}
29-
const m = this.inferModel(this.restOfPath)
30-
if (m) {
31-
return { requestBodyText, requestBodyData, requestModel: m[1] }
29+
const model = this.inferModel(this.restOfPath)
30+
if (model) {
31+
return { requestBodyText, requestBodyData, requestModel: model }
3232
}
3333
return { error: 'unable to find model in path' }
3434
}

gateway/src/providers/default.ts

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {
66
type Provider as UsageProvider,
77
} from '@pydantic/genai-prices'
88
import * as logfire from '@pydantic/logfire-api'
9+
import { EventStreamCodec } from '@smithy/eventstream-codec'
910
import { createParser, type EventSourceMessage } from 'eventsource-parser'
1011

1112
import type { GatewayOptions } from '..'
@@ -385,12 +386,14 @@ export class DefaultProviderProxy {
385386

386387
// IMPORTANT: Start consuming BOTH streams immediately to prevent tee() from buffering
387388
// The tee() requires both streams to be consumed concurrently, otherwise it will buffer
388-
389-
// Tee stream: one for client, one for processing
390389
const [responseStream, processingStream] = response.body.tee()
391390

392-
// Parse SSE events from processing stream
393-
const events = this.parseSSE(processingStream)
391+
let events: AsyncIterable<JsonData>
392+
if (responseHeaders.get('content-type')?.toLowerCase().startsWith('application/vnd.amazon.eventstream')) {
393+
events = this.parseAmazonEventStream(processingStream)
394+
} else {
395+
events = this.parseSSE(processingStream)
396+
}
394397

395398
// @ts-expect-error: TODO(Marcelo): Fix this type error.
396399
const extractionPromise = this.processChunks(modelAPI, events)
@@ -474,9 +477,42 @@ export class DefaultProviderProxy {
474477
}
475478
}
476479

480+
protected async *parseAmazonEventStream(stream: ReadableStream<Uint8Array>): AsyncIterable<JsonData> {
481+
const encoder = new TextEncoder()
482+
const codec = new EventStreamCodec((str) => str, encoder.encode)
483+
const decoder = new TextDecoder()
484+
let buffer = new Uint8Array(0)
485+
486+
for await (const chunk of stream) {
487+
// Append incoming chunk to buffer since messages can span multiple network chunks
488+
const combined = new Uint8Array(buffer.length + chunk.length)
489+
combined.set(buffer, 0)
490+
combined.set(chunk, buffer.length)
491+
buffer = combined
492+
493+
// Extract complete messages from buffer (eventstream format: 4-byte length prefix + message data)
494+
while (buffer.length >= 4) {
495+
const messageLength = new DataView(buffer.buffer, buffer.byteOffset).getUint32(0, false)
496+
if (buffer.length < messageLength) break
497+
498+
try {
499+
const message = codec.decode(buffer.subarray(0, messageLength))
500+
if (message.body?.length > 0) {
501+
yield JSON.parse(decoder.decode(message.body))
502+
}
503+
buffer = buffer.subarray(messageLength)
504+
} catch (error) {
505+
logfire.reportError('Error parsing Amazon EventStream', error as Error)
506+
break
507+
}
508+
}
509+
}
510+
}
511+
477512
protected isStreaming(responseHeaders: Headers, requestBodyData: JsonData): boolean {
478513
return (
479514
responseHeaders.get('content-type')?.toLowerCase().startsWith('text/event-stream') ||
515+
responseHeaders.get('content-type')?.toLowerCase().startsWith('application/vnd.amazon.eventstream') ||
480516
('stream' in requestBodyData && requestBodyData.stream === true)
481517
)
482518
}

gateway/test/providers/bedrock.spec.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { EventStreamCodec } from '@smithy/eventstream-codec'
12
import { describe, expect } from 'vitest'
23
import { deserializeRequest } from '../otel'
34
import { test } from '../setup'
@@ -24,4 +25,50 @@ describe('bedrock', () => {
2425
expect(otelBatch, 'otelBatch length not 1').toHaveLength(1)
2526
expect(deserializeRequest(otelBatch[0]!)).toMatchSnapshot('span')
2627
})
28+
29+
test('bedrock/stream', async ({ gateway }) => {
30+
const { fetch, otelBatch } = gateway
31+
32+
const result = await fetch('https://example.com/converse/model/amazon.nova-micro-v1%3A0/converse-stream', {
33+
method: 'POST',
34+
headers: { 'Content-Type': 'application/json', Authorization: 'healthy', 'x-vcr-filename': 'stream' },
35+
body: JSON.stringify({
36+
modelId: 'amazon.nova-premier-v1:0',
37+
system: [{ text: 'You are a helpful assistant.' }],
38+
messages: [{ role: 'user', content: [{ text: 'What is the capital of France?' }] }],
39+
}),
40+
})
41+
const chunks: string[] = []
42+
for await (const chunk of parseEventStream(result.body!)) {
43+
chunks.push(chunk)
44+
}
45+
expect(chunks).toMatchSnapshot('chunks')
46+
expect(otelBatch, 'otelBatch length not 1').toHaveLength(1)
47+
expect(deserializeRequest(otelBatch[0]!)).toMatchSnapshot('span')
48+
})
2749
})
50+
51+
async function* parseEventStream(stream: ReadableStream<Uint8Array>): AsyncIterable<string> {
52+
const encoder = new TextEncoder()
53+
const codec = new EventStreamCodec((str) => str, encoder.encode)
54+
const decoder = new TextDecoder()
55+
let buffer = new Uint8Array(0)
56+
57+
for await (const chunk of stream) {
58+
const combined = new Uint8Array(buffer.length + chunk.length)
59+
combined.set(buffer, 0)
60+
combined.set(chunk, buffer.length)
61+
buffer = combined
62+
63+
while (buffer.length >= 4) {
64+
const messageLength = new DataView(buffer.buffer, buffer.byteOffset).getUint32(0, false)
65+
if (buffer.length < messageLength) break
66+
67+
const message = codec.decode(buffer.subarray(0, messageLength))
68+
if (message.body?.length > 0) {
69+
yield JSON.parse(decoder.decode(message.body))
70+
}
71+
buffer = buffer.subarray(messageLength)
72+
}
73+
}
74+
}

0 commit comments

Comments
 (0)