diff --git a/src/extension/agents/copilotcli/node/copilotCli.ts b/src/extension/agents/copilotcli/node/copilotCli.ts index 3b1af7158c..843461511e 100644 --- a/src/extension/agents/copilotcli/node/copilotCli.ts +++ b/src/extension/agents/copilotcli/node/copilotCli.ts @@ -13,6 +13,7 @@ import { IWorkspaceService } from '../../../../platform/workspace/common/workspa import { createServiceIdentifier } from '../../../../util/common/services'; import { Lazy } from '../../../../util/vs/base/common/lazy'; import { Disposable, IDisposable, toDisposable } from '../../../../util/vs/base/common/lifecycle'; +import { ILanguageModelServer } from '../../node/langModelServer'; import { getCopilotLogger } from './logger'; import { ensureNodePtyShim } from './nodePtyShim'; @@ -116,15 +117,19 @@ export class CopilotCLISessionOptionsService implements ICopilotCLISessionOption @IWorkspaceService private readonly workspaceService: IWorkspaceService, @IAuthenticationService private readonly _authenticationService: IAuthenticationService, @ILogService private readonly logService: ILogService, + @ILanguageModelServer private readonly languageModelServer: ILanguageModelServer, ) { } public async createOptions(options: SessionOptions, permissionHandler: CopilotCLIPermissionsHandler) { const copilotToken = await this._authenticationService.getAnyGitHubSession(); const workingDirectory = options.workingDirectory ?? await this.getWorkspaceFolderPath(); + const serverConfig = this.languageModelServer.getConfig(); const allOptions: SessionOptions = { env: { ...process.env, - COPILOTCLI_DISABLE_NONESSENTIAL_TRAFFIC: '1' + COPILOTCLI_DISABLE_NONESSENTIAL_TRAFFIC: '1', + OPENAI_BASE_URL: `http://localhost:${serverConfig.port}`, + OPENAI_API_KEY: serverConfig.nonce, }, logger: getCopilotLogger(this.logService), requestPermission: async (permissionRequest) => { diff --git a/src/extension/agents/node/adapters/openaiAdapter.spec.ts b/src/extension/agents/node/adapters/openaiAdapter.spec.ts new file mode 100644 index 0000000000..e437cd1898 --- /dev/null +++ b/src/extension/agents/node/adapters/openaiAdapter.spec.ts @@ -0,0 +1,162 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as http from 'http'; +import { describe, expect, it } from 'vitest'; +import { OpenAIAdapterFactory } from './openaiAdapter'; + +describe('OpenAIAdapterFactory', () => { + it('should create an OpenAI adapter instance', () => { + const factory = new OpenAIAdapterFactory(); + const adapter = factory.createAdapter(); + + // Verify the adapter has the correct name + expect(adapter.name).toBe('openai'); + }); + + it('should parse a basic OpenAI request', () => { + const factory = new OpenAIAdapterFactory(); + const adapter = factory.createAdapter(); + + const requestBody = { + model: 'gpt-4o', + messages: [ + { role: 'user', content: 'Hello' } + ], + temperature: 0.7 + }; + + const parsedRequest = adapter.parseRequest(JSON.stringify(requestBody)); + + expect(parsedRequest.model).toBe('gpt-4o'); + expect(parsedRequest.messages).toHaveLength(1); + expect(parsedRequest.messages[0]).toEqual({ role: 'user', content: 'Hello' }); + expect(parsedRequest.options?.temperature).toBe(0.7); + }); + + it('should parse an OpenAI request with tools', () => { + const factory = new OpenAIAdapterFactory(); + const adapter = factory.createAdapter(); + + const requestBody = { + model: 'gpt-4o', + messages: [ + { role: 'user', content: 'What is the weather?' } + ], + tools: [ + { + type: 'function', + function: { + name: 'get_weather', + description: 'Get the current weather', + parameters: { + type: 'object', + properties: { + location: { type: 'string' } + } + } + } + } + ] + }; + + const parsedRequest = adapter.parseRequest(JSON.stringify(requestBody)); + + expect(parsedRequest.model).toBe('gpt-4o'); + expect(parsedRequest.messages).toHaveLength(1); + expect(parsedRequest.options?.tools).toBeDefined(); + expect(parsedRequest.options?.tools).toHaveLength(1); + }); + + it('should extract auth key from headers', () => { + const factory = new OpenAIAdapterFactory(); + const adapter = factory.createAdapter(); + + const headers: http.IncomingHttpHeaders = { + 'authorization': 'Bearer test-key-123' + }; + + const authKey = adapter.extractAuthKey(headers); + + expect(authKey).toBe('test-key-123'); + }); + + it('should format text stream response', () => { + const factory = new OpenAIAdapterFactory(); + const adapter = factory.createAdapter(); + + const context = { + requestId: 'test-request-id', + endpoint: { + modelId: 'gpt-4o', + modelMaxPromptTokens: 128000 + } + }; + + const streamData = { + type: 'text' as const, + content: 'Hello, world!' + }; + + const events = adapter.formatStreamResponse(streamData, context); + + expect(events).toHaveLength(1); + expect(events[0].event).toBe('message'); + expect(events[0].data).toContain('Hello, world!'); + }); + + it('should format tool call stream response', () => { + const factory = new OpenAIAdapterFactory(); + const adapter = factory.createAdapter(); + + const context = { + requestId: 'test-request-id', + endpoint: { + modelId: 'gpt-4o', + modelMaxPromptTokens: 128000 + } + }; + + const streamData = { + type: 'tool_call' as const, + callId: 'call_123', + name: 'get_weather', + input: { location: 'Boston' } + }; + + const events = adapter.formatStreamResponse(streamData, context); + + expect(events).toHaveLength(1); + expect(events[0].event).toBe('message'); + expect(events[0].data).toContain('get_weather'); + expect(events[0].data).toContain('Boston'); + }); + + it('should generate final events with usage', () => { + const factory = new OpenAIAdapterFactory(); + const adapter = factory.createAdapter(); + + const context = { + requestId: 'test-request-id', + endpoint: { + modelId: 'gpt-4o', + modelMaxPromptTokens: 128000 + } + }; + + const usage = { + prompt_tokens: 10, + completion_tokens: 20, + total_tokens: 30 + }; + + const events = adapter.generateFinalEvents(context, usage); + + expect(events).toHaveLength(1); + expect(events[0].event).toBe('message'); + expect(events[0].data).toContain('"prompt_tokens":10'); + expect(events[0].data).toContain('"completion_tokens":20'); + }); +}); \ No newline at end of file diff --git a/src/extension/agents/node/adapters/openaiAdapter.ts b/src/extension/agents/node/adapters/openaiAdapter.ts new file mode 100644 index 0000000000..7d5678b025 --- /dev/null +++ b/src/extension/agents/node/adapters/openaiAdapter.ts @@ -0,0 +1,196 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as http from 'http'; +import type { OpenAiFunctionTool } from '../../../../platform/networking/common/fetch'; +import { IMakeChatRequestOptions } from '../../../../platform/networking/common/networking'; +import { APIUsage } from '../../../../platform/networking/common/openai'; +import { coalesce } from '../../../../util/vs/base/common/arrays'; +import { IAgentStreamBlock, IParsedRequest, IProtocolAdapter, IProtocolAdapterFactory, IStreamEventData, IStreamingContext } from './types'; + +export class OpenAIAdapterFactory implements IProtocolAdapterFactory { + createAdapter(): IProtocolAdapter { + return new OpenAIAdapter(); + } +} + +class OpenAIAdapter implements IProtocolAdapter { + readonly name = 'openai'; + + // Per-request state + private currentBlockIndex = 0; + private hasTextBlock = false; + private hadToolCalls = false; + + parseRequest(body: string): IParsedRequest { + const requestBody: any = JSON.parse(body); + + // Extract model information + const model = requestBody.model; + + // Convert messages format if needed + const messages = Array.isArray(requestBody.messages) ? requestBody.messages : []; + + const options: IMakeChatRequestOptions['requestOptions'] = { + temperature: requestBody.temperature, + max_tokens: requestBody.max_tokens, + }; + + if (requestBody.tools && Array.isArray(requestBody.tools) && requestBody.tools.length > 0) { + // Map OpenAI tools to VS Code chat tools + const tools = coalesce(requestBody.tools.map((tool: any) => { + if (tool.type === 'function' && tool.function) { + const chatTool: OpenAiFunctionTool = { + type: 'function', + function: { + name: tool.function.name, + description: tool.function.description || '', + parameters: tool.function.parameters || {}, + } + }; + return chatTool; + } + return undefined; + })); + if (tools.length) { + options.tools = tools as OpenAiFunctionTool[]; + } + } + + return { + model, + messages, + options + }; + } + + formatStreamResponse( + streamData: IAgentStreamBlock, + context: IStreamingContext + ): IStreamEventData[] { + const events: IStreamEventData[] = []; + + if (streamData.type === 'text') { + if (!this.hasTextBlock) { + this.hasTextBlock = true; + } + + // Send text delta events + const textDelta = { + id: context.requestId, + object: 'chat.completion.chunk', + created: Math.floor(Date.now() / 1000), + model: context.endpoint.modelId, + choices: [{ + index: this.currentBlockIndex, + delta: { + content: streamData.content, + role: 'assistant' + }, + finish_reason: null + }] + }; + events.push({ + event: 'message', + data: this.formatEventData(textDelta) + }); + + } else if (streamData.type === 'tool_call') { + // End current text block if it exists + if (this.hasTextBlock) { + this.currentBlockIndex++; + this.hasTextBlock = false; + } + + this.hadToolCalls = true; + + // Send tool call events + const toolCallDelta = { + id: context.requestId, + object: 'chat.completion.chunk', + created: Math.floor(Date.now() / 1000), + model: context.endpoint.modelId, + choices: [{ + index: this.currentBlockIndex, + delta: { + tool_calls: [{ + index: this.currentBlockIndex, + id: streamData.callId, + type: 'function', + function: { + name: streamData.name, + arguments: JSON.stringify(streamData.input || {}) + } + }] + }, + finish_reason: null + }] + }; + events.push({ + event: 'message', + data: this.formatEventData(toolCallDelta) + }); + + this.currentBlockIndex++; + } + + return events; + } + + generateFinalEvents(context: IStreamingContext, usage?: APIUsage): IStreamEventData[] { + const events: IStreamEventData[] = []; + + // Send final completion event with usage information + const finalCompletion = { + id: context.requestId, + object: 'chat.completion', + created: Math.floor(Date.now() / 1000), + model: context.endpoint.modelId, + choices: [{ + index: 0, + message: { + role: 'assistant', + content: '', + }, + finish_reason: this.hadToolCalls ? 'tool_calls' : 'stop' + }], + usage: usage ? { + prompt_tokens: usage.prompt_tokens, + completion_tokens: usage.completion_tokens, + total_tokens: usage.total_tokens + } : { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0 + } + }; + + events.push({ + event: 'message', + data: this.formatEventData(finalCompletion) + }); + + return events; + } + + generateInitialEvents(context: IStreamingContext): IStreamEventData[] { + // OpenAI doesn't typically send initial events, but we can send an empty one if needed + return []; + } + + getContentType(): string { + return 'text/event-stream'; + } + + extractAuthKey(headers: http.IncomingHttpHeaders): string | undefined { + const authHeader = headers.authorization; + const bearerSpace = 'Bearer '; + return authHeader?.startsWith(bearerSpace) ? authHeader.substring(bearerSpace.length) : undefined; + } + + private formatEventData(data: any): string { + return JSON.stringify(data).replace(/\n/g, '\\n'); + } +} \ No newline at end of file diff --git a/src/extension/agents/node/langModelServer.ts b/src/extension/agents/node/langModelServer.ts index 68f13dba8b..91709320f7 100644 --- a/src/extension/agents/node/langModelServer.ts +++ b/src/extension/agents/node/langModelServer.ts @@ -15,6 +15,7 @@ import { CancellationTokenSource } from '../../../util/vs/base/common/cancellati import { generateUuid } from '../../../util/vs/base/common/uuid'; import { LanguageModelError } from '../../../vscodeTypes'; import { AnthropicAdapterFactory } from './adapters/anthropicAdapter'; +import { OpenAIAdapterFactory } from './adapters/openaiAdapter'; import { IAgentStreamBlock, IProtocolAdapter, IProtocolAdapterFactory, IStreamingContext } from './adapters/types'; export interface ILanguageModelServerConfig { @@ -47,6 +48,7 @@ export class LanguageModelServer implements ILanguageModelServer { }; this.adapterFactories = new Map(); this.adapterFactories.set('/v1/messages', new AnthropicAdapterFactory()); + this.adapterFactories.set('/v1/chat/completions', new OpenAIAdapterFactory()); this.server = this.createServer(); } @@ -112,7 +114,15 @@ export class LanguageModelServer implements ILanguageModelServer { private getAdapterFactoryForPath(url: string): IProtocolAdapterFactory | undefined { const pathname = this.parseUrlPathname(url); - return this.adapterFactories.get(pathname); + // Try exact match first + let adapterFactory = this.adapterFactories.get(pathname); + + // If no exact match, try to match OpenAI endpoints + if (!adapterFactory && (pathname === '/v1/chat/completions' || pathname === '/chat/completions' || pathname === '//chat/completions')) { + adapterFactory = this.adapterFactories.get('/v1/chat/completions'); + } + + return adapterFactory; } private async readRequestBody(req: http.IncomingMessage): Promise {