• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PI_JIT_GRAPH_BUILD_FUNC_GRAPH_BUILDER_H_
18 #define MINDSPORE_PI_JIT_GRAPH_BUILD_FUNC_GRAPH_BUILDER_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <string>
23 #include "ir/value.h"
24 #include "ops/sequence_ops.h"
25 #include "pipeline/jit/ps/parse/parse_base.h"
26 #include "pipeline/jit/ps/parse/parse.h"
27 
28 namespace mindspore {
29 class FuncGraphBuilder;
30 using FuncGraphBuilderPtr = std::shared_ptr<FuncGraphBuilder>;
31 
32 class FuncGraphBuilder {
33  public:
graph_(std::make_shared<FuncGraph> ())34   explicit FuncGraphBuilder(bool is_top = false) : graph_(std::make_shared<FuncGraph>()) {
35     if (is_top) {
36       parse::Parser::UpdateTopFuncGraph(graph_);
37     }
38   }
~FuncGraphBuilder()39   virtual ~FuncGraphBuilder() { py_obj_to_node_.clear(); }
40 
41   /// \brief Add an input parameter to the graph.
42   ///
43   /// \param[in] obj The input python object.
44   ///
45   /// \return If the input is a tensor, return a fake tensor python object, else return the origin python object.
46   py::object AddSubGraphInput(const py::object &obj);
47 
48   /// \brief Add an input parameter to the top graph.
49   ///
50   /// \param[in] packed_inputs The input python object for top graph.
51   ///
52   /// \return True if add top graph success, otherwise false.
53   bool AddTopGraphInputs(std::vector<py::object> packed_inputs);
54 
55   /// \brief Add a cnode to the graph.
56   ///
57   /// \param[in] callable_obj The callable python object.
58   /// \param[in] inputs_obj The input python objects.
59   ///
60   /// \return The python object of the infer result.
61   py::object AddNode(const py::object &callable_obj, const std::vector<py::object> &inputs_obj);
62 
63   /// \brief Add a cnode to the graph.
64   ///
65   /// \param[in] callable_value The callable value.
66   /// \param[in] inputs_obj The input python objects.
67   ///
68   /// \return The python object of the infer result.
69   py::object AddNode(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj);
70 
71   /// \brief Add a python object to graph.
72   ///
73   /// \param[in] object The python object add to graph.
74   ///
75   /// \return Indicate whether the python object add to graph successfully.
76   bool AddAttrPythonObject(const py::object &object);
77 
78   /// \brief Add a binary operation cnode to the graph.
79   ///
80   /// \param[in] opcode The binary operation code.
81   /// \param[in] inputs_obj The input python objects.
82   ///
83   /// \return The python object of the infer result.
84   py::object AddMultiNode(const std::string &opcode, const std::vector<py::object> &inputs_obj);
85 
86   /// \brief Add an output node to the graph.
87   ///
88   /// \param[in] output_obj The output python object.
89   /// \param[in] is_top_graph Indicate whether the graph to add output is top graph.
90   ///
91   /// \return Return true if the output object can be used as the output of the graph.
92   bool AddOutput(const py::object &output_obj, bool is_top_graph = true);
93 
94   /// \brief Remove an output node of the graph.
95   ///
96   /// \param[in] output_obj The output python object.
97   void RemoveOutput(const py::object &output_obj);
98 
99   /// \brief Clear all output node of the graph.
ClearOutputNodes()100   void ClearOutputNodes() { output_nodes_.clear(); }
101 
102   /// \brief Get the callable python primitive or function.
103   ///
104   /// \param[in] obj The method of a python object.
105   ///
106   /// \return Return the corresponding primitive of function of the func.
107   static py::object ConvertMethod(const py::object &obj);
108 
109   /// \brief Get the callable python primitive, meta_func_graph or function.
110   ///
111   /// \param[in] obj The python object of a function.
112   ///
113   /// \return Return the corresponding primitive of function of the func.
114   static py::object ConvertFunction(const py::object &obj);
115 
116   /// \brief Check if the python object can be converted to a cnode directly.
117   ///
118   /// \param[in] obj A python object.
119   ///
120   /// \return Return true if the python object can be converted to a cnode directly.
121   static bool CheckCallable(const py::object &obj);
122 
123   /// \brief Check if the python object is a function which can be constantly folded.
124   ///
125   /// \param[in] obj A python object.
126   ///
127   /// \return Return true if the python object is a function which can be constantly folded.
128   static bool CanConstantFoldFunc(const py::object &obj);
129 
130   /// \brief Check if the python object is valid as the callable object in graph.
131   ///
132   /// \param[in] obj A python object.
133   ///
134   /// \return Return true if the python object is valid as the callable object in graph.
135   static bool ValidateCallableObject(const py::object &obj);
136 
137   /// \brief Set the final outputs and get the graph.
138   ///
139   /// \return The graph constructed.
140   FuncGraphPtr graph();
141 
142   /// \brief Clear abstract for nodes.
143   void ClearNodeAbstract();
144 
145   /// \brief Set the name of the func_graph.
146   ///
147   /// \param[in] name The func_graph name to set.
148   void SetGraphName(const std::string &name);
149 
150   static ValuePtr ConvertPyObjToValue(const py::object &obj);
151 
152   static AbstractBasePtr EvalValue(const ValuePtr &value, const AbstractBasePtrList &inputs_abs_list);
153 
154   using PyTensorConverter = std::function<py::object(const py::object &)>;
155   static py::object ConvertToPyObj(const AbstractBasePtr &abs);
156   static py::object ConvertToPyObj(const AbstractBasePtr &abs, const PyTensorConverter &tensor_convert_func);
157 
158   void AddPrevBuilder(const FuncGraphBuilderPtr &builder);
159 
prev_builders()160   const std::vector<FuncGraphBuilder *> &prev_builders() const { return prev_builders_; }
161 
162   AnfNodePtr GetNodeByObject(const py::object &obj);
163 
164   AnfNodePtr ReadLocalVariable(const py::object &obj);
165 
166   bool AddLocalVariable(const py::object &obj);
167 
168  private:
169   static bool CheckCallable(const ValuePtr &value, const AbstractBasePtr &abs);
170 
171   static bool CheckGraphOutput(const AbstractBasePtr &abs);
172 
173   AnfNodePtr ConvertObjToNode(const py::object &input_obj);
174 
175   py::object AddFgCallNode(const FuncGraphPtr &fg, const std::vector<py::object> &inputs_obj);
176 
177   bool GetInputNodesAndAbstracts(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj,
178                                  std::vector<AnfNodePtr> *input_node_list,
179                                  std::vector<AbstractBasePtr> *input_abs_list);
180 
181   static AbstractBasePtr DoInferAndCheck(const ValuePtr &callable_value,
182                                          const std::vector<AbstractBasePtr> &input_abs_list);
183 
184   CNodePtr DoPrimitiveInferAndCheck(const PrimitivePtr &primitive, const AnfNodePtrList &input_node_list,
185                                     const AbstractBasePtrList &args_abs_list);
186   CNodePtr AddPrimitiveCNode(const PrimitivePtr &primitive, const AnfNodePtrList &input_node_list,
187                              const AbstractBasePtrList &args_abs_list);
188 
189   static AbstractBasePtr GetAbstractOf(const AnfNodePtr &node);
190 
191   py::object TryToAddNode(const ValuePtr &callable_value, const std::vector<py::object> &inputs_obj);
192 
193   py::object ConvertToPyTensorOrParameter(const py::object &cpp_tensor);
194 
195   static bool CheckInvalidCellListDictMethod(const py::object &obj);
196 
197   bool AddTopGraphArgsInputs(const py::object &object);
198 
199   bool AddTopGraphVargsInputs(const py::object &vargs);
200 
201   bool AddTopGraphKwargsInputs(const py::object &vargs);
202 
203   FuncGraphPtr graph_{nullptr};
204   bool has_set_output_{false};
205   HashMap<PyObject *, AnfNodePtr> py_obj_to_node_;
206   std::vector<AnfNodePtr> output_nodes_;
207 
208   // Store all previous builders for subgraph call and control flow.
209   std::vector<FuncGraphBuilder *> prev_builders_;
210 };
211 }  // namespace mindspore
212 #endif  // MINDSPORE_PI_JIT_GRAPH_BUILD_FUNC_GRAPH_BUILDER_H_
213