1/* @internal */ 2namespace ts.codefix { 3 type ContextualTrackChangesFunction = (cb: (changeTracker: textChanges.ChangeTracker) => void) => FileTextChanges[]; 4 const fixId = "addMissingAwait"; 5 const propertyAccessCode = Diagnostics.Property_0_does_not_exist_on_type_1.code; 6 const callableConstructableErrorCodes = [ 7 Diagnostics.This_expression_is_not_callable.code, 8 Diagnostics.This_expression_is_not_constructable.code, 9 ]; 10 const errorCodes = [ 11 Diagnostics.An_arithmetic_operand_must_be_of_type_any_number_bigint_or_an_enum_type.code, 12 Diagnostics.The_left_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code, 13 Diagnostics.The_right_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code, 14 Diagnostics.Operator_0_cannot_be_applied_to_type_1.code, 15 Diagnostics.Operator_0_cannot_be_applied_to_types_1_and_2.code, 16 Diagnostics.This_comparison_appears_to_be_unintentional_because_the_types_0_and_1_have_no_overlap.code, 17 Diagnostics.This_condition_will_always_return_true_since_this_0_is_always_defined.code, 18 Diagnostics.Type_0_is_not_an_array_type.code, 19 Diagnostics.Type_0_is_not_an_array_type_or_a_string_type.code, 20 Diagnostics.Type_0_can_only_be_iterated_through_when_using_the_downlevelIteration_flag_or_with_a_target_of_es2015_or_higher.code, 21 Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code, 22 Diagnostics.Type_0_is_not_an_array_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code, 23 Diagnostics.Type_0_must_have_a_Symbol_iterator_method_that_returns_an_iterator.code, 24 Diagnostics.Type_0_must_have_a_Symbol_asyncIterator_method_that_returns_an_async_iterator.code, 25 Diagnostics.Argument_of_type_0_is_not_assignable_to_parameter_of_type_1.code, 26 propertyAccessCode, 27 ...callableConstructableErrorCodes, 28 ]; 29 30 registerCodeFix({ 31 fixIds: [fixId], 32 errorCodes, 33 getCodeActions: function getCodeActionsToAddMissingAwait(context) { 34 const { sourceFile, errorCode, span, cancellationToken, program } = context; 35 const expression = getAwaitErrorSpanExpression(sourceFile, errorCode, span, cancellationToken, program); 36 if (!expression) { 37 return; 38 } 39 40 const checker = context.program.getTypeChecker(); 41 const trackChanges: ContextualTrackChangesFunction = cb => textChanges.ChangeTracker.with(context, cb); 42 return compact([ 43 getDeclarationSiteFix(context, expression, errorCode, checker, trackChanges), 44 getUseSiteFix(context, expression, errorCode, checker, trackChanges)]); 45 }, 46 getAllCodeActions: context => { 47 const { sourceFile, program, cancellationToken } = context; 48 const checker = context.program.getTypeChecker(); 49 const fixedDeclarations = new Set<number>(); 50 return codeFixAll(context, errorCodes, (t, diagnostic) => { 51 const expression = getAwaitErrorSpanExpression(sourceFile, diagnostic.code, diagnostic, cancellationToken, program); 52 if (!expression) { 53 return; 54 } 55 const trackChanges: ContextualTrackChangesFunction = cb => (cb(t), []); 56 return getDeclarationSiteFix(context, expression, diagnostic.code, checker, trackChanges, fixedDeclarations) 57 || getUseSiteFix(context, expression, diagnostic.code, checker, trackChanges, fixedDeclarations); 58 }); 59 }, 60 }); 61 62 function getAwaitErrorSpanExpression(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program) { 63 const expression = getFixableErrorSpanExpression(sourceFile, span); 64 return expression 65 && isMissingAwaitError(sourceFile, errorCode, span, cancellationToken, program) 66 && isInsideAwaitableBody(expression) ? expression : undefined; 67 } 68 69 function getDeclarationSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction, fixedDeclarations?: Set<number>) { 70 const { sourceFile, program, cancellationToken } = context; 71 const awaitableInitializers = findAwaitableInitializers(expression, sourceFile, cancellationToken, program, checker); 72 if (awaitableInitializers) { 73 const initializerChanges = trackChanges(t => { 74 forEach(awaitableInitializers.initializers, ({ expression }) => makeChange(t, errorCode, sourceFile, checker, expression, fixedDeclarations)); 75 if (fixedDeclarations && awaitableInitializers.needsSecondPassForFixAll) { 76 makeChange(t, errorCode, sourceFile, checker, expression, fixedDeclarations); 77 } 78 }); 79 // No fix-all because it will already be included once with the use site fix, 80 // and for simplicity the fix-all doesn‘t let the user choose between use-site and declaration-site fixes. 81 return createCodeFixActionWithoutFixAll( 82 "addMissingAwaitToInitializer", 83 initializerChanges, 84 awaitableInitializers.initializers.length === 1 85 ? [Diagnostics.Add_await_to_initializer_for_0, awaitableInitializers.initializers[0].declarationSymbol.name] 86 : Diagnostics.Add_await_to_initializers); 87 } 88 } 89 90 function getUseSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction, fixedDeclarations?: Set<number>) { 91 const changes = trackChanges(t => makeChange(t, errorCode, context.sourceFile, checker, expression, fixedDeclarations)); 92 return createCodeFixAction(fixId, changes, Diagnostics.Add_await, fixId, Diagnostics.Fix_all_expressions_possibly_missing_await); 93 } 94 95 function isMissingAwaitError(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program) { 96 const checker = program.getTypeChecker(); 97 const diagnostics = checker.getDiagnostics(sourceFile, cancellationToken); 98 return some(diagnostics, ({ start, length, relatedInformation, code }) => 99 isNumber(start) && isNumber(length) && textSpansEqual({ start, length }, span) && 100 code === errorCode && 101 !!relatedInformation && 102 some(relatedInformation, related => related.code === Diagnostics.Did_you_forget_to_use_await.code)); 103 } 104 105 interface AwaitableInitializer { 106 expression: Expression; 107 declarationSymbol: Symbol; 108 } 109 110 interface AwaitableInitializers { 111 initializers: readonly AwaitableInitializer[]; 112 needsSecondPassForFixAll: boolean; 113 } 114 115 function findAwaitableInitializers( 116 expression: Node, 117 sourceFile: SourceFile, 118 cancellationToken: CancellationToken, 119 program: Program, 120 checker: TypeChecker, 121 ): AwaitableInitializers | undefined { 122 const identifiers = getIdentifiersFromErrorSpanExpression(expression, checker); 123 if (!identifiers) { 124 return; 125 } 126 127 let isCompleteFix = identifiers.isCompleteFix; 128 let initializers: AwaitableInitializer[] | undefined; 129 for (const identifier of identifiers.identifiers) { 130 const symbol = checker.getSymbolAtLocation(identifier); 131 if (!symbol) { 132 continue; 133 } 134 135 const declaration = tryCast(symbol.valueDeclaration, isVariableDeclaration); 136 const variableName = declaration && tryCast(declaration.name, isIdentifier); 137 const variableStatement = getAncestor(declaration, SyntaxKind.VariableStatement); 138 if (!declaration || !variableStatement || 139 declaration.type || 140 !declaration.initializer || 141 variableStatement.getSourceFile() !== sourceFile || 142 hasSyntacticModifier(variableStatement, ModifierFlags.Export) || 143 !variableName || 144 !isInsideAwaitableBody(declaration.initializer)) { 145 isCompleteFix = false; 146 continue; 147 } 148 149 const diagnostics = program.getSemanticDiagnostics(sourceFile, cancellationToken); 150 const isUsedElsewhere = FindAllReferences.Core.eachSymbolReferenceInFile(variableName, checker, sourceFile, reference => { 151 return identifier !== reference && !symbolReferenceIsAlsoMissingAwait(reference, diagnostics, sourceFile, checker); 152 }); 153 154 if (isUsedElsewhere) { 155 isCompleteFix = false; 156 continue; 157 } 158 159 (initializers || (initializers = [])).push({ 160 expression: declaration.initializer, 161 declarationSymbol: symbol, 162 }); 163 } 164 return initializers && { 165 initializers, 166 needsSecondPassForFixAll: !isCompleteFix, 167 }; 168 } 169 170 interface Identifiers { 171 identifiers: readonly Identifier[]; 172 isCompleteFix: boolean; 173 } 174 175 function getIdentifiersFromErrorSpanExpression(expression: Node, checker: TypeChecker): Identifiers | undefined { 176 if (isPropertyAccessExpression(expression.parent) && isIdentifier(expression.parent.expression)) { 177 return { identifiers: [expression.parent.expression], isCompleteFix: true }; 178 } 179 if (isIdentifier(expression)) { 180 return { identifiers: [expression], isCompleteFix: true }; 181 } 182 if (isBinaryExpression(expression)) { 183 let sides: Identifier[] | undefined; 184 let isCompleteFix = true; 185 for (const side of [expression.left, expression.right]) { 186 const type = checker.getTypeAtLocation(side); 187 if (checker.getPromisedTypeOfPromise(type)) { 188 if (!isIdentifier(side)) { 189 isCompleteFix = false; 190 continue; 191 } 192 (sides || (sides = [])).push(side); 193 } 194 } 195 return sides && { identifiers: sides, isCompleteFix }; 196 } 197 } 198 199 function symbolReferenceIsAlsoMissingAwait(reference: Identifier, diagnostics: readonly Diagnostic[], sourceFile: SourceFile, checker: TypeChecker) { 200 const errorNode = isPropertyAccessExpression(reference.parent) ? reference.parent.name : 201 isBinaryExpression(reference.parent) ? reference.parent : 202 reference; 203 const diagnostic = find(diagnostics, diagnostic => 204 diagnostic.start === errorNode.getStart(sourceFile) && 205 (diagnostic.start + diagnostic.length!) === errorNode.getEnd()); 206 207 return diagnostic && contains(errorCodes, diagnostic.code) || 208 // A Promise is usually not correct in a binary expression (it’s not valid 209 // in an arithmetic expression and an equality comparison seems unusual), 210 // but if the other side of the binary expression has an error, the side 211 // is typed `any` which will squash the error that would identify this 212 // Promise as an invalid operand. So if the whole binary expression is 213 // typed `any` as a result, there is a strong likelihood that this Promise 214 // is accidentally missing `await`. 215 checker.getTypeAtLocation(errorNode).flags & TypeFlags.Any; 216 } 217 218 function isInsideAwaitableBody(node: Node) { 219 return node.kind & NodeFlags.AwaitContext || !!findAncestor(node, ancestor => 220 ancestor.parent && isArrowFunction(ancestor.parent) && ancestor.parent.body === ancestor || 221 isBlock(ancestor) && ( 222 ancestor.parent.kind === SyntaxKind.FunctionDeclaration || 223 ancestor.parent.kind === SyntaxKind.FunctionExpression || 224 ancestor.parent.kind === SyntaxKind.ArrowFunction || 225 ancestor.parent.kind === SyntaxKind.MethodDeclaration)); 226 } 227 228 function makeChange(changeTracker: textChanges.ChangeTracker, errorCode: number, sourceFile: SourceFile, checker: TypeChecker, insertionSite: Expression, fixedDeclarations?: Set<number>) { 229 if (isForOfStatement(insertionSite.parent) && !insertionSite.parent.awaitModifier) { 230 const exprType = checker.getTypeAtLocation(insertionSite); 231 const asyncIter = checker.getAsyncIterableType(); 232 if (asyncIter && checker.isTypeAssignableTo(exprType, asyncIter)) { 233 const forOf = insertionSite.parent; 234 changeTracker.replaceNode(sourceFile, forOf, factory.updateForOfStatement(forOf, factory.createToken(SyntaxKind.AwaitKeyword), forOf.initializer, forOf.expression, forOf.statement)); 235 return; 236 } 237 } 238 if (isBinaryExpression(insertionSite)) { 239 for (const side of [insertionSite.left, insertionSite.right]) { 240 if (fixedDeclarations && isIdentifier(side)) { 241 const symbol = checker.getSymbolAtLocation(side); 242 if (symbol && fixedDeclarations.has(getSymbolId(symbol))) { 243 continue; 244 } 245 } 246 const type = checker.getTypeAtLocation(side); 247 const newNode = checker.getPromisedTypeOfPromise(type) ? factory.createAwaitExpression(side) : side; 248 changeTracker.replaceNode(sourceFile, side, newNode); 249 } 250 } 251 else if (errorCode === propertyAccessCode && isPropertyAccessExpression(insertionSite.parent)) { 252 if (fixedDeclarations && isIdentifier(insertionSite.parent.expression)) { 253 const symbol = checker.getSymbolAtLocation(insertionSite.parent.expression); 254 if (symbol && fixedDeclarations.has(getSymbolId(symbol))) { 255 return; 256 } 257 } 258 changeTracker.replaceNode( 259 sourceFile, 260 insertionSite.parent.expression, 261 factory.createParenthesizedExpression(factory.createAwaitExpression(insertionSite.parent.expression))); 262 insertLeadingSemicolonIfNeeded(changeTracker, insertionSite.parent.expression, sourceFile); 263 } 264 else if (contains(callableConstructableErrorCodes, errorCode) && isCallOrNewExpression(insertionSite.parent)) { 265 if (fixedDeclarations && isIdentifier(insertionSite)) { 266 const symbol = checker.getSymbolAtLocation(insertionSite); 267 if (symbol && fixedDeclarations.has(getSymbolId(symbol))) { 268 return; 269 } 270 } 271 changeTracker.replaceNode(sourceFile, insertionSite, factory.createParenthesizedExpression(factory.createAwaitExpression(insertionSite))); 272 insertLeadingSemicolonIfNeeded(changeTracker, insertionSite, sourceFile); 273 } 274 else { 275 if (fixedDeclarations && isVariableDeclaration(insertionSite.parent) && isIdentifier(insertionSite.parent.name)) { 276 const symbol = checker.getSymbolAtLocation(insertionSite.parent.name); 277 if (symbol && !tryAddToSet(fixedDeclarations, getSymbolId(symbol))) { 278 return; 279 } 280 } 281 changeTracker.replaceNode(sourceFile, insertionSite, factory.createAwaitExpression(insertionSite)); 282 } 283 } 284 285 function insertLeadingSemicolonIfNeeded(changeTracker: textChanges.ChangeTracker, beforeNode: Node, sourceFile: SourceFile) { 286 const precedingToken = findPrecedingToken(beforeNode.pos, sourceFile); 287 if (precedingToken && positionIsASICandidate(precedingToken.end, precedingToken.parent, sourceFile)) { 288 changeTracker.insertText(sourceFile, beforeNode.getStart(sourceFile), ";"); 289 } 290 } 291} 292