@@ -10,7 +10,7 @@ import { ExtHostChatProviderShape, IMainContext, MainContext, MainThreadChatProv
1010import * as typeConvert from 'vs/workbench/api/common/extHostTypeConverters' ;
1111import type * as vscode from 'vscode' ;
1212import { Progress } from 'vs/platform/progress/common/progress' ;
13- import { IChatMessage , IChatResponseFragment } from 'vs/workbench/contrib/chat/common/chatProvider' ;
13+ import { IChatMessage , IChatResponseFragment , IChatResponseProviderMetadata } from 'vs/workbench/contrib/chat/common/chatProvider' ;
1414import { ExtensionIdentifier , ExtensionIdentifierMap , ExtensionIdentifierSet , IExtensionDescription } from 'vs/platform/extensions/common/extensions' ;
1515import { AsyncIterableSource } from 'vs/base/common/async' ;
1616import { Emitter , Event } from 'vs/base/common/event' ;
@@ -97,7 +97,6 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape {
9797
9898 private readonly _languageModels = new Map < number , LanguageModelData > ( ) ;
9999 private readonly _languageModelIds = new Set < string > ( ) ; // these are ALL models, not just the one in this EH
100- private readonly _accesslist = new ExtensionIdentifierMap < boolean > ( ) ;
101100 private readonly _modelAccessList = new ExtensionIdentifierMap < ExtensionIdentifierSet > ( ) ;
102101 private readonly _pendingRequest = new Map < number , { languageModelId : string ; res : LanguageModelRequest } > ( ) ;
103102
@@ -197,18 +196,6 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape {
197196 return Array . from ( this . _languageModelIds ) ;
198197 }
199198
200- $updateAccesslist ( data : { extension : ExtensionIdentifier ; enabled : boolean } [ ] ) : void {
201- const updated = new ExtensionIdentifierSet ( ) ;
202- for ( const { extension, enabled } of data ) {
203- const oldValue = this . _accesslist . get ( extension ) ;
204- if ( oldValue !== enabled ) {
205- this . _accesslist . set ( extension , enabled ) ;
206- updated . add ( extension ) ;
207- }
208- }
209- this . _onDidChangeAccess . fire ( updated ) ;
210- }
211-
212199 $updateModelAccesslist ( data : { from : ExtensionIdentifier ; to : ExtensionIdentifier ; enabled : boolean } [ ] ) : void {
213200 const updated = new Array < { from : ExtensionIdentifier ; to : ExtensionIdentifier } > ( ) ;
214201 for ( const { from, to, enabled } of data ) {
@@ -230,23 +217,15 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape {
230217
231218 async requestLanguageModelAccess ( extension : IExtensionDescription , languageModelId : string , options ?: vscode . LanguageModelAccessOptions ) : Promise < vscode . LanguageModelAccess > {
232219 const from = extension . identifier ;
233- // check if the extension is in the access list and allowed to make chat requests
234- if ( this . _accesslist . get ( from ) === false ) {
235- throw new Error ( 'Extension is NOT allowed to make chat requests' ) ;
236- }
237-
238220 const justification = options ?. justification ;
239221 const metadata = await this . _proxy . $prepareChatAccess ( from , languageModelId , justification ) ;
240222
241223 if ( ! metadata ) {
242- if ( ! this . _accesslist . get ( from ) ) {
243- throw new Error ( 'Extension is NOT allowed to make chat requests' ) ;
244- }
245224 throw new Error ( `Language model '${ languageModelId } ' NOT found` ) ;
246225 }
247226
248- if ( metadata . auth ) {
249- await this . _checkAuthAccess ( extension , { identifier : metadata . extension , displayName : metadata . auth ? .providerLabel } , justification ) ;
227+ if ( this . _isUsingAuth ( from , metadata ) ) {
228+ await this . _getAuthAccess ( extension , { identifier : metadata . extension , displayName : metadata . auth . providerLabel } , justification ) ;
250229 }
251230
252231 const that = this ;
@@ -256,9 +235,7 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape {
256235 return metadata . model ;
257236 } ,
258237 get isRevoked ( ) {
259- return ! that . _accesslist . get ( from )
260- || ( metadata . auth && ! that . _modelAccessList . get ( from ) ?. has ( metadata . extension ) )
261- || ! that . _languageModelIds . has ( languageModelId ) ;
238+ return ( that . _isUsingAuth ( from , metadata ) && ! that . _modelAccessList . get ( from ) ?. has ( metadata . extension ) ) || ! that . _languageModelIds . has ( languageModelId ) ;
262239 } ,
263240 get onDidChangeAccess ( ) {
264241 const onDidChangeAccess = Event . filter ( that . _onDidChangeAccess . event , set => set . has ( from ) ) ;
@@ -267,7 +244,7 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape {
267244 return Event . signal ( Event . any ( onDidChangeAccess , onDidRemoveLM , onDidChangeModelAccess ) ) ;
268245 } ,
269246 makeChatRequest ( messages , options , token ) {
270- if ( ! that . _accesslist . get ( from ) || ( metadata . auth && ! that . _modelAccessList . get ( from ) ?. has ( metadata . extension ) ) ) {
247+ if ( that . _isUsingAuth ( from , metadata ) && ! that . _modelAccessList . get ( from ) ?. has ( metadata . extension ) ) {
271248 throw new Error ( 'Access to chat has been revoked' ) ;
272249 }
273250 if ( ! that . _languageModelIds . has ( languageModelId ) ) {
@@ -297,7 +274,7 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape {
297274 }
298275
299276 // BIG HACK: Using AuthenticationProviders to check access to Language Models
300- private async _checkAuthAccess ( from : IExtensionDescription , to : { identifier : ExtensionIdentifier ; displayName : string } , detail ?: string ) : Promise < void > {
277+ private async _getAuthAccess ( from : IExtensionDescription , to : { identifier : ExtensionIdentifier ; displayName : string } , detail ?: string ) : Promise < void > {
301278 // This needs to be done in both MainThread & ExtHost ChatProvider
302279 const providerId = INTERNAL_AUTH_PROVIDER_PREFIX + to . identifier . value ;
303280 const session = await this . _extHostAuthentication . getSession ( from , providerId , [ ] , { silent : true } ) ;
@@ -315,4 +292,11 @@ export class ExtHostChatProvider implements ExtHostChatProviderShape {
315292
316293 this . $updateModelAccesslist ( [ { from : from . identifier , to : to . identifier , enabled : true } ] ) ;
317294 }
295+
296+ private _isUsingAuth ( from : ExtensionIdentifier , toMetadata : IChatResponseProviderMetadata ) : toMetadata is IChatResponseProviderMetadata & { auth : NonNullable < IChatResponseProviderMetadata [ 'auth' ] > } {
297+ // If the 'to' extension uses an auth check
298+ return ! ! toMetadata . auth
299+ // And we're asking from a different extension
300+ && ! ExtensionIdentifier . equals ( toMetadata . extension , from ) ;
301+ }
318302}
0 commit comments