Skip to content

Commit 44d9891

Browse files
authored
Fix logic for merging default onDeviceParams with user-provided onDeviceParams (#9314)
1 parent c8263c4 commit 44d9891

File tree

3 files changed

+83
-5
lines changed

3 files changed

+83
-5
lines changed

.changeset/rare-hats-know.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@firebase/ai': patch
3+
---
4+
5+
Fix logic for merging default `onDeviceParams` with user-provided `onDeviceParams`.

packages/ai/src/methods/chrome-adapter-browser.test.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,63 @@ describe('ChromeAdapter', () => {
7878
expectedInputs: [{ type: 'image' }]
7979
});
8080
});
81+
it('sets image as expected input type by default even if other onDeviceParams params are set', async () => {
82+
const languageModelProvider = {
83+
availability: () => Promise.resolve(Availability.AVAILABLE)
84+
} as LanguageModel;
85+
const availabilityStub = stub(
86+
languageModelProvider,
87+
'availability'
88+
).resolves(Availability.AVAILABLE);
89+
const adapter = new ChromeAdapterImpl(
90+
languageModelProvider,
91+
InferenceMode.PREFER_ON_DEVICE,
92+
{
93+
promptOptions: {}
94+
}
95+
);
96+
await adapter.isAvailable({
97+
contents: [
98+
{
99+
role: 'user',
100+
parts: [{ text: 'hi' }]
101+
}
102+
]
103+
});
104+
expect(availabilityStub).to.have.been.calledWith({
105+
expectedInputs: [{ type: 'image' }]
106+
});
107+
});
108+
it('sets image as expected input type by default even if other createOptions params are set', async () => {
109+
const languageModelProvider = {
110+
availability: () => Promise.resolve(Availability.AVAILABLE)
111+
} as LanguageModel;
112+
const availabilityStub = stub(
113+
languageModelProvider,
114+
'availability'
115+
).resolves(Availability.AVAILABLE);
116+
const adapter = new ChromeAdapterImpl(
117+
languageModelProvider,
118+
InferenceMode.PREFER_ON_DEVICE,
119+
{
120+
createOptions: {
121+
topK: 22
122+
}
123+
}
124+
);
125+
await adapter.isAvailable({
126+
contents: [
127+
{
128+
role: 'user',
129+
parts: [{ text: 'hi' }]
130+
}
131+
]
132+
});
133+
expect(availabilityStub).to.have.been.calledWith({
134+
topK: 22,
135+
expectedInputs: [{ type: 'image' }]
136+
});
137+
});
81138
it('honors explicitly set expected inputs', async () => {
82139
const languageModelProvider = {
83140
availability: () => Promise.resolve(Availability.AVAILABLE)

packages/ai/src/methods/chrome-adapter.ts

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@ import { ChromeAdapter } from '../types/chrome-adapter';
3131
import {
3232
Availability,
3333
LanguageModel,
34+
LanguageModelExpected,
3435
LanguageModelMessage,
3536
LanguageModelMessageContent,
3637
LanguageModelMessageRole
3738
} from '../types/language-model';
3839

40+
// Defaults to support image inputs for convenience.
41+
const defaultExpectedInputs: LanguageModelExpected[] = [{ type: 'image' }];
42+
3943
/**
4044
* Defines an inference "backend" that uses Chrome's on-device model,
4145
* and encapsulates logic for detecting when on-device inference is
@@ -47,16 +51,28 @@ export class ChromeAdapterImpl implements ChromeAdapter {
4751
private isDownloading = false;
4852
private downloadPromise: Promise<LanguageModel | void> | undefined;
4953
private oldSession: LanguageModel | undefined;
54+
onDeviceParams: OnDeviceParams = {
55+
createOptions: {
56+
expectedInputs: defaultExpectedInputs
57+
}
58+
};
5059
constructor(
5160
public languageModelProvider: LanguageModel,
5261
public mode: InferenceMode,
53-
public onDeviceParams: OnDeviceParams = {
54-
createOptions: {
55-
// Defaults to support image inputs for convenience.
56-
expectedInputs: [{ type: 'image' }]
62+
onDeviceParams?: OnDeviceParams
63+
) {
64+
if (onDeviceParams) {
65+
this.onDeviceParams = onDeviceParams;
66+
if (!this.onDeviceParams.createOptions) {
67+
this.onDeviceParams.createOptions = {
68+
expectedInputs: defaultExpectedInputs
69+
};
70+
} else if (!this.onDeviceParams.createOptions.expectedInputs) {
71+
this.onDeviceParams.createOptions.expectedInputs =
72+
defaultExpectedInputs;
5773
}
5874
}
59-
) {}
75+
}
6076

6177
/**
6278
* Checks if a given request can be made on-device.

0 commit comments

Comments
 (0)