1 /* 2 * Copyright (c) 2021-2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef COMPILER_OPTIMIZER_OPTIMIZATIONS_LOOP_TRANSFORM_H 17 #define COMPILER_OPTIMIZER_OPTIMIZATIONS_LOOP_TRANSFORM_H 18 19 #include "optimizer/ir/graph.h" 20 #include "optimizer/ir/basicblock.h" 21 #include "optimizer/pass.h" 22 #include "optimizer/analysis/loop_analyzer.h" 23 24 namespace ark::compiler { 25 enum class LoopExitPoint : uint8_t { ALL_LOOP, LOOP_EXIT_HEADER, LOOP_EXIT_BACKEDGE }; 26 27 template <const LoopExitPoint EXIT_POINT> 28 class LoopTransform : public Optimization { 29 protected: LoopTransform(Graph * graph)30 explicit LoopTransform(Graph *graph) : Optimization(graph) {} 31 32 virtual bool TransformLoop(Loop *loop) = 0; 33 RunLoopsVisitor()34 void RunLoopsVisitor() 35 { 36 GetGraph()->template RunPass<LoopAnalyzer>(); 37 ASSERT(GetGraph()->GetRootLoop() != nullptr); 38 if (GetGraph()->GetRootLoop()->GetInnerLoops().empty()) { 39 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Graph doesn't have loops"; 40 } 41 42 auto markerHolder = MarkerHolder(GetGraph()); 43 auto markerLoopExit = markerHolder.GetMarker(); 44 MarkLoopExits(GetGraph(), markerLoopExit); 45 for (auto loop : GetGraph()->GetRootLoop()->GetInnerLoops()) { 46 LoopVisitLRN(loop, markerLoopExit); 47 } 48 } 49 IsSupportedLoopType(const Loop * loop)50 bool IsSupportedLoopType(const Loop *loop) 51 { 52 if (loop->IsIrreducible()) { 53 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Irreducible loop isn't visited, id = " << loop->GetId(); 54 return false; 55 } 56 if (loop->IsOsrLoop()) { 57 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "OSR entry isn't visited, loop id = " << loop->GetId(); 58 return false; 59 } 60 if (loop->IsTryCatchLoop()) { 61 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) << "Try-catch loop isn't visited, loop id = " << loop->GetId(); 62 return false; 63 } 64 if (loop->GetBackEdges().size() > 1) { 65 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) 66 << "Loop with more than 1 back-edge isn't visited, id = " << loop->GetId(); 67 return false; 68 } 69 return true; 70 } 71 LoopVisitLRN(Loop * loop,Marker marker)72 bool LoopVisitLRN(Loop *loop, Marker marker) 73 { 74 ASSERT(loop != nullptr); 75 const auto &innerLoops = loop->GetInnerLoops(); 76 bool result = true; 77 for (auto innerLoop : innerLoops) { 78 result &= LoopVisitLRN(innerLoop, marker); 79 } 80 81 if (result && IsSupportedLoopType(loop)) { 82 return VisitLoop(loop, marker); 83 } 84 return false; 85 } 86 87 #ifndef __clang_analyzer__ VisitBlockInLoop(BasicBlock * block,Loop * loop,Marker marker)88 bool VisitBlockInLoop(BasicBlock *block, Loop *loop, Marker marker) 89 { 90 if constexpr (EXIT_POINT == LoopExitPoint::LOOP_EXIT_HEADER) { 91 // NOTE (a.popov) Support infinite loops unrolling 92 if (!block->IsMarked(marker) && block->IsLoopHeader()) { 93 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) 94 << "Loop without exit-point from loop-header isn't visited, id = " << loop->GetId(); 95 return false; 96 } 97 if (block->IsMarked(marker) && !block->IsLoopHeader()) { 98 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) 99 << "Loop with loop-exit not from loop-header isn't visited, id = " << loop->GetId(); 100 return false; 101 } 102 } else if constexpr (EXIT_POINT == LoopExitPoint::LOOP_EXIT_BACKEDGE) { 103 ASSERT(loop->GetBackEdges().size() == 1); 104 auto back_edge = loop->GetBackEdges()[0]; 105 if (!block->IsMarked(marker) && block == back_edge) { 106 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) 107 << "Loop without exit-point from back-edge isn't visited, id = " << loop->GetId(); 108 return false; 109 } 110 if (block->IsMarked(marker) && block != back_edge) { 111 COMPILER_LOG(DEBUG, LOOP_TRANSFORM) 112 << "Loop with loop-exit not from last block isn't visited, id = " << loop->GetId(); 113 return false; 114 } 115 } 116 return true; 117 } 118 #endif 119 VisitLoop(Loop * loop,Marker marker)120 bool VisitLoop(Loop *loop, [[maybe_unused]] Marker marker) 121 { 122 #ifndef __clang_analyzer__ 123 if constexpr (EXIT_POINT != LoopExitPoint::ALL_LOOP) { 124 for (auto block : loop->GetBlocks()) { 125 if (!VisitBlockInLoop(block, loop, marker)) { 126 return false; 127 } 128 } 129 } 130 #endif 131 return TransformLoop(loop); 132 } 133 GetLoopOuterBlock(BasicBlock * exitBlock)134 BasicBlock *GetLoopOuterBlock(BasicBlock *exitBlock) 135 { 136 ASSERT(exitBlock->GetSuccsBlocks().size() == 2U); 137 auto loop = exitBlock->GetLoop(); 138 auto outer = exitBlock->GetTrueSuccessor(); 139 if (outer->GetLoop() == loop) { 140 outer = exitBlock->GetFalseSuccessor(); 141 } 142 ASSERT(outer->GetLoop() != loop); 143 return outer; 144 } 145 }; 146 } // namespace ark::compiler 147 148 #endif // COMPILER_OPTIMIZER_OPTIMIZATIONS_LOOP_TRANSFORM_H 149