diff --git a/packages/inference/src/providers/replicate.ts b/packages/inference/src/providers/replicate.ts index 75496fecee..4e93074e4e 100644 --- a/packages/inference/src/providers/replicate.ts +++ b/packages/inference/src/providers/replicate.ts @@ -33,6 +33,51 @@ export interface ReplicateOutput { output?: string | string[]; } +type ReplicatePredictionStatus = "starting" | "processing" | "succeeded" | "failed" | "canceled" | "queued"; + +interface ReplicateAsyncResponse extends ReplicateOutput { + id?: string; + status?: ReplicatePredictionStatus; + error?: unknown; + urls?: { + get?: string; + }; +} + +const POLLING_INTERVAL_MS = 1_000; + +function headersInitToRecord(headers?: HeadersInit): Record { + if (!headers) { + return {}; + } + if (headers instanceof Headers) { + return Object.fromEntries(headers.entries()); + } + if (Array.isArray(headers)) { + return Object.fromEntries(headers); + } + return { ...headers }; +} + +function getErrorMessage(error: unknown): string | undefined { + if (!error) { + return undefined; + } + if (typeof error === "string") { + return error; + } + if (typeof error === "object" && "message" in error && typeof error.message === "string") { + return error.message; + } + return undefined; +} + +async function sleep(ms: number): Promise { + await new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} + abstract class ReplicateTask extends TaskProviderHelper { constructor(url?: string) { super("replicate", url || "https://api.replicate.com"); @@ -69,6 +114,97 @@ abstract class ReplicateTask extends TaskProviderHelper { } return `${baseUrl}/v1/models/${params.model}/predictions`; } + + protected async ensureFinalResponse( + response: ReplicateOutput | Blob | ReplicateAsyncResponse, + requestUrl?: string, + headers?: HeadersInit + ): Promise { + if (response instanceof Blob) { + return response; + } + + if (!response || typeof response !== "object") { + return response as ReplicateOutput; + } + + const status = "status" in response ? response.status : undefined; + + if (!status || status === "succeeded") { + return response as ReplicateOutput; + } + + if (status === "failed" || status === "canceled") { + const message = getErrorMessage((response as ReplicateAsyncResponse).error); + throw new InferenceClientProviderOutputError(`Replicate prediction ${status}${message ? `: ${message}` : ""}`); + } + + const pollUrl = this.getPollUrl(response as ReplicateAsyncResponse, requestUrl); + if (!pollUrl) { + throw new InferenceClientProviderOutputError( + "Received incomplete response from Replicate API: missing polling URL" + ); + } + + const headerRecord = headersInitToRecord(headers); + const pollHeaders: Record = {}; + if (headerRecord.Authorization) { + pollHeaders.Authorization = headerRecord.Authorization; + } + pollHeaders.Accept = "application/json"; + + // Poll the prediction endpoint until completion + while (true) { + await sleep(POLLING_INTERVAL_MS); + const pollResponse = await fetch(pollUrl, { + method: "GET", + headers: pollHeaders, + }); + + if (!pollResponse.ok) { + throw new InferenceClientProviderOutputError( + `Failed to poll Replicate prediction status: HTTP ${pollResponse.status}` + ); + } + + const prediction = (await pollResponse.json()) as ReplicateAsyncResponse; + const predictionStatus = prediction.status; + + if (!predictionStatus || predictionStatus === "succeeded") { + return prediction as ReplicateOutput; + } + + if (predictionStatus === "failed" || predictionStatus === "canceled") { + const message = getErrorMessage(prediction.error); + throw new InferenceClientProviderOutputError( + `Replicate prediction ${predictionStatus}${message ? `: ${message}` : ""}` + ); + } + } + } + + private getPollUrl(response: ReplicateAsyncResponse, requestUrl?: string): string | undefined { + if (response.urls && typeof response.urls === "object" && typeof response.urls.get === "string") { + return response.urls.get; + } + + if (!response.id || !requestUrl) { + return undefined; + } + + try { + const url = new URL(requestUrl); + const pathname = url.pathname.replace(/\/$/, ""); + if (pathname.endsWith("/predictions")) { + url.pathname = `${pathname}/${response.id}`; + return url.toString(); + } + } catch { + return undefined; + } + + return undefined; + } } export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper { @@ -94,21 +230,22 @@ export class ReplicateTextToImageTask extends ReplicateTask implements TextToIma outputType?: "url" | "blob" | "json" ): Promise> { void url; - void headers; + const finalResponse = (await this.ensureFinalResponse(res, url, headers)) as ReplicateOutput; + if ( - typeof res === "object" && - "output" in res && - Array.isArray(res.output) && - res.output.length > 0 && - typeof res.output[0] === "string" + typeof finalResponse === "object" && + "output" in finalResponse && + Array.isArray(finalResponse.output) && + finalResponse.output.length > 0 && + typeof finalResponse.output[0] === "string" ) { if (outputType === "json") { - return { ...res }; + return { ...finalResponse }; } if (outputType === "url") { - return res.output[0]; + return finalResponse.output[0]; } - const urlResponse = await fetch(res.output[0]); + const urlResponse = await fetch(finalResponse.output[0]); return await urlResponse.blob(); } @@ -130,17 +267,19 @@ export class ReplicateTextToSpeechTask extends ReplicateTask { return payload; } - override async getResponse(response: ReplicateOutput): Promise { - if (response instanceof Blob) { - return response; + override async getResponse(response: ReplicateOutput | Blob, url?: string, headers?: HeadersInit): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput | Blob; + + if (finalResponse instanceof Blob) { + return finalResponse; } - if (response && typeof response === "object") { - if ("output" in response) { - if (typeof response.output === "string") { - const urlResponse = await fetch(response.output); + if (finalResponse && typeof finalResponse === "object") { + if ("output" in finalResponse) { + if (typeof finalResponse.output === "string") { + const urlResponse = await fetch(finalResponse.output); return await urlResponse.blob(); - } else if (Array.isArray(response.output)) { - const urlResponse = await fetch(response.output[0]); + } else if (Array.isArray(finalResponse.output)) { + const urlResponse = await fetch(finalResponse.output[0]); return await urlResponse.blob(); } } @@ -150,15 +289,16 @@ export class ReplicateTextToSpeechTask extends ReplicateTask { } export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVideoTaskHelper { - override async getResponse(response: ReplicateOutput): Promise { + override async getResponse(response: ReplicateOutput | Blob, url?: string, headers?: HeadersInit): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; if ( - typeof response === "object" && - !!response && - "output" in response && - typeof response.output === "string" && - isUrl(response.output) + typeof finalResponse === "object" && + !!finalResponse && + "output" in finalResponse && + typeof finalResponse.output === "string" && + isUrl(finalResponse.output) ) { - const urlResponse = await fetch(response.output); + const urlResponse = await fetch(finalResponse.output); return await urlResponse.blob(); } @@ -199,11 +339,17 @@ export class ReplicateAutomaticSpeechRecognitionTask }; } - override async getResponse(response: ReplicateOutput): Promise { - if (typeof response?.output === "string") return { text: response.output }; - if (Array.isArray(response?.output) && typeof response.output[0] === "string") return { text: response.output[0] }; + override async getResponse( + response: ReplicateOutput | Blob, + url?: string, + headers?: HeadersInit + ): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; + if (typeof finalResponse?.output === "string") return { text: finalResponse.output }; + if (Array.isArray(finalResponse?.output) && typeof finalResponse.output[0] === "string") + return { text: finalResponse.output[0] }; - const out = response?.output as + const out = finalResponse?.output as | undefined | { transcription?: string; @@ -254,27 +400,28 @@ export class ReplicateImageToImageTask extends ReplicateTask implements ImageToI }; } - override async getResponse(response: ReplicateOutput): Promise { + override async getResponse(response: ReplicateOutput | Blob, url?: string, headers?: HeadersInit): Promise { + const finalResponse = (await this.ensureFinalResponse(response, url, headers)) as ReplicateOutput; if ( - typeof response === "object" && - !!response && - "output" in response && - Array.isArray(response.output) && - response.output.length > 0 && - typeof response.output[0] === "string" + typeof finalResponse === "object" && + !!finalResponse && + "output" in finalResponse && + Array.isArray(finalResponse.output) && + finalResponse.output.length > 0 && + typeof finalResponse.output[0] === "string" ) { - const urlResponse = await fetch(response.output[0]); + const urlResponse = await fetch(finalResponse.output[0]); return await urlResponse.blob(); } if ( - typeof response === "object" && - !!response && - "output" in response && - typeof response.output === "string" && - isUrl(response.output) + typeof finalResponse === "object" && + !!finalResponse && + "output" in finalResponse && + typeof finalResponse.output === "string" && + isUrl(finalResponse.output) ) { - const urlResponse = await fetch(response.output); + const urlResponse = await fetch(finalResponse.output); return await urlResponse.blob(); }