• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "alias_visitor.h"
17 #include "loop_analyzer.h"
18 #include "object_type_propagation.h"
19 #include "typed_ref_set.h"
20 #include "optimizer/analysis/live_in_analysis.h"
21 #include "optimizer/ir/basicblock.h"
22 #include "compiler_logger.h"
23 
24 namespace ark::compiler {
25 
Merge(ObjectTypeInfo lhs,ObjectTypeInfo rhs)26 ObjectTypeInfo Merge(ObjectTypeInfo lhs, ObjectTypeInfo rhs)
27 {
28     if (lhs == ObjectTypeInfo::INVALID || rhs == ObjectTypeInfo::UNKNOWN) {
29         return lhs;
30     }
31     if (rhs == ObjectTypeInfo::INVALID || lhs == ObjectTypeInfo::UNKNOWN) {
32         return rhs;
33     }
34     // can be improved by finding a common superclass
35     if (lhs.GetClass() != rhs.GetClass()) {
36         return ObjectTypeInfo::INVALID;
37     }
38     return {lhs.GetClass(), lhs.IsExact() && rhs.IsExact()};
39 }
40 
TryImproveTypeInfo(ObjectTypeInfo & lhs,ObjectTypeInfo rhs)41 void TryImproveTypeInfo(ObjectTypeInfo &lhs, ObjectTypeInfo rhs)
42 {
43     if (!lhs || (rhs && !lhs.IsExact() && rhs.IsExact())) {
44         lhs = rhs;
45     }
46 }
47 
48 struct EmptyDeleter {
operator ()ark::compiler::EmptyDeleter49     void operator()([[maybe_unused]] void *ptr) {}
50 };
51 
52 template <typename T>
53 class COWPtr {
54 public:
COWPtr()55     COWPtr() : ptr_(nullptr) {}
COWPtr(ArenaAllocator * alloc)56     explicit COWPtr(ArenaAllocator *alloc) : ptr_(alloc->New<T>(alloc->Adapter()), EmptyDeleter {}, alloc->Adapter()) {}
COWPtr(T * obj,ArenaAllocator * alloc)57     COWPtr(T *obj, ArenaAllocator *alloc) : ptr_(obj, EmptyDeleter {}, alloc->Adapter()) {}
58 
operator *() const59     const T &operator*() const
60     {
61         return *ptr_;
62     }
63 
operator ->() const64     const T *operator->() const
65     {
66         return Get();
67     }
68 
Get() const69     const T *Get() const
70     {
71         return ptr_.get();
72     }
73 
Mut(ArenaAllocator * alloc)74     T *Mut(ArenaAllocator *alloc)
75     {
76         if (ptr_.unique()) {
77             return ptr_.get();
78         }
79         T *copy = alloc->New<T>(*ptr_);
80         ptr_ = std::shared_ptr<T> {copy, EmptyDeleter {}, alloc->Adapter()};
81         ASSERT(ptr_.unique());
82         ASSERT(ptr_.get() == copy);
83         return copy;
84     }
85 
86 private:
87     std::shared_ptr<T> ptr_;
88 };
89 
90 namespace {
91 
92 class ObjectTypePropagationVisitor;
93 constexpr Ref NULL_REF = 0;
94 
95 class BasicBlockState {
96 public:
97     BasicBlockState(ObjectTypePropagationVisitor *visitor, BasicBlock *bb);
98     BasicBlockState(const BasicBlockState &other, BasicBlock *bb);
99     ~BasicBlockState() = default;
100 
101     const ArenaTypedRefSet &GetFieldRefSet(Ref base, const PointerOffset &offset);
102     void CreateFieldRefSetForNewObject(Ref base, bool defaultConstructed);
103     ArenaTypedRefSet &CreateFieldRefSet(Ref base, const PointerOffset &offset, bool &changed);
104     const ArenaTypedRefSet &NewUnknownRef();
105     const ArenaTypedRefSet &GetNullRef();
106     void Merge(BasicBlockState *other);
107     void TryEscape(Ref ref);
108     void ForceEscape(Ref ref);
109     void Escape(const Inst *inst);
110     void InvalidateEscaped();
111     bool IsEscaped(Ref ref) const;
112     template <typename F>
113     void Cleanup(const F &remove);
114     size_t GetBlockId() const;
115 
116     constexpr static size_t REAL_REF_START = 3U;
117 
118 private:
119     void Escape(ArenaVector<Ref> &worklist, const ArenaTypedRefSet &refs);
120     void Escape(ArenaVector<Ref> &worklist);
121     void InvalidateEscapedRef(Ref ref);
122 
123 private:
124     DEFAULT_COPY_CTOR(BasicBlockState);
125     NO_COPY_OPERATOR(BasicBlockState);
126     NO_MOVE_SEMANTIC(BasicBlockState);
127 
128     ArenaAllocator *GetLocalAllocator();
129     friend std::ostream &operator<<(std::ostream &os, const BasicBlockState &state);
130 
131     ObjectTypePropagationVisitor *visitor_;
132     ArenaMap<Ref, COWPtr<PointerOffset::Map<ArenaTypedRefSet>>> fieldRefs_;
133     ArenaTypedRefSet escaped_;
134 #ifndef NDEBUG
135     BasicBlock *bb_ {nullptr};
136 #endif
137 };
138 
139 class TypePropagationVisitor : public GraphVisitor {
140 public:
TypePropagationVisitor(Graph * graph)141     explicit TypePropagationVisitor(Graph *graph) : graph_(graph) {}
142     ~TypePropagationVisitor() override = default;
143     NO_COPY_SEMANTIC(TypePropagationVisitor);
144     NO_MOVE_SEMANTIC(TypePropagationVisitor);
145 
146     static void VisitNewObject(GraphVisitor *v, Inst *i);
147     static void VisitParameter(GraphVisitor *v, Inst *i);
148     static void VisitNewArray(GraphVisitor *v, Inst *i);
149     static void VisitLoadArray(GraphVisitor *v, Inst *i);
150     static void VisitLoadString(GraphVisitor *v, Inst *i);
151     static void VisitLoadObject(GraphVisitor *v, Inst *i);
152     static void VisitLoadStatic(GraphVisitor *v, Inst *i);
153     static void VisitCallStatic(GraphVisitor *v, Inst *i);
154     static void VisitCallVirtual(GraphVisitor *v, Inst *i);
155     static void VisitRefTypeCheck(GraphVisitor *v, Inst *i);
156 
157 #include "optimizer/ir/visitor.inc"
158 
159 protected:
160     virtual void SetTypeInfo(Inst *inst, ObjectTypeInfo info) = 0;
161 
162 private:
163     static void ProcessManagedCall(GraphVisitor *v, CallInst *inst);
GetGraph()164     Graph *GetGraph()
165     {
166         return graph_;
167     }
168 
169     Graph *graph_;
170 };
171 
172 // NOLINTNEXTLINE(fuchsia-multiple-inheritance)
173 class LoopPropagationVisitor : public AliasVisitor, public TypePropagationVisitor {
174 public:
175     explicit LoopPropagationVisitor(ObjectTypePropagationVisitor *parent);
176     ~LoopPropagationVisitor() override = default;
177     NO_COPY_SEMANTIC(LoopPropagationVisitor);
178     NO_MOVE_SEMANTIC(LoopPropagationVisitor);
179 
GetBlocksToVisit() const180     const ArenaVector<BasicBlock *> &GetBlocksToVisit() const override
181     {
182         UNREACHABLE();
183     }
184 
185     void AddDirectEdge(const Pointer &p) override;
186 
187     void VisitAllocation(Inst *inst) override;
188 
189     void AddConstantDirectEdge(Inst *inst, uint32_t id) override;
190 
191     void AddCopyEdge(const Pointer &from, const Pointer &to) override;
192     void AddPseudoCopyEdge(const Pointer &base, const Pointer &field) override;
193 
VisitHeapInv(Inst * inst)194     void VisitHeapInv([[maybe_unused]] Inst *inst) override
195     {
196         heapInv_ = true;
197     }
198 
199     void Escape(const Inst *inst) override;
200 
201     void SetTypeInfo(Inst *inst, ObjectTypeInfo info) override;
202 
203     void VisitLoop(Loop *loop);
204 
205     void VisitBlock(BasicBlock *bb) override;
206 
207     void VisitLoopRec(Loop *loop);
208 
209 private:
210     ObjectTypePropagationVisitor *parent_;
211     bool heapInv_ {false};
212     BasicBlockState *headerState_ {nullptr};
213     ArenaSet<const Inst *> escapedInsts_;
214 };
215 
216 // NOLINTNEXTLINE(fuchsia-multiple-inheritance)
217 class ObjectTypePropagationVisitor : public Analysis, public AliasVisitor, public TypePropagationVisitor {
218 public:
ObjectTypePropagationVisitor(Graph * graph)219     explicit ObjectTypePropagationVisitor(Graph *graph)
220         : Analysis(graph),
221           TypePropagationVisitor(graph),
222           states_(graph->GetVectorBlocks().size(), graph->GetLocalAllocator()->Adapter()),
223           refInfos_(BasicBlockState::REAL_REF_START, ObjectTypeInfo::INVALID, graph->GetLocalAllocator()->Adapter()),
224           instRefSets_(graph->GetLocalAllocator()->Adapter()),
225           loopVisitor_(this),
226           liveIns_(graph),
227           loopEdges_(graph->GetLocalAllocator()->Adapter()),
228           loopStoreEdges_(graph->GetLocalAllocator()->Adapter()),
229           workSet_(graph->GetLocalAllocator()->Adapter()),
230           nullSet_(graph->GetLocalAllocator(), ObjectTypeInfo::UNKNOWN),
231           allSet_(graph->GetLocalAllocator(), ObjectTypeInfo::INVALID)
232     {
233         AliasVisitor::Init(graph->GetLocalAllocator());
234         refInfos_[NULL_REF] = ObjectTypeInfo::UNKNOWN;
235         nullSet_.SetBit(NULL_REF, ObjectTypeInfo::UNKNOWN);
236         allSet_.SetBit(ALL_REF, ObjectTypeInfo::INVALID);
237     }
238 
GetPassName() const239     const char *GetPassName() const override
240     {
241         return "ObjectTypePropagationVisitor";
242     }
243 
GetBlocksToVisit() const244     const ArenaVector<BasicBlock *> &GetBlocksToVisit() const override
245     {
246         // We use only VisitBlock
247         UNREACHABLE();
248     }
249 
250     using Analysis::GetGraph;
251 
252     bool RunImpl() override;
253     void SetTypeInfosInGraph();
254     void ResetTypeInfosInGraph();
255     void WalkOutgoingEdges(const Inst *fromInst, const PointerOffset::Map<ArenaVector<Pointer>> &edges);
256     void WalkStoreEdges(const Inst *toInst, const PointerOffset::Map<ArenaVector<const Inst *>> &edges);
257     void WalkEdges();
258     Ref NewRef(ObjectTypeInfo info = ObjectTypeInfo::INVALID);
259     ObjectTypeInfo GetRefTypeInfo(Ref ref) const;
260     std::string ToString(ObjectTypeInfo typeInfo);
261     void DumpRefInfos(std::ostream &os);
262     ArenaTypedRefSet &GetInstRefSet(const Inst *inst);
263     ArenaTypedRefSet &TryGetInstRefSet(const Inst *inst);
264     const ArenaTypedRefSet &GetNullSet() const;
265     const ArenaTypedRefSet &GetAllSet() const;
266     ArenaTypedRefSet &GetOrCreateInstRefSet(const Inst *inst, bool &changed);
GetInstRefSets() const267     const auto &GetInstRefSets() const
268     {
269         return instRefSets_;
270     }
271 
272 protected:
273     void AddDirectEdge(const Pointer &p) override;
274     void AddConstantDirectEdge(Inst *inst, uint32_t id) override;
275 
276     void AddCopyEdge(const Pointer &from, const Pointer &to) override;
AddPseudoCopyEdge(const Pointer & base,const Pointer & field)277     void AddPseudoCopyEdge([[maybe_unused]] const Pointer &base, [[maybe_unused]] const Pointer &field) override
278     {
279         // we don't do that here
280     }
281 
282     void VisitHeapInv(Inst *inst) override;
283     void Escape(const Inst *inst) override;
284     void VisitAllocation(Inst *inst) override;
285     void SetTypeInfo(Inst *inst, ObjectTypeInfo info) override;
286 
287 private:
288     BasicBlockState *GetState(const BasicBlock *block);
289     void CleanupState(BasicBlock *block);
290     void VisitInstsInBlock(BasicBlock *bb);
291     bool VisitBlockInternal(BasicBlock *block);
292     void VisitLoop(Loop *loop);
293     void RollbackChains();
294     void AddTempEdge(const Pointer &from, const Pointer &to);
295     bool AddEdge(Pointer from, Pointer to);
296     bool AddLoadEdge(const Pointer &from, const Inst *toObj, const ArenaTypedRefSet &srcRefSet);
297     bool AddStoreEdge(const Inst *fromObj, const Pointer &to, const ArenaTypedRefSet &srcRefSet);
298 
299     friend class LoopPropagationVisitor;
300 
301 private:
302     // main state:
303     Marker visited_ {};
304     ArenaVector<BasicBlockState *> states_;
305     ArenaVector<ObjectTypeInfo> refInfos_;
306     ArenaMap<const Inst *, ArenaTypedRefSet> instRefSets_;
307     // helper structs:
308     LoopPropagationVisitor loopVisitor_;
309     LiveInAnalysis liveIns_;
310 
311     // helper containers:
312     // src base -> src offset -> dest pointers
313     ArenaMap<const Inst *, PointerOffset::Map<ArenaVector<Pointer>>> loopEdges_;
314     // store base -> store offset -> stored objects
315     ArenaMap<const Inst *, PointerOffset::Map<ArenaVector<const Inst *>>> loopStoreEdges_;
316     ArenaSet<const Inst *> workSet_;
317     BasicBlockState *currentBlockState_ {nullptr};
318     ArenaTypedRefSet nullSet_;
319     ArenaTypedRefSet allSet_;
320     bool inLoop_ {false};
321 };
322 
RunImpl()323 bool ObjectTypePropagationVisitor::RunImpl()
324 {
325     if (!liveIns_.Run(false)) {
326         UNREACHABLE();
327     }
328     auto *graph = GetGraph();
329     ASSERT(graph != nullptr);
330     graph->RunPass<LoopAnalyzer>();
331     MarkerHolder holder(graph);
332     visited_ = holder.GetMarker();
333     for (auto *block : graph->GetBlocksRPO()) {
334         if (!VisitBlockInternal(block)) {
335             ResetTypeInfosInGraph();
336             return false;
337         }
338     }
339     SetTypeInfosInGraph();
340     return true;
341 }
342 
SetTypeInfosInGraph()343 void ObjectTypePropagationVisitor::SetTypeInfosInGraph()
344 {
345     for (auto &[inst, refs] : instRefSets_) {
346         if (inst->IsParameter() && inst->GetObjectTypeInfo()) {
347             // type info was set during inlining
348             continue;
349         }
350         ObjectTypeInfo info = ObjectTypeInfo::UNKNOWN;
351         [[maybe_unused]] bool isNull = true;
352         refs.Visit([this, &info, &isNull](Ref ref) {
353             info = Merge(info, GetRefTypeInfo(ref));
354             if (ref != NULL_REF) {
355                 isNull = false;
356             }
357         });
358         ASSERT_DO(info != ObjectTypeInfo::UNKNOWN || isNull,
359                   (std::cerr << "Inst shoulnd't have UNKNOWN TypeInfo: " << *inst << "\n"
360                              << "refs: " << refs << "\n",
361                    GetGraph()->Dump(&std::cerr)));
362         auto commonInfo = refs.GetTypeInfo();
363         TryImproveTypeInfo(info, commonInfo);
364         const_cast<Inst *>(inst)->SetObjectTypeInfo(info);
365     }
366 }
367 
ResetTypeInfosInGraph()368 void ObjectTypePropagationVisitor::ResetTypeInfosInGraph()
369 {
370     for (auto *block : GetGraph()->GetBlocksRPO()) {
371         for (auto inst : block->AllInsts()) {
372             inst->SetObjectTypeInfo(ObjectTypeInfo::INVALID);
373         }
374     }
375 }
376 
WalkOutgoingEdges(const Inst * fromInst,const PointerOffset::Map<ArenaVector<Pointer>> & edges)377 void ObjectTypePropagationVisitor::WalkOutgoingEdges(const Inst *fromInst,
378                                                      const PointerOffset::Map<ArenaVector<Pointer>> &edges)
379 {
380     for (const auto &[fromOffset, toPtrs] : edges) {
381         for (const auto &toPtr : toPtrs) {
382             Pointer fromPtr {fromInst, fromOffset};
383             bool changed = AddEdge(fromPtr, toPtr);
384             if (changed) {
385                 ASSERT(toPtr.GetBase() != nullptr);
386                 workSet_.insert(toPtr.GetBase());
387             }
388         }
389     }
390 }
391 
WalkStoreEdges(const Inst * toInst,const PointerOffset::Map<ArenaVector<const Inst * >> & edges)392 void ObjectTypePropagationVisitor::WalkStoreEdges(const Inst *toInst,
393                                                   const PointerOffset::Map<ArenaVector<const Inst *>> &edges)
394 {
395     for (const auto &[toOffset, fromInsts] : edges) {
396         Pointer toPtr {toInst, toOffset};
397         for (const auto *fromInst : fromInsts) {
398             auto fromPtr = Pointer::CreateObject(fromInst);
399             bool changed = AddEdge(fromPtr, toPtr);
400             if (changed) {
401                 ASSERT(toPtr.GetBase() != nullptr);
402             }
403         }
404     }
405 }
406 
WalkEdges()407 void ObjectTypePropagationVisitor::WalkEdges()
408 {
409     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "LOOP - WalkEdges";
410     workSet_ = {};
411     for (const auto &[fromInst, edges] : loopEdges_) {
412         ASSERT(fromInst != nullptr);
413         WalkOutgoingEdges(fromInst, edges);
414     }
415     while (!workSet_.empty()) {
416         auto inst = *workSet_.begin();
417         workSet_.erase(workSet_.begin());
418 
419         if (auto it = loopStoreEdges_.find(inst); it != loopStoreEdges_.end()) {
420             WalkStoreEdges(inst, it->second);
421         }
422         if (auto it = loopEdges_.find(inst); it != loopEdges_.end()) {
423             WalkOutgoingEdges(inst, it->second);
424         }
425     }
426 }
427 
NewRef(ObjectTypeInfo info)428 Ref ObjectTypePropagationVisitor::NewRef(ObjectTypeInfo info)
429 {
430     refInfos_.push_back(info);
431     return refInfos_.size() - 1;
432 }
433 
GetRefTypeInfo(Ref ref) const434 ObjectTypeInfo ObjectTypePropagationVisitor::GetRefTypeInfo(Ref ref) const
435 {
436     if (ref == ALL_REF) {
437         return ObjectTypeInfo::INVALID;
438     }
439     return refInfos_.at(ref);
440 }
441 
ToString(ObjectTypeInfo typeInfo)442 std::string ObjectTypePropagationVisitor::ToString(ObjectTypeInfo typeInfo)
443 {
444     if (typeInfo == ObjectTypeInfo::UNKNOWN) {
445         return "UNKNOWN";
446     }
447     if (typeInfo == ObjectTypeInfo::INVALID) {
448         return "INVALID";
449     }
450     return GetGraph()->GetRuntime()->GetClassName(typeInfo.GetClass());
451 }
452 
DumpRefInfos(std::ostream & os)453 void ObjectTypePropagationVisitor::DumpRefInfos(std::ostream &os)
454 {
455     for (Ref ref = BasicBlockState::REAL_REF_START; ref < refInfos_.size(); ref++) {
456         auto typeInfo = GetRefTypeInfo(ref);
457         os << ref << ": " << ToString(typeInfo) << "\n";
458     }
459 }
460 
GetInstRefSet(const Inst * inst)461 ArenaTypedRefSet &ObjectTypePropagationVisitor::GetInstRefSet(const Inst *inst)
462 {
463     switch (inst->GetOpcode()) {
464         // No passes that check class references aliasing
465         case Opcode::GetInstanceClass:
466         case Opcode::LoadImmediate:
467             return nullSet_;
468         default:
469             break;
470     }
471     if (instRefSets_.find(inst) == instRefSets_.end()) {
472         std::cerr << "no inst: " << *inst << '\n';
473         GetGraph()->Dump(&std::cerr);
474         UNREACHABLE();
475     }
476     return instRefSets_.at(inst);
477 }
478 
TryGetInstRefSet(const Inst * inst)479 ArenaTypedRefSet &ObjectTypePropagationVisitor::TryGetInstRefSet(const Inst *inst)
480 {
481     auto it = instRefSets_.find(inst);
482     if (it == instRefSets_.end()) {
483         return nullSet_;
484     }
485     return it->second;
486 }
487 
GetNullSet() const488 const ArenaTypedRefSet &ObjectTypePropagationVisitor::GetNullSet() const
489 {
490     ASSERT(nullSet_.PopCount() == 1);
491     return nullSet_;
492 }
493 
GetAllSet() const494 const ArenaTypedRefSet &ObjectTypePropagationVisitor::GetAllSet() const
495 {
496     return allSet_;
497 }
498 
GetOrCreateInstRefSet(const Inst * inst,bool & changed)499 ArenaTypedRefSet &ObjectTypePropagationVisitor::GetOrCreateInstRefSet(const Inst *inst, bool &changed)
500 {
501     ASSERT(inst != nullptr);
502     auto ret = instRefSets_.try_emplace(inst, GetGraph()->GetLocalAllocator(), ObjectTypeInfo::UNKNOWN);
503     changed |= ret.second;
504     return ret.first->second;
505 }
506 
AddDirectEdge(const Pointer & p)507 void ObjectTypePropagationVisitor::AddDirectEdge(const Pointer &p)
508 {
509     if (p.GetType() == STATIC_FIELD) {
510         // Load/Store Static - processed in AddCopyEdge
511         return;
512     }
513     auto inst = p.GetBase();
514     ASSERT(inst != nullptr);
515     auto &refs = IsZeroConstantOrNullPtr(inst) ? nullSet_ : currentBlockState_->NewUnknownRef();
516     instRefSets_.try_emplace(inst, refs);
517     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "Add Direct " << *inst << ": " << instRefSets_.at(inst);
518 }
519 
AddConstantDirectEdge(Inst * inst,uint32_t id)520 void ObjectTypePropagationVisitor::AddConstantDirectEdge(Inst *inst, [[maybe_unused]] uint32_t id)
521 {
522     ASSERT(inst != nullptr);
523     instRefSets_.try_emplace(inst, currentBlockState_->NewUnknownRef());
524 }
525 
AddCopyEdge(const Pointer & from,const Pointer & to)526 void ObjectTypePropagationVisitor::AddCopyEdge(const Pointer &from, const Pointer &to)
527 {
528     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "Add Copy " << from << " -> " << to;
529     [[maybe_unused]] auto changed = AddEdge(from, to);
530     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "  ToSet: " << TryGetInstRefSet(to.GetBase()) << " changed: " << changed;
531 }
532 
VisitHeapInv(Inst * inst)533 void ObjectTypePropagationVisitor::VisitHeapInv([[maybe_unused]] Inst *inst)
534 {
535     currentBlockState_->InvalidateEscaped();
536 }
537 
Escape(const Inst * inst)538 void ObjectTypePropagationVisitor::Escape(const Inst *inst)
539 {
540     currentBlockState_->Escape(inst);
541 }
542 
VisitAllocation(Inst * inst)543 void ObjectTypePropagationVisitor::VisitAllocation(Inst *inst)
544 {
545     bool defaultConstructed = inst->GetOpcode() != Opcode::InitObject;
546     auto ref = NewRef();
547     ASSERT(inst != nullptr);
548     auto [it, inserted] = instRefSets_.try_emplace(inst, GetGraph()->GetLocalAllocator(), ObjectTypeInfo::UNKNOWN);
549     it->second.SetBit(ref, ObjectTypeInfo::UNKNOWN);
550     currentBlockState_->CreateFieldRefSetForNewObject(ref, defaultConstructed);
551 }
552 
SetTypeInfo(Inst * inst,ObjectTypeInfo info)553 void ObjectTypePropagationVisitor::SetTypeInfo(Inst *inst, ObjectTypeInfo info)
554 {
555     // should be data-flow input instead of check inst
556     ASSERT(!inst->IsCheck() || inst->GetOpcode() == Opcode::RefTypeCheck);
557     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "SetTypeInfo inst " << inst->GetId() << " " << ToString(info);
558     auto &refSet = TryGetInstRefSet(inst);
559     if (&refSet == &nullSet_) {
560         ASSERT(inLoop_);
561         return;
562     }
563     if (refSet.PopCount() == 1) {
564         auto ref = refSet.GetSingle();
565         TryImproveTypeInfo(refInfos_[ref], info);
566     }
567     refSet.TryImproveTypeInfo(info);
568 }
569 
GetState(const BasicBlock * block)570 BasicBlockState *ObjectTypePropagationVisitor::GetState(const BasicBlock *block)
571 {
572     ASSERT(block->GetId() < states_.size());
573     return states_[block->GetId()];
574 }
575 
576 // Remove escaped (!) refs dead in the beginning of the block
577 // Non-escaped refs cannot be removed because they can be part of escape-chain later
CleanupState(BasicBlock * block)578 void ObjectTypePropagationVisitor::CleanupState(BasicBlock *block)
579 {
580     ASSERT(currentBlockState_ == GetState(block));
581     ArenaBitVector liveRefs(GetGraph()->GetLocalAllocator());
582     liveIns_.VisitAlive(block, [this, &liveRefs](const Inst *inst) {
583         // Currently CatchPhi's inputs can be marked live in the beginning of the start block,
584         // so there is not *always* RefSet for inst
585         TryGetInstRefSet(inst).Visit([&liveRefs](Ref ref) { liveRefs.SetBit(ref); });
586     });
587     currentBlockState_->Cleanup([&liveRefs](Ref ref) { return !liveRefs.GetBit(ref); });
588 }
589 
VisitInstsInBlock(BasicBlock * bb)590 void ObjectTypePropagationVisitor::VisitInstsInBlock(BasicBlock *bb)
591 {
592     for (auto inst : bb->AllInsts()) {
593         AliasVisitor::VisitInstruction(inst);
594         TypePropagationVisitor::VisitInstruction(inst);
595     }
596 }
597 
VisitBlockInternal(BasicBlock * block)598 bool ObjectTypePropagationVisitor::VisitBlockInternal(BasicBlock *block)
599 {
600     Loop *loop = nullptr;
601     if (block->IsLoopHeader()) {
602         loop = block->GetLoop();
603         if (loop->IsIrreducible()) {
604             return false;
605         }
606     }
607 
608     BasicBlockState *state = nullptr;
609     for (auto *pred : block->GetPredsBlocks()) {
610         if (!pred->IsMarked(visited_)) {
611             ASSERT(block->IsLoopHeader());
612             [[maybe_unused]] auto *predLoop = pred->GetLoop();
613             ASSERT(predLoop != nullptr);
614             ASSERT(predLoop == loop || predLoop->IsInside(loop));
615             continue;
616         }
617         if (state == nullptr) {
618             state = GetGraph()->GetLocalAllocator()->New<BasicBlockState>(*GetState(pred), block);
619         } else {
620             state->Merge(GetState(pred));
621         }
622     }
623     if (state == nullptr) {
624         state = GetGraph()->GetLocalAllocator()->New<BasicBlockState>(this, block);
625     }
626     states_[block->GetId()] = state;
627     currentBlockState_ = state;
628     if (loop != nullptr) {
629         VisitLoop(loop);
630     }
631     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "before visit: " << *state;
632     block->SetMarker(visited_);
633     CleanupState(block);
634     VisitInstsInBlock(block);
635     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "after visit: " << *state;
636     currentBlockState_ = nullptr;
637     return true;
638 }
639 
VisitLoop(Loop * loop)640 void ObjectTypePropagationVisitor::VisitLoop(Loop *loop)
641 {
642     ASSERT(!inLoop_);
643     ASSERT(GetGraph()->IsAnalysisValid<LoopAnalyzer>());
644     inLoop_ = true;
645     loopVisitor_.VisitLoop(loop);
646     RollbackChains();
647     inLoop_ = false;
648 }
649 
RollbackChains()650 void ObjectTypePropagationVisitor::RollbackChains()
651 {
652     for (auto &[base, edges] : loopEdges_) {
653         edges.clear();
654     }
655     for (auto &[base, edges] : loopStoreEdges_) {
656         edges.clear();
657     }
658 }
659 
AddTempEdge(const Pointer & from,const Pointer & to)660 void ObjectTypePropagationVisitor::AddTempEdge(const Pointer &from, const Pointer &to)
661 {
662     ASSERT(from.GetBase() != nullptr);
663     {
664         auto &instEdges =
665             loopEdges_.try_emplace(from.GetBase(), GetGraph()->GetLocalAllocator()->Adapter()).first->second;
666         auto &ptrEdges =
667             instEdges.try_emplace(from.GetOffset(), GetGraph()->GetLocalAllocator()->Adapter()).first->second;
668         ptrEdges.push_back(to);
669     }
670     if (from.IsObject() && !to.IsObject() && to.GetType() != STATIC_FIELD && to.GetType() != POOL_CONSTANT) {
671         // store
672         auto &instEdges =
673             loopStoreEdges_.try_emplace(to.GetBase(), GetGraph()->GetLocalAllocator()->Adapter()).first->second;
674         auto &ptrEdges =
675             instEdges.try_emplace(to.GetOffset(), GetGraph()->GetLocalAllocator()->Adapter()).first->second;
676         ptrEdges.push_back(from.GetBase());
677     }
678 }
679 
AddEdge(Pointer from,Pointer to)680 bool ObjectTypePropagationVisitor::AddEdge(Pointer from, Pointer to)
681 {
682     from = from.DropIdx();
683     to = to.DropIdx();
684     auto fromObj = from.GetBase();
685     auto it = instRefSets_.find(fromObj);
686     const ArenaTypedRefSet *srcRefSet = nullptr;
687     auto toObj = to.GetBase();
688     if (fromObj == nullptr) {
689         ASSERT(from.GetType() == STATIC_FIELD || from.GetType() == POOL_CONSTANT);
690         srcRefSet = &currentBlockState_->NewUnknownRef();
691     } else if (it != instRefSets_.end()) {
692         srcRefSet = &it->second;
693     } else if (!fromObj->IsReferenceOrAny()) {
694         // inputs of instructions with ANY type may all have primitive types
695         srcRefSet = &currentBlockState_->NewUnknownRef();
696     } else {
697         ASSERT(inLoop_ || toObj->IsPhi() || toObj->IsCatchPhi());
698         return false;
699     }
700     ASSERT(currentBlockState_ != nullptr);
701     if (from.IsObject() && to.IsObject()) {
702         // copy
703         bool changed = false;
704         auto &destRefSet = GetOrCreateInstRefSet(toObj, changed);
705         if (!destRefSet.Includes(*srcRefSet)) {
706             destRefSet |= *srcRefSet;
707             changed = true;
708         }
709         return changed;
710     }
711     if (to.IsObject()) {
712         return AddLoadEdge(from, toObj, *srcRefSet);
713     }
714     if (from.IsObject()) {
715         return AddStoreEdge(fromObj, to, *srcRefSet);
716     }
717     // both are not objects
718     UNREACHABLE();
719 }
720 
AddLoadEdge(const Pointer & from,const Inst * toObj,const ArenaTypedRefSet & srcRefSet)721 bool ObjectTypePropagationVisitor::AddLoadEdge(const Pointer &from, const Inst *toObj,
722                                                const ArenaTypedRefSet &srcRefSet)
723 {
724     bool changed = false;
725     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << " LOAD";
726     auto &destRefSet = GetOrCreateInstRefSet(toObj, changed);
727     if (from.GetType() == STATIC_FIELD || from.GetType() == POOL_CONSTANT) {
728         COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "  STATIC FIELD -> " << toObj->GetId();
729         changed |= destRefSet != srcRefSet;
730         destRefSet = srcRefSet;
731         return changed;
732     }
733     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "  initial destRefSet: " << destRefSet;
734     srcRefSet.Visit([offset = from.GetOffset(), &destRefSet, &changed, this](Ref ref) {
735         COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "  Visit " << ref;
736         auto &srcFieldSet = currentBlockState_->GetFieldRefSet(ref, offset);
737         COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "  srcFieldSet: " << srcFieldSet;
738         if (!destRefSet.Includes(srcFieldSet)) {
739             destRefSet |= srcFieldSet;
740             changed = true;
741         }
742     });
743     return changed;
744 }
745 
AddStoreEdge(const Inst * fromObj,const Pointer & to,const ArenaTypedRefSet & srcRefSet)746 bool ObjectTypePropagationVisitor::AddStoreEdge(const Inst *fromObj, const Pointer &to,
747                                                 const ArenaTypedRefSet &srcRefSet)
748 {
749     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << " STORE";
750     ASSERT(fromObj != nullptr);
751     if (!fromObj->IsReferenceOrAny()) {
752         // primitive ANY type or null as integer is stored
753         return false;
754     }
755     if (to.GetType() == STATIC_FIELD || to.GetType() == POOL_CONSTANT) {
756         return false;
757     }
758     if (instRefSets_.find(to.GetBase()) == instRefSets_.end()) {
759         ASSERT(inLoop_);
760         return true;
761     }
762     auto &destRefSet = instRefSets_.at(to.GetBase());
763     bool changed = false;
764     bool escape = false;
765     destRefSet.Visit([&to, &srcRefSet, &changed, &escape, this](Ref ref) {
766         if (currentBlockState_->IsEscaped(ref)) {
767             escape = true;
768         }
769         if (ref == ALL_REF) {
770             return;
771         }
772         auto offset = to.GetOffset();
773         auto &dstFieldSet = currentBlockState_->CreateFieldRefSet(ref, offset, changed);
774         COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "  ref: " << ref;
775         COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "  src: " << srcRefSet;
776         if (!inLoop_ && to.GetType() != UNKNOWN_OFFSET) {
777             changed |= dstFieldSet != srcRefSet;
778             dstFieldSet = srcRefSet;
779         } else if (!dstFieldSet.Includes(srcRefSet)) {
780             dstFieldSet |= srcRefSet;
781             changed = true;
782         }
783         COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "  dst: " << dstFieldSet;
784     });
785     if (escape) {
786         srcRefSet.Visit([this](Ref ref) { currentBlockState_->TryEscape(ref); });
787     }
788     return changed;
789 }
790 
LoopPropagationVisitor(ObjectTypePropagationVisitor * parent)791 LoopPropagationVisitor::LoopPropagationVisitor(ObjectTypePropagationVisitor *parent)
792     : TypePropagationVisitor(parent->GetGraph()),
793       parent_(parent),
794       escapedInsts_(parent_->GetGraph()->GetLocalAllocator()->Adapter())
795 {
796     AliasVisitor::Init(parent->GetGraph()->GetLocalAllocator());
797 }
798 
AddDirectEdge(const Pointer & p)799 void LoopPropagationVisitor::AddDirectEdge(const Pointer &p)
800 {
801     parent_->AddDirectEdge(p);
802 }
803 
VisitAllocation(Inst * inst)804 void LoopPropagationVisitor::VisitAllocation(Inst *inst)
805 {
806     parent_->VisitAllocation(inst);
807 }
808 
AddConstantDirectEdge(Inst * inst,uint32_t id)809 void LoopPropagationVisitor::AddConstantDirectEdge(Inst *inst, uint32_t id)
810 {
811     parent_->AddConstantDirectEdge(inst, id);
812 }
813 
AddCopyEdge(const Pointer & from,const Pointer & to)814 void LoopPropagationVisitor::AddCopyEdge(const Pointer &from, const Pointer &to)
815 {
816     if (from.GetType() == STATIC_FIELD || from.GetType() == POOL_CONSTANT) {
817         ASSERT(to.IsObject());
818         AddDirectEdge(to);
819     } else if (to.GetType() == STATIC_FIELD || to.GetType() == POOL_CONSTANT) {
820         ASSERT(from.IsObject());
821         Escape(from.GetBase());
822     } else {
823         parent_->AddTempEdge(from, to);
824     }
825 }
826 
AddPseudoCopyEdge(const Pointer & base,const Pointer & field)827 void LoopPropagationVisitor::AddPseudoCopyEdge([[maybe_unused]] const Pointer &base,
828                                                [[maybe_unused]] const Pointer &field)
829 {
830     // no action
831 }
832 
Escape(const Inst * inst)833 void LoopPropagationVisitor::Escape(const Inst *inst)
834 {
835     escapedInsts_.insert(inst);
836 }
837 
SetTypeInfo(Inst * inst,ObjectTypeInfo info)838 void LoopPropagationVisitor::SetTypeInfo(Inst *inst, ObjectTypeInfo info)
839 {
840     parent_->SetTypeInfo(inst, info);
841 }
842 
VisitLoop(Loop * loop)843 void LoopPropagationVisitor::VisitLoop(Loop *loop)
844 {
845     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "Visit loop " << loop->GetId();
846     ASSERT(loop != nullptr && !loop->IsIrreducible());
847     heapInv_ = false;
848     headerState_ = parent_->GetState(loop->GetHeader());
849     escapedInsts_.clear();
850     VisitLoopRec(loop);
851     parent_->WalkEdges();
852     for (auto *inst : escapedInsts_) {
853         headerState_->Escape(inst);
854     }
855     if (heapInv_) {
856         headerState_->InvalidateEscaped();
857     }
858 }
859 
VisitBlock(BasicBlock * bb)860 void LoopPropagationVisitor::VisitBlock(BasicBlock *bb)
861 {
862     for (auto inst : bb->AllInsts()) {
863         AliasVisitor::VisitInstruction(inst);
864         TypePropagationVisitor::VisitInstruction(inst);
865     }
866 }
867 
VisitLoopRec(Loop * loop)868 void LoopPropagationVisitor::VisitLoopRec(Loop *loop)
869 {
870     ASSERT(loop != nullptr && !loop->IsRoot());
871     for (auto *block : loop->GetBlocks()) {
872         VisitBlock(block);
873     }
874     for (auto *inner : loop->GetInnerLoops()) {
875         VisitLoopRec(inner);
876     }
877 }
878 
BasicBlockState(ObjectTypePropagationVisitor * visitor,BasicBlock * bb)879 BasicBlockState::BasicBlockState(ObjectTypePropagationVisitor *visitor, [[maybe_unused]] BasicBlock *bb)
880     : visitor_(visitor),
881       fieldRefs_(visitor->GetGraph()->GetLocalAllocator()->Adapter()),
882       escaped_(visitor->GetGraph()->GetLocalAllocator(), ObjectTypeInfo::UNKNOWN)
883 {
884     escaped_.SetBit(NULL_REF, ObjectTypeInfo::UNKNOWN);
885     // two refs representing unknown objects
886     // two of them to express that 2 unknown objects are not equal (intersection size > 1)
887     escaped_.SetBit(1U, ObjectTypeInfo::INVALID);
888     escaped_.SetBit(2U, ObjectTypeInfo::INVALID);
889 #ifndef NDEBUG
890     bb_ = bb;
891 #endif
892 }
893 
BasicBlockState(const BasicBlockState & other,BasicBlock * bb)894 BasicBlockState::BasicBlockState(const BasicBlockState &other, [[maybe_unused]] BasicBlock *bb) : BasicBlockState(other)
895 {
896 #ifndef NDEBUG
897     this->bb_ = bb;
898 #endif
899 }
900 
GetFieldRefSet(Ref base,const PointerOffset & offset)901 const ArenaTypedRefSet &BasicBlockState::GetFieldRefSet(Ref base, const PointerOffset &offset)
902 {
903     if (base == ALL_REF) {
904         return visitor_->GetAllSet();
905     }
906     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "    GetFieldRefSet " << base << ' ' << offset;
907     auto it = fieldRefs_.find(base);
908     if (it == fieldRefs_.end()) {
909         return escaped_;
910     }
911     auto &fieldRefsOfBase = it->second;
912     if (fieldRefsOfBase->find(offset) == fieldRefsOfBase->end()) {
913         return escaped_;
914     }
915     // lazy propagation of refs on unknown offset
916     if (auto overwriting = fieldRefsOfBase->find(PointerOffset::CreateUnknownOffset());
917         overwriting != fieldRefsOfBase->end()) {
918         fieldRefsOfBase.Mut(GetLocalAllocator())->at(offset) |= overwriting->second;
919     }
920     return fieldRefsOfBase->at(offset);
921 }
922 
CreateFieldRefSetForNewObject(Ref base,bool defaultConstructed)923 void BasicBlockState::CreateFieldRefSetForNewObject(Ref base, bool defaultConstructed)
924 {
925     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "    CreateFieldRefSetForNewObject " << base;
926     auto [it, inserted] = fieldRefs_.try_emplace(base, GetLocalAllocator());
927     ASSERT(inserted);
928     auto &objectInfos = it->second;
929     const auto &refs = defaultConstructed ? visitor_->GetNullSet() : escaped_;
930     auto mut = objectInfos.Mut(GetLocalAllocator());
931     ASSERT(mut != nullptr);
932     // create "other refs" entry
933     if (!mut->try_emplace(PointerOffset::CreateDefaultField(), refs).second) {
934         UNREACHABLE();
935     }
936 }
937 
CreateFieldRefSet(Ref base,const PointerOffset & offset,bool & changed)938 ArenaTypedRefSet &BasicBlockState::CreateFieldRefSet(Ref base, const PointerOffset &offset, bool &changed)
939 {
940     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "    CreateFieldRefSet " << base << ' ' << offset;
941     auto [it1, ins1] = fieldRefs_.try_emplace(base, GetLocalAllocator());
942     changed |= ins1;
943     auto &objectInfos = it1->second;
944     if (ins1) {
945         auto mut = objectInfos.Mut(GetLocalAllocator());
946         ASSERT(mut != nullptr);
947         // create "other refs" entry
948         if (!mut->try_emplace(PointerOffset::CreateDefaultField(), escaped_).second) {
949             UNREACHABLE();
950         }
951     }
952     auto [it2, ins2] =
953         objectInfos.Mut(GetLocalAllocator())->try_emplace(offset, GetLocalAllocator(), ObjectTypeInfo::UNKNOWN);
954     changed |= ins2;
955     return it2->second;
956 }
957 
NewUnknownRef()958 const ArenaTypedRefSet &BasicBlockState::NewUnknownRef()
959 {
960     return escaped_;
961 }
962 
963 template <typename T>
964 class IsMap : public std::false_type {
965 };
966 
967 template <typename K, typename T, typename C, typename A>
968 class IsMap<std::map<K, T, C, A>> : public std::true_type {
969 };
970 
971 template <typename T>
972 constexpr bool IS_ORDERED_MAP_V = IsMap<std::decay_t<T>>::value;
973 
974 // Harsh - set intersection for keys + set union for keys in intersection
975 // Not harsh (more precise) - set union for keys + set union for keys in intersection
976 template <bool HARSH = false>
MergeFieldSets(PointerOffset::Map<ArenaTypedRefSet> & to,const PointerOffset::Map<ArenaTypedRefSet> & from)977 void MergeFieldSets(PointerOffset::Map<ArenaTypedRefSet> &to, const PointerOffset::Map<ArenaTypedRefSet> &from)
978 {
979     static_assert(!IS_ORDERED_MAP_V<decltype(to)>);
980     ASSERT(to.count(PointerOffset::CreateDefaultField()));
981     ASSERT(from.count(PointerOffset::CreateDefaultField()));
982     auto &defaultRefsTo = to.at(PointerOffset::CreateDefaultField());
983     auto defaultRefsFrom = from.at(PointerOffset::CreateDefaultField());
984     for (auto it = to.begin(); it != to.end();) {
985         auto fromIt = from.find(it->first);
986         if (fromIt != from.end()) {
987             it->second |= fromIt->second;
988             it++;
989         } else if constexpr (!HARSH) {
990             it->second |= defaultRefsFrom;
991             it++;
992         } else {
993             defaultRefsTo |= it->second;
994             it = to.erase(it);
995         }
996     }
997     for (const auto &[offset, refs] : from) {
998         if (to.find(offset) == to.end()) {
999             if constexpr (HARSH) {
1000                 defaultRefsTo |= refs;
1001             } else {
1002                 auto it = to.emplace(offset, refs).first;
1003                 // defaultRefsTo may be invalidated, recompute
1004                 it->second |= to.at(PointerOffset::CreateDefaultField());
1005             }
1006         }
1007     }
1008 }
1009 
Merge(BasicBlockState * other)1010 void BasicBlockState::Merge(BasicBlockState *other)
1011 {
1012     // in-place set intersection
1013     auto otherIt = other->fieldRefs_.begin();
1014     static_assert(IS_ORDERED_MAP_V<decltype(fieldRefs_)>);
1015     for (auto it = fieldRefs_.begin(); it != fieldRefs_.end();) {
1016         otherIt = std::find_if(otherIt, other->fieldRefs_.end(),
1017                                [&it](auto otherElem) { return otherElem.first >= it->first; });
1018         auto &[ref, refFields] = *it;
1019         if (otherIt != other->fieldRefs_.end() && otherIt->first == ref) {
1020             auto mut = refFields.Mut(GetLocalAllocator());
1021             ASSERT(mut != nullptr);
1022             MergeFieldSets(*mut, *otherIt->second);
1023             if (IsEscaped(ref) != other->IsEscaped(ref)) {
1024                 ForceEscape(ref);
1025             }
1026             it++;
1027         } else {
1028             ASSERT(otherIt == other->fieldRefs_.end() || otherIt->first > it->first);
1029             it = fieldRefs_.erase(it);
1030         }
1031     }
1032     escaped_ |= other->escaped_;
1033 }
1034 
Escape(ArenaVector<Ref> & worklist,const ArenaTypedRefSet & refs)1035 void BasicBlockState::Escape(ArenaVector<Ref> &worklist, const ArenaTypedRefSet &refs)
1036 {
1037     refs.Visit([this, &worklist](Ref fieldRef) {
1038         if (!IsEscaped(fieldRef)) {
1039             escaped_.SetBit(fieldRef, ObjectTypeInfo::INVALID);
1040             worklist.push_back(fieldRef);
1041         }
1042     });
1043 }
1044 
Escape(ArenaVector<Ref> & worklist)1045 void BasicBlockState::Escape(ArenaVector<Ref> &worklist)
1046 {
1047     while (!worklist.empty()) {
1048         auto ref = worklist.back();
1049         worklist.pop_back();
1050         auto it = fieldRefs_.find(ref);
1051         if (it == fieldRefs_.end()) {
1052             continue;
1053         }
1054         for (const auto &fieldRefs : *it->second) {
1055             Escape(worklist, fieldRefs.second);
1056         }
1057     }
1058 }
1059 
Escape(const Inst * inst)1060 void BasicBlockState::Escape(const Inst *inst)
1061 {
1062     if (!inst->IsReferenceOrAny()) {
1063         return;
1064     }
1065     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "ESCAPE inst: " << *inst;
1066 
1067     ArenaVector<Ref> worklist(visitor_->GetGraph()->GetLocalAllocator()->Adapter());
1068     visitor_->GetInstRefSet(inst).Visit([this, &worklist](Ref ref) {
1069         if (!IsEscaped(ref)) {
1070             COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "ESCAPE ref: " << ref;
1071             escaped_.SetBit(ref, ObjectTypeInfo::INVALID);
1072             worklist.push_back(ref);
1073         }
1074     });
1075 
1076     Escape(worklist);
1077 }
1078 
TryEscape(Ref ref)1079 void BasicBlockState::TryEscape(Ref ref)
1080 {
1081     if (IsEscaped(ref)) {
1082         return;
1083     }
1084 
1085     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "ESCAPE ref: " << ref;
1086     escaped_.SetBit(ref, ObjectTypeInfo::INVALID);
1087     ArenaVector<Ref> worklist({ref}, visitor_->GetGraph()->GetLocalAllocator()->Adapter());
1088     Escape(worklist);
1089 }
1090 
1091 /* Updates `escaped_` RefSet if:
1092  * - ref has escaped already, but new refs to non-escaped refs were added;
1093  * - or ref has not escaped yet.
1094  */
ForceEscape(Ref ref)1095 void BasicBlockState::ForceEscape(Ref ref)
1096 {
1097     escaped_.SetBit(ref, ObjectTypeInfo::INVALID);
1098     ArenaVector<Ref> worklist({ref}, visitor_->GetGraph()->GetLocalAllocator()->Adapter());
1099     Escape(worklist);
1100 }
1101 
InvalidateEscapedRef(Ref ref)1102 void BasicBlockState::InvalidateEscapedRef(Ref ref)
1103 {
1104     if (ref == ALL_REF) {
1105         if (fieldRefs_.size() == 1 && fieldRefs_.begin()->first == ALL_REF) {
1106             // already invalidated
1107             return;
1108         }
1109         fieldRefs_.clear();
1110         fieldRefs_.try_emplace(ALL_REF, GetLocalAllocator());
1111         return;
1112     }
1113     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "INVALIDATE ref " << ref;
1114     auto it = fieldRefs_.find(ref);
1115     if (it == fieldRefs_.end()) {
1116         return;
1117     }
1118     auto *refFields = it->second.Mut(GetLocalAllocator());
1119     ASSERT(refFields != nullptr);
1120     auto oldDefault = refFields->find(PointerOffset::CreateDefaultField());
1121     ASSERT(oldDefault != refFields->end());
1122     // erase all except `default value`
1123     refFields->erase(std::next(oldDefault), refFields->end());
1124     refFields->erase(refFields->begin(), oldDefault);
1125     ASSERT(refFields->size() == 1);
1126     ASSERT(refFields->begin()->first == PointerOffset::CreateDefaultField());
1127     // invoke copy-assignment - no additional space is consumed on repeated invalidations
1128     refFields->begin()->second = escaped_;
1129 }
1130 
InvalidateEscaped()1131 void BasicBlockState::InvalidateEscaped()
1132 {
1133     COMPILER_LOG(DEBUG, TYPE_PROPAGATION) << "INVALIDATE BB " << GetBlockId();
1134     escaped_.Visit([this](Ref ref) {
1135         if (ref >= REAL_REF_START) {
1136             InvalidateEscapedRef(ref);
1137         }
1138     });
1139 }
1140 
IsEscaped(Ref ref) const1141 bool BasicBlockState::IsEscaped(Ref ref) const
1142 {
1143     if (ref == ALL_REF) {
1144         return true;
1145     }
1146     return escaped_.GetBit(ref);
1147 }
1148 
1149 template <typename F>
Cleanup(const F & remove)1150 void BasicBlockState::Cleanup(const F &remove)
1151 {
1152     for (auto it = fieldRefs_.begin(); it != fieldRefs_.end();) {
1153         if (IsEscaped(it->first) && remove(it->first)) {
1154             it = fieldRefs_.erase(it);
1155         } else {
1156             it++;
1157         }
1158     }
1159 }
1160 
GetBlockId() const1161 size_t BasicBlockState::GetBlockId() const
1162 {
1163 #ifndef NDEBUG
1164     return bb_->GetId();
1165 #else
1166     return 0;
1167 #endif  // NDEBUG
1168 }
1169 
GetLocalAllocator()1170 ArenaAllocator *BasicBlockState::GetLocalAllocator()
1171 {
1172     return visitor_->GetGraph()->GetLocalAllocator();
1173 }
1174 
operator <<(std::ostream & os,const BasicBlockState & state)1175 [[maybe_unused]] std::ostream &operator<<(std::ostream &os, const BasicBlockState &state)
1176 {
1177     auto *visitor = state.visitor_;
1178     os << "BB " << state.GetBlockId() << " state:\n";
1179     for (auto [ref, refInfo] : state.fieldRefs_) {
1180         os << "  ref " << ref << " " << visitor->ToString(visitor->GetRefTypeInfo(ref)) << ":\n";
1181         for (auto [offset, refs] : *refInfo) {
1182             os << "    " << offset << " -> " << refs << " " << visitor->ToString(refs.GetTypeInfo()) << "\n";
1183         }
1184     }
1185     os << "Escaped: " << state.escaped_ << "\n";
1186     os << "Instructions RefSets (global):\n";
1187     for (auto &[inst, refs] : visitor->GetInstRefSets()) {
1188         // short inst dump without inputs/outputs
1189         os << "  " << inst->GetId() << "." << inst->GetType() << " ";
1190         inst->DumpOpcode(&os);
1191         os << "-> " << refs << " " << visitor->ToString(refs.GetTypeInfo()) << "\n";
1192     }
1193     visitor->DumpRefInfos(os);
1194     return os;
1195 }
1196 
VisitNewObject(GraphVisitor * v,Inst * i)1197 void TypePropagationVisitor::VisitNewObject(GraphVisitor *v, Inst *i)
1198 {
1199     auto *self = static_cast<TypePropagationVisitor *>(v);
1200     auto inst = i->CastToNewObject();
1201     auto klass = self->GetGraph()->GetRuntime()->GetClass(inst->GetMethod(), inst->GetTypeId());
1202     if (klass != nullptr) {
1203         self->SetTypeInfo(inst, {klass, true});
1204     }
1205 }
1206 
VisitNewArray(GraphVisitor * v,Inst * i)1207 void TypePropagationVisitor::VisitNewArray(GraphVisitor *v, Inst *i)
1208 {
1209     auto *self = static_cast<TypePropagationVisitor *>(v);
1210     auto inst = i->CastToNewArray();
1211     auto klass = self->GetGraph()->GetRuntime()->GetClass(inst->GetMethod(), inst->GetTypeId());
1212     if (klass != nullptr) {
1213         self->SetTypeInfo(inst, {klass, true});
1214     }
1215 }
1216 
VisitLoadString(GraphVisitor * v,Inst * i)1217 void TypePropagationVisitor::VisitLoadString(GraphVisitor *v, Inst *i)
1218 {
1219     auto *self = static_cast<TypePropagationVisitor *>(v);
1220     auto inst = i->CastToLoadString();
1221     auto klass = self->GetGraph()->GetRuntime()->GetStringClass(inst->GetMethod(), nullptr);
1222     if (klass != nullptr) {
1223         self->SetTypeInfo(inst, {klass, true});
1224     }
1225 }
1226 
VisitLoadArray(GraphVisitor * v,Inst * i)1227 void TypePropagationVisitor::VisitLoadArray([[maybe_unused]] GraphVisitor *v, [[maybe_unused]] Inst *i)
1228 {
1229     if (i->GetType() != DataType::REFERENCE) {
1230         return;
1231     }
1232     VisitRefTypeCheck(v, i);
1233 }
1234 
VisitLoadObject(GraphVisitor * v,Inst * i)1235 void TypePropagationVisitor::VisitLoadObject(GraphVisitor *v, Inst *i)
1236 {
1237     if (i->GetType() != DataType::REFERENCE || i->CastToLoadObject()->GetObjectType() != ObjectType::MEM_OBJECT) {
1238         return;
1239     }
1240     auto *self = static_cast<TypePropagationVisitor *>(v);
1241     auto inst = i->CastToLoadObject();
1242     auto fieldId = inst->GetTypeId();
1243     if (fieldId == 0) {
1244         return;
1245     }
1246     auto runtime = self->GetGraph()->GetRuntime();
1247     auto method = inst->GetMethod();
1248     auto typeId = runtime->GetFieldValueTypeId(method, fieldId);
1249     auto klass = runtime->GetClass(method, typeId);
1250     if (klass != nullptr) {
1251         auto isExact = runtime->GetClassType(method, typeId) == ClassType::FINAL_CLASS;
1252         self->SetTypeInfo(inst, {klass, isExact});
1253     }
1254 }
1255 
VisitLoadStatic(GraphVisitor * v,Inst * i)1256 void TypePropagationVisitor::VisitLoadStatic(GraphVisitor *v, Inst *i)
1257 {
1258     if (i->GetType() != DataType::REFERENCE) {
1259         return;
1260     }
1261     auto *self = static_cast<TypePropagationVisitor *>(v);
1262     auto inst = i->CastToLoadStatic();
1263     auto fieldId = inst->GetTypeId();
1264     if (fieldId == 0) {
1265         return;
1266     }
1267     auto runtime = self->GetGraph()->GetRuntime();
1268     auto method = inst->GetMethod();
1269     auto typeId = runtime->GetFieldValueTypeId(method, fieldId);
1270     auto klass = runtime->GetClass(method, typeId);
1271     if (klass != nullptr) {
1272         auto isExact = runtime->GetClassType(method, typeId) == ClassType::FINAL_CLASS;
1273         self->SetTypeInfo(inst, {klass, isExact});
1274     }
1275 }
1276 
VisitCallStatic(GraphVisitor * v,Inst * i)1277 void TypePropagationVisitor::VisitCallStatic(GraphVisitor *v, Inst *i)
1278 {
1279     ProcessManagedCall(v, i->CastToCallStatic());
1280 }
1281 
VisitCallVirtual(GraphVisitor * v,Inst * i)1282 void TypePropagationVisitor::VisitCallVirtual(GraphVisitor *v, Inst *i)
1283 {
1284     ProcessManagedCall(v, i->CastToCallVirtual());
1285 }
1286 
VisitRefTypeCheck(GraphVisitor * v,Inst * i)1287 void TypePropagationVisitor::VisitRefTypeCheck(GraphVisitor *v, Inst *i)
1288 {
1289     auto arrayTypeInfo = i->GetDataFlowInput(0)->GetObjectTypeInfo();
1290     if (!arrayTypeInfo) {
1291         return;
1292     }
1293     auto *self = static_cast<TypePropagationVisitor *>(v);
1294     auto runtime = self->GetGraph()->GetRuntime();
1295     if (!runtime->IsArrayClass(arrayTypeInfo.GetClass())) {
1296         return;
1297     }
1298     if (runtime->GetArrayComponentType(arrayTypeInfo.GetClass()) != DataType::REFERENCE) {
1299         return;
1300     }
1301     auto storedClass = runtime->GetArrayElementClass(arrayTypeInfo.GetClass());
1302     if (storedClass == nullptr) {
1303         return;
1304     }
1305     auto isExact = runtime->GetClassType(storedClass) == ClassType::FINAL_CLASS;
1306     self->SetTypeInfo(i, {storedClass, isExact});
1307 }
1308 
VisitParameter(GraphVisitor * v,Inst * i)1309 void TypePropagationVisitor::VisitParameter(GraphVisitor *v, Inst *i)
1310 {
1311     auto inst = i->CastToParameter();
1312     auto graph = i->GetBasicBlock()->GetGraph();
1313     if (inst->GetType() != DataType::REFERENCE || graph->IsBytecodeOptimizer() || inst->HasObjectTypeInfo()) {
1314         return;
1315     }
1316     auto refNum = inst->GetArgRefNumber();
1317     auto runtime = graph->GetRuntime();
1318     auto method = graph->GetMethod();
1319     RuntimeInterface::ClassPtr klass;
1320     if (refNum == ParameterInst::INVALID_ARG_REF_NUM) {
1321         // This parametr doesn't have ArgRefNumber
1322         if (inst->GetArgNumber() != 0 || runtime->IsMethodStatic(method)) {
1323             return;
1324         }
1325         klass = runtime->GetClass(method);
1326     } else {
1327         auto typeId = runtime->GetMethodArgReferenceTypeId(method, refNum);
1328         klass = runtime->GetClass(method, typeId);
1329     }
1330     if (klass != nullptr) {
1331         auto isExact = runtime->GetClassType(klass) == ClassType::FINAL_CLASS;
1332         auto *self = static_cast<TypePropagationVisitor *>(v);
1333         self->SetTypeInfo(inst, {klass, isExact});
1334     }
1335 }
1336 
ProcessManagedCall(GraphVisitor * v,CallInst * inst)1337 void TypePropagationVisitor::ProcessManagedCall(GraphVisitor *v, CallInst *inst)
1338 {
1339     if (inst->GetType() != DataType::REFERENCE) {
1340         return;
1341     }
1342     if (inst->IsInlined()) {
1343         return;
1344     }
1345     auto *self = static_cast<TypePropagationVisitor *>(v);
1346     auto runtime = self->GetGraph()->GetRuntime();
1347     auto method = inst->GetCallMethod();
1348     auto typeId = runtime->GetMethodReturnTypeId(method);
1349     auto klass = runtime->GetClass(method, typeId);
1350     if (klass != nullptr) {
1351         auto isExact = runtime->GetClassType(method, typeId) == ClassType::FINAL_CLASS;
1352         self->SetTypeInfo(inst, {klass, isExact});
1353     }
1354 }
1355 
1356 }  // namespace
1357 
ObjectTypePropagation(Graph * graph)1358 ObjectTypePropagation::ObjectTypePropagation(Graph *graph) : Analysis(graph) {}
1359 
RunImpl()1360 bool ObjectTypePropagation::RunImpl()
1361 {
1362     ObjectTypePropagationVisitor visitor(GetGraph());
1363     visitor.RunImpl();
1364     return true;
1365 }
1366 
1367 }  // namespace ark::compiler
1368