Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
conversational: new PublicAI.PublicAIConversationalTask(),
},
replicate: {
"text-generation": new Replicate.ReplicateTextGenerationTask(),
"text-to-image": new Replicate.ReplicateTextToImageTask(),
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
Expand Down
58 changes: 56 additions & 2 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,54 @@ import {
TaskProviderHelper,
type AutomaticSpeechRecognitionTaskHelper,
type ImageToImageTaskHelper,
type TextGenerationTaskHelper,
type TextToImageTaskHelper,
type TextToVideoTaskHelper,
} from "./providerHelper.js";
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js";
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
import type { AutomaticSpeechRecognitionOutput, TextGenerationOutput } from "@huggingface/tasks";
import { base64FromBytes } from "../utils/base64FromBytes.js";
export interface ReplicateOutput {
output?: string | string[];
output?: unknown;
}

function extractTextFromReplicateResponse(value: unknown): string | undefined {
if (value == null) {
return undefined;
}
if (typeof value === "string") {
return value;
}
if (Array.isArray(value)) {
for (const item of value) {
const text = extractTextFromReplicateResponse(item);
if (typeof text === "string" && text.length > 0) {
return text;
Comment on lines +44 to +48

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Picking first array element returns user prompt instead of model output

The new extractTextFromReplicateResponse walks arrays and returns the first non‑empty string it finds. Replicate’s LLM responses are often arrays of chat messages where the first element is the user (or system) prompt and the assistant’s generated text appears later in the array. In such cases textGeneration will now echo the prompt instead of the model’s reply. Consider iterating from the end of the array or preferring elements with role === "assistant" so the final generated text is returned.

Useful? React with 👍 / 👎.

}
}
return undefined;
}
if (typeof value === "object") {
const record = value as Record<string, unknown>;
const directTextKeys = ["output_text", "generated_text", "text", "content"] as const;
for (const key of directTextKeys) {
const maybeText = record[key];
if (typeof maybeText === "string" && maybeText.length > 0) {
return maybeText;
}
}
const nestedKeys = ["output", "choices", "message", "delta", "content", "data"] as const;
for (const key of nestedKeys) {
if (key in record) {
const text = extractTextFromReplicateResponse(record[key]);
if (typeof text === "string" && text.length > 0) {
return text;
}
}
}
}
return undefined;
}

abstract class ReplicateTask extends TaskProviderHelper {
Expand Down Expand Up @@ -116,6 +155,21 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma
}
}

export class ReplicateTextGenerationTask extends ReplicateTask implements TextGenerationTaskHelper {
override async getResponse(response: ReplicateOutput): Promise<TextGenerationOutput> {
if (response instanceof Blob) {
throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-generation API");
}

const text = extractTextFromReplicateResponse(response);
if (typeof text === "string") {
return { generated_text: text };
}

throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-generation API");
}
}

export class ReplicateTextToSpeechTask extends ReplicateTask {
override preparePayload(params: BodyParams): Record<string, unknown> {
const payload = super.preparePayload(params);
Expand Down
16 changes: 16 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,22 @@ describe.skip("InferenceClient", () => {
() => {
const client = new InferenceClient(env.HF_REPLICATE_KEY ?? "dummy");

it("textGeneration - akhaliq/gpt-5", async () => {
const res = await client.textGeneration({
model: "akhaliq/gpt-5",
provider: "replicate",
inputs: "The capital city of France is",
parameters: {
max_new_tokens: 20,
temperature: 0.2,
},
});

expect(res).toBeDefined();
expect(typeof res.generated_text).toBe("string");
expect(res.generated_text.length).toBeGreaterThan(0);
});

it("textToImage canonical - black-forest-labs/FLUX.1-schnell", async () => {
const res = await client.textToImage({
model: "black-forest-labs/FLUX.1-schnell",
Expand Down