Skip to content

Commit 0a60860

Browse files
authored
Merge branch 'main' into add-provider-siliconflow
2 parents 57ce64d + a7cfae4 commit 0a60860

File tree

14 files changed

+130
-13
lines changed

14 files changed

+130
-13
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ You can run our packages with vanilla JS, without any bundler, by using a CDN or
9898

9999
```html
100100
<script type="module">
101-
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@4.11.3/+esm';
101+
import { InferenceClient } from 'https://cdn.jsdelivr.net/npm/@huggingface/inference@4.13.0/+esm';
102102
import { createRepo, commit, deleteRepo, listFiles } from "https://cdn.jsdelivr.net/npm/@huggingface/hub@2.6.12/+esm";
103103
</script>
104104
```

packages/hub/src/lib/parse-safetensors-metadata.spec.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ describe("parseSafetensorsMetadata", () => {
143143
assert.strictEqual(safetensorsShardFileInfo?.total, "00072");
144144
});
145145

146+
it("should detect sharded safetensors filename with 6 digits", async () => {
147+
const safetensorsFilename = "model-00001-of-000163.safetensors"; // https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/model-00001-of-000163.safetensors
148+
const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);
149+
150+
assert.strictEqual(safetensorsShardFileInfo?.prefix, "model-");
151+
assert.strictEqual(safetensorsShardFileInfo?.basePrefix, "model");
152+
assert.strictEqual(safetensorsShardFileInfo?.shard, "00001");
153+
assert.strictEqual(safetensorsShardFileInfo?.total, "000163");
154+
});
155+
146156
it("should support sub-byte data types", async () => {
147157
const newDataTypes: Array<"F4" | "F6_E2M3" | "F6_E3M2" | "E8M0"> = ["F4", "F6_E2M3", "F6_E3M2", "E8M0"];
148158

packages/hub/src/lib/parse-safetensors-metadata.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json";
1414
export const RE_SAFETENSORS_FILE = /\.safetensors$/;
1515
export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/;
1616
export const RE_SAFETENSORS_SHARD_FILE =
17-
/^(?<prefix>(?<basePrefix>.*?)[_-])(?<shard>\d{5})-of-(?<total>\d{5})\.safetensors$/;
17+
/^(?<prefix>(?<basePrefix>.*?)[_-])(?<shard>\d{5,6})-of-(?<total>\d{5,6})\.safetensors$/;
1818
export interface SafetensorsShardFileInfo {
1919
prefix: string;
2020
basePrefix: string;

packages/inference/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@huggingface/inference",
3-
"version": "4.11.3",
3+
"version": "4.13.0",
44
"packageManager": "pnpm@10.10.0",
55
"license": "MIT",
66
"author": "Hugging Face and Tim Mikeladze <tim.mikeladze@gmail.com>",

packages/inference/src/errors.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ export class InferenceClientInputError extends InferenceClientError {
1717
}
1818
}
1919

20+
export class InferenceClientRoutingError extends InferenceClientError {
21+
constructor(message: string) {
22+
super(message);
23+
this.name = "RoutingError";
24+
}
25+
}
26+
2027
interface HttpRequest {
2128
url: string;
2229
method: string;

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ export async function getInferenceProviderMapping(
124124
}
125125
): Promise<InferenceProviderMappingEntry | null> {
126126
const logger = getLogger();
127+
if (params.provider === ("auto" as InferenceProvider) && params.task === "conversational") {
128+
// Special case for auto + conversational to avoid extra API calls
129+
// Call directly the server-side auto router
130+
return {
131+
hfModelId: params.modelId,
132+
provider: "auto",
133+
providerId: params.modelId,
134+
status: "live",
135+
task: "conversational",
136+
};
137+
}
127138
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
128139
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
129140
}

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
182182
"text-to-image": new Wavespeed.WavespeedAITextToImageTask(),
183183
"text-to-video": new Wavespeed.WavespeedAITextToVideoTask(),
184184
"image-to-image": new Wavespeed.WavespeedAIImageToImageTask(),
185+
"image-to-video": new Wavespeed.WavespeedAIImageToVideoTask(),
185186
},
186187
"zai-org": {
187188
conversational: new Zai.ZaiConversationalTask(),

packages/inference/src/package.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
// Generated file from package.json. Issues importing JSON directly when publishing on commonjs/ESM - see https://github.com/microsoft/TypeScript/issues/51783
2-
export const PACKAGE_VERSION = "4.11.3";
2+
export const PACKAGE_VERSION = "4.13.0";
33
export const PACKAGE_NAME = "@huggingface/inference";

packages/inference/src/providers/providerHelper.ts

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import type {
4747
ZeroShotImageClassificationOutput,
4848
} from "@huggingface/tasks";
4949
import { HF_ROUTER_URL } from "../config.js";
50-
import { InferenceClientProviderOutputError } from "../errors.js";
50+
import { InferenceClientProviderOutputError, InferenceClientRoutingError } from "../errors.js";
5151
import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio.js";
5252
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, RequestArgs, UrlParams } from "../types.js";
5353
import { toArray } from "../utils/toArray.js";
@@ -62,7 +62,7 @@ import type { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js";
6262
export abstract class TaskProviderHelper {
6363
constructor(
6464
readonly provider: InferenceProvider,
65-
private baseUrl: string,
65+
protected baseUrl: string,
6666
readonly clientSideRoutingOnly: boolean = false
6767
) {}
6868

@@ -369,3 +369,16 @@ export class BaseTextGenerationTask extends TaskProviderHelper implements TextGe
369369
throw new InferenceClientProviderOutputError("Expected Array<{generated_text: string}>");
370370
}
371371
}
372+
373+
export class AutoRouterConversationalTask extends BaseConversationalTask {
374+
constructor() {
375+
super("auto" as InferenceProvider, "https://router.huggingface.co");
376+
}
377+
378+
override makeBaseUrl(params: UrlParams): string {
379+
if (params.authMethod !== "hf-token") {
380+
throw new InferenceClientRoutingError("Cannot select auto-router when using non-Hugging Face API key.");
381+
}
382+
return this.baseUrl;
383+
}
384+
}

packages/inference/src/providers/wavespeed.ts

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import type { TextToImageArgs } from "../tasks/cv/textToImage.js";
22
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
33
import type { TextToVideoArgs } from "../tasks/cv/textToVideo.js";
4+
import type { ImageToVideoArgs } from "../tasks/cv/imageToVideo.js";
45
import type { BodyParams, RequestArgs, UrlParams } from "../types.js";
56
import { delay } from "../utils/delay.js";
67
import { omit } from "../utils/omit.js";
78
import { base64FromBytes } from "../utils/base64FromBytes.js";
8-
import type { TextToImageTaskHelper, TextToVideoTaskHelper, ImageToImageTaskHelper } from "./providerHelper.js";
9+
import type {
10+
TextToImageTaskHelper,
11+
TextToVideoTaskHelper,
12+
ImageToImageTaskHelper,
13+
ImageToVideoTaskHelper,
14+
} from "./providerHelper.js";
915
import { TaskProviderHelper } from "./providerHelper.js";
1016
import {
1117
InferenceClientInputError,
@@ -72,7 +78,9 @@ abstract class WavespeedAITask extends TaskProviderHelper {
7278
return `/api/v3/${params.model}`;
7379
}
7480

75-
preparePayload(params: BodyParams<ImageToImageArgs | TextToImageArgs | TextToVideoArgs>): Record<string, unknown> {
81+
preparePayload(
82+
params: BodyParams<ImageToImageArgs | TextToImageArgs | TextToVideoArgs | ImageToVideoArgs>
83+
): Record<string, unknown> {
7684
const payload: Record<string, unknown> = {
7785
...omit(params.args, ["inputs", "parameters"]),
7886
...params.args.parameters,
@@ -95,11 +103,17 @@ abstract class WavespeedAITask extends TaskProviderHelper {
95103
url?: string,
96104
headers?: Record<string, string>
97105
): Promise<Blob> {
98-
if (!headers) {
106+
if (!url || !headers) {
99107
throw new InferenceClientInputError("Headers are required for WaveSpeed AI API calls");
100108
}
101109

102-
const resultUrl = response.data.urls.get;
110+
const parsedUrl = new URL(url);
111+
const resultPath = new URL(response.data.urls.get).pathname;
112+
/// override the base url to use the router.huggingface.co if going through huggingface router
113+
const baseUrl = `${parsedUrl.protocol}//${parsedUrl.host}${
114+
parsedUrl.host === "router.huggingface.co" ? "/wavespeed" : ""
115+
}`;
116+
const resultUrl = `${baseUrl}${resultPath}`;
103117

104118
// Poll for results until completion
105119
while (true) {
@@ -183,3 +197,19 @@ export class WavespeedAIImageToImageTask extends WavespeedAITask implements Imag
183197
};
184198
}
185199
}
200+
201+
export class WavespeedAIImageToVideoTask extends WavespeedAITask implements ImageToVideoTaskHelper {
202+
constructor() {
203+
super(WAVESPEEDAI_API_BASE_URL);
204+
}
205+
206+
async preparePayloadAsync(args: ImageToVideoArgs): Promise<RequestArgs> {
207+
return {
208+
...args,
209+
inputs: args.parameters?.prompt,
210+
image: base64FromBytes(
211+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
212+
),
213+
};
214+
}
215+
}

0 commit comments

Comments
 (0)