diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 5f5f16b044..efc205c9ff 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -154,6 +154,7 @@ export const PROVIDERS: Record 0) { + return text; + } + } + return undefined; + } + if (typeof value === "object") { + const record = value as Record; + const directTextKeys = ["output_text", "generated_text", "text", "content"] as const; + for (const key of directTextKeys) { + const maybeText = record[key]; + if (typeof maybeText === "string" && maybeText.length > 0) { + return maybeText; + } + } + const nestedKeys = ["output", "choices", "message", "delta", "content", "data"] as const; + for (const key of nestedKeys) { + if (key in record) { + const text = extractTextFromReplicateResponse(record[key]); + if (typeof text === "string" && text.length > 0) { + return text; + } + } + } + } + return undefined; } abstract class ReplicateTask extends TaskProviderHelper { @@ -116,6 +155,21 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma } } +export class ReplicateTextGenerationTask extends ReplicateTask implements TextGenerationTaskHelper { + override async getResponse(response: ReplicateOutput): Promise { + if (response instanceof Blob) { + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-generation API"); + } + + const text = extractTextFromReplicateResponse(response); + if (typeof text === "string") { + return { generated_text: text }; + } + + throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-generation API"); + } +} + export class ReplicateTextToSpeechTask extends ReplicateTask { override preparePayload(params: BodyParams): Record { const payload = super.preparePayload(params); diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index fcb6e55cb3..3fe4aa4b8f 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1195,6 +1195,22 @@ describe.skip("InferenceClient", () => { () => { const client = new InferenceClient(env.HF_REPLICATE_KEY ?? "dummy"); + it("textGeneration - akhaliq/gpt-5", async () => { + const res = await client.textGeneration({ + model: "akhaliq/gpt-5", + provider: "replicate", + inputs: "The capital city of France is", + parameters: { + max_new_tokens: 20, + temperature: 0.2, + }, + }); + + expect(res).toBeDefined(); + expect(typeof res.generated_text).toBe("string"); + expect(res.generated_text.length).toBeGreaterThan(0); + }); + it("textToImage canonical - black-forest-labs/FLUX.1-schnell", async () => { const res = await client.textToImage({ model: "black-forest-labs/FLUX.1-schnell",