1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_PI_JIT_GRAPH_BUILD_FUNC_GRAPH_BUILDER_H_ 18 #define MINDSPORE_PI_JIT_GRAPH_BUILD_FUNC_GRAPH_BUILDER_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include "ir/value.h" 24 #include "ops/sequence_ops.h" 25 #include "pipeline/jit/ps/parse/parse_base.h" 26 #include "pipeline/jit/ps/parse/parse.h" 27 28 namespace mindspore { 29 class FuncGraphBuilder; 30 using FuncGraphBuilderPtr = std::shared_ptr<FuncGraphBuilder>; 31 32 class FuncGraphBuilder { 33 public: graph_(std::make_shared<FuncGraph> ())34 explicit FuncGraphBuilder(bool is_top = false) : graph_(std::make_shared<FuncGraph>()) { 35 if (is_top) { 36 parse::Parser::UpdateTopFuncGraph(graph_); 37 } 38 } ~FuncGraphBuilder()39 virtual ~FuncGraphBuilder() { py_obj_to_node_.clear(); } 40 41 /// \brief Add an input parameter to the graph. 42 /// 43 /// \param[in] obj The input python object. 44 /// 45 /// \return If the input is a tensor, return a fake tensor python object, else return the origin python object. 46 py::object AddSubGraphInput(const py::object &obj); 47 48 /// \brief Add an input parameter to the top graph. 49 /// 50 /// \param[in] packed_inputs The input python object for top graph. 51 /// 52 /// \return True if add top graph success, otherwise false. 53 bool AddTopGraphInputs(std::vector<py::object> packed_inputs); 54 55 /// \brief Add a cnode to the graph. 56 /// 57 /// \param[in] callable_obj The callable python object. 58 /// \param[in] inputs_obj The input python objects. 59 /// 60 /// \return The python object of the infer result. 61 py::object AddNode(const py::object &callable_obj, const std::vector<py::object> &inputs_obj); 62 63 /// \brief Add a cnode to the graph. 64 /// 65 /// \param[in] callable_value The callable value. 66 /// \param[in] inputs_obj The input python objects. 67 /// 68 /// \return The python object of the infer result. 69 py::object AddNode(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj); 70 71 /// \brief Add a python object to graph. 72 /// 73 /// \param[in] object The python object add to graph. 74 /// 75 /// \return Indicate whether the python object add to graph successfully. 76 bool AddAttrPythonObject(const py::object &object); 77 78 /// \brief Add a binary operation cnode to the graph. 79 /// 80 /// \param[in] opcode The binary operation code. 81 /// \param[in] inputs_obj The input python objects. 82 /// 83 /// \return The python object of the infer result. 84 py::object AddMultiNode(const std::string &opcode, const std::vector<py::object> &inputs_obj); 85 86 /// \brief Add an output node to the graph. 87 /// 88 /// \param[in] output_obj The output python object. 89 /// \param[in] is_top_graph Indicate whether the graph to add output is top graph. 90 /// 91 /// \return Return true if the output object can be used as the output of the graph. 92 bool AddOutput(const py::object &output_obj, bool is_top_graph = true); 93 94 /// \brief Remove an output node of the graph. 95 /// 96 /// \param[in] output_obj The output python object. 97 void RemoveOutput(const py::object &output_obj); 98 99 /// \brief Clear all output node of the graph. ClearOutputNodes()100 void ClearOutputNodes() { output_nodes_.clear(); } 101 102 /// \brief Get the callable python primitive or function. 103 /// 104 /// \param[in] obj The method of a python object. 105 /// 106 /// \return Return the corresponding primitive of function of the func. 107 static py::object ConvertMethod(const py::object &obj); 108 109 /// \brief Get the callable python primitive, meta_func_graph or function. 110 /// 111 /// \param[in] obj The python object of a function. 112 /// 113 /// \return Return the corresponding primitive of function of the func. 114 static py::object ConvertFunction(const py::object &obj); 115 116 /// \brief Check if the python object can be converted to a cnode directly. 117 /// 118 /// \param[in] obj A python object. 119 /// 120 /// \return Return true if the python object can be converted to a cnode directly. 121 static bool CheckCallable(const py::object &obj); 122 123 /// \brief Check if the python object is a function which can be constantly folded. 124 /// 125 /// \param[in] obj A python object. 126 /// 127 /// \return Return true if the python object is a function which can be constantly folded. 128 static bool CanConstantFoldFunc(const py::object &obj); 129 130 /// \brief Check if the python object is valid as the callable object in graph. 131 /// 132 /// \param[in] obj A python object. 133 /// 134 /// \return Return true if the python object is valid as the callable object in graph. 135 static bool ValidateCallableObject(const py::object &obj); 136 137 /// \brief Set the final outputs and get the graph. 138 /// 139 /// \return The graph constructed. 140 FuncGraphPtr graph(); 141 142 /// \brief Clear abstract for nodes. 143 void ClearNodeAbstract(); 144 145 /// \brief Set the name of the func_graph. 146 /// 147 /// \param[in] name The func_graph name to set. 148 void SetGraphName(const std::string &name); 149 150 static ValuePtr ConvertPyObjToValue(const py::object &obj); 151 152 static AbstractBasePtr EvalValue(const ValuePtr &value, const AbstractBasePtrList &inputs_abs_list); 153 154 using PyTensorConverter = std::function<py::object(const py::object &)>; 155 static py::object ConvertToPyObj(const AbstractBasePtr &abs); 156 static py::object ConvertToPyObj(const AbstractBasePtr &abs, const PyTensorConverter &tensor_convert_func); 157 158 void AddPrevBuilder(const FuncGraphBuilderPtr &builder); 159 prev_builders()160 const std::vector<FuncGraphBuilder *> &prev_builders() const { return prev_builders_; } 161 162 AnfNodePtr GetNodeByObject(const py::object &obj); 163 164 AnfNodePtr ReadLocalVariable(const py::object &obj); 165 166 bool AddLocalVariable(const py::object &obj); 167 168 private: 169 static bool CheckCallable(const ValuePtr &value, const AbstractBasePtr &abs); 170 171 static bool CheckGraphOutput(const AbstractBasePtr &abs); 172 173 AnfNodePtr ConvertObjToNode(const py::object &input_obj); 174 175 py::object AddFgCallNode(const FuncGraphPtr &fg, const std::vector<py::object> &inputs_obj); 176 177 bool GetInputNodesAndAbstracts(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj, 178 std::vector<AnfNodePtr> *input_node_list, 179 std::vector<AbstractBasePtr> *input_abs_list); 180 181 static AbstractBasePtr DoInferAndCheck(const ValuePtr &callable_value, 182 const std::vector<AbstractBasePtr> &input_abs_list); 183 184 CNodePtr DoPrimitiveInferAndCheck(const PrimitivePtr &primitive, const AnfNodePtrList &input_node_list, 185 const AbstractBasePtrList &args_abs_list); 186 CNodePtr AddPrimitiveCNode(const PrimitivePtr &primitive, const AnfNodePtrList &input_node_list, 187 const AbstractBasePtrList &args_abs_list); 188 189 static AbstractBasePtr GetAbstractOf(const AnfNodePtr &node); 190 191 py::object TryToAddNode(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj); 192 193 py::object ConvertToPyTensorOrParameter(const py::object &cpp_tensor); 194 195 static bool CheckInvalidCellListDictMethod(const py::object &obj); 196 197 bool AddTopGraphArgsInputs(const py::object &object); 198 199 bool AddTopGraphVargsInputs(const py::object &vargs); 200 201 bool AddTopGraphKwargsInputs(const py::object &vargs); 202 203 FuncGraphPtr graph_{nullptr}; 204 bool has_set_output_{false}; 205 HashMap<PyObject *, AnfNodePtr> py_obj_to_node_; 206 std::vector<AnfNodePtr> output_nodes_; 207 208 // Store all previous builders for subgraph call and control flow. 209 std::vector<FuncGraphBuilder *> prev_builders_; 210 }; 211 } // namespace mindspore 212 #endif // MINDSPORE_PI_JIT_GRAPH_BUILD_FUNC_GRAPH_BUILDER_H_ 213