• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* @internal */
2namespace ts.codefix {
3    const fixId = "convertToAsyncFunction";
4    const errorCodes = [Diagnostics.This_may_be_converted_to_an_async_function.code];
5    let codeActionSucceeded = true;
6    registerCodeFix({
7        errorCodes,
8        getCodeActions(context: CodeFixContext) {
9            codeActionSucceeded = true;
10            const changes = textChanges.ChangeTracker.with(context, (t) => convertToAsyncFunction(t, context.sourceFile, context.span.start, context.program.getTypeChecker()));
11            return codeActionSucceeded ? [createCodeFixAction(fixId, changes, Diagnostics.Convert_to_async_function, fixId, Diagnostics.Convert_all_to_async_functions)] : [];
12        },
13        fixIds: [fixId],
14        getAllCodeActions: context => codeFixAll(context, errorCodes, (changes, err) => convertToAsyncFunction(changes, err.file, err.start, context.program.getTypeChecker())),
15    });
16
17    const enum SynthBindingNameKind {
18        Identifier,
19        BindingPattern,
20    }
21
22    type SynthBindingName = SynthBindingPattern | SynthIdentifier;
23
24    interface SynthBindingPattern {
25        readonly kind: SynthBindingNameKind.BindingPattern;
26        readonly elements: readonly SynthBindingName[];
27        readonly bindingPattern: BindingPattern;
28        readonly types: Type[];
29    }
30
31    interface SynthIdentifier {
32        readonly kind: SynthBindingNameKind.Identifier;
33        readonly identifier: Identifier;
34        readonly types: Type[];
35        /** A declaration for this identifier has already been generated */
36        hasBeenDeclared: boolean;
37        hasBeenReferenced: boolean;
38    }
39
40    interface Transformer {
41        readonly checker: TypeChecker;
42        readonly synthNamesMap: ESMap<string, SynthIdentifier>; // keys are the symbol id of the identifier
43        readonly setOfExpressionsToReturn: ReadonlySet<number>; // keys are the node ids of the expressions
44        readonly isInJSFile: boolean;
45    }
46
47    interface PromiseReturningCallExpression<Name extends string> extends CallExpression {
48        readonly expression: PropertyAccessExpression & {
49            readonly escapedText: Name;
50        };
51    }
52
53    function convertToAsyncFunction(changes: textChanges.ChangeTracker, sourceFile: SourceFile, position: number, checker: TypeChecker): void {
54        // get the function declaration - returns a promise
55        const tokenAtPosition = getTokenAtPosition(sourceFile, position);
56        let functionToConvert: FunctionLikeDeclaration | undefined;
57
58        // if the parent of a FunctionLikeDeclaration is a variable declaration, the convertToAsync diagnostic will be reported on the variable name
59        if (isIdentifier(tokenAtPosition) && isVariableDeclaration(tokenAtPosition.parent) &&
60            tokenAtPosition.parent.initializer && isFunctionLikeDeclaration(tokenAtPosition.parent.initializer)) {
61            functionToConvert = tokenAtPosition.parent.initializer;
62        }
63        else {
64            functionToConvert = tryCast(getContainingFunction(getTokenAtPosition(sourceFile, position)), canBeConvertedToAsync);
65        }
66
67        if (!functionToConvert) {
68            return;
69        }
70
71        const synthNamesMap = new Map<string, SynthIdentifier>();
72        const isInJavascript = isInJSFile(functionToConvert);
73        const setOfExpressionsToReturn = getAllPromiseExpressionsToReturn(functionToConvert, checker);
74        const functionToConvertRenamed = renameCollidingVarNames(functionToConvert, checker, synthNamesMap);
75        if (!returnsPromise(functionToConvertRenamed, checker)) {
76            return;
77        }
78
79        const returnStatements = functionToConvertRenamed.body && isBlock(functionToConvertRenamed.body) ? getReturnStatementsWithPromiseHandlers(functionToConvertRenamed.body, checker) : emptyArray;
80        const transformer: Transformer = { checker, synthNamesMap, setOfExpressionsToReturn, isInJSFile: isInJavascript };
81        if (!returnStatements.length) {
82            return;
83        }
84
85        const pos = skipTrivia(sourceFile.text, moveRangePastModifiers(functionToConvert).pos);
86        changes.insertModifierAt(sourceFile, pos, SyntaxKind.AsyncKeyword, { suffix: " " });
87
88        for (const returnStatement of returnStatements) {
89            forEachChild(returnStatement, function visit(node) {
90                if (isCallExpression(node)) {
91                    const newNodes = transformExpression(node, node, transformer, /*hasContinuation*/ false);
92                    if (hasFailed()) {
93                        return true; // return something truthy to shortcut out of more work
94                    }
95                    changes.replaceNodeWithNodes(sourceFile, returnStatement, newNodes);
96                }
97                else if (!isFunctionLike(node)) {
98                    forEachChild(node, visit);
99                    if (hasFailed()) {
100                        return true; // return something truthy to shortcut out of more work
101                    }
102                }
103            });
104            if (hasFailed()) {
105                return; // shortcut out of more work
106            }
107        }
108    }
109
110    function getReturnStatementsWithPromiseHandlers(body: Block, checker: TypeChecker): readonly ReturnStatement[] {
111        const res: ReturnStatement[] = [];
112        forEachReturnStatement(body, ret => {
113            if (isReturnStatementWithFixablePromiseHandler(ret, checker)) res.push(ret);
114        });
115        return res;
116    }
117
118    /*
119        Finds all of the expressions of promise type that should not be saved in a variable during the refactor
120    */
121    function getAllPromiseExpressionsToReturn(func: FunctionLikeDeclaration, checker: TypeChecker): Set<number> {
122        if (!func.body) {
123            return new Set();
124        }
125
126        const setOfExpressionsToReturn = new Set<number>();
127        forEachChild(func.body, function visit(node: Node) {
128            if (isPromiseReturningCallExpression(node, checker, "then")) {
129                setOfExpressionsToReturn.add(getNodeId(node));
130                forEach(node.arguments, visit);
131            }
132            else if (isPromiseReturningCallExpression(node, checker, "catch") ||
133                isPromiseReturningCallExpression(node, checker, "finally")) {
134                setOfExpressionsToReturn.add(getNodeId(node));
135                // if .catch() or .finally() is the last call in the chain, move leftward in the chain until we hit something else that should be returned
136                forEachChild(node, visit);
137            }
138            else if (isPromiseTypedExpression(node, checker)) {
139                setOfExpressionsToReturn.add(getNodeId(node));
140                // don't recurse here, since we won't refactor any children or arguments of the expression
141            }
142            else {
143                forEachChild(node, visit);
144            }
145        });
146
147        return setOfExpressionsToReturn;
148    }
149
150    function isPromiseReturningCallExpression<Name extends string>(node: Node, checker: TypeChecker, name: Name): node is PromiseReturningCallExpression<Name> {
151        if (!isCallExpression(node)) return false;
152        const isExpressionOfName = hasPropertyAccessExpressionWithName(node, name);
153        const nodeType = isExpressionOfName && checker.getTypeAtLocation(node);
154        return !!(nodeType && checker.getPromisedTypeOfPromise(nodeType));
155    }
156
157    // NOTE: this is a mostly copy of `isReferenceToType` from checker.ts. While this violates DRY, it keeps
158    // `isReferenceToType` in checker local to the checker to avoid the cost of a property lookup on `ts`.
159    function isReferenceToType(type: Type, target: Type) {
160        return (getObjectFlags(type) & ObjectFlags.Reference) !== 0
161            && (type as TypeReference).target === target;
162    }
163
164    function getExplicitPromisedTypeOfPromiseReturningCallExpression(node: PromiseReturningCallExpression<"then" | "catch" | "finally">, callback: Expression, checker: TypeChecker) {
165        if (node.expression.name.escapedText === "finally") {
166            // for a `finally`, there's no type argument
167            return undefined;
168        }
169
170        // If the call to `then` or `catch` comes from the global `Promise` or `PromiseLike` type, we can safely use the
171        // type argument supplied for the callback. For other promise types we would need a more complex heuristic to determine
172        // which type argument is safe to use as an annotation.
173        const promiseType = checker.getTypeAtLocation(node.expression.expression);
174        if (isReferenceToType(promiseType, checker.getPromiseType()) ||
175            isReferenceToType(promiseType, checker.getPromiseLikeType())) {
176            if (node.expression.name.escapedText === "then") {
177                if (callback === elementAt(node.arguments, 0)) {
178                    // for the `onfulfilled` callback, use the first type argument
179                    return elementAt(node.typeArguments, 0);
180                }
181                else if (callback === elementAt(node.arguments, 1)) {
182                    // for the `onrejected` callback, use the second type argument
183                    return elementAt(node.typeArguments, 1);
184                }
185            }
186            else {
187                return elementAt(node.typeArguments, 0);
188            }
189        }
190    }
191
192    function isPromiseTypedExpression(node: Node, checker: TypeChecker): node is Expression {
193        if (!isExpression(node)) return false;
194        return !!checker.getPromisedTypeOfPromise(checker.getTypeAtLocation(node));
195    }
196
197    /*
198        Renaming of identifiers may be necessary as the refactor changes scopes -
199        This function collects all existing identifier names and names of identifiers that will be created in the refactor.
200        It then checks for any collisions and renames them through getSynthesizedDeepClone
201    */
202    function renameCollidingVarNames(nodeToRename: FunctionLikeDeclaration, checker: TypeChecker, synthNamesMap: ESMap<string, SynthIdentifier>): FunctionLikeDeclaration {
203        const identsToRenameMap = new Map<string, Identifier>(); // key is the symbol id
204        const collidingSymbolMap = createMultiMap<Symbol>();
205        forEachChild(nodeToRename, function visit(node: Node) {
206            if (!isIdentifier(node)) {
207                forEachChild(node, visit);
208                return;
209            }
210            const symbol = checker.getSymbolAtLocation(node);
211            if (symbol) {
212                const type = checker.getTypeAtLocation(node);
213                // Note - the choice of the last call signature is arbitrary
214                const lastCallSignature = getLastCallSignature(type, checker);
215                const symbolIdString = getSymbolId(symbol).toString();
216
217                // If the identifier refers to a function, we want to add the new synthesized variable for the declaration. Example:
218                //   fetch('...').then(response => { ... })
219                // will eventually become
220                //   const response = await fetch('...')
221                // so we push an entry for 'response'.
222                if (lastCallSignature && !isParameter(node.parent) && !isFunctionLikeDeclaration(node.parent) && !synthNamesMap.has(symbolIdString)) {
223                    const firstParameter = firstOrUndefined(lastCallSignature.parameters);
224                    const ident = firstParameter?.valueDeclaration
225                        && isParameter(firstParameter.valueDeclaration)
226                        && tryCast(firstParameter.valueDeclaration.name, isIdentifier)
227                        || factory.createUniqueName("result", GeneratedIdentifierFlags.Optimistic);
228                    const synthName = getNewNameIfConflict(ident, collidingSymbolMap);
229                    synthNamesMap.set(symbolIdString, synthName);
230                    collidingSymbolMap.add(ident.text, symbol);
231                }
232                // We only care about identifiers that are parameters, variable declarations, or binding elements
233                else if (node.parent && (isParameter(node.parent) || isVariableDeclaration(node.parent) || isBindingElement(node.parent))) {
234                    const originalName = node.text;
235                    const collidingSymbols = collidingSymbolMap.get(originalName);
236
237                    // if the identifier name conflicts with a different identifier that we've already seen
238                    if (collidingSymbols && collidingSymbols.some(prevSymbol => prevSymbol !== symbol)) {
239                        const newName = getNewNameIfConflict(node, collidingSymbolMap);
240                        identsToRenameMap.set(symbolIdString, newName.identifier);
241                        synthNamesMap.set(symbolIdString, newName);
242                        collidingSymbolMap.add(originalName, symbol);
243                    }
244                    else {
245                        const identifier = getSynthesizedDeepClone(node);
246                        synthNamesMap.set(symbolIdString, createSynthIdentifier(identifier));
247                        collidingSymbolMap.add(originalName, symbol);
248                    }
249                }
250            }
251        });
252
253        return getSynthesizedDeepCloneWithReplacements(nodeToRename, /*includeTrivia*/ true, original => {
254            if (isBindingElement(original) && isIdentifier(original.name) && isObjectBindingPattern(original.parent)) {
255                const symbol = checker.getSymbolAtLocation(original.name);
256                const renameInfo = symbol && identsToRenameMap.get(String(getSymbolId(symbol)));
257                if (renameInfo && renameInfo.text !== (original.name || original.propertyName).getText()) {
258                    return factory.createBindingElement(
259                        original.dotDotDotToken,
260                        original.propertyName || original.name,
261                        renameInfo,
262                        original.initializer);
263                }
264            }
265            else if (isIdentifier(original)) {
266                const symbol = checker.getSymbolAtLocation(original);
267                const renameInfo = symbol && identsToRenameMap.get(String(getSymbolId(symbol)));
268                if (renameInfo) {
269                    return factory.createIdentifier(renameInfo.text);
270                }
271            }
272        });
273    }
274
275    function getNewNameIfConflict(name: Identifier, originalNames: ReadonlyESMap<string, Symbol[]>): SynthIdentifier {
276        const numVarsSameName = (originalNames.get(name.text) || emptyArray).length;
277        const identifier = numVarsSameName === 0 ? name : factory.createIdentifier(name.text + "_" + numVarsSameName);
278        return createSynthIdentifier(identifier);
279    }
280
281    function hasFailed() {
282        return !codeActionSucceeded;
283    }
284
285    function silentFail() {
286        codeActionSucceeded = false;
287        return emptyArray;
288    }
289
290    // dispatch function to recursively build the refactoring
291    // should be kept up to date with isFixablePromiseHandler in suggestionDiagnostics.ts
292    /**
293     * @param hasContinuation Whether another `then`, `catch`, or `finally` continuation follows the continuation to which this expression belongs.
294     * @param continuationArgName The argument name for the continuation that follows this call.
295     */
296    function transformExpression(returnContextNode: Expression, node: Expression, transformer: Transformer, hasContinuation: boolean, continuationArgName?: SynthBindingName): readonly Statement[] {
297        if (isPromiseReturningCallExpression(node, transformer.checker, "then")) {
298            return transformThen(node, elementAt(node.arguments, 0), elementAt(node.arguments, 1), transformer, hasContinuation, continuationArgName);
299        }
300        if (isPromiseReturningCallExpression(node, transformer.checker, "catch")) {
301            return transformCatch(node, elementAt(node.arguments, 0), transformer, hasContinuation, continuationArgName);
302        }
303        if (isPromiseReturningCallExpression(node, transformer.checker, "finally")) {
304            return transformFinally(node, elementAt(node.arguments, 0), transformer, hasContinuation, continuationArgName);
305        }
306        if (isPropertyAccessExpression(node)) {
307            return transformExpression(returnContextNode, node.expression, transformer, hasContinuation, continuationArgName);
308        }
309
310        const nodeType = transformer.checker.getTypeAtLocation(node);
311        if (nodeType && transformer.checker.getPromisedTypeOfPromise(nodeType)) {
312            Debug.assertNode(getOriginalNode(node).parent, isPropertyAccessExpression);
313            return transformPromiseExpressionOfPropertyAccess(returnContextNode, node, transformer, hasContinuation, continuationArgName);
314        }
315
316        return silentFail();
317    }
318
319    function isNullOrUndefined({ checker }: Transformer, node: Expression) {
320        if (node.kind === SyntaxKind.NullKeyword) return true;
321        if (isIdentifier(node) && !isGeneratedIdentifier(node) && idText(node) === "undefined") {
322            const symbol = checker.getSymbolAtLocation(node);
323            return !symbol || checker.isUndefinedSymbol(symbol);
324        }
325        return false;
326    }
327
328    function createUniqueSynthName(prevArgName: SynthIdentifier): SynthIdentifier {
329        const renamedPrevArg = factory.createUniqueName(prevArgName.identifier.text, GeneratedIdentifierFlags.Optimistic);
330        return createSynthIdentifier(renamedPrevArg);
331    }
332
333    function getPossibleNameForVarDecl(node: PromiseReturningCallExpression<"then" | "catch" | "finally">, transformer: Transformer, continuationArgName?: SynthBindingName) {
334        let possibleNameForVarDecl: SynthIdentifier | undefined;
335
336        // If there is another call in the chain after the .catch() or .finally() we are transforming, we will need to save the result of both paths
337        // (try block and catch/finally block). To do this, we will need to synthesize a variable that we were not aware of while we were adding
338        // identifiers to the synthNamesMap. We will use the continuationArgName and then update the synthNamesMap with a new variable name for
339        // the next transformation step
340
341        if (continuationArgName && !shouldReturn(node, transformer)) {
342            if (isSynthIdentifier(continuationArgName)) {
343                possibleNameForVarDecl = continuationArgName;
344                transformer.synthNamesMap.forEach((val, key) => {
345                    if (val.identifier.text === continuationArgName.identifier.text) {
346                        const newSynthName = createUniqueSynthName(continuationArgName);
347                        transformer.synthNamesMap.set(key, newSynthName);
348                    }
349                });
350            }
351            else {
352                possibleNameForVarDecl = createSynthIdentifier(factory.createUniqueName("result", GeneratedIdentifierFlags.Optimistic), continuationArgName.types);
353            }
354
355            // We are about to write a 'let' variable declaration, but `transformExpression` for both
356            // the try block and catch/finally block will assign to this name. Setting this flag indicates
357            // that future assignments should be written as `name = value` instead of `const name = value`.
358            declareSynthIdentifier(possibleNameForVarDecl);
359        }
360
361        return possibleNameForVarDecl;
362    }
363
364    function finishCatchOrFinallyTransform(node: PromiseReturningCallExpression<"then" | "catch" | "finally">, transformer: Transformer, tryStatement: TryStatement, possibleNameForVarDecl: SynthIdentifier | undefined, continuationArgName?: SynthBindingName) {
365        const statements: Statement[] = [];
366
367        // In order to avoid an implicit any, we will synthesize a type for the declaration using the unions of the types of both paths (try block and catch block)
368        let varDeclIdentifier: Identifier | undefined;
369
370        if (possibleNameForVarDecl && !shouldReturn(node, transformer)) {
371            varDeclIdentifier = getSynthesizedDeepClone(declareSynthIdentifier(possibleNameForVarDecl));
372            const typeArray: Type[] = possibleNameForVarDecl.types;
373            const unionType = transformer.checker.getUnionType(typeArray, UnionReduction.Subtype);
374            const unionTypeNode = transformer.isInJSFile ? undefined : transformer.checker.typeToTypeNode(unionType, /*enclosingDeclaration*/ undefined, /*flags*/ undefined);
375            const varDecl = [factory.createVariableDeclaration(varDeclIdentifier, /*exclamationToken*/ undefined, unionTypeNode)];
376            const varDeclList = factory.createVariableStatement(/*modifiers*/ undefined, factory.createVariableDeclarationList(varDecl, NodeFlags.Let));
377            statements.push(varDeclList);
378        }
379
380        statements.push(tryStatement);
381
382        if (continuationArgName && varDeclIdentifier && isSynthBindingPattern(continuationArgName)) {
383            statements.push(factory.createVariableStatement(
384                /*modifiers*/ undefined,
385                factory.createVariableDeclarationList([
386                    factory.createVariableDeclaration(
387                        getSynthesizedDeepClone(declareSynthBindingPattern(continuationArgName)),
388                        /*exclamationToken*/ undefined,
389                        /*type*/ undefined,
390                        varDeclIdentifier
391                    )],
392                    NodeFlags.Const)));
393        }
394
395        return statements;
396    }
397
398    /**
399     * @param hasContinuation Whether another `then`, `catch`, or `finally` continuation follows this continuation.
400     * @param continuationArgName The argument name for the continuation that follows this call.
401     */
402    function transformFinally(node: PromiseReturningCallExpression<"finally">, onFinally: Expression | undefined, transformer: Transformer, hasContinuation: boolean, continuationArgName?: SynthBindingName): readonly Statement[] {
403        if (!onFinally || isNullOrUndefined(transformer, onFinally)) {
404            // Ignore this call as it has no effect on the result
405            return transformExpression(/* returnContextNode */ node, node.expression.expression, transformer, hasContinuation, continuationArgName);
406        }
407
408        const possibleNameForVarDecl = getPossibleNameForVarDecl(node, transformer, continuationArgName);
409
410        // Transform the left-hand-side of `.finally` into an array of inlined statements. We pass `true` for hasContinuation as `node` is the outer continuation.
411        const inlinedLeftHandSide = transformExpression(/*returnContextNode*/ node, node.expression.expression, transformer, /*hasContinuation*/ true, possibleNameForVarDecl);
412        if (hasFailed()) return silentFail(); // shortcut out of more work
413
414        // Transform the callback argument into an array of inlined statements. We pass whether we have an outer continuation here
415        // as that indicates whether `return` is valid.
416        const inlinedCallback = transformCallbackArgument(onFinally, hasContinuation, /*continuationArgName*/ undefined, /*argName*/ undefined, node, transformer);
417        if (hasFailed()) return silentFail(); // shortcut out of more work
418
419        const tryBlock = factory.createBlock(inlinedLeftHandSide);
420        const finallyBlock = factory.createBlock(inlinedCallback);
421        const tryStatement = factory.createTryStatement(tryBlock, /*catchClause*/ undefined, finallyBlock);
422        return finishCatchOrFinallyTransform(node, transformer, tryStatement, possibleNameForVarDecl, continuationArgName);
423    }
424
425    /**
426     * @param hasContinuation Whether another `then`, `catch`, or `finally` continuation follows this continuation.
427     * @param continuationArgName The argument name for the continuation that follows this call.
428     */
429    function transformCatch(node: PromiseReturningCallExpression<"then" | "catch">, onRejected: Expression | undefined, transformer: Transformer, hasContinuation: boolean, continuationArgName?: SynthBindingName): readonly Statement[] {
430        if (!onRejected || isNullOrUndefined(transformer, onRejected)) {
431            // Ignore this call as it has no effect on the result
432            return transformExpression(/* returnContextNode */ node, node.expression.expression, transformer, hasContinuation, continuationArgName);
433        }
434
435        const inputArgName = getArgBindingName(onRejected, transformer);
436        const possibleNameForVarDecl = getPossibleNameForVarDecl(node, transformer, continuationArgName);
437
438        // Transform the left-hand-side of `.then`/`.catch` into an array of inlined statements. We pass `true` for hasContinuation as `node` is the outer continuation.
439        const inlinedLeftHandSide = transformExpression(/*returnContextNode*/ node, node.expression.expression, transformer, /*hasContinuation*/ true, possibleNameForVarDecl);
440        if (hasFailed()) return silentFail(); // shortcut out of more work
441
442        // Transform the callback argument into an array of inlined statements. We pass whether we have an outer continuation here
443        // as that indicates whether `return` is valid.
444        const inlinedCallback = transformCallbackArgument(onRejected, hasContinuation, possibleNameForVarDecl, inputArgName, node, transformer);
445        if (hasFailed()) return silentFail(); // shortcut out of more work
446
447        const tryBlock = factory.createBlock(inlinedLeftHandSide);
448        const catchClause = factory.createCatchClause(inputArgName && getSynthesizedDeepClone(declareSynthBindingName(inputArgName)), factory.createBlock(inlinedCallback));
449        const tryStatement = factory.createTryStatement(tryBlock, catchClause, /*finallyBlock*/ undefined);
450        return finishCatchOrFinallyTransform(node, transformer, tryStatement, possibleNameForVarDecl, continuationArgName);
451    }
452
453    /**
454     * @param hasContinuation Whether another `then`, `catch`, or `finally` continuation follows this continuation.
455     * @param continuationArgName The argument name for the continuation that follows this call.
456     */
457    function transformThen(node: PromiseReturningCallExpression<"then">, onFulfilled: Expression | undefined, onRejected: Expression | undefined, transformer: Transformer, hasContinuation: boolean, continuationArgName?: SynthBindingName): readonly Statement[] {
458        if (!onFulfilled || isNullOrUndefined(transformer, onFulfilled)) {
459            // If we don't have an `onfulfilled` callback, try treating this as a `.catch`.
460            return transformCatch(node, onRejected, transformer, hasContinuation, continuationArgName);
461        }
462
463        // We don't currently support transforming a `.then` with both onfulfilled and onrejected handlers, per GH#38152.
464        if (onRejected && !isNullOrUndefined(transformer, onRejected)) {
465            return silentFail();
466        }
467
468        const inputArgName = getArgBindingName(onFulfilled, transformer);
469
470        // Transform the left-hand-side of `.then` into an array of inlined statements. We pass `true` for hasContinuation as `node` is the outer continuation.
471        const inlinedLeftHandSide = transformExpression(node.expression.expression, node.expression.expression, transformer, /*hasContinuation*/ true, inputArgName);
472        if (hasFailed()) return silentFail(); // shortcut out of more work
473
474        // Transform the callback argument into an array of inlined statements. We pass whether we have an outer continuation here
475        // as that indicates whether `return` is valid.
476        const inlinedCallback = transformCallbackArgument(onFulfilled, hasContinuation, continuationArgName, inputArgName, node, transformer);
477        if (hasFailed()) return silentFail(); // shortcut out of more work
478
479        return concatenate(inlinedLeftHandSide, inlinedCallback);
480    }
481
482    /**
483     * Transforms the 'x' part of `x.then(...)`, or the 'y()' part of `y().catch(...)`, where 'x' and 'y()' are Promises.
484     */
485    function transformPromiseExpressionOfPropertyAccess(returnContextNode: Expression, node: Expression, transformer: Transformer, hasContinuation: boolean, continuationArgName?: SynthBindingName): readonly Statement[] {
486        if (shouldReturn(returnContextNode, transformer)) {
487            let returnValue = getSynthesizedDeepClone(node);
488            if (hasContinuation) {
489                returnValue = factory.createAwaitExpression(returnValue);
490            }
491            return [factory.createReturnStatement(returnValue)];
492        }
493
494        return createVariableOrAssignmentOrExpressionStatement(continuationArgName, factory.createAwaitExpression(node), /*typeAnnotation*/ undefined);
495    }
496
497    function createVariableOrAssignmentOrExpressionStatement(variableName: SynthBindingName | undefined, rightHandSide: Expression, typeAnnotation: TypeNode | undefined): readonly Statement[] {
498        if (!variableName || isEmptyBindingName(variableName)) {
499            // if there's no argName to assign to, there still might be side effects
500            return [factory.createExpressionStatement(rightHandSide)];
501        }
502
503        if (isSynthIdentifier(variableName) && variableName.hasBeenDeclared) {
504            // if the variable has already been declared, we don't need "let" or "const"
505            return [factory.createExpressionStatement(factory.createAssignment(getSynthesizedDeepClone(referenceSynthIdentifier(variableName)), rightHandSide))];
506        }
507
508        return [
509            factory.createVariableStatement(
510                /*modifiers*/ undefined,
511                factory.createVariableDeclarationList([
512                    factory.createVariableDeclaration(
513                        getSynthesizedDeepClone(declareSynthBindingName(variableName)),
514                        /*exclamationToken*/ undefined,
515                        typeAnnotation,
516                        rightHandSide)],
517                    NodeFlags.Const))];
518    }
519
520    function maybeAnnotateAndReturn(expressionToReturn: Expression | undefined, typeAnnotation: TypeNode | undefined): Statement[] {
521        if (typeAnnotation && expressionToReturn) {
522            const name = factory.createUniqueName("result", GeneratedIdentifierFlags.Optimistic);
523            return [
524                ...createVariableOrAssignmentOrExpressionStatement(createSynthIdentifier(name), expressionToReturn, typeAnnotation),
525                factory.createReturnStatement(name)
526            ];
527        }
528        return [factory.createReturnStatement(expressionToReturn)];
529    }
530
531    // should be kept up to date with isFixablePromiseArgument in suggestionDiagnostics.ts
532    /**
533     * @param hasContinuation Whether another `then`, `catch`, or `finally` continuation follows the continuation to which this callback belongs.
534     * @param continuationArgName The argument name for the continuation that follows this call.
535     * @param inputArgName The argument name provided to this call
536     */
537    function transformCallbackArgument(func: Expression, hasContinuation: boolean, continuationArgName: SynthBindingName | undefined, inputArgName: SynthBindingName | undefined, parent: PromiseReturningCallExpression<"then" | "catch" | "finally">, transformer: Transformer): readonly Statement[] {
538        switch (func.kind) {
539            case SyntaxKind.NullKeyword:
540                // do not produce a transformed statement for a null argument
541                break;
542            case SyntaxKind.PropertyAccessExpression:
543            case SyntaxKind.Identifier: // identifier includes undefined
544                if (!inputArgName) {
545                    // undefined was argument passed to promise handler
546                    break;
547                }
548
549                const synthCall = factory.createCallExpression(getSynthesizedDeepClone(func as Identifier | PropertyAccessExpression), /*typeArguments*/ undefined, isSynthIdentifier(inputArgName) ? [referenceSynthIdentifier(inputArgName)] : []);
550
551                if (shouldReturn(parent, transformer)) {
552                    return maybeAnnotateAndReturn(synthCall, getExplicitPromisedTypeOfPromiseReturningCallExpression(parent, func, transformer.checker));
553                }
554
555                const type = transformer.checker.getTypeAtLocation(func);
556                const callSignatures = transformer.checker.getSignaturesOfType(type, SignatureKind.Call);
557                if (!callSignatures.length) {
558                    // if identifier in handler has no call signatures, it's invalid
559                    return silentFail();
560                }
561                const returnType = callSignatures[0].getReturnType();
562                const varDeclOrAssignment = createVariableOrAssignmentOrExpressionStatement(continuationArgName, factory.createAwaitExpression(synthCall), getExplicitPromisedTypeOfPromiseReturningCallExpression(parent, func, transformer.checker));
563                if (continuationArgName) {
564                    continuationArgName.types.push(transformer.checker.getAwaitedType(returnType) || returnType);
565                }
566                return varDeclOrAssignment;
567
568            case SyntaxKind.FunctionExpression:
569            case SyntaxKind.ArrowFunction: {
570                const funcBody = (func as FunctionExpression | ArrowFunction).body;
571                const returnType = getLastCallSignature(transformer.checker.getTypeAtLocation(func), transformer.checker)?.getReturnType();
572
573                // Arrow functions with block bodies { } will enter this control flow
574                if (isBlock(funcBody)) {
575                    let refactoredStmts: Statement[] = [];
576                    let seenReturnStatement = false;
577                    for (const statement of funcBody.statements) {
578                        if (isReturnStatement(statement)) {
579                            seenReturnStatement = true;
580                            if (isReturnStatementWithFixablePromiseHandler(statement, transformer.checker)) {
581                                refactoredStmts = refactoredStmts.concat(transformReturnStatementWithFixablePromiseHandler(transformer, statement, hasContinuation, continuationArgName));
582                            }
583                            else {
584                                const possiblyAwaitedRightHandSide = returnType && statement.expression ? getPossiblyAwaitedRightHandSide(transformer.checker, returnType, statement.expression) : statement.expression;
585                                refactoredStmts.push(...maybeAnnotateAndReturn(possiblyAwaitedRightHandSide, getExplicitPromisedTypeOfPromiseReturningCallExpression(parent, func, transformer.checker)));
586                            }
587                        }
588                        else if (hasContinuation && forEachReturnStatement(statement, returnTrue)) {
589                            // If there is a nested `return` in a callback that has a trailing continuation, we don't transform it as the resulting complexity is too great. For example:
590                            //
591                            // source                               | result
592                            // -------------------------------------| ---------------------------------------
593                            // function f(): Promise<number> {      | async function f9(): Promise<number> {
594                            //     return foo().then(() => {        |     await foo();
595                            //         if (Math.random()) {         |     if (Math.random()) {
596                            //             return 1;                |         return 1; // incorrect early return
597                            //         }                            |     }
598                            //         return 2;                    |     return 2; // incorrect early return
599                            //     }).then(a => {                   |     const a = undefined;
600                            //         return a + 1;                |     return a + 1;
601                            //     });                              | }
602                            // }                                    |
603                            //
604                            // However, branching returns in the outermost continuation are acceptable as no other continuation follows it:
605                            //
606                            // source                               | result
607                            //--------------------------------------|---------------------------------------
608                            // function f() {                       | async function f() {
609                            //     return foo().then(res => {       |     const res = await foo();
610                            //       if (res.ok) {                  |     if (res.ok) {
611                            //         return 1;                    |         return 1;
612                            //       }                              |     }
613                            //       else {                         |     else {
614                            //         if (res.buffer.length > 5) { |         if (res.buffer.length > 5) {
615                            //           return 2;                  |             return 2;
616                            //         }                            |         }
617                            //         else {                       |         else {
618                            //             return 3;                |             return 3;
619                            //         }                            |         }
620                            //       }                              |     }
621                            //     });                              | }
622                            // }                                    |
623                            //
624                            // We may improve this in the future, but for now the heuristics are too complex
625
626                            return silentFail();
627                        }
628                        else {
629                            refactoredStmts.push(statement);
630                        }
631                    }
632
633                    return shouldReturn(parent, transformer)
634                        ? refactoredStmts.map(s => getSynthesizedDeepClone(s))
635                        : removeReturns(
636                            refactoredStmts,
637                            continuationArgName,
638                            transformer,
639                            seenReturnStatement);
640                }
641                else {
642                    const inlinedStatements = isFixablePromiseHandler(funcBody, transformer.checker) ?
643                        transformReturnStatementWithFixablePromiseHandler(transformer, factory.createReturnStatement(funcBody), hasContinuation, continuationArgName) :
644                        emptyArray;
645
646                    if (inlinedStatements.length > 0) {
647                        return inlinedStatements;
648                    }
649
650                    if (returnType) {
651                        const possiblyAwaitedRightHandSide = getPossiblyAwaitedRightHandSide(transformer.checker, returnType, funcBody);
652
653                        if (!shouldReturn(parent, transformer)) {
654                            const transformedStatement = createVariableOrAssignmentOrExpressionStatement(continuationArgName, possiblyAwaitedRightHandSide, /*typeAnnotation*/ undefined);
655                            if (continuationArgName) {
656                                continuationArgName.types.push(transformer.checker.getAwaitedType(returnType) || returnType);
657                            }
658                            return transformedStatement;
659                        }
660                        else {
661                            return maybeAnnotateAndReturn(possiblyAwaitedRightHandSide, getExplicitPromisedTypeOfPromiseReturningCallExpression(parent, func, transformer.checker));
662                        }
663                    }
664                    else {
665                        return silentFail();
666                    }
667                }
668            }
669            default:
670                // If no cases apply, we've found a transformation body we don't know how to handle, so the refactoring should no-op to avoid deleting code.
671                return silentFail();
672        }
673        return emptyArray;
674    }
675
676    function getPossiblyAwaitedRightHandSide(checker: TypeChecker, type: Type, expr: Expression): AwaitExpression | Expression {
677        const rightHandSide = getSynthesizedDeepClone(expr);
678        return !!checker.getPromisedTypeOfPromise(type) ? factory.createAwaitExpression(rightHandSide) : rightHandSide;
679    }
680
681    function getLastCallSignature(type: Type, checker: TypeChecker): Signature | undefined {
682        const callSignatures = checker.getSignaturesOfType(type, SignatureKind.Call);
683        return lastOrUndefined(callSignatures);
684    }
685
686    function removeReturns(stmts: readonly Statement[], prevArgName: SynthBindingName | undefined, transformer: Transformer, seenReturnStatement: boolean): readonly Statement[] {
687        const ret: Statement[] = [];
688        for (const stmt of stmts) {
689            if (isReturnStatement(stmt)) {
690                if (stmt.expression) {
691                    const possiblyAwaitedExpression = isPromiseTypedExpression(stmt.expression, transformer.checker) ? factory.createAwaitExpression(stmt.expression) : stmt.expression;
692                    if (prevArgName === undefined) {
693                        ret.push(factory.createExpressionStatement(possiblyAwaitedExpression));
694                    }
695                    else if (isSynthIdentifier(prevArgName) && prevArgName.hasBeenDeclared) {
696                        ret.push(factory.createExpressionStatement(factory.createAssignment(referenceSynthIdentifier(prevArgName), possiblyAwaitedExpression)));
697                    }
698                    else {
699                        ret.push(factory.createVariableStatement(/*modifiers*/ undefined,
700                            (factory.createVariableDeclarationList([factory.createVariableDeclaration(declareSynthBindingName(prevArgName), /*exclamationToken*/ undefined, /*type*/ undefined, possiblyAwaitedExpression)], NodeFlags.Const))));
701                    }
702                }
703            }
704            else {
705                ret.push(getSynthesizedDeepClone(stmt));
706            }
707        }
708
709        // if block has no return statement, need to define prevArgName as undefined to prevent undeclared variables
710        if (!seenReturnStatement && prevArgName !== undefined) {
711            ret.push(factory.createVariableStatement(/*modifiers*/ undefined,
712                (factory.createVariableDeclarationList([factory.createVariableDeclaration(declareSynthBindingName(prevArgName), /*exclamationToken*/ undefined, /*type*/ undefined, factory.createIdentifier("undefined"))], NodeFlags.Const))));
713        }
714
715        return ret;
716    }
717
718    /**
719     * @param hasContinuation Whether another `then`, `catch`, or `finally` continuation follows the continuation to which this statement belongs.
720     * @param continuationArgName The argument name for the continuation that follows this call.
721     */
722    function transformReturnStatementWithFixablePromiseHandler(transformer: Transformer, innerRetStmt: ReturnStatement, hasContinuation: boolean, continuationArgName?: SynthBindingName) {
723        let innerCbBody: Statement[] = [];
724        forEachChild(innerRetStmt, function visit(node) {
725            if (isCallExpression(node)) {
726                const temp = transformExpression(node, node, transformer, hasContinuation, continuationArgName);
727                innerCbBody = innerCbBody.concat(temp);
728                if (innerCbBody.length > 0) {
729                    return;
730                }
731            }
732            else if (!isFunctionLike(node)) {
733                forEachChild(node, visit);
734            }
735        });
736        return innerCbBody;
737    }
738
739    function getArgBindingName(funcNode: Expression, transformer: Transformer): SynthBindingName | undefined {
740        const types: Type[] = [];
741        let name: SynthBindingName | undefined;
742
743        if (isFunctionLikeDeclaration(funcNode)) {
744            if (funcNode.parameters.length > 0) {
745                const param = funcNode.parameters[0].name;
746                name = getMappedBindingNameOrDefault(param);
747            }
748        }
749        else if (isIdentifier(funcNode)) {
750            name = getMapEntryOrDefault(funcNode);
751        }
752        else if (isPropertyAccessExpression(funcNode) && isIdentifier(funcNode.name)) {
753            name = getMapEntryOrDefault(funcNode.name);
754        }
755
756        // return undefined argName when arg is null or undefined
757        // eslint-disable-next-line local/no-in-operator
758        if (!name || "identifier" in name && name.identifier.text === "undefined") {
759            return undefined;
760        }
761
762        return name;
763
764        function getMappedBindingNameOrDefault(bindingName: BindingName): SynthBindingName {
765            if (isIdentifier(bindingName)) return getMapEntryOrDefault(bindingName);
766            const elements = flatMap(bindingName.elements, element => {
767                if (isOmittedExpression(element)) return [];
768                return [getMappedBindingNameOrDefault(element.name)];
769            });
770
771            return createSynthBindingPattern(bindingName, elements);
772        }
773
774        function getMapEntryOrDefault(identifier: Identifier): SynthIdentifier {
775            const originalNode = getOriginalNode(identifier);
776            const symbol = getSymbol(originalNode);
777
778            if (!symbol) {
779                return createSynthIdentifier(identifier, types);
780            }
781
782            const mapEntry = transformer.synthNamesMap.get(getSymbolId(symbol).toString());
783            return mapEntry || createSynthIdentifier(identifier, types);
784        }
785
786        function getSymbol(node: Node): Symbol | undefined {
787            return node.symbol ? node.symbol : transformer.checker.getSymbolAtLocation(node);
788        }
789
790        function getOriginalNode(node: Node): Node {
791            return node.original ? node.original : node;
792        }
793    }
794
795    function isEmptyBindingName(bindingName: SynthBindingName | undefined): boolean {
796        if (!bindingName) {
797            return true;
798        }
799        if (isSynthIdentifier(bindingName)) {
800            return !bindingName.identifier.text;
801        }
802        return every(bindingName.elements, isEmptyBindingName);
803    }
804
805    function createSynthIdentifier(identifier: Identifier, types: Type[] = []): SynthIdentifier {
806        return { kind: SynthBindingNameKind.Identifier, identifier, types, hasBeenDeclared: false, hasBeenReferenced: false };
807    }
808
809    function createSynthBindingPattern(bindingPattern: BindingPattern, elements: readonly SynthBindingName[] = emptyArray, types: Type[] = []): SynthBindingPattern {
810        return { kind: SynthBindingNameKind.BindingPattern, bindingPattern, elements, types };
811    }
812
813    function referenceSynthIdentifier(synthId: SynthIdentifier) {
814        synthId.hasBeenReferenced = true;
815        return synthId.identifier;
816    }
817
818    function declareSynthBindingName(synthName: SynthBindingName) {
819        return isSynthIdentifier(synthName) ? declareSynthIdentifier(synthName) : declareSynthBindingPattern(synthName);
820    }
821
822    function declareSynthBindingPattern(synthPattern: SynthBindingPattern) {
823        for (const element of synthPattern.elements) {
824            declareSynthBindingName(element);
825        }
826        return synthPattern.bindingPattern;
827    }
828
829    function declareSynthIdentifier(synthId: SynthIdentifier) {
830        synthId.hasBeenDeclared = true;
831        return synthId.identifier;
832    }
833
834    function isSynthIdentifier(bindingName: SynthBindingName): bindingName is SynthIdentifier {
835        return bindingName.kind === SynthBindingNameKind.Identifier;
836    }
837
838    function isSynthBindingPattern(bindingName: SynthBindingName): bindingName is SynthBindingPattern {
839        return bindingName.kind === SynthBindingNameKind.BindingPattern;
840    }
841
842    function shouldReturn(expression: Expression, transformer: Transformer): boolean {
843        return !!expression.original && transformer.setOfExpressionsToReturn.has(getNodeId(expression.original));
844    }
845}
846