• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* @internal */
2namespace ts.codefix {
3    type ContextualTrackChangesFunction = (cb: (changeTracker: textChanges.ChangeTracker) => void) => FileTextChanges[];
4    const fixId = "addMissingAwait";
5    const propertyAccessCode = Diagnostics.Property_0_does_not_exist_on_type_1.code;
6    const callableConstructableErrorCodes = [
7        Diagnostics.This_expression_is_not_callable.code,
8        Diagnostics.This_expression_is_not_constructable.code,
9    ];
10    const errorCodes = [
11        Diagnostics.An_arithmetic_operand_must_be_of_type_any_number_bigint_or_an_enum_type.code,
12        Diagnostics.The_left_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code,
13        Diagnostics.The_right_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code,
14        Diagnostics.Operator_0_cannot_be_applied_to_type_1.code,
15        Diagnostics.Operator_0_cannot_be_applied_to_types_1_and_2.code,
16        Diagnostics.This_comparison_appears_to_be_unintentional_because_the_types_0_and_1_have_no_overlap.code,
17        Diagnostics.This_condition_will_always_return_true_since_this_0_is_always_defined.code,
18        Diagnostics.Type_0_is_not_an_array_type.code,
19        Diagnostics.Type_0_is_not_an_array_type_or_a_string_type.code,
20        Diagnostics.Type_0_can_only_be_iterated_through_when_using_the_downlevelIteration_flag_or_with_a_target_of_es2015_or_higher.code,
21        Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code,
22        Diagnostics.Type_0_is_not_an_array_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code,
23        Diagnostics.Type_0_must_have_a_Symbol_iterator_method_that_returns_an_iterator.code,
24        Diagnostics.Type_0_must_have_a_Symbol_asyncIterator_method_that_returns_an_async_iterator.code,
25        Diagnostics.Argument_of_type_0_is_not_assignable_to_parameter_of_type_1.code,
26        propertyAccessCode,
27        ...callableConstructableErrorCodes,
28    ];
29
30    registerCodeFix({
31        fixIds: [fixId],
32        errorCodes,
33        getCodeActions: function getCodeActionsToAddMissingAwait(context) {
34            const { sourceFile, errorCode, span, cancellationToken, program } = context;
35            const expression = getAwaitErrorSpanExpression(sourceFile, errorCode, span, cancellationToken, program);
36            if (!expression) {
37                return;
38            }
39
40            const checker = context.program.getTypeChecker();
41            const trackChanges: ContextualTrackChangesFunction = cb => textChanges.ChangeTracker.with(context, cb);
42            return compact([
43                getDeclarationSiteFix(context, expression, errorCode, checker, trackChanges),
44                getUseSiteFix(context, expression, errorCode, checker, trackChanges)]);
45        },
46        getAllCodeActions: context => {
47            const { sourceFile, program, cancellationToken } = context;
48            const checker = context.program.getTypeChecker();
49            const fixedDeclarations = new Set<number>();
50            return codeFixAll(context, errorCodes, (t, diagnostic) => {
51                const expression = getAwaitErrorSpanExpression(sourceFile, diagnostic.code, diagnostic, cancellationToken, program);
52                if (!expression) {
53                    return;
54                }
55                const trackChanges: ContextualTrackChangesFunction = cb => (cb(t), []);
56                return getDeclarationSiteFix(context, expression, diagnostic.code, checker, trackChanges, fixedDeclarations)
57                    || getUseSiteFix(context, expression, diagnostic.code, checker, trackChanges, fixedDeclarations);
58            });
59        },
60    });
61
62    function getAwaitErrorSpanExpression(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program) {
63        const expression = getFixableErrorSpanExpression(sourceFile, span);
64        return expression
65            && isMissingAwaitError(sourceFile, errorCode, span, cancellationToken, program)
66            && isInsideAwaitableBody(expression) ? expression : undefined;
67    }
68
69    function getDeclarationSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction, fixedDeclarations?: Set<number>) {
70        const { sourceFile, program, cancellationToken } = context;
71        const awaitableInitializers = findAwaitableInitializers(expression, sourceFile, cancellationToken, program, checker);
72        if (awaitableInitializers) {
73            const initializerChanges = trackChanges(t => {
74                forEach(awaitableInitializers.initializers, ({ expression }) => makeChange(t, errorCode, sourceFile, checker, expression, fixedDeclarations));
75                if (fixedDeclarations && awaitableInitializers.needsSecondPassForFixAll) {
76                    makeChange(t, errorCode, sourceFile, checker, expression, fixedDeclarations);
77                }
78            });
79            // No fix-all because it will already be included once with the use site fix,
80            // and for simplicity the fix-all doesn‘t let the user choose between use-site and declaration-site fixes.
81            return createCodeFixActionWithoutFixAll(
82                "addMissingAwaitToInitializer",
83                initializerChanges,
84                awaitableInitializers.initializers.length === 1
85                    ? [Diagnostics.Add_await_to_initializer_for_0, awaitableInitializers.initializers[0].declarationSymbol.name]
86                    : Diagnostics.Add_await_to_initializers);
87        }
88    }
89
90    function getUseSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction, fixedDeclarations?: Set<number>) {
91        const changes = trackChanges(t => makeChange(t, errorCode, context.sourceFile, checker, expression, fixedDeclarations));
92        return createCodeFixAction(fixId, changes, Diagnostics.Add_await, fixId, Diagnostics.Fix_all_expressions_possibly_missing_await);
93    }
94
95    function isMissingAwaitError(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program) {
96        const checker = program.getTypeChecker();
97        const diagnostics = checker.getDiagnostics(sourceFile, cancellationToken);
98        return some(diagnostics, ({ start, length, relatedInformation, code }) =>
99            isNumber(start) && isNumber(length) && textSpansEqual({ start, length }, span) &&
100            code === errorCode &&
101            !!relatedInformation &&
102            some(relatedInformation, related => related.code === Diagnostics.Did_you_forget_to_use_await.code));
103    }
104
105    interface AwaitableInitializer {
106        expression: Expression;
107        declarationSymbol: Symbol;
108    }
109
110    interface AwaitableInitializers {
111        initializers: readonly AwaitableInitializer[];
112        needsSecondPassForFixAll: boolean;
113    }
114
115    function findAwaitableInitializers(
116        expression: Node,
117        sourceFile: SourceFile,
118        cancellationToken: CancellationToken,
119        program: Program,
120        checker: TypeChecker,
121    ): AwaitableInitializers | undefined {
122        const identifiers = getIdentifiersFromErrorSpanExpression(expression, checker);
123        if (!identifiers) {
124            return;
125        }
126
127        let isCompleteFix = identifiers.isCompleteFix;
128        let initializers: AwaitableInitializer[] | undefined;
129        for (const identifier of identifiers.identifiers) {
130            const symbol = checker.getSymbolAtLocation(identifier);
131            if (!symbol) {
132                continue;
133            }
134
135            const declaration = tryCast(symbol.valueDeclaration, isVariableDeclaration);
136            const variableName = declaration && tryCast(declaration.name, isIdentifier);
137            const variableStatement = getAncestor(declaration, SyntaxKind.VariableStatement);
138            if (!declaration || !variableStatement ||
139                declaration.type ||
140                !declaration.initializer ||
141                variableStatement.getSourceFile() !== sourceFile ||
142                hasSyntacticModifier(variableStatement, ModifierFlags.Export) ||
143                !variableName ||
144                !isInsideAwaitableBody(declaration.initializer)) {
145                isCompleteFix = false;
146                continue;
147            }
148
149            const diagnostics = program.getSemanticDiagnostics(sourceFile, cancellationToken);
150            const isUsedElsewhere = FindAllReferences.Core.eachSymbolReferenceInFile(variableName, checker, sourceFile, reference => {
151                return identifier !== reference && !symbolReferenceIsAlsoMissingAwait(reference, diagnostics, sourceFile, checker);
152            });
153
154            if (isUsedElsewhere) {
155                isCompleteFix = false;
156                continue;
157            }
158
159            (initializers || (initializers = [])).push({
160                expression: declaration.initializer,
161                declarationSymbol: symbol,
162            });
163        }
164        return initializers && {
165            initializers,
166            needsSecondPassForFixAll: !isCompleteFix,
167        };
168    }
169
170    interface Identifiers {
171        identifiers: readonly Identifier[];
172        isCompleteFix: boolean;
173    }
174
175    function getIdentifiersFromErrorSpanExpression(expression: Node, checker: TypeChecker): Identifiers | undefined {
176        if (isPropertyAccessExpression(expression.parent) && isIdentifier(expression.parent.expression)) {
177            return { identifiers: [expression.parent.expression], isCompleteFix: true };
178        }
179        if (isIdentifier(expression)) {
180            return { identifiers: [expression], isCompleteFix: true };
181        }
182        if (isBinaryExpression(expression)) {
183            let sides: Identifier[] | undefined;
184            let isCompleteFix = true;
185            for (const side of [expression.left, expression.right]) {
186                const type = checker.getTypeAtLocation(side);
187                if (checker.getPromisedTypeOfPromise(type)) {
188                    if (!isIdentifier(side)) {
189                        isCompleteFix = false;
190                        continue;
191                    }
192                    (sides || (sides = [])).push(side);
193                }
194            }
195            return sides && { identifiers: sides, isCompleteFix };
196        }
197    }
198
199    function symbolReferenceIsAlsoMissingAwait(reference: Identifier, diagnostics: readonly Diagnostic[], sourceFile: SourceFile, checker: TypeChecker) {
200        const errorNode = isPropertyAccessExpression(reference.parent) ? reference.parent.name :
201            isBinaryExpression(reference.parent) ? reference.parent :
202            reference;
203        const diagnostic = find(diagnostics, diagnostic =>
204            diagnostic.start === errorNode.getStart(sourceFile) &&
205            (diagnostic.start + diagnostic.length!) === errorNode.getEnd());
206
207        return diagnostic && contains(errorCodes, diagnostic.code) ||
208            // A Promise is usually not correct in a binary expression (it’s not valid
209            // in an arithmetic expression and an equality comparison seems unusual),
210            // but if the other side of the binary expression has an error, the side
211            // is typed `any` which will squash the error that would identify this
212            // Promise as an invalid operand. So if the whole binary expression is
213            // typed `any` as a result, there is a strong likelihood that this Promise
214            // is accidentally missing `await`.
215            checker.getTypeAtLocation(errorNode).flags & TypeFlags.Any;
216    }
217
218    function isInsideAwaitableBody(node: Node) {
219        return node.kind & NodeFlags.AwaitContext || !!findAncestor(node, ancestor =>
220            ancestor.parent && isArrowFunction(ancestor.parent) && ancestor.parent.body === ancestor ||
221            isBlock(ancestor) && (
222                ancestor.parent.kind === SyntaxKind.FunctionDeclaration ||
223                ancestor.parent.kind === SyntaxKind.FunctionExpression ||
224                ancestor.parent.kind === SyntaxKind.ArrowFunction ||
225                ancestor.parent.kind === SyntaxKind.MethodDeclaration));
226    }
227
228    function makeChange(changeTracker: textChanges.ChangeTracker, errorCode: number, sourceFile: SourceFile, checker: TypeChecker, insertionSite: Expression, fixedDeclarations?: Set<number>) {
229        if (isForOfStatement(insertionSite.parent) && !insertionSite.parent.awaitModifier) {
230            const exprType = checker.getTypeAtLocation(insertionSite);
231            const asyncIter = checker.getAsyncIterableType();
232            if (asyncIter && checker.isTypeAssignableTo(exprType, asyncIter)) {
233                const forOf = insertionSite.parent;
234                changeTracker.replaceNode(sourceFile, forOf, factory.updateForOfStatement(forOf, factory.createToken(SyntaxKind.AwaitKeyword), forOf.initializer, forOf.expression, forOf.statement));
235                return;
236            }
237        }
238        if (isBinaryExpression(insertionSite)) {
239            for (const side of [insertionSite.left, insertionSite.right]) {
240                if (fixedDeclarations && isIdentifier(side)) {
241                    const symbol = checker.getSymbolAtLocation(side);
242                    if (symbol && fixedDeclarations.has(getSymbolId(symbol))) {
243                        continue;
244                    }
245                }
246                const type = checker.getTypeAtLocation(side);
247                const newNode = checker.getPromisedTypeOfPromise(type) ? factory.createAwaitExpression(side) : side;
248                changeTracker.replaceNode(sourceFile, side, newNode);
249            }
250        }
251        else if (errorCode === propertyAccessCode && isPropertyAccessExpression(insertionSite.parent)) {
252            if (fixedDeclarations && isIdentifier(insertionSite.parent.expression)) {
253                const symbol = checker.getSymbolAtLocation(insertionSite.parent.expression);
254                if (symbol && fixedDeclarations.has(getSymbolId(symbol))) {
255                    return;
256                }
257            }
258            changeTracker.replaceNode(
259                sourceFile,
260                insertionSite.parent.expression,
261                factory.createParenthesizedExpression(factory.createAwaitExpression(insertionSite.parent.expression)));
262            insertLeadingSemicolonIfNeeded(changeTracker, insertionSite.parent.expression, sourceFile);
263        }
264        else if (contains(callableConstructableErrorCodes, errorCode) && isCallOrNewExpression(insertionSite.parent)) {
265            if (fixedDeclarations && isIdentifier(insertionSite)) {
266                const symbol = checker.getSymbolAtLocation(insertionSite);
267                if (symbol && fixedDeclarations.has(getSymbolId(symbol))) {
268                    return;
269                }
270            }
271            changeTracker.replaceNode(sourceFile, insertionSite, factory.createParenthesizedExpression(factory.createAwaitExpression(insertionSite)));
272            insertLeadingSemicolonIfNeeded(changeTracker, insertionSite, sourceFile);
273        }
274        else {
275            if (fixedDeclarations && isVariableDeclaration(insertionSite.parent) && isIdentifier(insertionSite.parent.name)) {
276                const symbol = checker.getSymbolAtLocation(insertionSite.parent.name);
277                if (symbol && !tryAddToSet(fixedDeclarations, getSymbolId(symbol))) {
278                    return;
279                }
280            }
281            changeTracker.replaceNode(sourceFile, insertionSite, factory.createAwaitExpression(insertionSite));
282        }
283    }
284
285    function insertLeadingSemicolonIfNeeded(changeTracker: textChanges.ChangeTracker, beforeNode: Node, sourceFile: SourceFile) {
286        const precedingToken = findPrecedingToken(beforeNode.pos, sourceFile);
287        if (precedingToken && positionIsASICandidate(precedingToken.end, precedingToken.parent, sourceFile)) {
288            changeTracker.insertText(sourceFile, beforeNode.getStart(sourceFile), ";");
289        }
290    }
291}
292