1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_PI_JIT_FUNC_GRAPH_BUILDER_H_ 18 #define MINDSPORE_PI_JIT_FUNC_GRAPH_BUILDER_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include "ir/func_graph.h" 25 #include "ir/value.h" 26 #include "pipeline/jit/pi/graph_compiler/pi_ir/ir_mutator.h" 27 #include "utils/trace_info.h" 28 29 namespace mindspore::pijit { 30 namespace py = pybind11; 31 32 class MindNode : public ir::Node { 33 public: MindNode(const AnfNodePtr & node)34 explicit MindNode(const AnfNodePtr &node) : node_(node) {} 35 36 // Destructor. 37 ~MindNode() override = default; 38 JIT_DECLARE_PARENT(MindNode, Node); 39 GetAnfNode()40 const AnfNodePtr &GetAnfNode() const { return node_; } 41 42 /** 43 * \brief Get the description of this Mind Node. 44 * \return The description. 45 */ ToString()46 std::string ToString() const override { return node_->DebugString(); } 47 48 private: 49 AnfNodePtr node_; 50 }; 51 52 using MindNodePtr = std::shared_ptr<MindNode>; 53 54 // FuncGraphBuilder to convert ir graph to function graph 55 class FuncGraphBuilder : public ir::IRMutator { 56 public: FuncGraphBuilder(const ir::FunctionNodePtr & func)57 explicit FuncGraphBuilder(const ir::FunctionNodePtr &func) : FuncGraphBuilder(func, {}, NewValueNode(kNone)) {} FuncGraphBuilder(const ir::FunctionNodePtr & func,const AnfNodePtrList & args,const AnfNodePtr & kwargs)58 FuncGraphBuilder(const ir::FunctionNodePtr &func, const AnfNodePtrList &args, const AnfNodePtr &kwargs) 59 : func_(func), 60 args_(args), 61 kwargs_(kwargs), 62 func_graph_(std::make_shared<FuncGraph>()), 63 last_line_no_(func->GetFirstLineNo()) {} 64 virtual ~FuncGraphBuilder() = default; 65 static FuncGraphPtr BuildFuncGraph(const ir::FunctionNodePtr &func, const py::tuple &args, const py::dict &kwargs); 66 static FuncGraphPtr BuildFuncGraph(const ir::FunctionNodePtr &func, const AnfNodePtrList &args, 67 const AnfNodePtr &kwargs); 68 69 // overloadable Mutate function. 70 ir::NodePtr Mutate_(const ir::RefNodePtr &node) override; 71 ir::NodePtr Mutate_(const ir::ParameterPtr &node) override; 72 ir::NodePtr Mutate_(const ir::FunctionNodePtr &node) override; 73 ir::NodePtr Mutate_(const ir::ValuePtr &node) override; 74 ir::NodePtr Mutate_(const ir::IfNodePtr &node) override; 75 ir::NodePtr Mutate_(const ir::BinaryOperationPtr &node) override; 76 ir::NodePtr Mutate_(const ir::NegativeNodePtr &node) override; 77 ir::NodePtr Mutate_(const ir::NotNodePtr &node) override; 78 ir::NodePtr Mutate_(const ir::InvertNodePtr &node) override; 79 ir::NodePtr Mutate_(const ir::ReturnNodePtr &node) override; 80 ir::NodePtr Mutate_(const ir::CastNodePtr &node) override; 81 ir::NodePtr Mutate_(const ir::FormatNodePtr &node) override; 82 ir::NodePtr Mutate_(const ir::AddNodePtr &node) override; 83 ir::NodePtr Mutate_(const ir::SubNodePtr &node) override; 84 ir::NodePtr Mutate_(const ir::MulNodePtr &node) override; 85 ir::NodePtr Mutate_(const ir::DivNodePtr &node) override; 86 ir::NodePtr Mutate_(const ir::BitwiseNodePtr &node) override; 87 ir::NodePtr Mutate_(const ir::IsNodePtr &node) override; 88 ir::NodePtr Mutate_(const ir::ContainsNodePtr &node) override; 89 ir::NodePtr Mutate_(const ir::StoreNodePtr &node) override; 90 ir::NodePtr Mutate_(const ir::CompareNodePtr &node) override; 91 ir::NodePtr Mutate_(const ir::LoadValueNodePtr &node) override; 92 ir::NodePtr Mutate_(const ir::LoadFieldNodePtr &node) override; 93 ir::NodePtr Mutate_(const ir::BuildNodePtr &node) override; 94 ir::NodePtr Mutate_(const ir::CallNodePtr &node) override; 95 ir::NodePtr Mutate_(const ir::UpdateNodePtr &node) override; 96 ir::NodePtr Mutate_(const ir::SubscrNodePtr &node) override; 97 ir::NodePtr Mutate_(const ir::AttrNodePtr &node) override; 98 99 private: 100 void UpdateLocation(const AnfNodePtr &anf_node, const ir::NodePtr &node); 101 AnfNodePtr ConvertListOrTupleToCNode(const py::object &obj); 102 AnfNodePtr GetAnfNode(const ir::NodePtr &node); 103 AnfNodePtr MergeList(const AnfNodePtr &left, const AnfNodePtr &right); 104 std::pair<AnfNodePtrList, AnfNodePtrList> GetKeysAndValueOfDict(const AnfNodePtr &node); 105 AnfNodePtr MergeDict(const AnfNodePtr &left, const AnfNodePtr &right); 106 107 const ir::FunctionNodePtr func_; 108 AnfNodePtrList args_; 109 AnfNodePtr kwargs_; 110 FuncGraphPtr func_graph_; 111 int last_line_no_; 112 bool enable_debug_info_{false}; 113 114 // Store variable's name, variable's node. 115 std::map<std::string, AnfNodePtr> assigned_vars_; 116 }; 117 118 using FuncGraphBuilderPtr = std::shared_ptr<FuncGraphBuilder>; 119 } // namespace mindspore::pijit 120 121 #endif // MINDSPORE_PI_JIT_FUNC_GRAPH_BUILDER_H_ 122