• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) 2022-2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15import * as arkts from '@koalaui/libarkts';
16
17import { factory } from './memo-factory';
18import { AbstractVisitor, VisitorOptions } from '../common/abstract-visitor';
19import {
20    buildReturnTypeInfo,
21    castParameters,
22    findReturnTypeFromTypeAnnotation,
23    isMemoETSParameterExpression,
24    isMemoParametersDeclaration,
25    isUnmemoizedInFunction,
26    isVoidType,
27    MemoInfo,
28    ParamInfo,
29    PositionalIdTracker,
30    ReturnTypeInfo,
31    RuntimeNames,
32} from './utils';
33import { ReturnTransformer } from './return-transformer';
34
35export interface ParameterTransformerOptions extends VisitorOptions {
36    positionalIdTracker: PositionalIdTracker;
37}
38
39interface RewriteMemoInfo extends MemoInfo {
40    rewritePeer: number;
41}
42
43export class ParameterTransformer extends AbstractVisitor {
44    private rewriteIdentifiers?: Map<number, () => arkts.MemberExpression | arkts.Identifier>;
45    private rewriteCalls?: Map<number, (passArgs: arkts.Expression[]) => arkts.CallExpression>;
46    private rewriteMemoInfos?: Map<number, RewriteMemoInfo>;
47    private rewriteThis?: boolean;
48    private skipNode?: arkts.VariableDeclaration;
49    private visited: Set<number>;
50
51    private positionalIdTracker: PositionalIdTracker;
52
53    constructor(options: ParameterTransformerOptions) {
54        super(options);
55        this.positionalIdTracker = options.positionalIdTracker;
56        this.visited = new Set();
57    }
58
59    reset(): void {
60        super.reset();
61        this.rewriteIdentifiers = undefined;
62        this.rewriteCalls = undefined;
63        this.rewriteMemoInfos = undefined;
64        this.skipNode = undefined;
65        this.visited.clear();
66    }
67
68    withThis(flag: boolean): ParameterTransformer {
69        this.rewriteThis = flag;
70        return this;
71    }
72
73    withParameters(parameters: ParamInfo[]): ParameterTransformer {
74        this.rewriteCalls = new Map(
75            parameters
76                .filter(
77                    (it) =>
78                        it.param.type && (arkts.isETSFunctionType(it.param.type) || arkts.isETSUnionType(it.param.type))
79                )
80                .map((it) => {
81                    return [
82                        it.param.identifier.name.startsWith(RuntimeNames.GENSYM)
83                            ? it.ident.originalPeer
84                            : it.param.originalPeer,
85                        (passArgs: arkts.Expression[]): arkts.CallExpression => {
86                            return factory.createMemoParameterAccessCall(it.ident.name, passArgs);
87                        },
88                    ];
89                })
90        );
91        this.rewriteIdentifiers = new Map(
92            parameters.map((it) => {
93                return [
94                    it.param.identifier.name.startsWith(RuntimeNames.GENSYM)
95                        ? it.ident.originalPeer
96                        : it.param.originalPeer,
97                    (): arkts.MemberExpression => {
98                        return factory.createMemoParameterAccess(it.ident.name);
99                    },
100                ];
101            })
102        );
103        this.rewriteMemoInfos = new Map(
104            parameters.map((it) => {
105                return [
106                    it.param.identifier.name.startsWith(RuntimeNames.GENSYM)
107                        ? it.ident.originalPeer
108                        : it.param.originalPeer,
109                    {
110                        name: it.param.identifier.name,
111                        rewritePeer: it.param.identifier.originalPeer,
112                        isMemo: isMemoETSParameterExpression(it.param),
113                    },
114                ];
115            })
116        );
117        return this;
118    }
119
120    skip(memoParametersDeclaration?: arkts.VariableDeclaration): ParameterTransformer {
121        this.skipNode = memoParametersDeclaration;
122        return this;
123    }
124
125    track(node: arkts.AstNode | undefined): void {
126        if (!!node?.peer) {
127            this.visited.add(node.peer);
128        }
129    }
130
131    isTracked(node: arkts.AstNode | undefined): boolean {
132        return !!node?.peer && this.visited.has(node.peer);
133    }
134
135    private updateArrowFunctionFromVariableDeclareInit(
136        initializer: arkts.ArrowFunctionExpression,
137        returnType: arkts.TypeNode | undefined
138    ): arkts.ArrowFunctionExpression {
139        const scriptFunction = initializer.scriptFunction;
140        if (!scriptFunction.body || !arkts.isBlockStatement(scriptFunction.body)) {
141            return initializer;
142        }
143        if (isUnmemoizedInFunction(scriptFunction.params)) {
144            return initializer;
145        }
146        const returnTypeInfo: ReturnTypeInfo = buildReturnTypeInfo(
147            returnType ?? scriptFunction.returnTypeAnnotation,
148            true
149        );
150        const [body, parameterIdentifiers, memoParametersDeclaration, syntheticReturnStatement] =
151            factory.updateFunctionBody(
152                scriptFunction.body,
153                castParameters(scriptFunction.params),
154                returnTypeInfo,
155                this.positionalIdTracker.id()
156            );
157        const paramaterTransformer = new ParameterTransformer({
158            positionalIdTracker: this.positionalIdTracker,
159        });
160        const returnTransformer = new ReturnTransformer();
161        const afterParameterTransformer = paramaterTransformer
162            .withParameters(parameterIdentifiers)
163            .skip(memoParametersDeclaration)
164            .visitor(body);
165        const afterReturnTransformer = returnTransformer
166            .skip(syntheticReturnStatement)
167            .registerReturnTypeInfo(returnTypeInfo)
168            .visitor(afterParameterTransformer);
169        const updateScriptFunction = factory.updateScriptFunctionWithMemoParameters(
170            scriptFunction,
171            afterReturnTransformer,
172            returnTypeInfo.node
173        );
174        paramaterTransformer.reset();
175        returnTransformer.reset();
176        this.track(updateScriptFunction.body);
177        return arkts.factory.updateArrowFunction(initializer, updateScriptFunction);
178    }
179
180    private updateVariableDeclareInit<T extends arkts.AstNode>(
181        initializer: T | undefined,
182        returnType: arkts.TypeNode | undefined
183    ): T | undefined {
184        if (!initializer) {
185            return undefined;
186        }
187        if (arkts.isConditionalExpression(initializer)) {
188            return arkts.factory.updateConditionalExpression(
189                initializer,
190                initializer.test,
191                this.updateVariableDeclareInit(initializer.consequent, returnType),
192                this.updateVariableDeclareInit(initializer.alternate, returnType)
193            ) as unknown as T;
194        }
195        if (arkts.isTSAsExpression(initializer)) {
196            return arkts.factory.updateTSAsExpression(
197                initializer,
198                this.updateVariableDeclareInit(initializer.expr, returnType),
199                factory.updateMemoTypeAnnotation(initializer.typeAnnotation),
200                initializer.isConst
201            ) as unknown as T;
202        }
203        if (arkts.isArrowFunctionExpression(initializer)) {
204            return this.updateArrowFunctionFromVariableDeclareInit(initializer, returnType) as unknown as T;
205        }
206        return initializer;
207    }
208
209    private updateParamReDeclare(node: arkts.VariableDeclarator, memoInfo: RewriteMemoInfo): arkts.VariableDeclarator {
210        const shouldUpdate: boolean = node.name.name !== memoInfo.name && memoInfo.isMemo;
211        if (!shouldUpdate) {
212            return node;
213        }
214        const decl = arkts.getPeerDecl(memoInfo.rewritePeer);
215        if (!decl || !arkts.isEtsParameterExpression(decl)) {
216            return node;
217        }
218
219        let typeAnnotation: arkts.TypeNode | undefined;
220        if (
221            !!node.name.typeAnnotation &&
222            !(typeAnnotation = factory.updateMemoTypeAnnotation(node.name.typeAnnotation))
223        ) {
224            console.error(`ETSFunctionType or ETSUnionType expected for @memo-variable-type ${node.name.name}`);
225            throw 'Invalid @memo usage';
226        }
227
228        const returnType = findReturnTypeFromTypeAnnotation(decl.type);
229        return arkts.factory.updateVariableDeclarator(
230            node,
231            node.flag,
232            arkts.factory.updateIdentifier(node.name, node.name.name, typeAnnotation),
233            this.updateVariableDeclareInit(node.initializer, returnType)
234        );
235    }
236
237    private updateVariableReDeclarationFromParam(node: arkts.VariableDeclaration): arkts.VariableDeclaration {
238        const that = this;
239        return arkts.factory.updateVariableDeclaration(
240            node,
241            node.modifiers,
242            node.declarationKind,
243            node.declarators.map((declarator) => {
244                if (that.rewriteMemoInfos?.has(declarator.name.originalPeer)) {
245                    const memoInfo = that.rewriteMemoInfos.get(declarator.name.originalPeer)!;
246                    return that.updateParamReDeclare(declarator, memoInfo);
247                }
248                if (!!declarator.initializer && arkts.isIdentifier(declarator.initializer)) {
249                    const decl = arkts.getPeerDecl(declarator.initializer.originalPeer);
250                    if (decl && that.rewriteIdentifiers?.has(decl.peer)) {
251                        return arkts.factory.updateVariableDeclarator(
252                            declarator,
253                            declarator.flag,
254                            declarator.name,
255                            that.rewriteIdentifiers.get(decl.peer)!()
256                        );
257                    }
258                }
259                return declarator;
260            })
261        );
262    }
263
264    private updateCallReDeclare(
265        node: arkts.CallExpression,
266        oriName: arkts.Identifier,
267        memoInfo: RewriteMemoInfo
268    ): arkts.CallExpression {
269        const shouldUpdate: boolean = oriName.name !== memoInfo.name && memoInfo.isMemo;
270        if (!shouldUpdate) {
271            return node;
272        }
273        return factory.insertHiddenArgumentsToCall(node, this.positionalIdTracker.id(oriName.name));
274    }
275
276    visitor(beforeChildren: arkts.AstNode): arkts.AstNode {
277        // TODO: temporary checking skip nodes by comparison with expected skip nodes
278        // Should be fixed when update procedure implemented properly
279        if (/* beforeChildren === this.skipNode */ isMemoParametersDeclaration(beforeChildren)) {
280            return beforeChildren;
281        }
282        if (arkts.isVariableDeclaration(beforeChildren)) {
283            return this.updateVariableReDeclarationFromParam(beforeChildren);
284        }
285        if (arkts.isCallExpression(beforeChildren) && arkts.isIdentifier(beforeChildren.expression)) {
286            const decl = arkts.getPeerDecl(beforeChildren.expression.originalPeer);
287            if (decl && this.rewriteCalls?.has(decl.peer)) {
288                const updateCall = this.rewriteCalls.get(decl.peer)!(
289                    beforeChildren.arguments.map((it) => this.visitor(it) as arkts.Expression)
290                );
291                if (this.rewriteMemoInfos?.has(decl.peer)) {
292                    const memoInfo = this.rewriteMemoInfos.get(decl.peer)!;
293                    return this.updateCallReDeclare(updateCall, beforeChildren.expression, memoInfo);
294                }
295                return updateCall;
296            }
297        }
298        const node = this.visitEachChild(beforeChildren);
299        if (arkts.isIdentifier(node)) {
300            const decl = arkts.getPeerDecl(node.originalPeer);
301            if (decl && this.rewriteIdentifiers?.has(decl.peer)) {
302                return this.rewriteIdentifiers.get(decl.peer)!();
303            }
304        }
305        if (arkts.isThisExpression(node) && this.rewriteThis) {
306            return factory.createMemoParameterAccess(RuntimeNames.THIS);
307        }
308        return node;
309    }
310}
311