• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) 2024-2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16import { DefUseChain } from '../base/DefUseChain';
17import { Local } from '../base/Local';
18import { Stmt } from '../base/Stmt';
19import { ArkError, ArkErrorCode } from '../common/ArkError';
20import { ArkMethod } from '../model/ArkMethod';
21import { BasicBlock } from './BasicBlock';
22import Logger, { LOG_MODULE_TYPE } from '../../utils/logger';
23import { ArkStaticInvokeExpr } from '../base/Expr';
24import { Value } from '../base/Value';
25const logger = Logger.getLogger(LOG_MODULE_TYPE.ARKANALYZER, 'BasicBlock');
26
27/**
28 * @category core/graph
29 */
30export class Cfg {
31    private blocks: Set<BasicBlock> = new Set();
32    private stmtToBlock: Map<Stmt, BasicBlock> = new Map();
33    private startingStmt!: Stmt;
34
35    private defUseChains: DefUseChain[] = [];
36    private declaringMethod: ArkMethod = new ArkMethod();
37
38    constructor() {}
39
40    public getStmts(): Stmt[] {
41        let stmts = new Array<Stmt>();
42        for (const block of this.blocks) {
43            block.getStmts().forEach(s => stmts.push(s));
44        }
45        return stmts;
46    }
47
48    /**
49     * Inserts toInsert in the basic block in CFG after point.
50     * @param toInsert
51     * @param point
52     * @returns The number of successfully inserted statements
53     */
54    public insertAfter(toInsert: Stmt | Stmt[], point: Stmt): number {
55        const block = this.stmtToBlock.get(point);
56        if (!block) {
57            return 0;
58        }
59
60        this.updateStmt2BlockMap(block, toInsert);
61        return block.insertAfter(toInsert, point);
62    }
63
64    /**
65     * Inserts toInsert in the basic block in CFG befor point.
66     * @param toInsert
67     * @param point
68     * @returns The number of successfully inserted statements
69     */
70    public insertBefore(toInsert: Stmt | Stmt[], point: Stmt): number {
71        const block = this.stmtToBlock.get(point);
72        if (!block) {
73            return 0;
74        }
75
76        this.updateStmt2BlockMap(block, toInsert);
77        return block.insertBefore(toInsert, point);
78    }
79
80    /**
81     * Removes the given stmt from the basic block in CFG.
82     * @param stmt
83     * @returns
84     */
85    public remove(stmt: Stmt): void {
86        const block = this.stmtToBlock.get(stmt);
87        if (!block) {
88            return;
89        }
90        this.stmtToBlock.delete(stmt);
91        block.remove(stmt);
92    }
93
94    /**
95     * Update stmtToBlock Map
96     * @param block
97     * @param changed
98     */
99    public updateStmt2BlockMap(block: BasicBlock, changed?: Stmt | Stmt[]): void {
100        if (!changed) {
101            for (const stmt of block.getStmts()) {
102                this.stmtToBlock.set(stmt, block);
103            }
104        } else if (changed instanceof Stmt) {
105            this.stmtToBlock.set(changed, block);
106        } else {
107            for (const insert of changed) {
108                this.stmtToBlock.set(insert, block);
109            }
110        }
111    }
112
113    // TODO: 添加block之间的边
114    public addBlock(block: BasicBlock): void {
115        this.blocks.add(block);
116
117        for (const stmt of block.getStmts()) {
118            this.stmtToBlock.set(stmt, block);
119        }
120    }
121
122    public getBlocks(): Set<BasicBlock> {
123        return this.blocks;
124    }
125
126    public getStartingBlock(): BasicBlock | undefined {
127        return this.stmtToBlock.get(this.startingStmt);
128    }
129
130    public getStartingStmt(): Stmt {
131        return this.startingStmt;
132    }
133
134    public setStartingStmt(newStartingStmt: Stmt): void {
135        this.startingStmt = newStartingStmt;
136    }
137
138    public getDeclaringMethod(): ArkMethod {
139        return this.declaringMethod;
140    }
141
142    public setDeclaringMethod(method: ArkMethod): void {
143        this.declaringMethod = method;
144    }
145
146    public getDefUseChains(): DefUseChain[] {
147        return this.defUseChains;
148    }
149
150    // TODO: 整理成类似jimple的输出
151    public toString(): string {
152        return 'cfg';
153    }
154
155    public buildDefUseStmt(locals: Set<Local>): void {
156        for (const block of this.blocks) {
157            for (const stmt of block.getStmts()) {
158                const defValue = stmt.getDef();
159                if (defValue && defValue instanceof Local && defValue.getDeclaringStmt() === null) {
160                    defValue.setDeclaringStmt(stmt);
161                }
162                for (const value of stmt.getUses()) {
163                    this.buildUseStmt(value, locals, stmt);
164                }
165            }
166        }
167    }
168
169    private buildUseStmt(value: Value, locals: Set<Local>, stmt: Stmt): void {
170        if (value instanceof Local) {
171            value.addUsedStmt(stmt);
172        } else if (value instanceof ArkStaticInvokeExpr) {
173            for (let local of locals) {
174                if (local.getName() === value.getMethodSignature().getMethodSubSignature().getMethodName()) {
175                    local.addUsedStmt(stmt);
176                    return;
177                }
178            }
179        }
180    }
181
182    private handleDefUseForValue(value: Value, block: BasicBlock, stmt: Stmt, stmtIndex: number): void {
183        const name = value.toString();
184        const defStmts: Stmt[] = [];
185        // 判断本block之前有无对应def
186        for (let i = stmtIndex - 1; i >= 0; i--) {
187            const beforeStmt = block.getStmts()[i];
188            if (beforeStmt.getDef() && beforeStmt.getDef()?.toString() === name) {
189                defStmts.push(beforeStmt);
190                break;
191            }
192        }
193        // 本block有对应def直接结束,否则找所有的前序block
194        if (defStmts.length !== 0) {
195            this.defUseChains.push(new DefUseChain(value, defStmts[0], stmt));
196            return;
197        }
198        const needWalkBlocks: BasicBlock[] = [...block.getPredecessors()];
199        const walkedBlocks = new Set();
200        while (needWalkBlocks.length > 0) {
201            const predecessor = needWalkBlocks.pop();
202            if (!predecessor) {
203                return;
204            }
205            const predecessorStmts = predecessor.getStmts();
206            let predecessorHasDef = false;
207            for (let i = predecessorStmts.length - 1; i >= 0; i--) {
208                const beforeStmt = predecessorStmts[i];
209                if (beforeStmt.getDef() && beforeStmt.getDef()?.toString() === name) {
210                    defStmts.push(beforeStmt);
211                    predecessorHasDef = true;
212                    break;
213                }
214            }
215            walkedBlocks.add(predecessor);
216            if (predecessorHasDef) {
217                continue;
218            }
219            for (const morePredecessor of predecessor.getPredecessors()) {
220                if (!walkedBlocks.has(morePredecessor) && !needWalkBlocks.includes(morePredecessor)) {
221                    needWalkBlocks.unshift(morePredecessor);
222                }
223            }
224        }
225        for (const def of defStmts) {
226            this.defUseChains.push(new DefUseChain(value, def, stmt));
227        }
228    }
229
230    public buildDefUseChain(): void {
231        for (const block of this.blocks) {
232            for (let stmtIndex = 0; stmtIndex < block.getStmts().length; stmtIndex++) {
233                const stmt = block.getStmts()[stmtIndex];
234                for (const value of stmt.getUses()) {
235                    this.handleDefUseForValue(value, block, stmt, stmtIndex);
236                }
237            }
238        }
239    }
240
241    public getUnreachableBlocks(): Set<BasicBlock> {
242        let unreachable = new Set<BasicBlock>();
243        let startBB = this.getStartingBlock();
244        if (!startBB) {
245            return unreachable;
246        }
247        let postOrder = this.dfsPostOrder(startBB);
248        for (const bb of this.blocks) {
249            if (!postOrder.has(bb)) {
250                unreachable.add(bb);
251            }
252        }
253        return unreachable;
254    }
255
256    public validate(): ArkError {
257        let startBB = this.getStartingBlock();
258        if (!startBB) {
259            let errMsg = `Not found starting block}`;
260            logger.error(errMsg);
261            return {
262                errCode: ArkErrorCode.CFG_NOT_FOUND_START_BLOCK,
263                errMsg: errMsg,
264            };
265        }
266
267        let unreachable = this.getUnreachableBlocks();
268        if (unreachable.size !== 0) {
269            let errMsg = `Unreachable blocks: ${Array.from(unreachable)
270                .map(value => value.toString())
271                .join('\n')}`;
272            logger.error(errMsg);
273            return {
274                errCode: ArkErrorCode.CFG_HAS_UNREACHABLE_BLOCK,
275                errMsg: errMsg,
276            };
277        }
278
279        return { errCode: ArkErrorCode.OK };
280    }
281
282    private dfsPostOrder(node: BasicBlock, visitor: Set<BasicBlock> = new Set(), postOrder: Set<BasicBlock> = new Set()): Set<BasicBlock> {
283        visitor.add(node);
284        for (const succ of node.getSuccessors()) {
285            if (visitor.has(succ)) {
286                continue;
287            }
288            this.dfsPostOrder(succ, visitor, postOrder);
289        }
290        postOrder.add(node);
291        return postOrder;
292    }
293}
294