Skip to content

Commit 7e6267d

Browse files
authored
Enable returning LanguageModelDataPart from chat provider (microsoft#259843)
* Start making datapart work with an LMChatProvider * Cleanup IChatResponseDataPart * Handle buffers in LM ipc correctly * Cleanup
1 parent 2b68877 commit 7e6267d

File tree

6 files changed

+33
-30
lines changed

6 files changed

+33
-30
lines changed

src/vs/workbench/api/browser/mainThreadLanguageModels.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
103103
this._lmProviderChange.fire({ vendor });
104104
}
105105

106-
async $reportResponsePart(requestId: number, chunk: IChatResponseFragment | IChatResponseFragment[]): Promise<void> {
106+
async $reportResponsePart(requestId: number, chunk: SerializableObjectWithBuffers<IChatResponseFragment | IChatResponseFragment[]>): Promise<void> {
107107
const data = this._pendingProgress.get(requestId);
108108
this._logService.trace('[LM] report response PART', Boolean(data), requestId, chunk);
109109
if (data) {
110-
data.stream.emitOne(chunk);
110+
data.stream.emitOne(chunk.value);
111111
}
112112
}
113113

@@ -154,7 +154,7 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
154154
try {
155155
for await (const part of response.stream) {
156156
this._logService.trace('[CHAT] request PART', extension.value, requestId, part);
157-
await this._proxy.$acceptResponsePart(requestId, part);
157+
await this._proxy.$acceptResponsePart(requestId, new SerializableObjectWithBuffers(part));
158158
}
159159
this._logService.trace('[CHAT] request DONE', extension.value, requestId);
160160
} catch (err) {

src/vs/workbench/api/common/extHost.protocol.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,7 +1262,7 @@ export interface MainThreadLanguageModelsShape extends IDisposable {
12621262
$onLMProviderChange(vendor: string): void;
12631263
$unregisterProvider(vendor: string): void;
12641264
$tryStartChatRequest(extension: ExtensionIdentifier, modelIdentifier: string, requestId: number, messages: SerializableObjectWithBuffers<IChatMessage[]>, options: {}, token: CancellationToken): Promise<void>;
1265-
$reportResponsePart(requestId: number, chunk: IChatResponseFragment | IChatResponseFragment[]): Promise<void>;
1265+
$reportResponsePart(requestId: number, chunk: SerializableObjectWithBuffers<IChatResponseFragment | IChatResponseFragment[]>): Promise<void>;
12661266
$reportResponseDone(requestId: number, error: SerializedError | undefined): Promise<void>;
12671267
$selectChatModels(selector: ILanguageModelChatSelector): Promise<string[]>;
12681268
$countTokens(modelId: string, value: string | IChatMessage, token: CancellationToken): Promise<number>;
@@ -1275,7 +1275,7 @@ export interface ExtHostLanguageModelsShape {
12751275
$prepareLanguageModelProvider(vendor: string, options: { silent: boolean }, token: CancellationToken): Promise<ILanguageModelChatMetadataAndIdentifier[]>;
12761276
$updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void;
12771277
$startChatRequest(modelId: string, requestId: number, from: ExtensionIdentifier, messages: SerializableObjectWithBuffers<IChatMessage[]>, options: { [name: string]: any }, token: CancellationToken): Promise<void>;
1278-
$acceptResponsePart(requestId: number, chunk: IChatResponseFragment | IChatResponseFragment[]): Promise<void>;
1278+
$acceptResponsePart(requestId: number, chunk: SerializableObjectWithBuffers<IChatResponseFragment | IChatResponseFragment[]>): Promise<void>;
12791279
$acceptResponseDone(requestId: number, error: SerializedError | undefined): Promise<void>;
12801280
$provideTokenLength(modelId: string, value: string | IChatMessage, token: CancellationToken): Promise<number>;
12811281
$isFileIgnored(handle: number, uri: UriComponents, token: CancellationToken): Promise<boolean>;

src/vs/workbench/api/common/extHostLanguageModels.ts

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import type * as vscode from 'vscode';
77
import { AsyncIterableObject, AsyncIterableSource, RunOnceScheduler } from '../../../base/common/async.js';
8+
import { VSBuffer } from '../../../base/common/buffer.js';
89
import { CancellationToken } from '../../../base/common/cancellation.js';
910
import { SerializedError, transformErrorForSerialization, transformErrorFromSerialization } from '../../../base/common/errors.js';
1011
import { Emitter, Event } from '../../../base/common/event.js';
@@ -16,17 +17,16 @@ import { ExtensionIdentifier, ExtensionIdentifierMap, ExtensionIdentifierSet, IE
1617
import { createDecorator } from '../../../platform/instantiation/common/instantiation.js';
1718
import { ILogService } from '../../../platform/log/common/log.js';
1819
import { Progress } from '../../../platform/progress/common/progress.js';
19-
import { ChatImageMimeType, IChatMessage, IChatResponseFragment, IChatResponsePart, ILanguageModelChatMetadata, ILanguageModelChatMetadataAndIdentifier } from '../../contrib/chat/common/languageModels.js';
20+
import { IChatMessage, IChatResponseFragment, IChatResponsePart, ILanguageModelChatMetadata, ILanguageModelChatMetadataAndIdentifier } from '../../contrib/chat/common/languageModels.js';
21+
import { DEFAULT_MODEL_PICKER_CATEGORY } from '../../contrib/chat/common/modelPicker/modelPickerWidget.js';
2022
import { INTERNAL_AUTH_PROVIDER_PREFIX } from '../../services/authentication/common/authentication.js';
2123
import { checkProposedApiEnabled } from '../../services/extensions/common/extensions.js';
24+
import { SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js';
2225
import { ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from './extHost.protocol.js';
2326
import { IExtHostAuthentication } from './extHostAuthentication.js';
2427
import { IExtHostRpcService } from './extHostRpcService.js';
2528
import * as typeConvert from './extHostTypeConverters.js';
2629
import * as extHostTypes from './extHostTypes.js';
27-
import { SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js';
28-
import { VSBuffer } from '../../../base/common/buffer.js';
29-
import { DEFAULT_MODEL_PICKER_CATEGORY } from '../../contrib/chat/common/modelPicker/modelPickerWidget.js';
3030

3131
export interface IExtHostLanguageModels extends ExtHostLanguageModels { }
3232

@@ -38,15 +38,17 @@ type LanguageModelProviderData = {
3838
readonly provider: vscode.LanguageModelChatProvider2;
3939
};
4040

41+
type LMResponsePart = vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart | vscode.LanguageModelDataPart;
42+
4143
class LanguageModelResponseStream {
4244

43-
readonly stream = new AsyncIterableSource<vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart>();
45+
readonly stream = new AsyncIterableSource<LMResponsePart>();
4446

4547
constructor(
4648
readonly option: number,
47-
stream?: AsyncIterableSource<vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart>
49+
stream?: AsyncIterableSource<LMResponsePart>
4850
) {
49-
this.stream = stream ?? new AsyncIterableSource<vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart>();
51+
this.stream = stream ?? new AsyncIterableSource<LMResponsePart>();
5052
}
5153
}
5254

@@ -55,7 +57,7 @@ class LanguageModelResponse {
5557
readonly apiObject: vscode.LanguageModelChatResponse;
5658

5759
private readonly _responseStreams = new Map<number, LanguageModelResponseStream>();
58-
private readonly _defaultStream = new AsyncIterableSource<vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart>();
60+
private readonly _defaultStream = new AsyncIterableSource<LMResponsePart>();
5961
private _isDone: boolean = false;
6062

6163
constructor() {
@@ -93,15 +95,15 @@ class LanguageModelResponse {
9395
return;
9496
}
9597

96-
const partsByIndex = new Map<number, (vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart)[]>();
98+
const partsByIndex = new Map<number, LMResponsePart[]>();
9799

98100
for (const fragment of Iterable.wrap(fragments)) {
99101

100-
let out: vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart;
102+
let out: LMResponsePart;
101103
if (fragment.part.type === 'text') {
102104
out = new extHostTypes.LanguageModelTextPart(fragment.part.value, fragment.part.audience);
103105
} else if (fragment.part.type === 'data') {
104-
out = new extHostTypes.LanguageModelTextPart('');
106+
out = new extHostTypes.LanguageModelDataPart(fragment.part.data.buffer, fragment.part.mimeType, fragment.part.audience);
105107
} else {
106108
out = new extHostTypes.LanguageModelToolCallPart(fragment.part.toolCallId, fragment.part.name, fragment.part.parameters);
107109
}
@@ -270,7 +272,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
270272
const queue: IChatResponseFragment[] = [];
271273
const sendNow = () => {
272274
if (queue.length > 0) {
273-
this._proxy.$reportResponsePart(requestId, queue);
275+
this._proxy.$reportResponsePart(requestId, new SerializableObjectWithBuffers(queue));
274276
queue.length = 0;
275277
}
276278
};
@@ -298,7 +300,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
298300
} else if (fragment.part instanceof extHostTypes.LanguageModelTextPart) {
299301
part = { type: 'text', value: fragment.part.value, audience: fragment.part.audience };
300302
} else if (fragment.part instanceof extHostTypes.LanguageModelDataPart) {
301-
part = { type: 'data', value: { mimeType: fragment.part.mimeType as ChatImageMimeType, data: VSBuffer.wrap(fragment.part.data) }, audience: fragment.part.audience };
303+
part = { type: 'data', mimeType: fragment.part.mimeType, data: VSBuffer.wrap(fragment.part.data), audience: fragment.part.audience };
302304
}
303305

304306
if (!part) {
@@ -482,10 +484,10 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
482484
return internalMessages;
483485
}
484486

485-
async $acceptResponsePart(requestId: number, chunk: IChatResponseFragment | IChatResponseFragment[]): Promise<void> {
487+
async $acceptResponsePart(requestId: number, chunk: SerializableObjectWithBuffers<IChatResponseFragment | IChatResponseFragment[]>): Promise<void> {
486488
const data = this._pendingRequest.get(requestId);
487489
if (data) {
488-
data.res.handleFragment(chunk);
490+
data.res.handleFragment(chunk.value);
489491
}
490492
}
491493

src/vs/workbench/api/common/extHostTypeConverters.ts

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,13 +2318,15 @@ export namespace LanguageModelChatMessage {
23182318
if (c.type === 'text') {
23192319
return new LanguageModelTextPart(c.value, c.audience);
23202320
} else if (c.type === 'tool_result') {
2321-
const content: (LanguageModelTextPart | LanguageModelPromptTsxPart)[] = c.value.map(part => {
2321+
const content: (LanguageModelTextPart | LanguageModelPromptTsxPart)[] = coalesce(c.value.map(part => {
23222322
if (part.type === 'text') {
23232323
return new types.LanguageModelTextPart(part.value, part.audience);
2324-
} else {
2324+
} else if (part.type === 'prompt_tsx') {
23252325
return new types.LanguageModelPromptTsxPart(part.value);
2326+
} else {
2327+
return undefined; // Strip unknown parts
23262328
}
2327-
});
2329+
}));
23282330
return new types.LanguageModelToolResultPart(c.toolCallId, content, c.isError);
23292331
} else if (c.type === 'image_url') {
23302332
// Non-stable types
@@ -2418,7 +2420,7 @@ export namespace LanguageModelChatMessage2 {
24182420
if (part.type === 'text') {
24192421
return new types.LanguageModelTextPart(part.value, part.audience);
24202422
} else if (part.type === 'data') {
2421-
return new types.LanguageModelDataPart(part.value.data.buffer, part.value.mimeType);
2423+
return new types.LanguageModelDataPart(part.data.buffer, part.mimeType);
24222424
} else {
24232425
return new types.LanguageModelPromptTsxPart(part.value);
24242426
}
@@ -2467,10 +2469,8 @@ export namespace LanguageModelChatMessage2 {
24672469
} else if (part instanceof types.LanguageModelDataPart) {
24682470
return {
24692471
type: 'data',
2470-
value: {
2471-
mimeType: part.mimeType as chatProvider.ChatImageMimeType,
2472-
data: VSBuffer.wrap(part.data)
2473-
},
2472+
mimeType: part.mimeType,
2473+
data: VSBuffer.wrap(part.data),
24742474
audience: part.audience
24752475
} satisfies IChatResponseDataPart;
24762476
} else {

src/vs/workbench/contrib/chat/common/languageModels.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ export interface IChatResponsePromptTsxPart {
111111

112112
export interface IChatResponseDataPart {
113113
type: 'data';
114-
value: IChatImageURLPart;
114+
mimeType: string;
115+
data: VSBuffer;
115116
audience?: LanguageModelPartAudience[];
116117
}
117118

src/vscode-dts/vscode.proposed.chatProvider.d.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,6 @@ declare module 'vscode' {
139139

140140
export interface ChatResponseFragment2 {
141141
index: number;
142-
part: LanguageModelTextPart | LanguageModelToolCallPart;
142+
part: LanguageModelTextPart | LanguageModelToolCallPart | LanguageModelDataPart;
143143
}
144144
}

0 commit comments

Comments
 (0)