• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* @internal */
2namespace ts.codefix {
3    type ContextualTrackChangesFunction = (cb: (changeTracker: textChanges.ChangeTracker) => void) => FileTextChanges[];
4    const fixId = "addMissingAwait";
5    const propertyAccessCode = Diagnostics.Property_0_does_not_exist_on_type_1.code;
6    const callableConstructableErrorCodes = [
7        Diagnostics.This_expression_is_not_callable.code,
8        Diagnostics.This_expression_is_not_constructable.code,
9    ];
10    const errorCodes = [
11        Diagnostics.An_arithmetic_operand_must_be_of_type_any_number_bigint_or_an_enum_type.code,
12        Diagnostics.The_left_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code,
13        Diagnostics.The_right_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code,
14        Diagnostics.Operator_0_cannot_be_applied_to_type_1.code,
15        Diagnostics.Operator_0_cannot_be_applied_to_types_1_and_2.code,
16        Diagnostics.This_condition_will_always_return_0_since_the_types_1_and_2_have_no_overlap.code,
17        Diagnostics.Type_0_is_not_an_array_type.code,
18        Diagnostics.Type_0_is_not_an_array_type_or_a_string_type.code,
19        Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_Use_compiler_option_downlevelIteration_to_allow_iterating_of_iterators.code,
20        Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code,
21        Diagnostics.Type_0_is_not_an_array_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code,
22        Diagnostics.Type_0_must_have_a_Symbol_iterator_method_that_returns_an_iterator.code,
23        Diagnostics.Type_0_must_have_a_Symbol_asyncIterator_method_that_returns_an_async_iterator.code,
24        Diagnostics.Argument_of_type_0_is_not_assignable_to_parameter_of_type_1.code,
25        propertyAccessCode,
26        ...callableConstructableErrorCodes,
27    ];
28
29    registerCodeFix({
30        fixIds: [fixId],
31        errorCodes,
32        getCodeActions: context => {
33            const { sourceFile, errorCode, span, cancellationToken, program } = context;
34            const expression = getFixableErrorSpanExpression(sourceFile, errorCode, span, cancellationToken, program);
35            if (!expression) {
36                return;
37            }
38
39            const checker = context.program.getTypeChecker();
40            const trackChanges: ContextualTrackChangesFunction = cb => textChanges.ChangeTracker.with(context, cb);
41            return compact([
42                getDeclarationSiteFix(context, expression, errorCode, checker, trackChanges),
43                getUseSiteFix(context, expression, errorCode, checker, trackChanges)]);
44        },
45        getAllCodeActions: context => {
46            const { sourceFile, program, cancellationToken } = context;
47            const checker = context.program.getTypeChecker();
48            const fixedDeclarations = new Set<number>();
49            return codeFixAll(context, errorCodes, (t, diagnostic) => {
50                const expression = getFixableErrorSpanExpression(sourceFile, diagnostic.code, diagnostic, cancellationToken, program);
51                if (!expression) {
52                    return;
53                }
54                const trackChanges: ContextualTrackChangesFunction = cb => (cb(t), []);
55                return getDeclarationSiteFix(context, expression, diagnostic.code, checker, trackChanges, fixedDeclarations)
56                    || getUseSiteFix(context, expression, diagnostic.code, checker, trackChanges, fixedDeclarations);
57            });
58        },
59    });
60
61    function getDeclarationSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction, fixedDeclarations?: Set<number>) {
62        const { sourceFile, program, cancellationToken } = context;
63        const awaitableInitializers = findAwaitableInitializers(expression, sourceFile, cancellationToken, program, checker);
64        if (awaitableInitializers) {
65            const initializerChanges = trackChanges(t => {
66                forEach(awaitableInitializers.initializers, ({ expression }) => makeChange(t, errorCode, sourceFile, checker, expression, fixedDeclarations));
67                if (fixedDeclarations && awaitableInitializers.needsSecondPassForFixAll) {
68                    makeChange(t, errorCode, sourceFile, checker, expression, fixedDeclarations);
69                }
70            });
71            // No fix-all because it will already be included once with the use site fix,
72            // and for simplicity the fix-all doesn‘t let the user choose between use-site and declaration-site fixes.
73            return createCodeFixActionWithoutFixAll(
74                "addMissingAwaitToInitializer",
75                initializerChanges,
76                awaitableInitializers.initializers.length === 1
77                    ? [Diagnostics.Add_await_to_initializer_for_0, awaitableInitializers.initializers[0].declarationSymbol.name]
78                    : Diagnostics.Add_await_to_initializers);
79        }
80    }
81
82    function getUseSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction, fixedDeclarations?: Set<number>) {
83        const changes = trackChanges(t => makeChange(t, errorCode, context.sourceFile, checker, expression, fixedDeclarations));
84        return createCodeFixAction(fixId, changes, Diagnostics.Add_await, fixId, Diagnostics.Fix_all_expressions_possibly_missing_await);
85    }
86
87    function isMissingAwaitError(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program) {
88        const checker = program.getDiagnosticsProducingTypeChecker();
89        const diagnostics = checker.getDiagnostics(sourceFile, cancellationToken);
90        return some(diagnostics, ({ start, length, relatedInformation, code }) =>
91            isNumber(start) && isNumber(length) && textSpansEqual({ start, length }, span) &&
92            code === errorCode &&
93            !!relatedInformation &&
94            some(relatedInformation, related => related.code === Diagnostics.Did_you_forget_to_use_await.code));
95    }
96
97    function getFixableErrorSpanExpression(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program): Expression | undefined {
98        const token = getTokenAtPosition(sourceFile, span.start);
99        // Checker has already done work to determine that await might be possible, and has attached
100        // related info to the node, so start by finding the expression that exactly matches up
101        // with the diagnostic range.
102        const expression = findAncestor(token, node => {
103            if (node.getStart(sourceFile) < span.start || node.getEnd() > textSpanEnd(span)) {
104                return "quit";
105            }
106            return isExpression(node) && textSpansEqual(span, createTextSpanFromNode(node, sourceFile));
107        }) as Expression | undefined;
108
109        return expression
110            && isMissingAwaitError(sourceFile, errorCode, span, cancellationToken, program)
111            && isInsideAwaitableBody(expression) ? expression : undefined;
112    }
113
114    interface AwaitableInitializer {
115        expression: Expression;
116        declarationSymbol: Symbol;
117    }
118
119    interface AwaitableInitializers {
120        initializers: readonly AwaitableInitializer[];
121        needsSecondPassForFixAll: boolean;
122    }
123
124    function findAwaitableInitializers(
125        expression: Node,
126        sourceFile: SourceFile,
127        cancellationToken: CancellationToken,
128        program: Program,
129        checker: TypeChecker,
130    ): AwaitableInitializers | undefined {
131        const identifiers = getIdentifiersFromErrorSpanExpression(expression, checker);
132        if (!identifiers) {
133            return;
134        }
135
136        let isCompleteFix = identifiers.isCompleteFix;
137        let initializers: AwaitableInitializer[] | undefined;
138        for (const identifier of identifiers.identifiers) {
139            const symbol = checker.getSymbolAtLocation(identifier);
140            if (!symbol) {
141                continue;
142            }
143
144            const declaration = tryCast(symbol.valueDeclaration, isVariableDeclaration);
145            const variableName = declaration && tryCast(declaration.name, isIdentifier);
146            const variableStatement = getAncestor(declaration, SyntaxKind.VariableStatement);
147            if (!declaration || !variableStatement ||
148                declaration.type ||
149                !declaration.initializer ||
150                variableStatement.getSourceFile() !== sourceFile ||
151                hasSyntacticModifier(variableStatement, ModifierFlags.Export) ||
152                !variableName ||
153                !isInsideAwaitableBody(declaration.initializer)) {
154                isCompleteFix = false;
155                continue;
156            }
157
158            const diagnostics = program.getSemanticDiagnostics(sourceFile, cancellationToken);
159            const isUsedElsewhere = FindAllReferences.Core.eachSymbolReferenceInFile(variableName, checker, sourceFile, reference => {
160                return identifier !== reference && !symbolReferenceIsAlsoMissingAwait(reference, diagnostics, sourceFile, checker);
161            });
162
163            if (isUsedElsewhere) {
164                isCompleteFix = false;
165                continue;
166            }
167
168            (initializers || (initializers = [])).push({
169                expression: declaration.initializer,
170                declarationSymbol: symbol,
171            });
172        }
173        return initializers && {
174            initializers,
175            needsSecondPassForFixAll: !isCompleteFix,
176        };
177    }
178
179    interface Identifiers {
180        identifiers: readonly Identifier[];
181        isCompleteFix: boolean;
182    }
183
184    function getIdentifiersFromErrorSpanExpression(expression: Node, checker: TypeChecker): Identifiers | undefined {
185        if (isPropertyAccessExpression(expression.parent) && isIdentifier(expression.parent.expression)) {
186            return { identifiers: [expression.parent.expression], isCompleteFix: true };
187        }
188        if (isIdentifier(expression)) {
189            return { identifiers: [expression], isCompleteFix: true };
190        }
191        if (isBinaryExpression(expression)) {
192            let sides: Identifier[] | undefined;
193            let isCompleteFix = true;
194            for (const side of [expression.left, expression.right]) {
195                const type = checker.getTypeAtLocation(side);
196                if (checker.getPromisedTypeOfPromise(type)) {
197                    if (!isIdentifier(side)) {
198                        isCompleteFix = false;
199                        continue;
200                    }
201                    (sides || (sides = [])).push(side);
202                }
203            }
204            return sides && { identifiers: sides, isCompleteFix };
205        }
206    }
207
208    function symbolReferenceIsAlsoMissingAwait(reference: Identifier, diagnostics: readonly Diagnostic[], sourceFile: SourceFile, checker: TypeChecker) {
209        const errorNode = isPropertyAccessExpression(reference.parent) ? reference.parent.name :
210            isBinaryExpression(reference.parent) ? reference.parent :
211            reference;
212        const diagnostic = find(diagnostics, diagnostic =>
213            diagnostic.start === errorNode.getStart(sourceFile) &&
214            (diagnostic.start + diagnostic.length!) === errorNode.getEnd());
215
216        return diagnostic && contains(errorCodes, diagnostic.code) ||
217            // A Promise is usually not correct in a binary expression (it’s not valid
218            // in an arithmetic expression and an equality comparison seems unusual),
219            // but if the other side of the binary expression has an error, the side
220            // is typed `any` which will squash the error that would identify this
221            // Promise as an invalid operand. So if the whole binary expression is
222            // typed `any` as a result, there is a strong likelihood that this Promise
223            // is accidentally missing `await`.
224            checker.getTypeAtLocation(errorNode).flags & TypeFlags.Any;
225    }
226
227    function isInsideAwaitableBody(node: Node) {
228        return node.kind & NodeFlags.AwaitContext || !!findAncestor(node, ancestor =>
229            ancestor.parent && isArrowFunction(ancestor.parent) && ancestor.parent.body === ancestor ||
230            isBlock(ancestor) && (
231                ancestor.parent.kind === SyntaxKind.FunctionDeclaration ||
232                ancestor.parent.kind === SyntaxKind.FunctionExpression ||
233                ancestor.parent.kind === SyntaxKind.ArrowFunction ||
234                ancestor.parent.kind === SyntaxKind.MethodDeclaration));
235    }
236
237    function makeChange(changeTracker: textChanges.ChangeTracker, errorCode: number, sourceFile: SourceFile, checker: TypeChecker, insertionSite: Expression, fixedDeclarations?: Set<number>) {
238        if (isBinaryExpression(insertionSite)) {
239            for (const side of [insertionSite.left, insertionSite.right]) {
240                if (fixedDeclarations && isIdentifier(side)) {
241                    const symbol = checker.getSymbolAtLocation(side);
242                    if (symbol && fixedDeclarations.has(getSymbolId(symbol))) {
243                        continue;
244                    }
245                }
246                const type = checker.getTypeAtLocation(side);
247                const newNode = checker.getPromisedTypeOfPromise(type) ? factory.createAwaitExpression(side) : side;
248                changeTracker.replaceNode(sourceFile, side, newNode);
249            }
250        }
251        else if (errorCode === propertyAccessCode && isPropertyAccessExpression(insertionSite.parent)) {
252            if (fixedDeclarations && isIdentifier(insertionSite.parent.expression)) {
253                const symbol = checker.getSymbolAtLocation(insertionSite.parent.expression);
254                if (symbol && fixedDeclarations.has(getSymbolId(symbol))) {
255                    return;
256                }
257            }
258            changeTracker.replaceNode(
259                sourceFile,
260                insertionSite.parent.expression,
261                factory.createParenthesizedExpression(factory.createAwaitExpression(insertionSite.parent.expression)));
262            insertLeadingSemicolonIfNeeded(changeTracker, insertionSite.parent.expression, sourceFile);
263        }
264        else if (contains(callableConstructableErrorCodes, errorCode) && isCallOrNewExpression(insertionSite.parent)) {
265            if (fixedDeclarations && isIdentifier(insertionSite)) {
266                const symbol = checker.getSymbolAtLocation(insertionSite);
267                if (symbol && fixedDeclarations.has(getSymbolId(symbol))) {
268                    return;
269                }
270            }
271            changeTracker.replaceNode(sourceFile, insertionSite, factory.createParenthesizedExpression(factory.createAwaitExpression(insertionSite)));
272            insertLeadingSemicolonIfNeeded(changeTracker, insertionSite, sourceFile);
273        }
274        else {
275            if (fixedDeclarations && isVariableDeclaration(insertionSite.parent) && isIdentifier(insertionSite.parent.name)) {
276                const symbol = checker.getSymbolAtLocation(insertionSite.parent.name);
277                if (symbol && !tryAddToSet(fixedDeclarations, getSymbolId(symbol))) {
278                    return;
279                }
280            }
281            changeTracker.replaceNode(sourceFile, insertionSite, factory.createAwaitExpression(insertionSite));
282        }
283    }
284
285    function insertLeadingSemicolonIfNeeded(changeTracker: textChanges.ChangeTracker, beforeNode: Node, sourceFile: SourceFile) {
286        const precedingToken = findPrecedingToken(beforeNode.pos, sourceFile);
287        if (precedingToken && positionIsASICandidate(precedingToken.end, precedingToken.parent, sourceFile)) {
288            changeTracker.insertText(sourceFile, beforeNode.getStart(sourceFile), ";");
289        }
290    }
291}
292