1/* 2 * Copyright (c) 2022-2025 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15import * as arkts from '@koalaui/libarkts'; 16 17import { factory } from './memo-factory'; 18import { AbstractVisitor, VisitorOptions } from '../common/abstract-visitor'; 19import { 20 buildReturnTypeInfo, 21 castParameters, 22 findReturnTypeFromTypeAnnotation, 23 isMemoETSParameterExpression, 24 isMemoParametersDeclaration, 25 isUnmemoizedInFunction, 26 isVoidType, 27 MemoInfo, 28 ParamInfo, 29 PositionalIdTracker, 30 ReturnTypeInfo, 31 RuntimeNames, 32} from './utils'; 33import { ReturnTransformer } from './return-transformer'; 34 35export interface ParameterTransformerOptions extends VisitorOptions { 36 positionalIdTracker: PositionalIdTracker; 37} 38 39interface RewriteMemoInfo extends MemoInfo { 40 rewritePeer: number; 41} 42 43export class ParameterTransformer extends AbstractVisitor { 44 private rewriteIdentifiers?: Map<number, () => arkts.MemberExpression | arkts.Identifier>; 45 private rewriteCalls?: Map<number, (passArgs: arkts.Expression[]) => arkts.CallExpression>; 46 private rewriteMemoInfos?: Map<number, RewriteMemoInfo>; 47 private rewriteThis?: boolean; 48 private skipNode?: arkts.VariableDeclaration; 49 private visited: Set<number>; 50 51 private positionalIdTracker: PositionalIdTracker; 52 53 constructor(options: ParameterTransformerOptions) { 54 super(options); 55 this.positionalIdTracker = options.positionalIdTracker; 56 this.visited = new Set(); 57 } 58 59 reset(): void { 60 super.reset(); 61 this.rewriteIdentifiers = undefined; 62 this.rewriteCalls = undefined; 63 this.rewriteMemoInfos = undefined; 64 this.skipNode = undefined; 65 this.visited.clear(); 66 } 67 68 withThis(flag: boolean): ParameterTransformer { 69 this.rewriteThis = flag; 70 return this; 71 } 72 73 withParameters(parameters: ParamInfo[]): ParameterTransformer { 74 this.rewriteCalls = new Map( 75 parameters 76 .filter( 77 (it) => 78 it.param.type && (arkts.isETSFunctionType(it.param.type) || arkts.isETSUnionType(it.param.type)) 79 ) 80 .map((it) => { 81 return [ 82 it.param.identifier.name.startsWith(RuntimeNames.GENSYM) 83 ? it.ident.originalPeer 84 : it.param.originalPeer, 85 (passArgs: arkts.Expression[]): arkts.CallExpression => { 86 return factory.createMemoParameterAccessCall(it.ident.name, passArgs); 87 }, 88 ]; 89 }) 90 ); 91 this.rewriteIdentifiers = new Map( 92 parameters.map((it) => { 93 return [ 94 it.param.identifier.name.startsWith(RuntimeNames.GENSYM) 95 ? it.ident.originalPeer 96 : it.param.originalPeer, 97 (): arkts.MemberExpression => { 98 return factory.createMemoParameterAccess(it.ident.name); 99 }, 100 ]; 101 }) 102 ); 103 this.rewriteMemoInfos = new Map( 104 parameters.map((it) => { 105 return [ 106 it.param.identifier.name.startsWith(RuntimeNames.GENSYM) 107 ? it.ident.originalPeer 108 : it.param.originalPeer, 109 { 110 name: it.param.identifier.name, 111 rewritePeer: it.param.identifier.originalPeer, 112 isMemo: isMemoETSParameterExpression(it.param), 113 }, 114 ]; 115 }) 116 ); 117 return this; 118 } 119 120 skip(memoParametersDeclaration?: arkts.VariableDeclaration): ParameterTransformer { 121 this.skipNode = memoParametersDeclaration; 122 return this; 123 } 124 125 track(node: arkts.AstNode | undefined): void { 126 if (!!node?.peer) { 127 this.visited.add(node.peer); 128 } 129 } 130 131 isTracked(node: arkts.AstNode | undefined): boolean { 132 return !!node?.peer && this.visited.has(node.peer); 133 } 134 135 private updateArrowFunctionFromVariableDeclareInit( 136 initializer: arkts.ArrowFunctionExpression, 137 returnType: arkts.TypeNode | undefined 138 ): arkts.ArrowFunctionExpression { 139 const scriptFunction = initializer.scriptFunction; 140 if (!scriptFunction.body || !arkts.isBlockStatement(scriptFunction.body)) { 141 return initializer; 142 } 143 if (isUnmemoizedInFunction(scriptFunction.params)) { 144 return initializer; 145 } 146 const returnTypeInfo: ReturnTypeInfo = buildReturnTypeInfo( 147 returnType ?? scriptFunction.returnTypeAnnotation, 148 true 149 ); 150 const [body, parameterIdentifiers, memoParametersDeclaration, syntheticReturnStatement] = 151 factory.updateFunctionBody( 152 scriptFunction.body, 153 castParameters(scriptFunction.params), 154 returnTypeInfo, 155 this.positionalIdTracker.id() 156 ); 157 const paramaterTransformer = new ParameterTransformer({ 158 positionalIdTracker: this.positionalIdTracker, 159 }); 160 const returnTransformer = new ReturnTransformer(); 161 const afterParameterTransformer = paramaterTransformer 162 .withParameters(parameterIdentifiers) 163 .skip(memoParametersDeclaration) 164 .visitor(body); 165 const afterReturnTransformer = returnTransformer 166 .skip(syntheticReturnStatement) 167 .registerReturnTypeInfo(returnTypeInfo) 168 .visitor(afterParameterTransformer); 169 const updateScriptFunction = factory.updateScriptFunctionWithMemoParameters( 170 scriptFunction, 171 afterReturnTransformer, 172 returnTypeInfo.node 173 ); 174 paramaterTransformer.reset(); 175 returnTransformer.reset(); 176 this.track(updateScriptFunction.body); 177 return arkts.factory.updateArrowFunction(initializer, updateScriptFunction); 178 } 179 180 private updateVariableDeclareInit<T extends arkts.AstNode>( 181 initializer: T | undefined, 182 returnType: arkts.TypeNode | undefined 183 ): T | undefined { 184 if (!initializer) { 185 return undefined; 186 } 187 if (arkts.isConditionalExpression(initializer)) { 188 return arkts.factory.updateConditionalExpression( 189 initializer, 190 initializer.test, 191 this.updateVariableDeclareInit(initializer.consequent, returnType), 192 this.updateVariableDeclareInit(initializer.alternate, returnType) 193 ) as unknown as T; 194 } 195 if (arkts.isTSAsExpression(initializer)) { 196 return arkts.factory.updateTSAsExpression( 197 initializer, 198 this.updateVariableDeclareInit(initializer.expr, returnType), 199 factory.updateMemoTypeAnnotation(initializer.typeAnnotation), 200 initializer.isConst 201 ) as unknown as T; 202 } 203 if (arkts.isArrowFunctionExpression(initializer)) { 204 return this.updateArrowFunctionFromVariableDeclareInit(initializer, returnType) as unknown as T; 205 } 206 return initializer; 207 } 208 209 private updateParamReDeclare(node: arkts.VariableDeclarator, memoInfo: RewriteMemoInfo): arkts.VariableDeclarator { 210 const shouldUpdate: boolean = node.name.name !== memoInfo.name && memoInfo.isMemo; 211 if (!shouldUpdate) { 212 return node; 213 } 214 const decl = arkts.getPeerDecl(memoInfo.rewritePeer); 215 if (!decl || !arkts.isEtsParameterExpression(decl)) { 216 return node; 217 } 218 219 let typeAnnotation: arkts.TypeNode | undefined; 220 if ( 221 !!node.name.typeAnnotation && 222 !(typeAnnotation = factory.updateMemoTypeAnnotation(node.name.typeAnnotation)) 223 ) { 224 console.error(`ETSFunctionType or ETSUnionType expected for @memo-variable-type ${node.name.name}`); 225 throw 'Invalid @memo usage'; 226 } 227 228 const returnType = findReturnTypeFromTypeAnnotation(decl.type); 229 return arkts.factory.updateVariableDeclarator( 230 node, 231 node.flag, 232 arkts.factory.updateIdentifier(node.name, node.name.name, typeAnnotation), 233 this.updateVariableDeclareInit(node.initializer, returnType) 234 ); 235 } 236 237 private updateVariableReDeclarationFromParam(node: arkts.VariableDeclaration): arkts.VariableDeclaration { 238 const that = this; 239 return arkts.factory.updateVariableDeclaration( 240 node, 241 node.modifiers, 242 node.declarationKind, 243 node.declarators.map((declarator) => { 244 if (that.rewriteMemoInfos?.has(declarator.name.originalPeer)) { 245 const memoInfo = that.rewriteMemoInfos.get(declarator.name.originalPeer)!; 246 return that.updateParamReDeclare(declarator, memoInfo); 247 } 248 if (!!declarator.initializer && arkts.isIdentifier(declarator.initializer)) { 249 const decl = arkts.getPeerDecl(declarator.initializer.originalPeer); 250 if (decl && that.rewriteIdentifiers?.has(decl.peer)) { 251 return arkts.factory.updateVariableDeclarator( 252 declarator, 253 declarator.flag, 254 declarator.name, 255 that.rewriteIdentifiers.get(decl.peer)!() 256 ); 257 } 258 } 259 return declarator; 260 }) 261 ); 262 } 263 264 private updateCallReDeclare( 265 node: arkts.CallExpression, 266 oriName: arkts.Identifier, 267 memoInfo: RewriteMemoInfo 268 ): arkts.CallExpression { 269 const shouldUpdate: boolean = oriName.name !== memoInfo.name && memoInfo.isMemo; 270 if (!shouldUpdate) { 271 return node; 272 } 273 return factory.insertHiddenArgumentsToCall(node, this.positionalIdTracker.id(oriName.name)); 274 } 275 276 visitor(beforeChildren: arkts.AstNode): arkts.AstNode { 277 // TODO: temporary checking skip nodes by comparison with expected skip nodes 278 // Should be fixed when update procedure implemented properly 279 if (/* beforeChildren === this.skipNode */ isMemoParametersDeclaration(beforeChildren)) { 280 return beforeChildren; 281 } 282 if (arkts.isVariableDeclaration(beforeChildren)) { 283 return this.updateVariableReDeclarationFromParam(beforeChildren); 284 } 285 if (arkts.isCallExpression(beforeChildren) && arkts.isIdentifier(beforeChildren.expression)) { 286 const decl = arkts.getPeerDecl(beforeChildren.expression.originalPeer); 287 if (decl && this.rewriteCalls?.has(decl.peer)) { 288 const updateCall = this.rewriteCalls.get(decl.peer)!( 289 beforeChildren.arguments.map((it) => this.visitor(it) as arkts.Expression) 290 ); 291 if (this.rewriteMemoInfos?.has(decl.peer)) { 292 const memoInfo = this.rewriteMemoInfos.get(decl.peer)!; 293 return this.updateCallReDeclare(updateCall, beforeChildren.expression, memoInfo); 294 } 295 return updateCall; 296 } 297 } 298 const node = this.visitEachChild(beforeChildren); 299 if (arkts.isIdentifier(node)) { 300 const decl = arkts.getPeerDecl(node.originalPeer); 301 if (decl && this.rewriteIdentifiers?.has(decl.peer)) { 302 return this.rewriteIdentifiers.get(decl.peer)!(); 303 } 304 } 305 if (arkts.isThisExpression(node) && this.rewriteThis) { 306 return factory.createMemoParameterAccess(RuntimeNames.THIS); 307 } 308 return node; 309 } 310} 311