• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "compiler_logger.h"
17 #include "optimizer/ir/graph_visitor.h"
18 #include "optimizer/ir/basicblock.h"
19 #include "optimizer/ir/inst.h"
20 #include "optimizer/ir/analysis.h"
21 #include "optimizer/analysis/alias_analysis.h"
22 #include "optimizer/analysis/rpo.h"
23 #include "optimizer/analysis/loop_analyzer.h"
24 #include "optimizer/ir/runtime_interface.h"
25 #include "optimizer/optimizations/memory_coalescing.h"
26 
27 namespace ark::compiler {
28 /**
29  * Basic analysis for variables used in loops. It works as follows:
30  * 1) Identify variables that are derived from another variables and their difference (AddI, SubI supported).
31  * 2) Based on previous step reveal loop variables and their iteration increment if possible.
32  */
33 class VariableAnalysis {
34 public:
35     struct BaseVariable {
36         int64_t initial;
37         int64_t step;
38     };
39     struct DerivedVariable {
40         Inst *base;
41         int64_t diff;
42     };
43 
VariableAnalysis(Graph * graph)44     explicit VariableAnalysis(Graph *graph)
45         : base_(graph->GetLocalAllocator()->Adapter()), derived_(graph->GetLocalAllocator()->Adapter())
46     {
47         for (auto block : graph->GetBlocksRPO()) {
48             for (auto inst : block->AllInsts()) {
49                 if (GetCommonType(inst->GetType()) == DataType::INT64) {
50                     AddUsers(inst);
51                 }
52             }
53         }
54         for (auto loop : graph->GetRootLoop()->GetInnerLoops()) {
55             if (loop->IsIrreducible()) {
56                 continue;
57             }
58             auto header = loop->GetHeader();
59             for (auto phi : header->PhiInsts()) {
60                 constexpr auto INPUTS_COUNT = 2;
61                 if (phi->GetInputsCount() != INPUTS_COUNT || GetCommonType(phi->GetType()) != DataType::INT64) {
62                     continue;
63                 }
64                 auto var = phi->CastToPhi();
65                 Inst *initial = var->GetPhiInput(var->GetPhiInputBb(0));
66                 Inst *update = var->GetPhiInput(var->GetPhiInputBb(1));
67                 if (var->GetPhiInputBb(0) != loop->GetPreHeader()) {
68                     std::swap(initial, update);
69                 }
70 
71                 if (!initial->IsConst()) {
72                     continue;
73                 }
74 
75                 if (derived_.find(update) != derived_.end()) {
76                     auto initVal = static_cast<int64_t>(initial->CastToConstant()->GetIntValue());
77                     base_[var] = {initVal, derived_[update].diff};
78                 }
79             }
80         }
81 
82         COMPILER_LOG(DEBUG, MEMORY_COALESCING) << "Evolution variables:";
83         for (auto entry : base_) {
84             COMPILER_LOG(DEBUG, MEMORY_COALESCING)
85                 << "v" << entry.first->GetId() << " = {" << entry.second.initial << ", " << entry.second.step << "}";
86         }
87         COMPILER_LOG(DEBUG, MEMORY_COALESCING) << "Loop variables:";
88         for (auto entry : derived_) {
89             COMPILER_LOG(DEBUG, MEMORY_COALESCING)
90                 << "v" << entry.first->GetId() << " = v" << entry.second.base->GetId() << " + " << entry.second.diff;
91         }
92     }
93 
94     DEFAULT_MOVE_SEMANTIC(VariableAnalysis);
95     DEFAULT_COPY_SEMANTIC(VariableAnalysis);
96     ~VariableAnalysis() = default;
97 
IsAnalyzed(Inst * inst) const98     bool IsAnalyzed(Inst *inst) const
99     {
100         return derived_.find(inst) != derived_.end();
101     }
102 
GetBase(Inst * inst) const103     Inst *GetBase(Inst *inst) const
104     {
105         return derived_.at(inst).base;
106     }
107 
GetInitial(Inst * inst) const108     int64_t GetInitial(Inst *inst) const
109     {
110         auto var = derived_.at(inst);
111         return base_.at(var.base).initial + var.diff;
112     }
113 
GetDiff(Inst * inst) const114     int64_t GetDiff(Inst *inst) const
115     {
116         return derived_.at(inst).diff;
117     }
118 
GetStep(Inst * inst) const119     int64_t GetStep(Inst *inst) const
120     {
121         return base_.at(derived_.at(inst).base).step;
122     }
123 
IsEvoluted(Inst * inst) const124     bool IsEvoluted(Inst *inst) const
125     {
126         return derived_.at(inst).base->IsPhi();
127     }
128 
HasKnownEvolution(Inst * inst) const129     bool HasKnownEvolution(Inst *inst) const
130     {
131         Inst *base = derived_.at(inst).base;
132         return base->IsPhi() && base_.find(base) != base_.end();
133     }
134 
135 private:
136     /// Add derived variables if we can deduce the change from INST
AddUsers(Inst * inst)137     void AddUsers(Inst *inst)
138     {
139         auto acc = 0;
140         auto base = inst;
141         if (derived_.find(inst) != derived_.end()) {
142             acc += derived_[inst].diff;
143             base = derived_[inst].base;
144         } else {
145             derived_[inst] = {inst, 0};
146         }
147         for (auto &user : inst->GetUsers()) {
148             auto uinst = user.GetInst();
149             ASSERT(uinst->IsPhi() || derived_.find(uinst) == derived_.end());
150             switch (uinst->GetOpcode()) {
151                 case Opcode::AddI: {
152                     auto val = static_cast<int64_t>(uinst->CastToAddI()->GetImm());
153                     derived_[uinst] = {base, acc + val};
154                     break;
155                 }
156                 case Opcode::SubI: {
157                     auto val = static_cast<int64_t>(uinst->CastToSubI()->GetImm());
158                     derived_[uinst] = {base, acc - val};
159                     break;
160                 }
161                 default:
162                     break;
163             }
164         }
165     }
166 
167 private:
168     ArenaUnorderedMap<Inst *, struct BaseVariable> base_;
169     ArenaUnorderedMap<Inst *, struct DerivedVariable> derived_;
170 };
171 
172 /**
173  * The visitor collects pairs of memory instructions that can be coalesced.
174  * It operates in scope of basic block. During observation of instructions we
175  * collect memory instructions in one common queue of candidates that can be merged.
176  *
177  * Candidate is marked as invalid in the following conditions:
178  * - it has been paired already
179  * - it is a store and SaveState has been met
180  * - a BARRIER or CAN_TROW instruction has been met
181  *
182  * To pair valid array accesses:
183  * - check that accesses happen on the consecutive indices of the same array
184  * - find the lowest position the dominator access can be sunk
185  * - find the highest position the dominatee access can be hoisted
186  * - if highest position dominates lowest position the coalescing is possible
187  */
188 class PairCreatorVisitor : public GraphVisitor {
189 public:
PairCreatorVisitor(Graph * graph,const AliasAnalysis & aliases,const VariableAnalysis & vars,Marker mrk,bool aligned)190     explicit PairCreatorVisitor(Graph *graph, const AliasAnalysis &aliases, const VariableAnalysis &vars, Marker mrk,
191                                 bool aligned)
192         : alignedOnly_(aligned),
193           mrkInvalid_(mrk),
194           graph_(graph),
195           aliases_(aliases),
196           vars_(vars),
197           pairs_(graph->GetLocalAllocator()->Adapter()),
198           candidates_(graph->GetLocalAllocator()->Adapter())
199     {
200     }
201 
GetBlocksToVisit() const202     const ArenaVector<BasicBlock *> &GetBlocksToVisit() const override
203     {
204         return graph_->GetBlocksRPO();
205     }
206 
207     NO_MOVE_SEMANTIC(PairCreatorVisitor);
208     NO_COPY_SEMANTIC(PairCreatorVisitor);
209     ~PairCreatorVisitor() override = default;
210 
VisitLoadArray(GraphVisitor * v,Inst * inst)211     static void VisitLoadArray(GraphVisitor *v, Inst *inst)
212     {
213         static_cast<PairCreatorVisitor *>(v)->HandleArrayAccess(inst);
214     }
215 
VisitStoreArray(GraphVisitor * v,Inst * inst)216     static void VisitStoreArray(GraphVisitor *v, Inst *inst)
217     {
218         static_cast<PairCreatorVisitor *>(v)->HandleArrayAccess(inst);
219     }
220 
VisitLoadArrayI(GraphVisitor * v,Inst * inst)221     static void VisitLoadArrayI(GraphVisitor *v, Inst *inst)
222     {
223         static_cast<PairCreatorVisitor *>(v)->HandleArrayAccessI(inst);
224     }
225 
VisitStoreArrayI(GraphVisitor * v,Inst * inst)226     static void VisitStoreArrayI(GraphVisitor *v, Inst *inst)
227     {
228         static_cast<PairCreatorVisitor *>(v)->HandleArrayAccessI(inst);
229     }
230 
VisitLoadObject(GraphVisitor * v,Inst * inst)231     static void VisitLoadObject(GraphVisitor *v, Inst *inst)
232     {
233         static_cast<PairCreatorVisitor *>(v)->HandleObjectAccess(inst->CastToLoadObject());
234     }
235 
VisitStoreObject(GraphVisitor * v,Inst * inst)236     static void VisitStoreObject(GraphVisitor *v, Inst *inst)
237     {
238         static_cast<PairCreatorVisitor *>(v)->HandleObjectAccess(inst->CastToStoreObject());
239     }
240 
IsNotAcceptableForStore(Inst * inst)241     static bool IsNotAcceptableForStore(Inst *inst)
242     {
243         if (inst->GetOpcode() == Opcode::SaveState) {
244             for (auto &user : inst->GetUsers()) {
245                 auto *ui = user.GetInst();
246                 if (ui->CanThrow() || ui->CanDeoptimize()) {
247                     return true;
248                 }
249             }
250         }
251         return inst->GetOpcode() == Opcode::SaveStateDeoptimize;
252     }
253 
VisitDefault(Inst * inst)254     void VisitDefault(Inst *inst) override
255     {
256         if (inst->IsMemory()) {
257             candidates_.push_back(inst);
258             return;
259         }
260         if (inst->IsSaveState()) {
261             // 1. Load & Store can be moved through SafePoint
262             if (inst->GetOpcode() == Opcode::SafePoint) {
263                 return;
264             }
265             // 2. Load & Store can't be moved through SaveStateOsr
266             if (inst->GetOpcode() == Opcode::SaveStateOsr) {
267                 candidates_.clear();
268                 return;
269             }
270             // 3. Load can be moved through SaveState and SaveStateDeoptimize
271             // 4. Store can't be moved through SaveStateDeoptimize and SaveState with Users that are IsCheck or
272             //    CanDeoptimize. It is checked in IsNotAcceptableForStore
273             if (IsNotAcceptableForStore(inst)) {
274                 InvalidateStores();
275                 return;
276             }
277         }
278         if (inst->IsBarrier()) {
279             candidates_.clear();
280             return;
281         }
282         if (inst->CanThrow()) {
283             InvalidateStores();
284             return;
285         }
286     }
287 
Reset()288     void Reset()
289     {
290         candidates_.clear();
291     }
292 
GetPairs()293     ArenaUnorderedMap<Inst *, MemoryCoalescing::CoalescedPair> &GetPairs()
294     {
295         return pairs_;
296     }
297 
298 #include "optimizer/ir/visitor.inc"
299 private:
InvalidateStores()300     void InvalidateStores()
301     {
302         for (auto cand : candidates_) {
303             if (cand->IsStore()) {
304                 cand->SetMarker(mrkInvalid_);
305             }
306         }
307     }
308 
IsPairInst(Inst * inst)309     static bool IsPairInst(Inst *inst)
310     {
311         switch (inst->GetOpcode()) {
312             case Opcode::LoadArrayPair:
313             case Opcode::LoadArrayPairI:
314             case Opcode::StoreArrayPair:
315             case Opcode::StoreArrayPairI:
316             case Opcode::LoadObjectPair:
317             case Opcode::StoreObjectPair:
318                 return true;
319             default:
320                 return false;
321         }
322     }
323 
324     /**
325      * Return the highest instructions that INST can be inserted after (in scope of basic block).
326      * Consider aliased memory accesses and volatile operations. CHECK_CFG enables the check of INST inputs
327      * as well.
328      */
FindUpperInsertAfter(Inst * inst,Inst * bound,bool checkCfg)329     Inst *FindUpperInsertAfter(Inst *inst, Inst *bound, bool checkCfg)
330     {
331         ASSERT(bound != nullptr);
332         auto upperAfter = bound;
333         // We do not move higher than bound
334         auto lowerInput = upperAfter;
335         if (checkCfg) {
336             // Update upper bound according to def-use chains
337             for (auto &inputItem : inst->GetInputs()) {
338                 auto input = inputItem.GetInst();
339                 if (input->GetBasicBlock() == inst->GetBasicBlock() && lowerInput->IsPrecedingInSameBlock(input)) {
340                     ASSERT(input->IsPrecedingInSameBlock(inst));
341                     lowerInput = input;
342                 }
343             }
344             upperAfter = lowerInput;
345         }
346 
347         auto boundIt = std::find(candidates_.rbegin(), candidates_.rend(), bound);
348         ASSERT(boundIt != candidates_.rend());
349         for (auto it = candidates_.rbegin(); it != boundIt; it++) {
350             auto cand = *it;
351             if (checkCfg && cand->IsPrecedingInSameBlock(lowerInput)) {
352                 return lowerInput;
353             }
354             // Can't hoist load over aliased store and store over aliased memory instructions
355             if (inst->IsStore() || cand->IsStore()) {
356                 auto checkInst = cand;
357                 if (IsPairInst(cand)) {
358                     // We have already checked the second inst. We now want to check the first one
359                     // for alias.
360                     auto pair = pairs_[cand];
361                     checkInst = pair.first->IsPrecedingInSameBlock(pair.second) ? pair.first : pair.second;
362                 }
363                 if (aliases_.CheckInstAlias(inst, checkInst) != NO_ALIAS) {
364                     return cand;
365                 }
366             }
367             // Can't hoist over volatile load
368             if (cand->IsLoad() && IsVolatileMemInst(cand)) {
369                 return cand;
370             }
371         }
372         return upperAfter;
373     }
374 
375     /**
376      * Return the lowest instructions that INST can be inserted after (in scope of basic block).
377      * Consider aliased memory accesses and volatile operations. CHECK_CFG enables the check of INST users
378      * as well.
379      */
FindLowerInsertAfter(Inst * inst,Inst * bound,bool checkCfg=true)380     Inst *FindLowerInsertAfter(Inst *inst, Inst *bound, bool checkCfg = true)
381     {
382         ASSERT(bound != nullptr);
383         auto lowerAfter = bound->GetPrev();
384         // We do not move lower than bound
385         auto upperUser = lowerAfter;
386         ASSERT(upperUser != nullptr);
387         if (checkCfg) {
388             // Update lower bound according to def-use chains
389             for (auto &userItem : inst->GetUsers()) {
390                 auto user = userItem.GetInst();
391                 if (!user->IsPhi() && user->GetBasicBlock() == inst->GetBasicBlock() &&
392                     user->IsPrecedingInSameBlock(upperUser)) {
393                     ASSERT(inst->IsPrecedingInSameBlock(user));
394                     upperUser = user->GetPrev();
395                     ASSERT(upperUser != nullptr);
396                 }
397             }
398             lowerAfter = upperUser;
399         }
400 
401         auto instIt = std::find(candidates_.begin(), candidates_.end(), inst);
402         ASSERT(instIt != candidates_.end());
403         for (auto it = instIt + 1; it != candidates_.end(); it++) {
404             auto cand = *it;
405             if (checkCfg && upperUser->IsPrecedingInSameBlock(cand)) {
406                 return upperUser;
407             }
408             // Can't lower load over aliased store and store over aliased memory instructions
409             if (inst->IsStore() || cand->IsStore()) {
410                 auto checkInst = cand;
411                 if (IsPairInst(cand)) {
412                     // We have already checked the first inst. We now want to check the second one
413                     // for alias.
414                     auto pair = pairs_[cand];
415                     checkInst = pair.first->IsPrecedingInSameBlock(pair.second) ? pair.second : pair.first;
416                 }
417                 if (aliases_.CheckInstAlias(inst, checkInst) != NO_ALIAS) {
418                     ASSERT(cand->GetPrev() != nullptr);
419                     return cand->GetPrev();
420                 }
421             }
422             // Can't lower over volatile store
423             if (cand->IsStore() && IsVolatileMemInst(cand)) {
424                 ASSERT(cand->GetPrev() != nullptr);
425                 return cand->GetPrev();
426             }
427         }
428         return lowerAfter;
429     }
430 
431     /// Add a pair if a difference between indices equals to one. The first in pair is with lower index.
TryAddCoalescedPair(Inst * inst,int64_t instIdx,Inst * cand,int64_t candIdx)432     bool TryAddCoalescedPair(Inst *inst, int64_t instIdx, Inst *cand, int64_t candIdx)
433     {
434         Inst *first = nullptr;
435         Inst *second = nullptr;
436         Inst *insertAfter = nullptr;
437         if (instIdx == candIdx - 1) {
438             first = inst;
439             second = cand;
440         } else if (candIdx == instIdx - 1) {
441             first = cand;
442             second = inst;
443         } else {
444             return false;
445         }
446 
447         ASSERT(inst->IsMemory() && cand->IsMemory());
448         ASSERT(inst->GetOpcode() == cand->GetOpcode());
449         ASSERT(inst != cand && cand->IsPrecedingInSameBlock(inst));
450         Inst *candLowerAfter = nullptr;
451         Inst *instUpperAfter = nullptr;
452         if (first->IsLoad()) {
453             // Consider dominance of load users
454             bool checkCfg = true;
455             candLowerAfter = FindLowerInsertAfter(cand, inst, checkCfg);
456             // Do not need index if v0[v1] preceeds v0[v1 + 1] because v1 + 1 is not used in paired load.
457             checkCfg = second->IsPrecedingInSameBlock(first);
458             instUpperAfter = FindUpperInsertAfter(inst, cand, checkCfg);
459         } else if (first->IsStore()) {
460             // Store instructions do not have users. Don't check them
461             bool checkCfg = false;
462             candLowerAfter = FindLowerInsertAfter(cand, inst, checkCfg);
463             // Should check that stored value is ready
464             checkCfg = true;
465             instUpperAfter = FindUpperInsertAfter(inst, cand, checkCfg);
466         } else {
467             UNREACHABLE();
468         }
469 
470         // No intersection in reordering ranges
471         if (!instUpperAfter->IsPrecedingInSameBlock(candLowerAfter)) {
472             return false;
473         }
474         if (cand->IsPrecedingInSameBlock(instUpperAfter)) {
475             insertAfter = instUpperAfter;
476         } else {
477             insertAfter = cand;
478         }
479 
480         first->SetMarker(mrkInvalid_);
481         second->SetMarker(mrkInvalid_);
482         InsertPair(first, second, insertAfter);
483         return true;
484     }
485 
HandleArrayAccessI(Inst * inst)486     void HandleArrayAccessI(Inst *inst)
487     {
488         Inst *obj = inst->GetDataFlowInput(inst->GetInput(0).GetInst());
489         uint64_t idx = GetInstImm(inst);
490         if (!MemoryCoalescing::AcceptedType(inst->GetType())) {
491             candidates_.push_back(inst);
492             return;
493         }
494         /* Last candidates more likely to be coalesced */
495         for (auto iter = candidates_.rbegin(); iter != candidates_.rend(); iter++) {
496             auto cand = *iter;
497             /* Skip not interesting candidates */
498             if (cand->IsMarked(mrkInvalid_) || cand->GetOpcode() != inst->GetOpcode()) {
499                 continue;
500             }
501 
502             Inst *candObj = cand->GetDataFlowInput(cand->GetInput(0).GetInst());
503             /* Array objects must alias each other */
504             if (aliases_.CheckRefAlias(obj, candObj) != MUST_ALIAS) {
505                 continue;
506             }
507             /* The difference between indices should be equal to one */
508             uint64_t candIdx = GetInstImm(cand);
509             /* To keep alignment the lowest index should be even */
510             if (alignedOnly_ && ((idx < candIdx && (idx & 1U) != 0) || (candIdx < idx && (candIdx & 1U) != 0))) {
511                 continue;
512             }
513             if (TryAddCoalescedPair(inst, idx, cand, candIdx)) {
514                 break;
515             }
516         }
517 
518         candidates_.push_back(inst);
519     }
520 
HandleKnownEvolutionArrayAccessVar(Inst * idx,Inst * candIdx,int64_t idxInitial,int64_t candInitial)521     bool HandleKnownEvolutionArrayAccessVar(Inst *idx, Inst *candIdx, int64_t idxInitial, int64_t candInitial)
522     {
523         /* Accesses inside loop */
524         auto idxStep = vars_.GetStep(idx);
525         auto candStep = vars_.GetStep(candIdx);
526         /* Indices should be incremented at the same value and their
527             increment should be even to hold alignment */
528         if (idxStep != candStep) {
529             return false;
530         }
531         /* To keep alignment we need to have even step and even lowest initial */
532         constexpr auto IMM_2 = 2;
533         // NOLINTBEGIN(readability-simplify-boolean-expr)
534         if (alignedOnly_ && idxStep % IMM_2 != 0 &&
535             ((idxInitial < candInitial && idxInitial % IMM_2 != 0) ||
536              (candInitial < idxInitial && candInitial % IMM_2 != 0))) {
537             return false;
538         }
539         return true;
540         // NOLINTEND(readability-simplify-boolean-expr)
541     }
542 
HandleArrayAccess(Inst * inst)543     void HandleArrayAccess(Inst *inst)
544     {
545         Inst *obj = inst->GetDataFlowInput(inst->GetInput(0).GetInst());
546         Inst *idx = inst->GetDataFlowInput(inst->GetInput(1).GetInst());
547         if (!vars_.IsAnalyzed(idx) || !MemoryCoalescing::AcceptedType(inst->GetType())) {
548             candidates_.push_back(inst);
549             return;
550         }
551         /* Last candidates more likely to be coalesced */
552         for (auto iter = candidates_.rbegin(); iter != candidates_.rend(); iter++) {
553             auto cand = *iter;
554             /* Skip not interesting candidates */
555             if (cand->IsMarked(mrkInvalid_) || cand->GetOpcode() != inst->GetOpcode()) {
556                 continue;
557             }
558 
559             Inst *candObj = cand->GetDataFlowInput(cand->GetInput(0).GetInst());
560             auto candIdx = cand->GetDataFlowInput(cand->GetInput(1).GetInst());
561             /* We need to have info about candidate's index and array objects must alias each other */
562             if (!vars_.IsAnalyzed(candIdx) || aliases_.CheckRefAlias(obj, candObj) != MUST_ALIAS) {
563                 continue;
564             }
565             if (vars_.HasKnownEvolution(idx) && vars_.HasKnownEvolution(candIdx)) {
566                 auto idxInitial = vars_.GetInitial(idx);
567                 auto candInitial = vars_.GetInitial(candIdx);
568                 if (!HandleKnownEvolutionArrayAccessVar(idx, candIdx, idxInitial, candInitial)) {
569                     continue;
570                 }
571                 if (TryAddCoalescedPair(inst, idxInitial, cand, candInitial)) {
572                     break;
573                 }
574             } else if (!alignedOnly_ && !vars_.HasKnownEvolution(idx) && !vars_.HasKnownEvolution(candIdx)) {
575                 /* Accesses outside loop */
576                 if (vars_.GetBase(idx) != vars_.GetBase(candIdx)) {
577                     continue;
578                 }
579                 if (TryAddCoalescedPair(inst, vars_.GetDiff(idx), cand, vars_.GetDiff(candIdx))) {
580                     break;
581                 }
582             }
583         }
584 
585         candidates_.push_back(inst);
586     }
587 
588     template <typename T>
CheckForObjectCandidates(T * inst,uint8_t fieldSize,size_t fieldOffset)589     void CheckForObjectCandidates(T *inst, uint8_t fieldSize, size_t fieldOffset)
590     {
591         Inst *obj = inst->GetDataFlowInput(inst->GetInput(0).GetInst());
592         /* Last candidates more likely to be coalesced */
593         for (auto iter = candidates_.rbegin(); iter != candidates_.rend(); iter++) {
594             auto cand = *iter;
595             if (cand->GetOpcode() != Opcode::LoadObject && cand->GetOpcode() != Opcode::StoreObject) {
596                 continue;
597             }
598             Inst *candObj = cand->GetDataFlowInput(cand->GetInput(0).GetInst());
599             if (aliases_.CheckRefAlias(obj, candObj) != MUST_ALIAS) {
600                 if (aliases_.CheckInstAlias(inst, cand) == MAY_ALIAS) {
601                     cand->SetMarker(mrkInvalid_);
602                     inst->SetMarker(mrkInvalid_);
603                     break;
604                 }
605                 continue;
606             }
607             if (cand->IsMarked(mrkInvalid_)) {
608                 continue;
609             }
610             if (cand->GetOpcode() != inst->GetOpcode() || cand->GetType() != inst->GetType()) {
611                 continue;
612             }
613             size_t candFieldOffset;
614             if constexpr (std::is_same_v<T, LoadObjectInst>) {
615                 auto candLoadObj = cand->CastToLoadObject();
616                 candFieldOffset = GetObjectOffset(graph_, candLoadObj->GetObjectType(), candLoadObj->GetObjField(),
617                                                   candLoadObj->GetTypeId());
618             } else {
619                 auto candStoreObj = cand->CastToStoreObject();
620                 candFieldOffset = GetObjectOffset(graph_, candStoreObj->GetObjectType(), candStoreObj->GetObjField(),
621                                                   candStoreObj->GetTypeId());
622             }
623             auto candFieldSize = GetTypeByteSize(cand->GetType(), graph_->GetArch());
624             if ((fieldOffset + fieldSize == candFieldOffset && TryAddCoalescedPair(inst, 0, cand, 1)) ||
625                 (candFieldOffset + candFieldSize == fieldOffset && TryAddCoalescedPair(inst, 1, cand, 0))) {
626                 break;
627             }
628         }
629     }
630 
631     template <typename T>
HandleObjectAccess(T * inst)632     void HandleObjectAccess(T *inst)
633     {
634         ObjectType objType = inst->GetObjectType();
635         auto fieldSize = GetTypeByteSize(inst->GetType(), graph_->GetArch());
636         size_t fieldOffset = GetObjectOffset(graph_, objType, inst->GetObjField(), inst->GetTypeId());
637         bool isVolatile = inst->GetVolatile();
638         if (isVolatile) {
639             inst->SetMarker(mrkInvalid_);
640         }
641         if (!MemoryCoalescing::AcceptedType(inst->GetType()) || objType != MEM_OBJECT || isVolatile) {
642             candidates_.push_back(inst);
643             return;
644         }
645         CheckForObjectCandidates(inst, fieldSize, fieldOffset);
646         candidates_.push_back(inst);
647     }
648 
InsertPair(Inst * first,Inst * second,Inst * insertAfter)649     void InsertPair(Inst *first, Inst *second, Inst *insertAfter)
650     {
651         COMPILER_LOG(DEBUG, MEMORY_COALESCING)
652             << "Access that may be coalesced: v" << first->GetId() << " v" << second->GetId();
653 
654         ASSERT(first->GetType() == second->GetType());
655         Inst *paired = nullptr;
656         switch (first->GetOpcode()) {
657             case Opcode::LoadArray:
658                 paired = ReplaceLoadArray(first, second, insertAfter);
659                 break;
660             case Opcode::LoadArrayI:
661                 paired = ReplaceLoadArrayI(first, second, insertAfter);
662                 break;
663             case Opcode::StoreArray:
664                 paired = ReplaceStoreArray(first, second, insertAfter);
665                 break;
666             case Opcode::StoreArrayI:
667                 paired = ReplaceStoreArrayI(first, second, insertAfter);
668                 break;
669             case Opcode::LoadObject:
670                 paired = ReplaceLoadObject(first, second, insertAfter);
671                 break;
672             case Opcode::StoreObject:
673                 paired = ReplaceStoreObject(first, second, insertAfter);
674                 break;
675             default:
676                 UNREACHABLE();
677         }
678 
679         ASSERT(paired != nullptr);
680         COMPILER_LOG(DEBUG, MEMORY_COALESCING) << "Coalescing of {v" << first->GetId() << " v" << second->GetId()
681                                                << "} by " << paired->GetId() << " is successful";
682         graph_->GetEventWriter().EventMemoryCoalescing(first->GetId(), first->GetPc(), second->GetId(), second->GetPc(),
683                                                        paired->GetId(), paired->IsStore() ? "Store" : "Load");
684 
685         pairs_[paired] = {first, second};
686         paired->SetMarker(mrkInvalid_);
687         candidates_.insert(std::find_if(candidates_.rbegin(), candidates_.rend(),
688                                         [paired](auto x) { return x->IsPrecedingInSameBlock(paired); })
689                                .base(),
690                            paired);
691     }
692 
ReplaceLoadArray(Inst * first,Inst * second,Inst * insertAfter)693     Inst *ReplaceLoadArray(Inst *first, Inst *second, Inst *insertAfter)
694     {
695         ASSERT(first->GetOpcode() == Opcode::LoadArray);
696         ASSERT(second->GetOpcode() == Opcode::LoadArray);
697 
698         auto pload = graph_->CreateInstLoadArrayPair(first->GetType(), INVALID_PC, first->GetInput(0).GetInst(),
699                                                      first->GetInput(1).GetInst());
700         pload->CastToLoadArrayPair()->SetNeedBarrier(first->CastToLoadArray()->GetNeedBarrier() ||
701                                                      second->CastToLoadArray()->GetNeedBarrier());
702         insertAfter->InsertAfter(pload);
703         if (first->CanThrow() || second->CanThrow()) {
704             pload->SetFlag(compiler::inst_flags::CAN_THROW);
705         }
706         MemoryCoalescing::RemoveAddI(pload);
707         return pload;
708     }
709 
ReplaceLoadObject(Inst * first,Inst * second,Inst * insertAfter)710     Inst *ReplaceLoadObject(Inst *first, Inst *second, Inst *insertAfter)
711     {
712         ASSERT(first->GetOpcode() == Opcode::LoadObject);
713         ASSERT(second->GetOpcode() == Opcode::LoadObject);
714         ASSERT(!first->CastToLoadObject()->GetVolatile());
715         ASSERT(!second->CastToLoadObject()->GetVolatile());
716 
717         auto pload = graph_->CreateInstLoadObjectPair(first->GetType(), INVALID_PC);
718         pload->SetInput(InputOrd::INP0, first->GetInput(InputOrd::INP0).GetInst());
719         pload->SetType(first->GetType());
720         pload->SetTypeId0(first->CastToLoadObject()->GetTypeId());
721         pload->SetTypeId1(second->CastToLoadObject()->GetTypeId());
722         pload->SetObjField0(first->CastToLoadObject()->GetObjField());
723         pload->SetObjField1(second->CastToLoadObject()->GetObjField());
724 
725         pload->CastToLoadObjectPair()->SetNeedBarrier(first->CastToLoadObject()->GetNeedBarrier() ||
726                                                       second->CastToLoadObject()->GetNeedBarrier());
727         if (first->CanThrow() || second->CanThrow()) {
728             pload->SetFlag(compiler::inst_flags::CAN_THROW);
729         }
730         insertAfter->InsertAfter(pload);
731 
732         return pload;
733     }
734 
ReplaceLoadArrayI(Inst * first,Inst * second,Inst * insertAfter)735     Inst *ReplaceLoadArrayI(Inst *first, Inst *second, Inst *insertAfter)
736     {
737         ASSERT(first->GetOpcode() == Opcode::LoadArrayI);
738         ASSERT(second->GetOpcode() == Opcode::LoadArrayI);
739 
740         auto pload = graph_->CreateInstLoadArrayPairI(first->GetType(), INVALID_PC, first->GetInput(0).GetInst(),
741                                                       first->CastToLoadArrayI()->GetImm());
742         pload->CastToLoadArrayPairI()->SetNeedBarrier(first->CastToLoadArrayI()->GetNeedBarrier() ||
743                                                       second->CastToLoadArrayI()->GetNeedBarrier());
744         insertAfter->InsertAfter(pload);
745         if (first->CanThrow() || second->CanThrow()) {
746             pload->SetFlag(compiler::inst_flags::CAN_THROW);
747         }
748 
749         return pload;
750     }
751 
ReplaceStoreArray(Inst * first,Inst * second,Inst * insertAfter)752     Inst *ReplaceStoreArray(Inst *first, Inst *second, Inst *insertAfter)
753     {
754         ASSERT(first->GetOpcode() == Opcode::StoreArray);
755         ASSERT(second->GetOpcode() == Opcode::StoreArray);
756 
757         auto pstore = graph_->CreateInstStoreArrayPair(
758             first->GetType(), INVALID_PC,
759             std::array<Inst *, 4U> {first->GetInput(0).GetInst(), first->CastToStoreArray()->GetIndex(),
760                                     first->CastToStoreArray()->GetStoredValue(),
761                                     second->CastToStoreArray()->GetStoredValue()});
762         pstore->CastToStoreArrayPair()->SetNeedBarrier(first->CastToStoreArray()->GetNeedBarrier() ||
763                                                        second->CastToStoreArray()->GetNeedBarrier());
764         insertAfter->InsertAfter(pstore);
765         if (first->CanThrow() || second->CanThrow()) {
766             pstore->SetFlag(compiler::inst_flags::CAN_THROW);
767         }
768         MemoryCoalescing::RemoveAddI(pstore);
769         return pstore;
770     }
771 
ReplaceStoreObject(Inst * first,Inst * second,Inst * insertAfter)772     Inst *ReplaceStoreObject(Inst *first, Inst *second, Inst *insertAfter)
773     {
774         ASSERT(first->GetOpcode() == Opcode::StoreObject);
775         ASSERT(second->GetOpcode() == Opcode::StoreObject);
776         ASSERT(!first->CastToStoreObject()->GetVolatile());
777         ASSERT(!second->CastToStoreObject()->GetVolatile());
778 
779         auto pstore = graph_->CreateInstStoreObjectPair();
780         pstore->SetType(first->GetType());
781         pstore->SetTypeId0(first->CastToStoreObject()->GetTypeId());
782         pstore->SetTypeId1(second->CastToStoreObject()->GetTypeId());
783         pstore->SetInput(InputOrd::INP0, first->GetInput(InputOrd::INP0).GetInst());
784         pstore->SetInput(InputOrd::INP1, first->GetInput(InputOrd::INP1).GetInst());
785         pstore->SetInput(InputOrd::INP2, second->GetInput(InputOrd::INP1).GetInst());
786         pstore->CastToStoreObjectPair()->SetObjField0(first->CastToStoreObject()->GetObjField());
787         pstore->CastToStoreObjectPair()->SetObjField1(second->CastToStoreObject()->GetObjField());
788 
789         pstore->CastToStoreObjectPair()->SetNeedBarrier(first->CastToStoreObject()->GetNeedBarrier() ||
790                                                         second->CastToStoreObject()->GetNeedBarrier());
791         if (first->CanThrow() || second->CanThrow()) {
792             pstore->SetFlag(compiler::inst_flags::CAN_THROW);
793         }
794         insertAfter->InsertAfter(pstore);
795 
796         return pstore;
797     }
798 
ReplaceStoreArrayI(Inst * first,Inst * second,Inst * insertAfter)799     Inst *ReplaceStoreArrayI(Inst *first, Inst *second, Inst *insertAfter)
800     {
801         ASSERT(first->GetOpcode() == Opcode::StoreArrayI);
802         ASSERT(second->GetOpcode() == Opcode::StoreArrayI);
803 
804         auto pstore = graph_->CreateInstStoreArrayPairI(
805             first->GetType(), INVALID_PC, first->GetInput(0).GetInst(), first->CastToStoreArrayI()->GetStoredValue(),
806             second->CastToStoreArrayI()->GetStoredValue(), first->CastToStoreArrayI()->GetImm());
807         pstore->CastToStoreArrayPairI()->SetNeedBarrier(first->CastToStoreArrayI()->GetNeedBarrier() ||
808                                                         second->CastToStoreArrayI()->GetNeedBarrier());
809         insertAfter->InsertAfter(pstore);
810         if (first->CanThrow() || second->CanThrow()) {
811             pstore->SetFlag(compiler::inst_flags::CAN_THROW);
812         }
813 
814         return pstore;
815     }
816 
GetInstImm(Inst * inst)817     uint64_t GetInstImm(Inst *inst)
818     {
819         switch (inst->GetOpcode()) {
820             case Opcode::LoadArrayI:
821                 return inst->CastToLoadArrayI()->GetImm();
822             case Opcode::StoreArrayI:
823                 return inst->CastToStoreArrayI()->GetImm();
824             default:
825                 UNREACHABLE();
826         }
827     }
828 
829 private:
830     bool alignedOnly_;
831     Marker mrkInvalid_;
832     Graph *graph_ {nullptr};
833     const AliasAnalysis &aliases_;
834     const VariableAnalysis &vars_;
835     ArenaUnorderedMap<Inst *, MemoryCoalescing::CoalescedPair> pairs_;
836     InstVector candidates_;
837 };
838 
ReplaceLoadByPair(Inst * load,Inst * pairedLoad,int32_t dstIdx)839 static void ReplaceLoadByPair(Inst *load, Inst *pairedLoad, int32_t dstIdx)
840 {
841     auto graph = pairedLoad->GetBasicBlock()->GetGraph();
842     auto pairGetter = graph->CreateInstLoadPairPart(load->GetType(), INVALID_PC, pairedLoad, dstIdx);
843     load->ReplaceUsers(pairGetter);
844     pairedLoad->InsertAfter(pairGetter);
845 }
846 
RemoveAddI(Inst * inst)847 void MemoryCoalescing::RemoveAddI(Inst *inst)
848 {
849     auto opcode = inst->GetOpcode();
850     ASSERT(opcode == Opcode::LoadArrayPair || opcode == Opcode::StoreArrayPair);
851     auto input1 = inst->GetInput(1).GetInst();
852     if (input1->GetOpcode() == Opcode::AddI) {
853         uint64_t imm = input1->CastToAddI()->GetImm();
854         if (opcode == Opcode::LoadArrayPair) {
855             inst->CastToLoadArrayPair()->SetImm(imm);
856         } else if (opcode == Opcode::StoreArrayPair) {
857             inst->CastToStoreArrayPair()->SetImm(imm);
858         }
859         inst->SetInput(1, input1->GetInput(0).GetInst());
860     }
861 }
862 
863 /**
864  * This optimization coalesces two loads (stores) that read (write) values from (to) the consecutive memory into
865  * a single operation.
866  *
867  * 1) If we have two memory instruction that can be coalesced then we are trying to find a position for
868  *    coalesced operation. If it is possible, the memory operations are coalesced and skipped otherwise.
869  * 2) The instruction of Aarch64 requires memory address alignment. For arrays
870  *    it means we can coalesce only accesses that starts from even index.
871  * 3) The implemented coalescing for arrays supposes there is no volatile array element accesses.
872  */
RunImpl()873 bool MemoryCoalescing::RunImpl()
874 {
875     if (GetGraph()->GetArch() != Arch::AARCH64) {
876         COMPILER_LOG(INFO, MEMORY_COALESCING) << "Skipping Memory Coalescing for unsupported architecture";
877         return false;
878     }
879     COMPILER_LOG(DEBUG, MEMORY_COALESCING) << "Memory Coalescing running";
880     GetGraph()->RunPass<DominatorsTree>();
881     GetGraph()->RunPass<LoopAnalyzer>();
882     GetGraph()->RunPass<AliasAnalysis>();
883 
884     VariableAnalysis variables(GetGraph());
885     auto &aliases = GetGraph()->GetValidAnalysis<AliasAnalysis>();
886     Marker mrk = GetGraph()->NewMarker();
887     PairCreatorVisitor collector(GetGraph(), aliases, variables, mrk, alignedOnly_);
888     for (auto block : GetGraph()->GetBlocksRPO()) {
889         collector.VisitBlock(block);
890         collector.Reset();
891     }
892     GetGraph()->EraseMarker(mrk);
893     for (auto pair : collector.GetPairs()) {
894         auto bb = pair.first->GetBasicBlock();
895         if (pair.first->IsLoad()) {
896             ReplaceLoadByPair(pair.second.second, pair.first, 1);
897             ReplaceLoadByPair(pair.second.first, pair.first, 0);
898         }
899         bb->RemoveInst(pair.second.first);
900         bb->RemoveInst(pair.second.second);
901     }
902 
903     if (!collector.GetPairs().empty()) {
904         SaveStateBridgesBuilder ssb;
905         for (auto bb : GetGraph()->GetBlocksRPO()) {
906             if (!bb->IsEmpty() && !bb->IsStartBlock()) {
907                 ssb.FixSaveStatesInBB(bb);
908             }
909         }
910     }
911     COMPILER_LOG(DEBUG, MEMORY_COALESCING) << "Memory Coalescing completed";
912     return !collector.GetPairs().empty();
913 }
914 }  // namespace ark::compiler
915