1/* @internal */ 2namespace ts.codefix { 3 type ContextualTrackChangesFunction = (cb: (changeTracker: textChanges.ChangeTracker) => void) => FileTextChanges[]; 4 const fixId = "addMissingAwait"; 5 const propertyAccessCode = Diagnostics.Property_0_does_not_exist_on_type_1.code; 6 const callableConstructableErrorCodes = [ 7 Diagnostics.This_expression_is_not_callable.code, 8 Diagnostics.This_expression_is_not_constructable.code, 9 ]; 10 const errorCodes = [ 11 Diagnostics.An_arithmetic_operand_must_be_of_type_any_number_bigint_or_an_enum_type.code, 12 Diagnostics.The_left_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code, 13 Diagnostics.The_right_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code, 14 Diagnostics.Operator_0_cannot_be_applied_to_type_1.code, 15 Diagnostics.Operator_0_cannot_be_applied_to_types_1_and_2.code, 16 Diagnostics.This_condition_will_always_return_0_since_the_types_1_and_2_have_no_overlap.code, 17 Diagnostics.Type_0_is_not_an_array_type.code, 18 Diagnostics.Type_0_is_not_an_array_type_or_a_string_type.code, 19 Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_Use_compiler_option_downlevelIteration_to_allow_iterating_of_iterators.code, 20 Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code, 21 Diagnostics.Type_0_is_not_an_array_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code, 22 Diagnostics.Type_0_must_have_a_Symbol_iterator_method_that_returns_an_iterator.code, 23 Diagnostics.Type_0_must_have_a_Symbol_asyncIterator_method_that_returns_an_async_iterator.code, 24 Diagnostics.Argument_of_type_0_is_not_assignable_to_parameter_of_type_1.code, 25 propertyAccessCode, 26 ...callableConstructableErrorCodes, 27 ]; 28 29 registerCodeFix({ 30 fixIds: [fixId], 31 errorCodes, 32 getCodeActions: context => { 33 const { sourceFile, errorCode, span, cancellationToken, program } = context; 34 const expression = getFixableErrorSpanExpression(sourceFile, errorCode, span, cancellationToken, program); 35 if (!expression) { 36 return; 37 } 38 39 const checker = context.program.getTypeChecker(); 40 const trackChanges: ContextualTrackChangesFunction = cb => textChanges.ChangeTracker.with(context, cb); 41 return compact([ 42 getDeclarationSiteFix(context, expression, errorCode, checker, trackChanges), 43 getUseSiteFix(context, expression, errorCode, checker, trackChanges)]); 44 }, 45 getAllCodeActions: context => { 46 const { sourceFile, program, cancellationToken } = context; 47 const checker = context.program.getTypeChecker(); 48 const fixedDeclarations = new Set<number>(); 49 return codeFixAll(context, errorCodes, (t, diagnostic) => { 50 const expression = getFixableErrorSpanExpression(sourceFile, diagnostic.code, diagnostic, cancellationToken, program); 51 if (!expression) { 52 return; 53 } 54 const trackChanges: ContextualTrackChangesFunction = cb => (cb(t), []); 55 return getDeclarationSiteFix(context, expression, diagnostic.code, checker, trackChanges, fixedDeclarations) 56 || getUseSiteFix(context, expression, diagnostic.code, checker, trackChanges, fixedDeclarations); 57 }); 58 }, 59 }); 60 61 function getDeclarationSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction, fixedDeclarations?: Set<number>) { 62 const { sourceFile, program, cancellationToken } = context; 63 const awaitableInitializers = findAwaitableInitializers(expression, sourceFile, cancellationToken, program, checker); 64 if (awaitableInitializers) { 65 const initializerChanges = trackChanges(t => { 66 forEach(awaitableInitializers.initializers, ({ expression }) => makeChange(t, errorCode, sourceFile, checker, expression, fixedDeclarations)); 67 if (fixedDeclarations && awaitableInitializers.needsSecondPassForFixAll) { 68 makeChange(t, errorCode, sourceFile, checker, expression, fixedDeclarations); 69 } 70 }); 71 // No fix-all because it will already be included once with the use site fix, 72 // and for simplicity the fix-all doesn‘t let the user choose between use-site and declaration-site fixes. 73 return createCodeFixActionWithoutFixAll( 74 "addMissingAwaitToInitializer", 75 initializerChanges, 76 awaitableInitializers.initializers.length === 1 77 ? [Diagnostics.Add_await_to_initializer_for_0, awaitableInitializers.initializers[0].declarationSymbol.name] 78 : Diagnostics.Add_await_to_initializers); 79 } 80 } 81 82 function getUseSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction, fixedDeclarations?: Set<number>) { 83 const changes = trackChanges(t => makeChange(t, errorCode, context.sourceFile, checker, expression, fixedDeclarations)); 84 return createCodeFixAction(fixId, changes, Diagnostics.Add_await, fixId, Diagnostics.Fix_all_expressions_possibly_missing_await); 85 } 86 87 function isMissingAwaitError(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program) { 88 const checker = program.getDiagnosticsProducingTypeChecker(); 89 const diagnostics = checker.getDiagnostics(sourceFile, cancellationToken); 90 return some(diagnostics, ({ start, length, relatedInformation, code }) => 91 isNumber(start) && isNumber(length) && textSpansEqual({ start, length }, span) && 92 code === errorCode && 93 !!relatedInformation && 94 some(relatedInformation, related => related.code === Diagnostics.Did_you_forget_to_use_await.code)); 95 } 96 97 function getFixableErrorSpanExpression(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program): Expression | undefined { 98 const token = getTokenAtPosition(sourceFile, span.start); 99 // Checker has already done work to determine that await might be possible, and has attached 100 // related info to the node, so start by finding the expression that exactly matches up 101 // with the diagnostic range. 102 const expression = findAncestor(token, node => { 103 if (node.getStart(sourceFile) < span.start || node.getEnd() > textSpanEnd(span)) { 104 return "quit"; 105 } 106 return isExpression(node) && textSpansEqual(span, createTextSpanFromNode(node, sourceFile)); 107 }) as Expression | undefined; 108 109 return expression 110 && isMissingAwaitError(sourceFile, errorCode, span, cancellationToken, program) 111 && isInsideAwaitableBody(expression) ? expression : undefined; 112 } 113 114 interface AwaitableInitializer { 115 expression: Expression; 116 declarationSymbol: Symbol; 117 } 118 119 interface AwaitableInitializers { 120 initializers: readonly AwaitableInitializer[]; 121 needsSecondPassForFixAll: boolean; 122 } 123 124 function findAwaitableInitializers( 125 expression: Node, 126 sourceFile: SourceFile, 127 cancellationToken: CancellationToken, 128 program: Program, 129 checker: TypeChecker, 130 ): AwaitableInitializers | undefined { 131 const identifiers = getIdentifiersFromErrorSpanExpression(expression, checker); 132 if (!identifiers) { 133 return; 134 } 135 136 let isCompleteFix = identifiers.isCompleteFix; 137 let initializers: AwaitableInitializer[] | undefined; 138 for (const identifier of identifiers.identifiers) { 139 const symbol = checker.getSymbolAtLocation(identifier); 140 if (!symbol) { 141 continue; 142 } 143 144 const declaration = tryCast(symbol.valueDeclaration, isVariableDeclaration); 145 const variableName = declaration && tryCast(declaration.name, isIdentifier); 146 const variableStatement = getAncestor(declaration, SyntaxKind.VariableStatement); 147 if (!declaration || !variableStatement || 148 declaration.type || 149 !declaration.initializer || 150 variableStatement.getSourceFile() !== sourceFile || 151 hasSyntacticModifier(variableStatement, ModifierFlags.Export) || 152 !variableName || 153 !isInsideAwaitableBody(declaration.initializer)) { 154 isCompleteFix = false; 155 continue; 156 } 157 158 const diagnostics = program.getSemanticDiagnostics(sourceFile, cancellationToken); 159 const isUsedElsewhere = FindAllReferences.Core.eachSymbolReferenceInFile(variableName, checker, sourceFile, reference => { 160 return identifier !== reference && !symbolReferenceIsAlsoMissingAwait(reference, diagnostics, sourceFile, checker); 161 }); 162 163 if (isUsedElsewhere) { 164 isCompleteFix = false; 165 continue; 166 } 167 168 (initializers || (initializers = [])).push({ 169 expression: declaration.initializer, 170 declarationSymbol: symbol, 171 }); 172 } 173 return initializers && { 174 initializers, 175 needsSecondPassForFixAll: !isCompleteFix, 176 }; 177 } 178 179 interface Identifiers { 180 identifiers: readonly Identifier[]; 181 isCompleteFix: boolean; 182 } 183 184 function getIdentifiersFromErrorSpanExpression(expression: Node, checker: TypeChecker): Identifiers | undefined { 185 if (isPropertyAccessExpression(expression.parent) && isIdentifier(expression.parent.expression)) { 186 return { identifiers: [expression.parent.expression], isCompleteFix: true }; 187 } 188 if (isIdentifier(expression)) { 189 return { identifiers: [expression], isCompleteFix: true }; 190 } 191 if (isBinaryExpression(expression)) { 192 let sides: Identifier[] | undefined; 193 let isCompleteFix = true; 194 for (const side of [expression.left, expression.right]) { 195 const type = checker.getTypeAtLocation(side); 196 if (checker.getPromisedTypeOfPromise(type)) { 197 if (!isIdentifier(side)) { 198 isCompleteFix = false; 199 continue; 200 } 201 (sides || (sides = [])).push(side); 202 } 203 } 204 return sides && { identifiers: sides, isCompleteFix }; 205 } 206 } 207 208 function symbolReferenceIsAlsoMissingAwait(reference: Identifier, diagnostics: readonly Diagnostic[], sourceFile: SourceFile, checker: TypeChecker) { 209 const errorNode = isPropertyAccessExpression(reference.parent) ? reference.parent.name : 210 isBinaryExpression(reference.parent) ? reference.parent : 211 reference; 212 const diagnostic = find(diagnostics, diagnostic => 213 diagnostic.start === errorNode.getStart(sourceFile) && 214 (diagnostic.start + diagnostic.length!) === errorNode.getEnd()); 215 216 return diagnostic && contains(errorCodes, diagnostic.code) || 217 // A Promise is usually not correct in a binary expression (it’s not valid 218 // in an arithmetic expression and an equality comparison seems unusual), 219 // but if the other side of the binary expression has an error, the side 220 // is typed `any` which will squash the error that would identify this 221 // Promise as an invalid operand. So if the whole binary expression is 222 // typed `any` as a result, there is a strong likelihood that this Promise 223 // is accidentally missing `await`. 224 checker.getTypeAtLocation(errorNode).flags & TypeFlags.Any; 225 } 226 227 function isInsideAwaitableBody(node: Node) { 228 return node.kind & NodeFlags.AwaitContext || !!findAncestor(node, ancestor => 229 ancestor.parent && isArrowFunction(ancestor.parent) && ancestor.parent.body === ancestor || 230 isBlock(ancestor) && ( 231 ancestor.parent.kind === SyntaxKind.FunctionDeclaration || 232 ancestor.parent.kind === SyntaxKind.FunctionExpression || 233 ancestor.parent.kind === SyntaxKind.ArrowFunction || 234 ancestor.parent.kind === SyntaxKind.MethodDeclaration)); 235 } 236 237 function makeChange(changeTracker: textChanges.ChangeTracker, errorCode: number, sourceFile: SourceFile, checker: TypeChecker, insertionSite: Expression, fixedDeclarations?: Set<number>) { 238 if (isBinaryExpression(insertionSite)) { 239 for (const side of [insertionSite.left, insertionSite.right]) { 240 if (fixedDeclarations && isIdentifier(side)) { 241 const symbol = checker.getSymbolAtLocation(side); 242 if (symbol && fixedDeclarations.has(getSymbolId(symbol))) { 243 continue; 244 } 245 } 246 const type = checker.getTypeAtLocation(side); 247 const newNode = checker.getPromisedTypeOfPromise(type) ? factory.createAwaitExpression(side) : side; 248 changeTracker.replaceNode(sourceFile, side, newNode); 249 } 250 } 251 else if (errorCode === propertyAccessCode && isPropertyAccessExpression(insertionSite.parent)) { 252 if (fixedDeclarations && isIdentifier(insertionSite.parent.expression)) { 253 const symbol = checker.getSymbolAtLocation(insertionSite.parent.expression); 254 if (symbol && fixedDeclarations.has(getSymbolId(symbol))) { 255 return; 256 } 257 } 258 changeTracker.replaceNode( 259 sourceFile, 260 insertionSite.parent.expression, 261 factory.createParenthesizedExpression(factory.createAwaitExpression(insertionSite.parent.expression))); 262 insertLeadingSemicolonIfNeeded(changeTracker, insertionSite.parent.expression, sourceFile); 263 } 264 else if (contains(callableConstructableErrorCodes, errorCode) && isCallOrNewExpression(insertionSite.parent)) { 265 if (fixedDeclarations && isIdentifier(insertionSite)) { 266 const symbol = checker.getSymbolAtLocation(insertionSite); 267 if (symbol && fixedDeclarations.has(getSymbolId(symbol))) { 268 return; 269 } 270 } 271 changeTracker.replaceNode(sourceFile, insertionSite, factory.createParenthesizedExpression(factory.createAwaitExpression(insertionSite))); 272 insertLeadingSemicolonIfNeeded(changeTracker, insertionSite, sourceFile); 273 } 274 else { 275 if (fixedDeclarations && isVariableDeclaration(insertionSite.parent) && isIdentifier(insertionSite.parent.name)) { 276 const symbol = checker.getSymbolAtLocation(insertionSite.parent.name); 277 if (symbol && !tryAddToSet(fixedDeclarations, getSymbolId(symbol))) { 278 return; 279 } 280 } 281 changeTracker.replaceNode(sourceFile, insertionSite, factory.createAwaitExpression(insertionSite)); 282 } 283 } 284 285 function insertLeadingSemicolonIfNeeded(changeTracker: textChanges.ChangeTracker, beforeNode: Node, sourceFile: SourceFile) { 286 const precedingToken = findPrecedingToken(beforeNode.pos, sourceFile); 287 if (precedingToken && positionIsASICandidate(precedingToken.end, precedingToken.parent, sourceFile)) { 288 changeTracker.insertText(sourceFile, beforeNode.getStart(sourceFile), ";"); 289 } 290 } 291} 292