@@ -31,7 +31,7 @@ namespace ts.codefix {
3131 errorCodes,
3232 getCodeActions : context => {
3333 const { sourceFile, errorCode, span, cancellationToken, program } = context ;
34- const expression = getAwaitableExpression ( sourceFile , errorCode , span , cancellationToken , program ) ;
34+ const expression = getFixableErrorSpanExpression ( sourceFile , errorCode , span , cancellationToken , program ) ;
3535 if ( ! expression ) {
3636 return ;
3737 }
@@ -45,32 +45,40 @@ namespace ts.codefix {
4545 getAllCodeActions : context => {
4646 const { sourceFile, program, cancellationToken } = context ;
4747 const checker = context . program . getTypeChecker ( ) ;
48+ const fixedDeclarations = createMap < true > ( ) ;
4849 return codeFixAll ( context , errorCodes , ( t , diagnostic ) => {
49- const expression = getAwaitableExpression ( sourceFile , diagnostic . code , diagnostic , cancellationToken , program ) ;
50+ const expression = getFixableErrorSpanExpression ( sourceFile , diagnostic . code , diagnostic , cancellationToken , program ) ;
5051 if ( ! expression ) {
5152 return ;
5253 }
5354 const trackChanges : ContextualTrackChangesFunction = cb => ( cb ( t ) , [ ] ) ;
54- return getDeclarationSiteFix ( context , expression , diagnostic . code , checker , trackChanges )
55- || getUseSiteFix ( context , expression , diagnostic . code , checker , trackChanges ) ;
55+ return getDeclarationSiteFix ( context , expression , diagnostic . code , checker , trackChanges , fixedDeclarations )
56+ || getUseSiteFix ( context , expression , diagnostic . code , checker , trackChanges , fixedDeclarations ) ;
5657 } ) ;
5758 } ,
5859 } ) ;
5960
60- function getDeclarationSiteFix ( context : CodeFixContext | CodeFixAllContext , expression : Expression , errorCode : number , checker : TypeChecker , trackChanges : ContextualTrackChangesFunction ) {
61- const { sourceFile } = context ;
62- const awaitableInitializer = findAwaitableInitializer ( expression , sourceFile , checker ) ;
63- if ( awaitableInitializer ) {
64- const initializerChanges = trackChanges ( t => makeChange ( t , errorCode , sourceFile , checker , awaitableInitializer ) ) ;
61+ function getDeclarationSiteFix ( context : CodeFixContext | CodeFixAllContext , expression : Expression , errorCode : number , checker : TypeChecker , trackChanges : ContextualTrackChangesFunction , fixedDeclarations ?: Map < true > ) {
62+ const { sourceFile, program, cancellationToken } = context ;
63+ const awaitableInitializers = findAwaitableInitializers ( expression , sourceFile , cancellationToken , program , checker ) ;
64+ if ( awaitableInitializers ) {
65+ const initializerChanges = trackChanges ( t => {
66+ forEach ( awaitableInitializers . initializers , ( { expression } ) => makeChange ( t , errorCode , sourceFile , checker , expression , fixedDeclarations ) ) ;
67+ if ( fixedDeclarations && awaitableInitializers . needsSecondPassForFixAll ) {
68+ makeChange ( t , errorCode , sourceFile , checker , expression , fixedDeclarations ) ;
69+ }
70+ } ) ;
6571 return createCodeFixActionNoFixId (
6672 "addMissingAwaitToInitializer" ,
6773 initializerChanges ,
68- [ Diagnostics . Add_await_to_initializer_for_0 , expression . getText ( sourceFile ) ] ) ;
74+ awaitableInitializers . initializers . length === 1
75+ ? [ Diagnostics . Add_await_to_initializer_for_0 , awaitableInitializers . initializers [ 0 ] . declarationSymbol . name ]
76+ : Diagnostics . Add_await_to_initializers ) ;
6977 }
7078 }
7179
72- function getUseSiteFix ( context : CodeFixContext | CodeFixAllContext , expression : Expression , errorCode : number , checker : TypeChecker , trackChanges : ContextualTrackChangesFunction ) {
73- const changes = trackChanges ( t => makeChange ( t , errorCode , context . sourceFile , checker , expression ) ) ;
80+ function getUseSiteFix ( context : CodeFixContext | CodeFixAllContext , expression : Expression , errorCode : number , checker : TypeChecker , trackChanges : ContextualTrackChangesFunction , fixedDeclarations ?: Map < true > ) {
81+ const changes = trackChanges ( t => makeChange ( t , errorCode , context . sourceFile , checker , expression , fixedDeclarations ) ) ;
7482 return createCodeFixAction ( fixId , changes , Diagnostics . Add_await , fixId , Diagnostics . Fix_all_expressions_possibly_missing_await ) ;
7583 }
7684
@@ -84,7 +92,7 @@ namespace ts.codefix {
8492 some ( relatedInformation , related => related . code === Diagnostics . Did_you_forget_to_use_await . code ) ) ;
8593 }
8694
87- function getAwaitableExpression ( sourceFile : SourceFile , errorCode : number , span : TextSpan , cancellationToken : CancellationToken , program : Program ) : Expression | undefined {
95+ function getFixableErrorSpanExpression ( sourceFile : SourceFile , errorCode : number , span : TextSpan , cancellationToken : CancellationToken , program : Program ) : Expression | undefined {
8896 const token = getTokenAtPosition ( sourceFile , span . start ) ;
8997 // Checker has already done work to determine that await might be possible, and has attached
9098 // related info to the node, so start by finding the expression that exactly matches up
@@ -103,38 +111,117 @@ namespace ts.codefix {
103111 : undefined ;
104112 }
105113
106- function findAwaitableInitializer ( expression : Node , sourceFile : SourceFile , checker : TypeChecker ) : Expression | undefined {
107- if ( ! isIdentifier ( expression ) ) {
108- return ;
109- }
114+ interface AwaitableInitializer {
115+ expression : Expression ;
116+ declarationSymbol : Symbol ;
117+ }
110118
111- const symbol = checker . getSymbolAtLocation ( expression ) ;
112- if ( ! symbol ) {
119+ interface AwaitableInitializers {
120+ initializers : readonly AwaitableInitializer [ ] ;
121+ needsSecondPassForFixAll : boolean ;
122+ }
123+
124+ function findAwaitableInitializers (
125+ expression : Node ,
126+ sourceFile : SourceFile ,
127+ cancellationToken : CancellationToken ,
128+ program : Program ,
129+ checker : TypeChecker ,
130+ ) : AwaitableInitializers | undefined {
131+ const identifiers = getIdentifiersFromErrorSpanExpression ( expression , checker ) ;
132+ if ( ! identifiers ) {
113133 return ;
114134 }
115135
116- const declaration = tryCast ( symbol . valueDeclaration , isVariableDeclaration ) ;
117- const variableName = tryCast ( declaration && declaration . name , isIdentifier ) ;
118- const variableStatement = getAncestor ( declaration , SyntaxKind . VariableStatement ) ;
119- if ( ! declaration || ! variableStatement ||
120- declaration . type ||
121- ! declaration . initializer ||
122- variableStatement . getSourceFile ( ) !== sourceFile ||
123- hasModifier ( variableStatement , ModifierFlags . Export ) ||
124- ! variableName ||
125- ! isInsideAwaitableBody ( declaration . initializer ) ) {
126- return ;
136+ let isCompleteFix = identifiers . isCompleteFix ;
137+ let initializers : AwaitableInitializer [ ] | undefined ;
138+ for ( const identifier of identifiers . identifiers ) {
139+ const symbol = checker . getSymbolAtLocation ( identifier ) ;
140+ if ( ! symbol ) {
141+ continue ;
142+ }
143+
144+ const declaration = tryCast ( symbol . valueDeclaration , isVariableDeclaration ) ;
145+ const variableName = declaration && tryCast ( declaration . name , isIdentifier ) ;
146+ const variableStatement = getAncestor ( declaration , SyntaxKind . VariableStatement ) ;
147+ if ( ! declaration || ! variableStatement ||
148+ declaration . type ||
149+ ! declaration . initializer ||
150+ variableStatement . getSourceFile ( ) !== sourceFile ||
151+ hasModifier ( variableStatement , ModifierFlags . Export ) ||
152+ ! variableName ||
153+ ! isInsideAwaitableBody ( declaration . initializer ) ) {
154+ isCompleteFix = false ;
155+ continue ;
156+ }
157+
158+ const diagnostics = program . getSemanticDiagnostics ( sourceFile , cancellationToken ) ;
159+ const isUsedElsewhere = FindAllReferences . Core . eachSymbolReferenceInFile ( variableName , checker , sourceFile , reference => {
160+ return identifier !== reference && ! symbolReferenceIsAlsoMissingAwait ( reference , diagnostics , sourceFile , checker ) ;
161+ } ) ;
162+
163+ if ( isUsedElsewhere ) {
164+ isCompleteFix = false ;
165+ continue ;
166+ }
167+
168+ ( initializers || ( initializers = [ ] ) ) . push ( {
169+ expression : declaration . initializer ,
170+ declarationSymbol : symbol ,
171+ } ) ;
127172 }
173+ return initializers && {
174+ initializers,
175+ needsSecondPassForFixAll : ! isCompleteFix ,
176+ } ;
177+ }
128178
129- const isUsedElsewhere = FindAllReferences . Core . eachSymbolReferenceInFile ( variableName , checker , sourceFile , identifier => {
130- return identifier !== expression ;
131- } ) ;
179+ interface Identifiers {
180+ identifiers : readonly Identifier [ ] ;
181+ isCompleteFix : boolean ;
182+ }
132183
133- if ( isUsedElsewhere ) {
134- return ;
184+ function getIdentifiersFromErrorSpanExpression ( expression : Node , checker : TypeChecker ) : Identifiers | undefined {
185+ if ( isPropertyAccessExpression ( expression . parent ) && isIdentifier ( expression . parent . expression ) ) {
186+ return { identifiers : [ expression . parent . expression ] , isCompleteFix : true } ;
187+ }
188+ if ( isIdentifier ( expression ) ) {
189+ return { identifiers : [ expression ] , isCompleteFix : true } ;
190+ }
191+ if ( isBinaryExpression ( expression ) ) {
192+ let sides : Identifier [ ] | undefined ;
193+ let isCompleteFix = true ;
194+ for ( const side of [ expression . left , expression . right ] ) {
195+ const type = checker . getTypeAtLocation ( side ) ;
196+ if ( checker . getPromisedTypeOfPromise ( type ) ) {
197+ if ( ! isIdentifier ( side ) ) {
198+ isCompleteFix = false ;
199+ continue ;
200+ }
201+ ( sides || ( sides = [ ] ) ) . push ( side ) ;
202+ }
203+ }
204+ return sides && { identifiers : sides , isCompleteFix } ;
135205 }
206+ }
207+
208+ function symbolReferenceIsAlsoMissingAwait ( reference : Identifier , diagnostics : readonly Diagnostic [ ] , sourceFile : SourceFile , checker : TypeChecker ) {
209+ const errorNode = isPropertyAccessExpression ( reference . parent ) ? reference . parent . name :
210+ isBinaryExpression ( reference . parent ) ? reference . parent :
211+ reference ;
212+ const diagnostic = find ( diagnostics , diagnostic =>
213+ diagnostic . start === errorNode . getStart ( sourceFile ) &&
214+ diagnostic . start + diagnostic . length ! === errorNode . getEnd ( ) ) ;
136215
137- return declaration . initializer ;
216+ return diagnostic && contains ( errorCodes , diagnostic . code ) ||
217+ // A Promise is usually not correct in a binary expression (it’s not valid
218+ // in an arithmetic expression and an equality comparison seems unusual),
219+ // but if the other side of the binary expression has an error, the side
220+ // is typed `any` which will squash the error that would identify this
221+ // Promise as an invalid operand. So if the whole binary expression is
222+ // typed `any` as a result, there is a strong likelihood that this Promise
223+ // is accidentally missing `await`.
224+ checker . getTypeAtLocation ( errorNode ) . flags & TypeFlags . Any ;
138225 }
139226
140227 function isInsideAwaitableBody ( node : Node ) {
@@ -147,26 +234,48 @@ namespace ts.codefix {
147234 ancestor . parent . kind === SyntaxKind . MethodDeclaration ) ) ;
148235 }
149236
150- function makeChange ( changeTracker : textChanges . ChangeTracker , errorCode : number , sourceFile : SourceFile , checker : TypeChecker , insertionSite : Expression ) {
237+ function makeChange ( changeTracker : textChanges . ChangeTracker , errorCode : number , sourceFile : SourceFile , checker : TypeChecker , insertionSite : Expression , fixedDeclarations ?: Map < true > ) {
151238 if ( isBinaryExpression ( insertionSite ) ) {
152- const { left, right } = insertionSite ;
153- const leftType = checker . getTypeAtLocation ( left ) ;
154- const rightType = checker . getTypeAtLocation ( right ) ;
155- const newLeft = checker . getPromisedTypeOfPromise ( leftType ) ? createAwait ( left ) : left ;
156- const newRight = checker . getPromisedTypeOfPromise ( rightType ) ? createAwait ( right ) : right ;
157- changeTracker . replaceNode ( sourceFile , left , newLeft ) ;
158- changeTracker . replaceNode ( sourceFile , right , newRight ) ;
239+ for ( const side of [ insertionSite . left , insertionSite . right ] ) {
240+ if ( fixedDeclarations && isIdentifier ( side ) ) {
241+ const symbol = checker . getSymbolAtLocation ( side ) ;
242+ if ( symbol && fixedDeclarations . has ( getSymbolId ( symbol ) . toString ( ) ) ) {
243+ continue ;
244+ }
245+ }
246+ const type = checker . getTypeAtLocation ( side ) ;
247+ const newNode = checker . getPromisedTypeOfPromise ( type ) ? createAwait ( side ) : side ;
248+ changeTracker . replaceNode ( sourceFile , side , newNode ) ;
249+ }
159250 }
160251 else if ( errorCode === propertyAccessCode && isPropertyAccessExpression ( insertionSite . parent ) ) {
252+ if ( fixedDeclarations && isIdentifier ( insertionSite . parent . expression ) ) {
253+ const symbol = checker . getSymbolAtLocation ( insertionSite . parent . expression ) ;
254+ if ( symbol && fixedDeclarations . has ( getSymbolId ( symbol ) . toString ( ) ) ) {
255+ return ;
256+ }
257+ }
161258 changeTracker . replaceNode (
162259 sourceFile ,
163260 insertionSite . parent . expression ,
164261 createParen ( createAwait ( insertionSite . parent . expression ) ) ) ;
165262 }
166263 else if ( contains ( callableConstructableErrorCodes , errorCode ) && isCallOrNewExpression ( insertionSite . parent ) ) {
264+ if ( fixedDeclarations && isIdentifier ( insertionSite ) ) {
265+ const symbol = checker . getSymbolAtLocation ( insertionSite ) ;
266+ if ( symbol && fixedDeclarations . has ( getSymbolId ( symbol ) . toString ( ) ) ) {
267+ return ;
268+ }
269+ }
167270 changeTracker . replaceNode ( sourceFile , insertionSite , createParen ( createAwait ( insertionSite ) ) ) ;
168271 }
169272 else {
273+ if ( fixedDeclarations && isVariableDeclaration ( insertionSite . parent ) && isIdentifier ( insertionSite . parent . name ) ) {
274+ const symbol = checker . getSymbolAtLocation ( insertionSite . parent . name ) ;
275+ if ( symbol && ! addToSeen ( fixedDeclarations , getSymbolId ( symbol ) ) ) {
276+ return ;
277+ }
278+ }
170279 changeTracker . replaceNode ( sourceFile , insertionSite , createAwait ( insertionSite ) ) ;
171280 }
172281 }
0 commit comments