1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_LOOP_UNROLLING_H 17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_LOOP_UNROLLING_H 18 19 #include <map> 20 #include <memory> 21 #include <queue> 22 #include <set> 23 #include <string> 24 #include "pipeline/jit/pi/pydef.h" 25 #include "pipeline/jit/pi/graph_capture/cfg.h" 26 #include "pipeline/jit/pi/graph_capture/loop.h" 27 #include "pipeline/jit/pi/utils/allocator.h" 28 #include "pipeline/jit/pi/utils/utils.h" 29 30 namespace mindspore { 31 namespace pijit { 32 class Graph; 33 class LoopUnrolling { 34 public: IsloopUnorlling(LoopUnrollingReason res)35 static bool IsloopUnorlling(LoopUnrollingReason res) { return res == kCanForItemUnroll || res == kCanWhileUnroll; } 36 LoopUnrolling(Graph & graph)37 explicit LoopUnrolling(Graph &graph) : graph_(graph) {} 38 virtual ~LoopUnrolling() = default; 39 40 LoopUnrollingReason ExecuteLoopUnroll(Block *header); IsCFGChanged()41 bool IsCFGChanged() const { return is_cfg_changed_; } 42 std::string DumpLoopUnrolling(); 43 44 private: 45 void Run(); 46 LoopUnrollingReason AnalyzeForItem(); 47 bool AddLoopGurad(ValueNode *value); 48 LoopUnrollingReason CheckLoopUnrollingSideeffect(); 49 std::map<int, Block *> CopyBB(); 50 void CopyAndInsertBB(); 51 void RemoveBackedge(); 52 void AddLoopUnrollingInstr(Block *bb, int count); 53 void FixupInstr(); 54 55 Graph &graph_; 56 LoopInfo *loop_ = nullptr; 57 LoopUnrollingReason res_ = kCanNotUnroll; 58 int unrolling_count_ = 0; 59 int loop_op_ = -1; 60 int loop_arg_ = -1; 61 Instr *iter_instr_ = nullptr; 62 ValueNode *loop_value_ = nullptr; 63 bool is_cfg_changed_ = false; 64 }; 65 } // namespace pijit 66 } // namespace mindspore 67 68 #endif // MINDSPORE_PI_JIT_GRAPH_CAPTURE_LOOP_UNROLLING_H 69