• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* @internal */
2namespace ts.refactor.convertToOptionalChainExpression {
3    const refactorName = "Convert to optional chain expression";
4    const convertToOptionalChainExpressionMessage = getLocaleSpecificMessage(Diagnostics.Convert_to_optional_chain_expression);
5
6    const toOptionalChainAction = {
7        name: refactorName,
8        description: convertToOptionalChainExpressionMessage,
9        kind: "refactor.rewrite.expression.optionalChain",
10    };
11    registerRefactor(refactorName, {
12        kinds: [toOptionalChainAction.kind],
13        getAvailableActions,
14        getEditsForAction
15    });
16
17    function getAvailableActions(context: RefactorContext): readonly ApplicableRefactorInfo[] {
18        const info = getInfo(context, context.triggerReason === "invoked");
19        if (!info) return emptyArray;
20
21        if (!isRefactorErrorInfo(info)) {
22            return [{
23                name: refactorName,
24                description: convertToOptionalChainExpressionMessage,
25                actions: [toOptionalChainAction],
26            }];
27        }
28
29        if (context.preferences.provideRefactorNotApplicableReason) {
30            return [{
31                name: refactorName,
32                description: convertToOptionalChainExpressionMessage,
33                actions: [{ ...toOptionalChainAction, notApplicableReason: info.error }],
34            }];
35        }
36        return emptyArray;
37    }
38
39    function getEditsForAction(context: RefactorContext, actionName: string): RefactorEditInfo | undefined {
40        const info = getInfo(context);
41        Debug.assert(info && !isRefactorErrorInfo(info), "Expected applicable refactor info");
42        const edits = textChanges.ChangeTracker.with(context, t =>
43            doChange(context.file, context.program.getTypeChecker(), t, info, actionName)
44        );
45        return { edits, renameFilename: undefined, renameLocation: undefined };
46    }
47
48    type Occurrence = PropertyAccessExpression | ElementAccessExpression | Identifier;
49
50    interface OptionalChainInfo {
51        finalExpression: PropertyAccessExpression | ElementAccessExpression | CallExpression,
52        occurrences: Occurrence[],
53        expression: ValidExpression,
54    };
55
56    type ValidExpressionOrStatement = ValidExpression | ValidStatement;
57
58    /**
59     * Types for which a "Convert to optional chain refactor" are offered.
60     */
61    type ValidExpression = BinaryExpression | ConditionalExpression;
62
63    /**
64     * Types of statements which are likely to include a valid expression for extraction.
65     */
66    type ValidStatement = ExpressionStatement | ReturnStatement | VariableStatement;
67
68    function isValidExpression(node: Node): node is ValidExpression {
69        return isBinaryExpression(node) || isConditionalExpression(node);
70    }
71
72    function isValidStatement(node: Node): node is ValidStatement {
73        return isExpressionStatement(node) || isReturnStatement(node) || isVariableStatement(node);
74    }
75
76    function isValidExpressionOrStatement(node: Node): node is ValidExpressionOrStatement {
77        return isValidExpression(node) || isValidStatement(node);
78    }
79
80    function getInfo(context: RefactorContext, considerEmptySpans = true): OptionalChainInfo | RefactorErrorInfo | undefined {
81        const { file, program } = context;
82        const span = getRefactorContextSpan(context);
83
84        const forEmptySpan = span.length === 0;
85        if (forEmptySpan && !considerEmptySpans) return undefined;
86
87        // selecting fo[|o && foo.ba|]r should be valid, so adjust span to fit start and end tokens
88        const startToken = getTokenAtPosition(file, span.start);
89        const endToken = findTokenOnLeftOfPosition(file, span.start + span.length);
90        const adjustedSpan = createTextSpanFromBounds(startToken.pos, endToken && endToken.end >= startToken.pos ? endToken.getEnd() : startToken.getEnd());
91
92        const parent = forEmptySpan ? getValidParentNodeOfEmptySpan(startToken) : getValidParentNodeContainingSpan(startToken, adjustedSpan);
93        const expression = parent && isValidExpressionOrStatement(parent) ? getExpression(parent) : undefined;
94        if (!expression) return { error: getLocaleSpecificMessage(Diagnostics.Could_not_find_convertible_access_expression) };
95
96        const checker = program.getTypeChecker();
97        return isConditionalExpression(expression) ? getConditionalInfo(expression, checker) : getBinaryInfo(expression);
98    }
99
100    function getConditionalInfo(expression: ConditionalExpression, checker: TypeChecker): OptionalChainInfo | RefactorErrorInfo | undefined {
101        const condition = expression.condition;
102        const finalExpression = getFinalExpressionInChain(expression.whenTrue);
103
104        if (!finalExpression || checker.isNullableType(checker.getTypeAtLocation(finalExpression))) {
105            return { error: getLocaleSpecificMessage(Diagnostics.Could_not_find_convertible_access_expression) };
106        }
107
108        if ((isPropertyAccessExpression(condition) || isIdentifier(condition))
109            && getMatchingStart(condition, finalExpression.expression)) {
110            return { finalExpression, occurrences: [condition], expression };
111        }
112        else if (isBinaryExpression(condition)) {
113            const occurrences = getOccurrencesInExpression(finalExpression.expression, condition);
114            return occurrences ? { finalExpression, occurrences, expression } :
115                { error: getLocaleSpecificMessage(Diagnostics.Could_not_find_matching_access_expressions) };
116        }
117    }
118
119    function getBinaryInfo(expression: BinaryExpression): OptionalChainInfo | RefactorErrorInfo | undefined {
120        if (expression.operatorToken.kind !== SyntaxKind.AmpersandAmpersandToken) {
121            return { error: getLocaleSpecificMessage(Diagnostics.Can_only_convert_logical_AND_access_chains) };
122        };
123        const finalExpression = getFinalExpressionInChain(expression.right);
124
125        if (!finalExpression) return { error: getLocaleSpecificMessage(Diagnostics.Could_not_find_convertible_access_expression) };
126
127        const occurrences = getOccurrencesInExpression(finalExpression.expression, expression.left);
128        return occurrences ? { finalExpression, occurrences, expression } :
129            { error: getLocaleSpecificMessage(Diagnostics.Could_not_find_matching_access_expressions) };
130    }
131
132    /**
133     * Gets a list of property accesses that appear in matchTo and occur in sequence in expression.
134     */
135    function getOccurrencesInExpression(matchTo: Expression, expression: Expression): Occurrence[] | undefined {
136        const occurrences: Occurrence[] = [];
137        while (isBinaryExpression(expression) && expression.operatorToken.kind === SyntaxKind.AmpersandAmpersandToken) {
138            const match = getMatchingStart(skipParentheses(matchTo), skipParentheses(expression.right));
139            if (!match) {
140                break;
141            }
142            occurrences.push(match);
143            matchTo = match;
144            expression = expression.left;
145        }
146        const finalMatch = getMatchingStart(matchTo, expression);
147        if (finalMatch) {
148            occurrences.push(finalMatch);
149        }
150        return occurrences.length > 0 ? occurrences: undefined;
151    }
152
153    /**
154     * Returns subchain if chain begins with subchain syntactically.
155     */
156    function getMatchingStart(chain: Expression, subchain: Expression): PropertyAccessExpression | ElementAccessExpression | Identifier | undefined {
157        if (!isIdentifier(subchain) && !isPropertyAccessExpression(subchain) && !isElementAccessExpression(subchain)) {
158            return undefined;
159        }
160        return chainStartsWith(chain, subchain) ? subchain : undefined;
161    }
162
163    /**
164     * Returns true if chain begins with subchain syntactically.
165     */
166    function chainStartsWith(chain: Node, subchain: Node): boolean {
167        // skip until we find a matching identifier.
168        while (isCallExpression(chain) || isPropertyAccessExpression(chain) || isElementAccessExpression(chain)) {
169            if (getTextOfChainNode(chain) === getTextOfChainNode(subchain)) break;
170            chain = chain.expression;
171        }
172        // check that the chains match at each access. Call chains in subchain are not valid.
173        while ((isPropertyAccessExpression(chain) && isPropertyAccessExpression(subchain)) ||
174               (isElementAccessExpression(chain) && isElementAccessExpression(subchain))) {
175            if (getTextOfChainNode(chain) !== getTextOfChainNode(subchain)) return false;
176            chain = chain.expression;
177            subchain = subchain.expression;
178        }
179        // check if we have reached a final identifier.
180        return isIdentifier(chain) && isIdentifier(subchain) && chain.getText() === subchain.getText();
181    }
182
183    function getTextOfChainNode(node: Node): string | undefined {
184        if (isIdentifier(node) || isStringOrNumericLiteralLike(node)) {
185            return node.getText();
186        }
187        if (isPropertyAccessExpression(node)) {
188            return getTextOfChainNode(node.name);
189        }
190        if (isElementAccessExpression(node)) {
191            return getTextOfChainNode(node.argumentExpression);
192        }
193        return undefined;
194    }
195
196    /**
197     * Find the least ancestor of the input node that is a valid type for extraction and contains the input span.
198     */
199    function getValidParentNodeContainingSpan(node: Node, span: TextSpan): ValidExpressionOrStatement | undefined {
200        while (node.parent) {
201            if (isValidExpressionOrStatement(node) && span.length !== 0 && node.end >= span.start + span.length) {
202                return node;
203            }
204            node = node.parent;
205        }
206        return undefined;
207    }
208
209    /**
210     * Finds an ancestor of the input node that is a valid type for extraction, skipping subexpressions.
211     */
212    function getValidParentNodeOfEmptySpan(node: Node): ValidExpressionOrStatement | undefined {
213        while (node.parent) {
214            if (isValidExpressionOrStatement(node) && !isValidExpressionOrStatement(node.parent)) {
215                return node;
216            }
217            node = node.parent;
218        }
219        return undefined;
220    }
221
222    /**
223     * Gets an expression of valid extraction type from a valid statement or expression.
224     */
225    function getExpression(node: ValidExpressionOrStatement): ValidExpression | undefined {
226        if (isValidExpression(node)) {
227            return node;
228        }
229        if (isVariableStatement(node)) {
230            const variable = getSingleVariableOfVariableStatement(node);
231            const initializer = variable?.initializer;
232            return initializer && isValidExpression(initializer) ? initializer : undefined;
233        }
234        return node.expression && isValidExpression(node.expression) ? node.expression : undefined;
235    }
236
237    /**
238     * Gets a property access expression which may be nested inside of a binary expression. The final
239     * expression in an && chain will occur as the right child of the parent binary expression, unless
240     * it is followed by a different binary operator.
241     * @param node the right child of a binary expression or a call expression.
242     */
243    function getFinalExpressionInChain(node: Expression): CallExpression | PropertyAccessExpression | ElementAccessExpression | undefined {
244        // foo && |foo.bar === 1|; - here the right child of the && binary expression is another binary expression.
245        // the rightmost member of the && chain should be the leftmost child of that expression.
246        node = skipParentheses(node);
247        if (isBinaryExpression(node)) {
248            return getFinalExpressionInChain(node.left);
249        }
250        // foo && |foo.bar()()| - nested calls are treated like further accesses.
251        else if ((isPropertyAccessExpression(node) || isElementAccessExpression(node) || isCallExpression(node)) && !isOptionalChain(node)) {
252            return node;
253        }
254        return undefined;
255    }
256
257    /**
258     * Creates an access chain from toConvert with '?.' accesses at expressions appearing in occurrences.
259     */
260    function convertOccurrences(checker: TypeChecker, toConvert: Expression, occurrences: Occurrence[]): Expression {
261        if (isPropertyAccessExpression(toConvert) || isElementAccessExpression(toConvert) || isCallExpression(toConvert)) {
262            const chain = convertOccurrences(checker, toConvert.expression, occurrences);
263            const lastOccurrence = occurrences.length > 0 ? occurrences[occurrences.length - 1] : undefined;
264            const isOccurrence = lastOccurrence?.getText() === toConvert.expression.getText();
265            if (isOccurrence) occurrences.pop();
266            if (isCallExpression(toConvert)) {
267                return isOccurrence ?
268                    factory.createCallChain(chain, factory.createToken(SyntaxKind.QuestionDotToken), toConvert.typeArguments, toConvert.arguments) :
269                    factory.createCallChain(chain, toConvert.questionDotToken, toConvert.typeArguments, toConvert.arguments);
270            }
271            else if (isPropertyAccessExpression(toConvert)) {
272                return isOccurrence ?
273                    factory.createPropertyAccessChain(chain, factory.createToken(SyntaxKind.QuestionDotToken), toConvert.name) :
274                    factory.createPropertyAccessChain(chain, toConvert.questionDotToken, toConvert.name);
275            }
276            else if (isElementAccessExpression(toConvert)) {
277                return isOccurrence ?
278                    factory.createElementAccessChain(chain, factory.createToken(SyntaxKind.QuestionDotToken), toConvert.argumentExpression) :
279                    factory.createElementAccessChain(chain, toConvert.questionDotToken, toConvert.argumentExpression);
280            }
281        }
282        return toConvert;
283    }
284
285    function doChange(sourceFile: SourceFile, checker: TypeChecker, changes: textChanges.ChangeTracker, info: OptionalChainInfo, _actionName: string): void {
286        const { finalExpression, occurrences, expression } = info;
287        const firstOccurrence = occurrences[occurrences.length - 1];
288        const convertedChain = convertOccurrences(checker, finalExpression, occurrences);
289        if (convertedChain && (isPropertyAccessExpression(convertedChain) || isElementAccessExpression(convertedChain) || isCallExpression(convertedChain))) {
290            if (isBinaryExpression(expression)) {
291                changes.replaceNodeRange(sourceFile, firstOccurrence, finalExpression, convertedChain);
292            }
293            else if (isConditionalExpression(expression)) {
294                changes.replaceNode(sourceFile, expression,
295                    factory.createBinaryExpression(convertedChain, factory.createToken(SyntaxKind.QuestionQuestionToken), expression.whenFalse)
296                );
297            }
298        }
299    }
300}
301