Skip to content

Commit 17fb3a7

Browse files
arabot777hanouticelinaSBrandeis
authored
[inference provider] Add wavespeed.ai as an inference provider (#1424)
**What’s in this PR** [WaveSpeedAI](https://wavespeed.ai/) is a high-performance AI image and video generation service platform, offering industry-leading generation speeds. Now, want to be listed as an Inference Provider on the Hugging Face Hub The JS Client Integration was completed based on the inference-providers help documentation and passed the test. I am submitting the pr now and look forward to further communication with you **Test** ``` pnpm --filter @huggingface/inference test "test/InferenceClient.spec.ts" -t "^Wavespeed AI" > @huggingface/inference@3.11.0 test /Users/shanliu/work/huggingface.js/packages/inference > vitest run --config vitest.config.mts "test/InferenceClient.spec.ts" RUN v0.34.6 /Users/shanliu/work/huggingface.js/packages/inference ✓ test/InferenceClient.spec.ts (104) 198160ms ✓ InferenceClient (104) 198160ms ✓ backward compatibility (1) ✓ works with old HfInference name ↓ HF Inference (49) [skipped] ↓ throws error if model does not exist [skipped] ↓ fillMask [skipped] ↓ works without model [skipped] ↓ summarization [skipped] ↓ questionAnswering [skipped] ↓ tableQuestionAnswering [skipped] ↓ documentQuestionAnswering [skipped] ↓ documentQuestionAnswering with non-array output [skipped] ↓ visualQuestionAnswering [skipped] ↓ textClassification [skipped] ↓ textGeneration - gpt2 [skipped] ↓ textGeneration - openai-community/gpt2 [skipped] ↓ textGenerationStream - meta-llama/Llama-3.2-3B [skipped] ↓ textGenerationStream - catch error [skipped] ↓ textGenerationStream - Abort [skipped] ↓ tokenClassification [skipped] ↓ translation [skipped] ↓ zeroShotClassification [skipped] ↓ sentenceSimilarity [skipped] ↓ FeatureExtraction [skipped] ↓ FeatureExtraction - auto-compatibility sentence similarity [skipped] ↓ FeatureExtraction - facebook/bart-base [skipped] ↓ FeatureExtraction - facebook/bart-base, list input [skipped] ↓ automaticSpeechRecognition [skipped] ↓ audioClassification [skipped] ↓ audioToAudio [skipped] ↓ textToSpeech [skipped] ↓ imageClassification [skipped] ↓ zeroShotImageClassification [skipped] ↓ objectDetection [skipped] ↓ imageSegmentation [skipped] ↓ imageToImage [skipped] ↓ imageToImage blob data [skipped] ↓ textToImage [skipped] ↓ textToImage with parameters [skipped] ↓ imageToText [skipped] ↓ request - openai-community/gpt2 [skipped] ↓ tabularRegression [skipped] ↓ tabularClassification [skipped] ↓ endpoint - makes request to specified endpoint [skipped] ↓ endpoint - makes request to specified endpoint - alternative syntax [skipped] ↓ chatCompletion modelId - OpenAI Specs [skipped] ↓ chatCompletionStream modelId - OpenAI Specs [skipped] ↓ chatCompletionStream modelId Fail - OpenAI Specs [skipped] ↓ chatCompletion - OpenAI Specs [skipped] ↓ chatCompletionStream - OpenAI Specs [skipped] ↓ custom mistral - OpenAI Specs [skipped] ↓ custom openai - OpenAI Specs [skipped] ↓ OpenAI client side routing - model should have provider as prefix [skipped] ↓ Fal AI (4) [skipped] ↓ textToImage - black-forest-labs/FLUX.1-schnell [skipped] ↓ textToImage - SD LoRAs [skipped] ↓ textToImage - Flux LoRAs [skipped] ↓ automaticSpeechRecognition - openai/whisper-large-v3 [skipped] ↓ Featherless (3) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ textGeneration [skipped] ↓ Replicate (10) [skipped] ↓ textToImage canonical - black-forest-labs/FLUX.1-schnell [skipped] ↓ textToImage canonical - black-forest-labs/FLUX.1-dev [skipped] ↓ textToImage canonical - stabilityai/stable-diffusion-3.5-large-turbo [skipped] ↓ textToImage versioned - ByteDance/SDXL-Lightning [skipped] ↓ textToImage versioned - ByteDance/Hyper-SD [skipped] ↓ textToImage versioned - playgroundai/playground-v2.5-1024px-aesthetic [skipped] ↓ textToImage versioned - stabilityai/stable-diffusion-xl-base-1.0 [skipped] ↓ textToSpeech versioned [skipped] ↓ textToSpeech OuteTTS - usually Cold [skipped] ↓ textToSpeech Kokoro [skipped] ↓ SambaNova (3) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ featureExtraction [skipped] ↓ Together (4) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ textToImage [skipped] ↓ textGeneration [skipped] ↓ Nebius (3) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ textToImage [skipped] ↓ 3rd party providers (1) [skipped] ↓ chatCompletion - fails with unsupported model [skipped] ↓ Fireworks (2) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ Hyperbolic (4) [skipped] ↓ chatCompletion - hyperbolic [skipped] ↓ chatCompletion stream [skipped] ↓ textToImage [skipped] ↓ textGeneration [skipped] ↓ Novita (2) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ Black Forest Labs (2) [skipped] ↓ textToImage [skipped] ↓ textToImage URL [skipped] ↓ Cohere (2) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ Cerebras (2) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ Nscale (3) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ textToImage [skipped] ↓ Groq (2) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ OVHcloud (4) [skipped] ↓ chatCompletion [skipped] ↓ chatCompletion stream [skipped] ↓ textGeneration [skipped] ↓ textGeneration stream [skipped] ✓ Wavespeed AI (5) 89033ms ✓ textToImage - wavespeed-ai/flux-schnell 89032ms ✓ textToImage - wavespeed-ai/flux-dev-lora 12369ms ✓ textToImage - wavespeed-ai/flux-dev-lora-ultra-fast 17936ms ✓ textToVideo - wavespeed-ai/wan-2.1/t2v-480p 79507ms ✓ imageToImage - wavespeed-ai/hidream-e1-full 74481ms Test Files 1 passed (1) Tests 5 passed | 103 skipped (108) Start at 14:33:17 Duration 89.62s (transform 315ms, setup 14ms, collect 368ms, tests 89.03s, environment 0ms, prepare 74ms) ``` --------- Co-authored-by: célina <hanouticelina@gmail.com> Co-authored-by: Simon Brandeis <33657802+SBrandeis@users.noreply.github.com>
1 parent bd64ddf commit 17fb3a7

File tree

6 files changed

+308
-0
lines changed

6 files changed

+308
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Currently, we support the following providers:
6767
- [Cohere](https://cohere.com)
6868
- [Cerebras](https://cerebras.ai/)
6969
- [Groq](https://groq.com)
70+
- [Wavespeed.ai](https://wavespeed.ai/)
7071
- [Z.ai](https://z.ai/)
7172

7273
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers.
@@ -105,6 +106,7 @@ Only a subset of models are supported when requesting third-party providers. You
105106
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
106107
- [Groq supported models](https://console.groq.com/docs/models)
107108
- [Novita AI supported models](https://huggingface.co/api/partners/novita/models)
109+
- [Wavespeed.ai supported models](https://huggingface.co/api/partners/wavespeed/models)
108110
- [Z.ai supported models](https://huggingface.co/api/partners/zai-org/models)
109111

110112
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ import * as Replicate from "../providers/replicate.js";
5252
import * as Sambanova from "../providers/sambanova.js";
5353
import * as Scaleway from "../providers/scaleway.js";
5454
import * as Together from "../providers/together.js";
55+
import * as Wavespeed from "../providers/wavespeed.js";
5556
import * as Zai from "../providers/zai-org.js";
5657
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js";
5758
import { InferenceClientInputError } from "../errors.js";
@@ -173,6 +174,11 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
173174
conversational: new Together.TogetherConversationalTask(),
174175
"text-generation": new Together.TogetherTextGenerationTask(),
175176
},
177+
wavespeed: {
178+
"text-to-image": new Wavespeed.WavespeedAITextToImageTask(),
179+
"text-to-video": new Wavespeed.WavespeedAITextToVideoTask(),
180+
"image-to-image": new Wavespeed.WavespeedAIImageToImageTask(),
181+
},
176182
"zai-org": {
177183
conversational: new Zai.ZaiConversationalTask(),
178184
},

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
3939
sambanova: {},
4040
scaleway: {},
4141
together: {},
42+
wavespeed: {},
4243
"zai-org": {},
4344
};
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import type { TextToImageArgs } from "../tasks/cv/textToImage.js";
2+
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
3+
import type { TextToVideoArgs } from "../tasks/cv/textToVideo.js";
4+
import type { BodyParams, RequestArgs, UrlParams } from "../types.js";
5+
import { delay } from "../utils/delay.js";
6+
import { omit } from "../utils/omit.js";
7+
import { base64FromBytes } from "../utils/base64FromBytes.js";
8+
import type { TextToImageTaskHelper, TextToVideoTaskHelper, ImageToImageTaskHelper } from "./providerHelper.js";
9+
import { TaskProviderHelper } from "./providerHelper.js";
10+
import {
11+
InferenceClientInputError,
12+
InferenceClientProviderApiError,
13+
InferenceClientProviderOutputError,
14+
} from "../errors.js";
15+
16+
const WAVESPEEDAI_API_BASE_URL = "https://api.wavespeed.ai";
17+
18+
/**
19+
* Response structure for task status and results
20+
*/
21+
interface WaveSpeedAITaskResponse {
22+
id: string;
23+
model: string;
24+
outputs: string[];
25+
urls: {
26+
get: string;
27+
};
28+
has_nsfw_contents: boolean[];
29+
status: "created" | "processing" | "completed" | "failed";
30+
created_at: string;
31+
error: string;
32+
executionTime: number;
33+
timings: {
34+
inference: number;
35+
};
36+
}
37+
38+
/**
39+
* Response structure for initial task submission
40+
*/
41+
interface WaveSpeedAISubmitResponse {
42+
id: string;
43+
urls: {
44+
get: string;
45+
};
46+
}
47+
48+
/**
49+
* Response structure for WaveSpeed AI API
50+
*/
51+
interface WaveSpeedAIResponse {
52+
code: number;
53+
message: string;
54+
data: WaveSpeedAITaskResponse;
55+
}
56+
57+
/**
58+
* Response structure for WaveSpeed AI API with submit response data
59+
*/
60+
interface WaveSpeedAISubmitTaskResponse {
61+
code: number;
62+
message: string;
63+
data: WaveSpeedAISubmitResponse;
64+
}
65+
66+
abstract class WavespeedAITask extends TaskProviderHelper {
67+
constructor(url?: string) {
68+
super("wavespeed", url || WAVESPEEDAI_API_BASE_URL);
69+
}
70+
71+
makeRoute(params: UrlParams): string {
72+
return `/api/v3/${params.model}`;
73+
}
74+
75+
preparePayload(params: BodyParams<ImageToImageArgs | TextToImageArgs | TextToVideoArgs>): Record<string, unknown> {
76+
const payload: Record<string, unknown> = {
77+
...omit(params.args, ["inputs", "parameters"]),
78+
...params.args.parameters,
79+
prompt: params.args.inputs,
80+
};
81+
// Add LoRA support if adapter is specified in the mapping
82+
if (params.mapping?.adapter === "lora") {
83+
payload.loras = [
84+
{
85+
path: params.mapping.hfModelId,
86+
scale: 1, // Default scale value
87+
},
88+
];
89+
}
90+
return payload;
91+
}
92+
93+
override async getResponse(
94+
response: WaveSpeedAISubmitTaskResponse,
95+
url?: string,
96+
headers?: Record<string, string>
97+
): Promise<Blob> {
98+
if (!headers) {
99+
throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls");
100+
}
101+
102+
const resultUrl = response.data.urls.get;
103+
104+
// Poll for results until completion
105+
while (true) {
106+
const resultResponse = await fetch(resultUrl, { headers });
107+
108+
if (!resultResponse.ok) {
109+
throw new InferenceClientProviderApiError(
110+
"Failed to fetch response status from WaveSpeed AI API",
111+
{ url: resultUrl, method: "GET" },
112+
{
113+
requestId: resultResponse.headers.get("x-request-id") ?? "",
114+
status: resultResponse.status,
115+
body: await resultResponse.text(),
116+
}
117+
);
118+
}
119+
120+
const result: WaveSpeedAIResponse = await resultResponse.json();
121+
const taskResult = result.data;
122+
123+
switch (taskResult.status) {
124+
case "completed": {
125+
// Get the media data from the first output URL
126+
if (!taskResult.outputs?.[0]) {
127+
throw new InferenceClientProviderOutputError(
128+
"Received malformed response from WaveSpeed AI API: No output URL in completed response"
129+
);
130+
}
131+
const mediaResponse = await fetch(taskResult.outputs[0]);
132+
if (!mediaResponse.ok) {
133+
throw new InferenceClientProviderApiError(
134+
"Failed to fetch generation output from WaveSpeed AI API",
135+
{ url: taskResult.outputs[0], method: "GET" },
136+
{
137+
requestId: mediaResponse.headers.get("x-request-id") ?? "",
138+
status: mediaResponse.status,
139+
body: await mediaResponse.text(),
140+
}
141+
);
142+
}
143+
return await mediaResponse.blob();
144+
}
145+
case "failed": {
146+
throw new InferenceClientProviderOutputError(taskResult.error || "Task failed");
147+
}
148+
149+
default: {
150+
// Wait before polling again
151+
await delay(500);
152+
continue;
153+
}
154+
}
155+
}
156+
}
157+
}
158+
159+
export class WavespeedAITextToImageTask extends WavespeedAITask implements TextToImageTaskHelper {
160+
constructor() {
161+
super(WAVESPEEDAI_API_BASE_URL);
162+
}
163+
}
164+
165+
export class WavespeedAITextToVideoTask extends WavespeedAITask implements TextToVideoTaskHelper {
166+
constructor() {
167+
super(WAVESPEEDAI_API_BASE_URL);
168+
}
169+
}
170+
171+
export class WavespeedAIImageToImageTask extends WavespeedAITask implements ImageToImageTaskHelper {
172+
constructor() {
173+
super(WAVESPEEDAI_API_BASE_URL);
174+
}
175+
176+
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
177+
return {
178+
...args,
179+
inputs: args.parameters?.prompt,
180+
image: base64FromBytes(
181+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
182+
),
183+
};
184+
}
185+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ export const INFERENCE_PROVIDERS = [
6666
"sambanova",
6767
"scaleway",
6868
"together",
69+
"wavespeed",
6970
"zai-org",
7071
] as const;
7172

packages/inference/test/InferenceClient.spec.ts

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2291,6 +2291,119 @@ describe.skip("InferenceClient", () => {
22912291
TIMEOUT
22922292
);
22932293

2294+
describe.concurrent(
2295+
"Wavespeed AI",
2296+
() => {
2297+
const client = new InferenceClient(env.HF_WAVESPEED_KEY ?? "dummy");
2298+
2299+
HARDCODED_MODEL_INFERENCE_MAPPING["wavespeed"] = {
2300+
"black-forest-labs/FLUX.1-schnell": {
2301+
provider: "wavespeed",
2302+
hfModelId: "black-forest-labs/FLUX.1-schnell",
2303+
providerId: "wavespeed-ai/flux-schnell",
2304+
status: "live",
2305+
task: "text-to-image",
2306+
},
2307+
"Wan-AI/Wan2.1-T2V-14B": {
2308+
provider: "wavespeed",
2309+
hfModelId: "wavespeed-ai/wan-2.1/t2v-480p",
2310+
providerId: "wavespeed-ai/wan-2.1/t2v-480p",
2311+
status: "live",
2312+
task: "text-to-video",
2313+
},
2314+
"HiDream-ai/HiDream-E1-Full": {
2315+
provider: "wavespeed",
2316+
hfModelId: "wavespeed-ai/hidream-e1-full",
2317+
providerId: "wavespeed-ai/hidream-e1-full",
2318+
status: "live",
2319+
task: "image-to-image",
2320+
},
2321+
"openfree/flux-chatgpt-ghibli-lora": {
2322+
provider: "wavespeed",
2323+
hfModelId: "openfree/flux-chatgpt-ghibli-lora",
2324+
providerId: "wavespeed-ai/flux-dev-lora",
2325+
status: "live",
2326+
task: "text-to-image",
2327+
adapter: "lora",
2328+
adapterWeightsPath: "flux-chatgpt-ghibli-lora.safetensors",
2329+
},
2330+
"linoyts/yarn_art_Flux_LoRA": {
2331+
provider: "wavespeed",
2332+
hfModelId: "linoyts/yarn_art_Flux_LoRA",
2333+
providerId: "wavespeed-ai/flux-dev-lora-ultra-fast",
2334+
status: "live",
2335+
task: "text-to-image",
2336+
adapter: "lora",
2337+
adapterWeightsPath: "pytorch_lora_weights.safetensors",
2338+
},
2339+
};
2340+
it(`textToImage - black-forest-labs/FLUX.1-schnell`, async () => {
2341+
const res = await client.textToImage({
2342+
model: "black-forest-labs/FLUX.1-schnell",
2343+
provider: "wavespeed",
2344+
inputs:
2345+
"Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.",
2346+
});
2347+
expect(res).toBeInstanceOf(Blob);
2348+
});
2349+
2350+
it(`textToImage - openfree/flux-chatgpt-ghibli-lora`, async () => {
2351+
const res = await client.textToImage({
2352+
model: "openfree/flux-chatgpt-ghibli-lora",
2353+
provider: "wavespeed",
2354+
inputs:
2355+
"Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.",
2356+
});
2357+
expect(res).toBeInstanceOf(Blob);
2358+
});
2359+
2360+
it(`textToImage - linoyts/yarn_art_Flux_LoRA`, async () => {
2361+
const res = await client.textToImage({
2362+
model: "linoyts/yarn_art_Flux_LoRA",
2363+
provider: "wavespeed",
2364+
inputs:
2365+
"Cute boy with a hat, exploring nature, holding a telescope, backpack, surrounded by flowers, cartoon style, vibrant colors.",
2366+
});
2367+
expect(res).toBeInstanceOf(Blob);
2368+
});
2369+
2370+
it(`textToVideo - Wan-AI/Wan2.1-T2V-14B`, async () => {
2371+
const res = await client.textToVideo({
2372+
model: "Wan-AI/Wan2.1-T2V-14B",
2373+
provider: "wavespeed",
2374+
inputs:
2375+
"A cool street dancer, wearing a baggy hoodie and hip-hop pants, dancing in front of a graffiti wall, night neon background, quick camera cuts, urban trends.",
2376+
parameters: {
2377+
guidance_scale: 5,
2378+
num_inference_steps: 30,
2379+
seed: -1,
2380+
},
2381+
duration: 5,
2382+
enable_safety_checker: true,
2383+
flow_shift: 2.9,
2384+
size: "480*832",
2385+
});
2386+
expect(res).toBeInstanceOf(Blob);
2387+
});
2388+
2389+
it(`imageToImage - HiDream-ai/HiDream-E1-Full`, async () => {
2390+
const res = await client.imageToImage({
2391+
model: "HiDream-ai/HiDream-E1-Full",
2392+
provider: "wavespeed",
2393+
inputs: new Blob([readTestFile("cheetah.png")], { type: "image/png" }),
2394+
parameters: {
2395+
prompt: "The leopard chases its prey",
2396+
guidance_scale: 5,
2397+
num_inference_steps: 30,
2398+
seed: -1,
2399+
},
2400+
});
2401+
expect(res).toBeInstanceOf(Blob);
2402+
});
2403+
},
2404+
TIMEOUT
2405+
);
2406+
22942407
describe.concurrent(
22952408
"PublicAI",
22962409
() => {

0 commit comments

Comments
 (0)