1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_PI_JIT_FUNC_INLINER_H_ 18 #define MINDSPORE_PI_JIT_FUNC_INLINER_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include "pipeline/jit/pi/graph_compiler/pi_ir/ir_mutator.h" 24 #include "pipeline/jit/pi/graph_compiler/pi_ir/ir_visitor.h" 25 #include "utils/convert_utils_base.h" 26 27 namespace mindspore { 28 namespace pijit { 29 class FuncInlineDetector : public ir::IRVisitor { 30 public: FuncInlineDetector(const ir::FunctionNodePtr & func)31 explicit FuncInlineDetector(const ir::FunctionNodePtr &func) : func_(func), index_(0), cur_root_node_(nullptr) {} 32 virtual ~FuncInlineDetector() = default; 33 void Run(); 34 35 void Visit_(const ir::FunctionNodePtr &node) override; 36 void Visit_(const ir::CallNodePtr &node) override; 37 size_t GetRootNodeIndex(const ir::CallNodePtr &node) const; 38 const ir::NodePtr &GetRootNode(const ir::CallNodePtr &node) const; 39 40 private: 41 bool CanBeInlined(const ir::NodePtr &node) const; 42 void EvolvingFunction(const ir::FunctionNodePtr &func_node, const ir::NodePtrList args) const; 43 44 const ir::FunctionNodePtr func_; 45 size_t index_; 46 ir::NodePtr cur_root_node_; 47 std::map<ir::CallNodePtr, size_t> node_2_index_; 48 std::map<ir::CallNodePtr, ir::NodePtr> node_2_root_; 49 }; 50 51 using FuncInlineDetectorPtr = std::shared_ptr<FuncInlineDetector>; 52 53 class FuncLocalVarRenamer : public ir::IRVisitor { 54 public: FuncLocalVarRenamer(const ir::FunctionNodePtr & func)55 explicit FuncLocalVarRenamer(const ir::FunctionNodePtr &func) : func_(func) {} 56 virtual ~FuncLocalVarRenamer() = default; 57 void Run(); 58 59 void Visit_(const ir::ParameterPtr &node) override; 60 void Visit_(const ir::ValuePtr &node) override; 61 62 private: 63 const ir::FunctionNodePtr func_; 64 }; 65 66 using FuncLocalVarRenamerPtr = std::shared_ptr<FuncLocalVarRenamer>; 67 68 class FuncParameterEliminator : public ir::IRMutator { 69 public: FuncParameterEliminator(const ir::FunctionNodePtr & func,const ir::NodePtrList & args)70 explicit FuncParameterEliminator(const ir::FunctionNodePtr &func, const ir::NodePtrList &args) 71 : func_(func), args_(args) {} 72 virtual ~FuncParameterEliminator() = default; 73 void Run(); 74 75 ir::NodePtr Mutate_(const ir::ParameterPtr &node) override; 76 ir::NodePtr Mutate_(const ir::LoadValueNodePtr &node) override; 77 ir::NodePtr Mutate_(const ir::StoreNodePtr &node) override; 78 79 private: 80 const ir::FunctionNodePtr func_; 81 const ir::NodePtrList args_; 82 std::map<std::string, ir::NodePtr> name_2_node_; 83 }; 84 85 using FuncParameterEliminatorPtr = std::shared_ptr<FuncParameterEliminator>; 86 87 // FuncInliner to convert ir graph to function graph 88 class FuncInliner : public ir::IRMutator { 89 public: FuncInliner(const ir::FunctionNodePtr & func)90 explicit FuncInliner(const ir::FunctionNodePtr &func) 91 : func_(func), detector_(std::make_shared<FuncInlineDetector>(func)), inserted_nodes_cnt_(0) {} 92 virtual ~FuncInliner() = default; 93 void Run(); 94 void InsertSubFunction(); 95 96 // overloadable Mutate function. 97 ir::NodePtr Mutate_(const ir::CallNodePtr &node) override; 98 99 private: 100 const ir::FunctionNodePtr func_; 101 const FuncInlineDetectorPtr detector_; 102 size_t inserted_nodes_cnt_; 103 std::map<size_t, ir::FunctionNodePtr> index_2_function_; 104 }; 105 106 using FuncInlinerPtr = std::shared_ptr<FuncInliner>; 107 } // namespace pijit 108 } // namespace mindspore 109 110 #endif // MINDSPORE_PI_JIT_FUNC_INLINER_H_ 111