• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "dangling_pointers_checker.h"
17 #include "compiler/optimizer/ir/basicblock.h"
18 #include "compiler/optimizer/ir/graph.h"
19 #include "runtime/interpreter/frame.h"
20 #include "runtime/include/managed_thread.h"
21 
22 namespace ark::compiler {
DanglingPointersChecker(Graph * graph)23 DanglingPointersChecker::DanglingPointersChecker(Graph *graph)
24     : Analysis {graph},
25       objectsUsers_ {graph->GetLocalAllocator()->Adapter()},
26       checkedBlocks_ {graph->GetLocalAllocator()->Adapter()},
27       phiInsts_ {graph->GetLocalAllocator()->Adapter()},
28       objectInsts_ {graph->GetLocalAllocator()->Adapter()}
29 {
30 }
31 
RunImpl()32 bool DanglingPointersChecker::RunImpl()
33 {
34     return CheckAccSyncCallRuntime();
35 }
36 
MoveToPrevInst(Inst * inst,BasicBlock * bb)37 static Inst *MoveToPrevInst(Inst *inst, BasicBlock *bb)
38 {
39     if (inst == bb->GetFirstInst()) {
40         return nullptr;
41     }
42     return inst->GetPrev();
43 }
44 
IsRefType(Inst * inst)45 static bool IsRefType(Inst *inst)
46 {
47     return inst->GetType() == DataType::REFERENCE;
48 }
49 
IsPointerType(Inst * inst)50 static bool IsPointerType(Inst *inst)
51 {
52     return inst->GetType() == DataType::POINTER;
53 }
54 
IsObjectDef(Inst * inst)55 static bool IsObjectDef(Inst *inst)
56 {
57     if (!IsRefType(inst) && !IsPointerType(inst)) {
58         return false;
59     }
60     if (inst->IsPhi()) {
61         return false;
62     }
63     if (inst->GetOpcode() != Opcode::AddI) {
64         return true;
65     }
66     auto imm = static_cast<BinaryImmOperation *>(inst)->GetImm();
67     return imm != static_cast<uint64_t>(cross_values::GetFrameAccOffset(inst->GetBasicBlock()->GetGraph()->GetArch()));
68 }
69 
InitLiveIns()70 void DanglingPointersChecker::InitLiveIns()
71 {
72     auto arch = GetGraph()->GetArch();
73     for (auto inst : GetGraph()->GetStartBlock()->Insts()) {
74         if (inst->GetOpcode() != Opcode::LiveIn) {
75             continue;
76         }
77         if (inst->GetDstReg() == regmap_[arch]["acc"]) {
78             accLivein_ = inst;
79         }
80         if (inst->GetDstReg() == regmap_[arch]["acc_tag"]) {
81             accTagLivein_ = inst;
82         }
83         if (inst->GetDstReg() == regmap_[arch]["frame"]) {
84             frameLivein_ = inst;
85         }
86         if (inst->GetDstReg() == regmap_[arch]["thread"]) {
87             threadLivein_ = inst;
88         }
89     }
90 }
91 
92 // Frame can be defined in three ways:
93 // 1. LiveIn(frame)
94 // 2. LoadI(LiveIn(thread)).Imm(ManagedThread::GetFrameOffset())
95 // 3. LoadI(<another_frame_def>).Imm(Frame::GetPrevFrameOffset())
96 
IsFrameDef(Inst * inst)97 bool DanglingPointersChecker::IsFrameDef(Inst *inst)
98 {
99     //    inst := LiveIn(frame)
100     if (inst == frameLivein_) {
101         return true;
102     }
103     // or
104     //    inst := LoadI(inst_input).Imm(imm)
105     if (inst->GetOpcode() == Opcode::LoadI) {
106         // where
107         //       inst_input := LiveIn(thread)
108         //       imm := ManagedThread::GetFrameOffset()
109         auto instInput = inst->GetInput(0).GetInst();
110         if (instInput == threadLivein_ &&
111             static_cast<LoadInstI *>(inst)->GetImm() == ark::ManagedThread::GetFrameOffset()) {
112             return true;
113         }
114         // or
115         //       inst_input := <frame_def>
116         //       imm := Frame::GetPrevFrameOffset()
117         if (static_cast<LoadInstI *>(inst)->GetImm() == static_cast<uint64_t>(ark::Frame::GetPrevFrameOffset()) &&
118             IsFrameDef(instInput)) {
119             return true;
120         }
121     }
122     return false;
123 }
124 
CheckSuccessors(BasicBlock * bb,bool prevRes)125 bool DanglingPointersChecker::CheckSuccessors(BasicBlock *bb, bool prevRes)
126 {
127     if (checkedBlocks_.find(bb) != checkedBlocks_.end()) {
128         return prevRes;
129     }
130     for (auto succBb : bb->GetSuccsBlocks()) {
131         if (!prevRes) {
132             return false;
133         }
134         for (auto inst : succBb->AllInsts()) {
135             auto user = std::find(objectsUsers_.begin(), objectsUsers_.end(), inst);
136             if (user == objectsUsers_.end() || (*user)->IsPhi()) {
137                 continue;
138             }
139             return false;
140         }
141         checkedBlocks_.insert(bb);
142 
143         prevRes &= CheckSuccessors(succBb, prevRes);
144     }
145 
146     return prevRes;
147 }
148 
149 // Accumulator can be defined in three ways:
150 // 1. acc_def := LiveIn(acc)
151 // 2. acc_def := LoadI(last_frame_def).Imm(frame_acc_offset)
152 // 3. acc_ptr := AddI(last_frame_def).Imm(frame_acc_offset)
153 //    acc_def := LoadI(acc_ptr).Imm(0)
154 
GetAccAndFrameDefs(Inst * inst)155 std::tuple<Inst *, Inst *> DanglingPointersChecker::GetAccAndFrameDefs(Inst *inst)
156 {
157     auto arch = GetGraph()->GetArch();
158     if (inst == accLivein_) {
159         return std::make_pair(accLivein_, nullptr);
160     }
161     if (inst->GetOpcode() != Opcode::LoadI) {
162         return std::make_pair(nullptr, nullptr);
163     }
164 
165     auto instInput = inst->GetInput(0).GetInst();
166     auto frameAccOffset = static_cast<uint64_t>(cross_values::GetFrameAccOffset(arch));
167     auto loadImm = static_cast<LoadInstI *>(inst)->GetImm();
168     if (loadImm == frameAccOffset && IsFrameDef(instInput)) {
169         return std::make_pair(inst, instInput);
170     }
171 
172     if (loadImm == 0 && IsAccPtr(instInput)) {
173         return std::make_pair(inst, instInput->GetInput(0).GetInst());
174     }
175 
176     return std::make_pair(nullptr, nullptr);
177 }
178 
179 // Accumulator tag can be defined in three ways:
180 // 1. acc_tag_def := LiveIn(acc_tag)
181 // 2. acc_ptr     := AddI(last_frame_def).Imm(frame_acc_offset)
182 //    acc_tag_def := LoadI(acc_ptr).Imm(acc_tag_offset)
183 // 3. acc_ptr     := AddI(last_frame_def).Imm(frame_acc_offset)
184 //    acc_tag_ptr := AddI(acc_ptr).Imm(acc_tag_offset)
185 //    acc_tag_def := LoadI(acc_tag_ptr).Imm(0)
186 
IsAccTagDef(Inst * inst)187 bool DanglingPointersChecker::IsAccTagDef(Inst *inst)
188 {
189     if (inst == accTagLivein_) {
190         return true;
191     }
192     if (inst->GetOpcode() != Opcode::LoadI) {
193         return false;
194     }
195 
196     auto instInput = inst->GetInput(0).GetInst();
197     auto arch = GetGraph()->GetArch();
198     auto accTagOffset = static_cast<uint64_t>(cross_values::GetFrameAccMirrorOffset(arch));
199     auto loadImm = static_cast<LoadInstI *>(inst)->GetImm();
200     if (loadImm == accTagOffset && IsAccPtr(instInput)) {
201         return true;
202     }
203 
204     if (loadImm == 0 && IsAccTagPtr(instInput)) {
205         return true;
206     }
207 
208     return false;
209 }
210 
IsAccTagPtr(Inst * inst)211 bool DanglingPointersChecker::IsAccTagPtr(Inst *inst)
212 {
213     if (inst->GetOpcode() != Opcode::AddI) {
214         return false;
215     }
216     auto arch = GetGraph()->GetArch();
217     auto instImm = static_cast<BinaryImmOperation *>(inst)->GetImm();
218     auto accTagOffset = static_cast<uint64_t>(cross_values::GetFrameAccMirrorOffset(arch));
219     if (instImm != accTagOffset) {
220         return false;
221     }
222     auto accPtrInst = inst->GetInput(0).GetInst();
223     return IsAccPtr(accPtrInst);
224 }
225 
IsAccPtr(Inst * inst)226 bool DanglingPointersChecker::IsAccPtr(Inst *inst)
227 {
228     if (inst->GetOpcode() != Opcode::AddI) {
229         return false;
230     }
231     auto arch = GetGraph()->GetArch();
232     auto instImm = static_cast<BinaryImmOperation *>(inst)->GetImm();
233     auto frameAccOffset = static_cast<uint64_t>(cross_values::GetFrameAccOffset(arch));
234     if (instImm != frameAccOffset) {
235         return false;
236     }
237     auto frameInst = inst->GetInput(0).GetInst();
238     if (!IsFrameDef(frameInst)) {
239         return false;
240     }
241     if (lastFrameDef_ == nullptr) {
242         return true;
243     }
244     return lastFrameDef_ == frameInst;
245 }
246 
UpdateLastAccAndFrameDef(Inst * inst)247 void DanglingPointersChecker::UpdateLastAccAndFrameDef(Inst *inst)
248 {
249     auto [acc_def, frame_def] = GetAccAndFrameDefs(inst);
250     if (acc_def != nullptr) {
251         // inst is acc definition
252         if (lastAccDef_ == nullptr) {
253             // don't have acc definition before
254             lastAccDef_ = acc_def;
255             lastFrameDef_ = frame_def;
256         }
257     } else {
258         // inst isn't acc definition
259         if (IsObjectDef(inst) && !IsPointerType(inst)) {
260             // objects defs should be only ref type
261             objectInsts_.insert(inst);
262         }
263     }
264 }
265 
GetLastAccDefinition(CallInst * runtimeCallInst)266 void DanglingPointersChecker::GetLastAccDefinition(CallInst *runtimeCallInst)
267 {
268     auto block = runtimeCallInst->GetBasicBlock();
269     auto prevInst = runtimeCallInst->GetPrev();
270 
271     phiInsts_.clear();
272     objectInsts_.clear();
273     while (block != GetGraph()->GetStartBlock()) {
274         while (prevInst != nullptr) {
275             UpdateLastAccAndFrameDef(prevInst);
276 
277             if (lastAccTagDef_ == nullptr && IsAccTagDef(prevInst)) {
278                 lastAccTagDef_ = prevInst;
279             }
280 
281             prevInst = MoveToPrevInst(prevInst, block);
282         }
283 
284         objectInsts_.insert(accLivein_);
285 
286         for (auto *phiInst : block->PhiInsts()) {
287             phiInsts_.push_back(phiInst);
288         }
289         block = block->GetDominator();
290         prevInst = block->GetLastInst();
291     }
292 
293     // Check that accumulator has not been overwritten in any execution branch except restored acc
294     auto [phi_def_acc, phi_def_frame] = GetPhiAccDef();
295     if (phi_def_acc != nullptr) {
296         lastAccDef_ = phi_def_acc;
297         lastFrameDef_ = phi_def_frame;
298     }
299     if (lastAccTagDef_ == nullptr) {
300         lastAccTagDef_ = GetPhiAccTagDef();
301     }
302 
303     if (lastAccDef_ == nullptr) {
304         lastAccDef_ = accLivein_;
305     }
306 
307     if (lastAccTagDef_ == nullptr) {
308         lastAccTagDef_ = accTagLivein_;
309     }
310 }
311 
GetPhiAccDef()312 std::tuple<Inst *, Inst *> DanglingPointersChecker::GetPhiAccDef()
313 {
314     // If any input isn't a definition (or there are no definitions among its inputs),
315     // then the phi is not a definition.
316     // Otherwise, if we have reached the last input and it is a definition (or there is a definition in among its
317     // inputs), then the phi is a definition.
318     for (auto *phiInst : phiInsts_) {
319         bool isAccDefPhi = true;
320         auto inputsCount = phiInst->GetInputsCount();
321         Inst *accDef {nullptr};
322         Inst *frameDef {nullptr};
323         for (uint32_t inputIdx = 0; inputIdx < inputsCount; inputIdx++) {
324             auto inputInst = phiInst->GetInput(inputIdx).GetInst();
325             std::tie(accDef, frameDef) = GetAccAndFrameDefs(inputInst);
326             if (accDef != nullptr || inputInst == nullptr) {
327                 continue;
328             }
329             if (inputInst->IsConst() ||
330                 (inputInst->GetOpcode() == Opcode::Bitcast && inputInst->GetInput(0).GetInst()->IsConst())) {
331                 accDef = inputInst;
332                 continue;
333             }
334             std::tie(accDef, frameDef) = GetAccDefFromInputs(inputInst);
335             if (accDef == nullptr) {
336                 isAccDefPhi = false;
337                 break;
338             }
339         }
340         if (!isAccDefPhi) {
341             continue;
342         }
343         if (accDef != nullptr) {
344             return std::make_pair(phiInst, frameDef);
345         }
346     }
347     return std::make_pair(nullptr, nullptr);
348 }
349 
GetAccDefFromInputs(Inst * inst)350 std::tuple<Inst *, Inst *> DanglingPointersChecker::GetAccDefFromInputs(Inst *inst)
351 {
352     auto inputsCount = inst->GetInputsCount();
353     Inst *accDef {nullptr};
354     Inst *frameDef {nullptr};
355     for (uint32_t inputIdx = 0; inputIdx < inputsCount; inputIdx++) {
356         auto inputInst = inst->GetInput(inputIdx).GetInst();
357 
358         std::tie(accDef, frameDef) = GetAccAndFrameDefs(inputInst);
359         if (accDef != nullptr || inputInst == nullptr) {
360             continue;
361         }
362         if (inputInst->IsConst() ||
363             (inputInst->GetOpcode() == Opcode::Bitcast && inputInst->GetInput(0).GetInst()->IsConst())) {
364             accDef = inputInst;
365             continue;
366         }
367         std::tie(accDef, frameDef) = GetAccDefFromInputs(inputInst);
368         if (accDef == nullptr) {
369             return std::make_pair(nullptr, nullptr);
370         }
371     }
372     return std::make_pair(accDef, frameDef);
373 }
374 
GetPhiAccTagDef()375 Inst *DanglingPointersChecker::GetPhiAccTagDef()
376 {
377     for (auto *phiInst : phiInsts_) {
378         if (IsRefType(phiInst) || IsPointerType(phiInst)) {
379             continue;
380         }
381         auto inputsCount = phiInst->GetInputsCount();
382         for (uint32_t inputIdx = 0; inputIdx < inputsCount; inputIdx++) {
383             auto inputInst = phiInst->GetInput(inputIdx).GetInst();
384             auto isAccTagDef = IsAccTagDef(inputInst);
385             if ((isAccTagDef || inputInst->IsConst()) && (inputIdx == inputsCount - 1)) {
386                 return phiInst;
387             }
388 
389             if (isAccTagDef || inputInst->IsConst()) {
390                 continue;
391             }
392 
393             if (!IsAccTagDefInInputs(inputInst)) {
394                 break;
395             }
396             return phiInst;
397         }
398     }
399     return nullptr;
400 }
401 
IsAccTagDefInInputs(Inst * inst)402 bool DanglingPointersChecker::IsAccTagDefInInputs(Inst *inst)
403 {
404     auto inputsCount = inst->GetInputsCount();
405     for (uint32_t inputIdx = 0; inputIdx < inputsCount; inputIdx++) {
406         auto inputInst = inst->GetInput(inputIdx).GetInst();
407         if (IsAccTagDef(inputInst)) {
408             return true;
409         }
410 
411         if ((inputIdx == inputsCount - 1) && inputInst->IsConst()) {
412             return true;
413         }
414 
415         if (IsAccTagDefInInputs(inputInst)) {
416             return true;
417         }
418     }
419     return false;
420 }
421 
IsSaveAcc(const Inst * inst)422 bool DanglingPointersChecker::IsSaveAcc(const Inst *inst)
423 {
424     if (inst->GetOpcode() != Opcode::StoreI) {
425         return false;
426     }
427 
428     auto arch = GetGraph()->GetArch();
429     auto frameAccOffset = static_cast<uint64_t>(cross_values::GetFrameAccOffset(arch));
430     if (static_cast<const StoreInstI *>(inst)->GetImm() != frameAccOffset) {
431         return false;
432     }
433     auto storeInput1 = inst->GetInput(1).GetInst();
434     if (storeInput1 != lastAccDef_) {
435         return false;
436     }
437     auto storeInput0 = inst->GetInput(0).GetInst();
438     if (lastFrameDef_ == nullptr) {
439         if (IsFrameDef(storeInput0)) {
440             return true;
441         }
442     } else if (storeInput0 == lastFrameDef_) {
443         return true;
444     }
445     return false;
446 }
447 
448 // Accumulator is saved using the StoreI instruction:
449 // StoreI(last_frame_def, last_acc_def).Imm(cross_values::GetFrameAccOffset(GetArch()))
450 
CheckStoreAcc(CallInst * runtimeCallInst)451 bool DanglingPointersChecker::CheckStoreAcc(CallInst *runtimeCallInst)
452 {
453     auto prevInst = runtimeCallInst->GetPrev();
454     auto block = runtimeCallInst->GetBasicBlock();
455     while (block != GetGraph()->GetStartBlock()) {
456         while (prevInst != nullptr && prevInst != lastAccDef_) {
457             if (IsSaveAcc(prevInst)) {
458                 return true;
459             }
460             prevInst = MoveToPrevInst(prevInst, block);
461         }
462         block = block->GetDominator();
463         prevInst = block->GetLastInst();
464     }
465     return false;
466 }
467 
468 // Accumulator tag is saved using the StoreI instruction:
469 // StoreI(acc_ptr, last_acc_tag_def).Imm(cross_values::GetFrameAccMirrorOffset(GetArch()))
470 
CheckStoreAccTag(CallInst * runtimeCallInst)471 bool DanglingPointersChecker::CheckStoreAccTag(CallInst *runtimeCallInst)
472 {
473     bool isSaveAccTag = false;
474     auto arch = GetGraph()->GetArch();
475     auto prevInst = runtimeCallInst->GetPrev();
476     auto block = runtimeCallInst->GetBasicBlock();
477     auto accTagOffset = static_cast<uint64_t>(cross_values::GetFrameAccMirrorOffset(arch));
478     while (block != GetGraph()->GetStartBlock()) {
479         while (prevInst != nullptr && prevInst != lastAccDef_) {
480             if (prevInst->GetOpcode() != Opcode::StoreI) {
481                 prevInst = MoveToPrevInst(prevInst, block);
482                 continue;
483             }
484             if (static_cast<StoreInstI *>(prevInst)->GetImm() != accTagOffset) {
485                 prevInst = MoveToPrevInst(prevInst, block);
486                 continue;
487             }
488             auto storeInput1 = prevInst->GetInput(1).GetInst();
489             if (lastAccTagDef_ == nullptr) {
490                 lastAccTagDef_ = storeInput1;
491             }
492             if (storeInput1 != lastAccTagDef_ && !storeInput1->IsConst()) {
493                 prevInst = MoveToPrevInst(prevInst, block);
494                 continue;
495             }
496             auto storeInput0 = prevInst->GetInput(0).GetInst();
497             if (IsAccPtr(storeInput0)) {
498                 isSaveAccTag = true;
499                 break;
500             }
501 
502             prevInst = MoveToPrevInst(prevInst, block);
503         }
504         if (isSaveAccTag) {
505             break;
506         }
507         block = block->GetDominator();
508         prevInst = block->GetLastInst();
509     }
510     return isSaveAccTag;
511 }
512 
CheckAccUsers(CallInst * runtimeCallInst)513 bool DanglingPointersChecker::CheckAccUsers(CallInst *runtimeCallInst)
514 {
515     objectsUsers_.clear();
516     for (const auto &user : lastAccDef_->GetUsers()) {
517         objectsUsers_.push_back(user.GetInst());
518     }
519 
520     return CheckUsers(runtimeCallInst);
521 }
522 
CheckObjectsUsers(CallInst * runtimeCallInst)523 bool DanglingPointersChecker::CheckObjectsUsers(CallInst *runtimeCallInst)
524 {
525     objectsUsers_.clear();
526     for (auto *objectInst : objectInsts_) {
527         for (const auto &user : objectInst->GetUsers()) {
528             objectsUsers_.push_back(user.GetInst());
529         }
530     }
531 
532     return CheckUsers(runtimeCallInst);
533 }
534 
CheckUsers(CallInst * runtimeCallInst)535 bool DanglingPointersChecker::CheckUsers(CallInst *runtimeCallInst)
536 {
537     bool checkObjectUsers = true;
538     auto runtimeCallBlock = runtimeCallInst->GetBasicBlock();
539 
540     auto nextInst = runtimeCallInst->GetNext();
541     while (nextInst != nullptr) {
542         auto user = std::find(objectsUsers_.begin(), objectsUsers_.end(), nextInst);
543         if (user == objectsUsers_.end() || (*user)->IsPhi()) {
544             nextInst = nextInst->GetNext();
545             continue;
546         }
547         return false;
548     }
549 
550     checkedBlocks_.clear();
551     return CheckSuccessors(runtimeCallBlock, checkObjectUsers);
552 }
553 
CheckAccSyncCallRuntime()554 bool DanglingPointersChecker::CheckAccSyncCallRuntime()
555 {
556     if (regmap_.find(GetGraph()->GetArch()) == regmap_.end()) {
557         return true;
558     }
559 
560     if (GetGraph()->GetRelocationHandler() == nullptr) {
561         return true;
562     }
563 
564     // collect runtime calls
565     ArenaVector<CallInst *> runtimeCalls(GetGraph()->GetLocalAllocator()->Adapter());
566     for (auto block : GetGraph()->GetBlocksRPO()) {
567         for (auto inst : block->Insts()) {
568             if (!inst->IsCall()) {
569                 continue;
570             }
571             auto callInst = static_cast<CallInst *>(inst);
572             auto callFuncName =
573                 GetGraph()->GetRuntime()->GetExternalMethodName(GetGraph()->GetMethod(), callInst->GetCallMethodId());
574             if (targetFuncs_.find(callFuncName) == targetFuncs_.end()) {
575                 continue;
576             }
577             runtimeCalls.push_back(callInst);
578         }
579     }
580     if (runtimeCalls.empty()) {
581         return true;
582     }
583 
584     // find LiveIns for acc and frame
585     InitLiveIns();
586 
587     for (auto runtimeCallInst : runtimeCalls) {
588         lastAccDef_ = nullptr;
589         lastAccTagDef_ = nullptr;
590         GetLastAccDefinition(runtimeCallInst);
591 
592         if (!IsRefType(lastAccDef_) && !IsPointerType(lastAccDef_)) {
593             continue;
594         }
595 
596         // check that acc has been stored in the frame before call
597         if (!CheckStoreAcc(runtimeCallInst)) {
598             return false;
599         }
600 
601         if (!GetGraph()->IsDynamicMethod() && !CheckStoreAccTag(runtimeCallInst)) {
602             return false;
603         }
604 
605         // check that acc isn't used after call
606         if (!CheckAccUsers(runtimeCallInst)) {
607             return false;
608         }
609 
610         // check that other objects aren't used after call
611         if (!CheckObjectsUsers(runtimeCallInst)) {
612             return false;
613         }
614     }
615     return true;
616 }
617 }  // namespace ark::compiler
618