Skip to content

Commit 220051e

Browse files
authored
[Inference Providers] Snippets: prefer the namespace/model:provider syntax for conversational (#1830)
# TL;DR For conversational task, when using the Inference Clients libraries, prefer the `namespace/model:provider` syntax over explicitly specifying the provider in the parameters Also don't specify `provider="auto"` because it's the default
1 parent e5089f8 commit 220051e

File tree

36 files changed

+36
-53
lines changed

36 files changed

+36
-53
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ interface TemplateParams {
6060
importBase64?: boolean; // specific to snippetImportRequests
6161
importJson?: boolean; // specific to snippetImportRequests
6262
endpointUrl?: string;
63+
task?: InferenceTask;
64+
directRequest?: boolean;
6365
}
6466

6567
// Helpers to find + load templates
@@ -263,6 +265,8 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
263265
: providerModelId ?? model.id,
264266
billTo: opts?.billTo,
265267
endpointUrl: opts?.endpointUrl,
268+
task,
269+
directRequest: !!opts?.directRequest,
266270
};
267271

268272
/// Iterate over clients => check if a snippet exists => generate

packages/inference/src/snippets/templates/js/huggingface.js/conversational.jinja

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@ const chatCompletion = await client.chatCompletion({
66
{% if endpointUrl %}
77
endpointUrl: "{{ endpointUrl }}",
88
{% endif %}
9+
{% if directRequest %}
910
provider: "{{ provider }}",
1011
model: "{{ model.id }}",
12+
{% else %}
13+
model: "{{ providerModelId }}",
14+
{% endif %}
1115
{{ inputs.asTsString }}
1216
}{% if billTo %}, {
1317
billTo: "{{ billTo }}",

packages/inference/src/snippets/templates/js/huggingface.js/conversationalStream.jinja

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ const stream = client.chatCompletionStream({
88
{% if endpointUrl %}
99
endpointUrl: "{{ endpointUrl }}",
1010
{% endif %}
11-
provider: "{{ provider }}",
12-
model: "{{ model.id }}",
11+
model: "{{ providerModelId }}",
1312
{{ inputs.asTsString }}
1413
}{% if billTo %}, {
1514
billTo: "{{ billTo }}",

packages/inference/src/snippets/templates/python/huggingface_hub/conversational.jinja

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
completion = client.chat.completions.create(
2+
{% if directRequest %}
23
model="{{ model.id }}",
4+
{% else %}
5+
model="{{ providerModelId }}",
6+
{% endif %}
37
{{ inputs.asPythonString }}
48
)
59

packages/inference/src/snippets/templates/python/huggingface_hub/conversationalStream.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
stream = client.chat.completions.create(
2-
model="{{ model.id }}",
2+
model="{{ providerModelId }}",
33
{{ inputs.asPythonString }}
44
stream=True,
55
)

packages/inference/src/snippets/templates/python/huggingface_hub/importInferenceClient.jinja

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ client = InferenceClient(
44
{% if endpointUrl %}
55
base_url="{{ baseUrl }}",
66
{% endif %}
7+
{% if task != "conversational" or directRequest %}
78
provider="{{ provider }}",
9+
{% endif %}
810
api_key="{{ accessToken }}",
911
{% if billTo %}
1012
bill_to="{{ billTo }}",

packages/tasks-gen/snippets-fixtures/bill-to-param/js/huggingface.js/0.hf-inference.js

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ import { InferenceClient } from "@huggingface/inference";
33
const client = new InferenceClient(process.env.HF_TOKEN);
44

55
const chatCompletion = await client.chatCompletion({
6-
provider: "hf-inference",
7-
model: "meta-llama/Llama-3.1-8B-Instruct",
6+
model: "meta-llama/Llama-3.1-8B-Instruct:hf-inference",
87
messages: [
98
{
109
role: "user",

packages/tasks-gen/snippets-fixtures/bill-to-param/python/huggingface_hub/0.hf-inference.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
from huggingface_hub import InferenceClient
33

44
client = InferenceClient(
5-
provider="hf-inference",
65
api_key=os.environ["HF_TOKEN"],
76
bill_to="huggingface",
87
)
98

109
completion = client.chat.completions.create(
11-
model="meta-llama/Llama-3.1-8B-Instruct",
10+
model="meta-llama/Llama-3.1-8B-Instruct:hf-inference",
1211
messages=[
1312
{
1413
"role": "user",

packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/js/huggingface.js/0.hf-inference.js

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ const client = new InferenceClient(process.env.API_TOKEN);
44

55
const chatCompletion = await client.chatCompletion({
66
endpointUrl: "http://localhost:8080/v1",
7-
provider: "hf-inference",
87
model: "meta-llama/Llama-3.1-8B-Instruct",
98
messages: [
109
{

packages/tasks-gen/snippets-fixtures/conversational-llm-custom-endpoint/python/huggingface_hub/0.hf-inference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
client = InferenceClient(
55
base_url="http://localhost:8080/v1",
6-
provider="hf-inference",
76
api_key=os.environ["API_TOKEN"],
87
)
98

0 commit comments

Comments
 (0)