• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_BUILD_H
17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_BUILD_H
18 
19 #include <vector>
20 #include <unordered_map>
21 #include <utility>
22 #include <memory>
23 #include <string>
24 #include "pipeline/jit/pi/graph_capture/graph.h"
25 #include "pipeline/jit/pi/graph_build/func_graph_builder.h"
26 #include "utils/convert_utils_base.h"
27 
28 namespace mindspore {
29 namespace pijit {
30 class GraphBuilder;
31 class MindGraphBuilder;
32 using GraphBuilderPtr = std::shared_ptr<GraphBuilder>;
33 using MindGraphBuilderPtr = std::shared_ptr<MindGraphBuilder>;
34 
35 struct TryBlock {
36   int type;       /*what kind of block this is (SETUP_SETUP, SETUP_FINALLY, SETUP_EXCEPT)*/
37   int bci;        /*where to jump to find handler*/
38   int checkpoint; /*the handler to be rolled back*/
39   // int level;   /* value stack level to pop toe*/
40   bool IsFinallyBlock; /*record current block is in exception block or finally block*/
41 };
42 
43 bool CheckSupportCreateInstance(CallNode *call_node);
44 class GraphBuilder {
45  public:
46   static const char *ID___self__;
47   static const char *ID___globals__;
48   static const char *ID___call__;
49   static const char *ID_construct;
50 
51   explicit GraphBuilder(const PyFrameObject *f);
GraphBuilder(GraphBuilder * r,GraphBuilder * p,PyCodeObject * co,PyObject * globals)52   GraphBuilder(GraphBuilder *r, GraphBuilder *p, PyCodeObject *co, PyObject *globals)
53       : root_(r), parent_(p), graph_(NewGraph(co, globals)), frame_(), current_block_(nullptr) {}
GraphBuilder(GraphBuilder * r)54   explicit GraphBuilder(GraphBuilder *r) : root_(r), parent_(nullptr), graph_(nullptr), current_block_(nullptr) {}
~GraphBuilder()55   ~GraphBuilder() {
56     for (auto i : graph_pool_) {
57       delete i;
58     }
59     graph_pool_.clear();
60   }
Creator(const PyFrameObject * f,bool trace_flag)61   static GraphBuilderPtr Creator(const PyFrameObject *f, bool trace_flag) {
62     return trace_flag ? std::static_pointer_cast<GraphBuilder>(std::make_shared<MindGraphBuilder>(f))
63                       : std::make_shared<GraphBuilder>(f);
64   }
Creator(GraphBuilder * r,GraphBuilder * p,PyCodeObject * co,PyObject * globals,bool trace_flag)65   static GraphBuilderPtr Creator(GraphBuilder *r, GraphBuilder *p, PyCodeObject *co, PyObject *globals,
66                                  bool trace_flag) {
67     return trace_flag ? std::static_pointer_cast<GraphBuilder>(std::make_shared<MindGraphBuilder>(r, p, co, globals))
68                       : std::make_shared<GraphBuilder>(r, p, co, globals);
69   }
70 
71   StopTraceReason TraceRun();
trace_flag()72   virtual bool trace_flag() { return false; }
73 
74   void CollectInlineInfo(CallNode *node, int depth);
GetGraph()75   Graph *GetGraph() const { return graph_; }
76   void DumpDFG();
77 
78   // NOTE: nn.Cell will return 'construct'
79   static py::object FindPyFunc(AObject *vobj);
80   static py::object GetFuncInfo(ValueNode *func_node);
81 
82   // TryBlockStack operation
83   TryBlock &PeekStack(int p);
PushStack(TryBlock tb)84   void PushStack(TryBlock tb) { tryBlockStacks_.push_back(tb); }
StackSize()85   int StackSize() { return tryBlockStacks_.size(); }
GetTryBlockStacks()86   std::vector<TryBlock> &GetTryBlockStacks() { return tryBlockStacks_; }
87   TryBlock &PopStack();
88 
89   // loop analyze
90   void HandleLoop();
91 
92   /**
93    * Handle call node. Infer call result. Inline call node bytecode
94    * \param depth Current inline depth
95    * \return Ttop trace reason of sub-graph
96    */
97   StopTraceReason HandleCall(int depth);
98 
99   /**
100    * Resolve callable object, if it's unknown object, return infer failed reason.
101    * Check inline white list, infer result and not inline bytecode
102    * If call a class, try to handle class
103    * \param [in] call_node
104    * \param [out] stop_reason
105    * \return The function object of call target
106    */
107   virtual py::object ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason);
108 
109   /**
110    * Resolve closure of function, generate cell free nodes to trace closure
111    * \param func_info The function of call target
112    * \param callable_node The value node of callable object
113    * \param frame FrameStates to place closure node
114    */
115   void ResolveClosure(const py::object &func_info, ValueNode *callable_node, FrameStates *frame);
116 
117   std::pair<PyObject *, ValueNode *> SearchSelfPyObject(PyCodeObject *co);
118   bool HandleSuper(const Instr &instr, AObject *super);
119   AObject *BuildSuperObject(PyCodeObject *co);
120 
121   /**
122    * Collect parameters of call stack and set it to frame
123    * \param func_info The function of call target
124    * \param call_node This calling information
125    * \param frame FrameStates to place parameters nodes
126    * \return false if parameters is illegal
127    */
128   bool HandleCallParameters(const py::object &func_info, CallNode *call_node, FrameStates *frame);
129 
130   bool UnpackDynamicLengthTupleByBytecode(std::vector<ValueNode *> *params, ValueNode *args_node, CallNode *call_node);
131 
132   /**
133    * Unpack CALL_FUNCTION_EX parameters to stack
134    * \param[in] params the call stack
135    * \param[in] extra_local extra local index
136    * \param[out] extra_oper unpack operations by bytecode
137    * \param[out] has_kw this call has key-word arguments
138    * \return false if can't generate unpack operations
139    */
140   virtual bool UnpackCallExParams(std::vector<ValueNode *> *params, int extra_local, bool *has_kw, CallNode *call_node);
141 
142   virtual bool UnpackCallExDict(std::vector<ValueNode *> *params, CallNode *call_node);
143 
144   bool UnpackDynamicLengthDictByBytecode(std::vector<ValueNode *> *params, CallNode *call_node, ValueNode *dict_node);
145   // generate the general unpack operations of dict, return operations
146   std::vector<AbstractNode *> GenerateDictUnpack(ValueNode *kwargs_node);
147 
148   /**
149    * Pack key-word parameters, generate kwvargs value node, check kw-defaults arguments
150    * \param[in] func The function of call target
151    * \param[in] params This calling stack
152    * \param[in] frame FrameStates to place parameters nodes
153    * \param[out] extra_oper the move operations to move parameters to locals
154    * \return false if parameters is illegal
155    */
156   virtual bool HandleKWParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame);
157 
158   /**
159    * Pack key-word parameters to dict, unpack the position arguments by key from the dict.
160    * Set parameters to frame
161    * \param[in] func The function of call target
162    * \param[in] params This calling stack
163    * \param[in] frame FrameStates to place parameters nodes
164    * \param[out] dict_gen the move operations to move parameters to locals
165    * \param[out] dict_op the opcode of dict generation
166    * \return false if parameters is illegal
167    */
168   bool PackKwParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame,
169                     std::vector<ValueNode *> *kwvargs);
170 
171   bool CheckAndSetDefaultParams(const py::object &func, FrameStates *frame, int pargc);
172 
173   /**
174    * Use the call stack without key-word arguments to fill the frame locals
175    */
176   virtual bool HandlePositionParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame);
177 
178   // build subgraph, return stop trace reason
179   virtual StopTraceReason BuildSubGraph(CallNode *call_node, int depth, const py::object &func,
180                                         const GraphBuilderPtr &subgraph);
181 
182   bool ReplaceCall(CallNode *call_node, const py::object &func);
183 
184   // build abstract instance of python class
185   virtual bool HandleCallClass(CallNode *call_node);
186 
187   // return false if has unsupported bytecode
188   bool DoByteCode(const Instr &instr);
189 
190   // unpack elements
191   bool UnpackElements(ValueNode *);
192 
193   // unpack elements
194   bool UnpackSequenceElements(ValueNode *);
195 
196   // unpack object elements as LOAD_CONST
197   std::vector<ValueNode *> UnpackConstObject(const py::object &);
198 
199   // return true if not inline
200   virtual bool WhiteListFuncCheckAndInfer(CallNode *, const py::object &f);
201 
202   bool DoSetItem(ValueNode *map, ValueNode *key, ValueNode *val);
203 
204   // transform dict set item to make a new dict
205   ValueNode *TransformDictSetItem(ValueNode *map, ValueNode *key, ValueNode *val, bool ignore_key_error);
206 
207   // transform list set item to make a new list
208   ValueNode *TransformListSetItem(ValueNode *list, ValueNode *key, ValueNode *val);
209 
210   ValueNode *ReplaceMergeOp(int opcode, const std::vector<ValueNode *> &inputs);
211 
212   bool ClassInstantiationFold(CallNode *, AObject::Type);
213 
214   // frame operation
seek(int p)215   ValueNode *&seek(int p) { return frame_.Peek(p); }
push(ValueNode * v)216   void push(ValueNode *v) { frame_.Push(v); }
pop()217   ValueNode *pop() { return frame_.Pop(); }
popn(int n)218   void popn(int n) { frame_.Popn(n); }
getLocal(int i)219   ValueNode *getLocal(int i) { return frame_.Local(i); }
setLocal(int i,ValueNode * n)220   void setLocal(int i, ValueNode *n) { frame_.SetLocal(i, n); }
221 
222   // pointers
223   std::vector<Graph *> graph_pool_;
224   ValueNode *NewValueNode(AObject *o, int op, int arg, const std::vector<ValueNode *> &p = {},
225                           const std::string &name = "");
226   ValueNode *NewValueNode(AObject *o, const Instr &, const std::vector<ValueNode *> &p = {});
227   Graph *NewGraph(PyCodeObject *co, PyObject *f_globals);
228 
229   bool ReplaceAll(ValueNode *old_node, ValueNode *new_node, bool *referenced = nullptr);
230 
231   bool TraceRunForIterSequence(int jump_bci, bool is_range_type);
232   bool TraceRunForIterEnumerate(int jump_bci);
233   bool TraceRunForIterZip(int jump_bci);
234 
235   // bytecode operations
236   bool TraceRunControl(const Instr &instr);
237   bool TraceRunForIter(const Instr &instr);
238   bool DoUnpack(const Instr &instr);
239   bool DoCall(const Instr &instr);
240   bool DoNop(const Instr &instr);
241   bool DoReturn(const Instr &instr);
242   bool DoLocalAccess(const Instr &instr);
243   bool DoCellAccess(const Instr &instr);
244   bool DoGlobalAccess(const Instr &instr);
245   bool DoAttrAccess(const Instr &instr);
246   virtual ValueNode *HandleGetattr(ValueNode *target_node, const Instr &instr);
247   virtual bool DoGetItem(const Instr &instr);
248   virtual bool DoItemAccess(const Instr &instr);
249   bool DoStackOp(const Instr &instr);
250   bool DoLoadConst(const Instr &instr);
251   bool DoListToTuple(const Instr &instr);
252   bool DoGetIter(const Instr &instr);
253   bool DoMakeFunction(const Instr &instr);
254   AObject *InferUnary(ValueNode *, const Instr &instr);
255   virtual bool DoUnary(const Instr &instr);
256   AObject *InferBinary(ValueNode *, ValueNode *, const Instr &instr);
257   virtual bool DoBinary(const Instr &instr);
258   virtual bool DoIsOp(const Instr &instr);
259   virtual bool DoBinaryMul(const Instr &instr);
260   bool DoBinaryAdd(const Instr &instr);
261   bool DoInplaceAdd(const Instr &instr);
262   virtual bool DoCompare(const Instr &instr);
263   virtual bool DoBuildOp(const Instr &instr);
264   bool DoMergeOp(const Instr &instr);
265   bool DoFormatValue(const Instr &instr);
266   bool DoImport(const Instr &instr);
267   bool DoYieldValue(const Instr &instr);
268   bool DoException(const Instr &instr);
269   bool DoWith(const Instr &instr);
270   bool DoOtherBytecode(const Instr &instr);
271   bool NotImplementBytecode(const Instr &instr);
272 
root()273   const auto &root() const { return root_; }
frame()274   const auto &frame() const { return frame_; }
275 
276  protected:
277   GraphBuilder *root_;
278   GraphBuilder *parent_;
279   Graph *graph_;
280   FrameStates frame_;
281   Block *current_block_;
282   int cur_bci_;
283   std::vector<TryBlock> tryBlockStacks_{};
284 
285   static const std::unordered_map<int, bool (GraphBuilder::*)(const Instr &)> bytecode_meth_map_;
286 
287   ValueNode *GetCallFunctionNode(ValueNode *node, PyObject *dst_dtype);
288   bool DoMixedPrecisionLocalAccess(const Instr &instr, ValueNode *node);
289   ValueNode *DoMixedPrecisionAttrAccess(const Instr &instr, ValueNode *node, ValueNode *attr);
290 };
291 
292 class MindGraphBuilder : public GraphBuilder {
293  public:
294   explicit MindGraphBuilder(const PyFrameObject *f);
MindGraphBuilder(GraphBuilder * r,GraphBuilder * p,PyCodeObject * co,PyObject * globals)295   MindGraphBuilder(GraphBuilder *r, GraphBuilder *p, PyCodeObject *co, PyObject *globals)
296       : GraphBuilder(r, p, co, globals) {
297     std::vector<std::string> comments;
298     auto location = co ? std::make_shared<Location>(py::cast<std::string>(co->co_filename), co->co_firstlineno, 0,
299                                                     co->co_firstlineno, 0, "", std::move(comments))
300                        : std::make_shared<Location>("anonymous", 0, 0, 0, 0, "", std::move(comments));
301     MS_EXCEPTION_IF_NULL(location);
302     TraceGuard trace_guard(location);
303     fg_builder_ = std::make_shared<FuncGraphBuilder>();
304   }
trace_flag()305   bool trace_flag() { return true; }
FGBuilder()306   mindspore::FuncGraphBuilderPtr FGBuilder() const { return fg_builder_; }
307   bool FGAddInputs(const std::vector<py::object> &args);
308   py::object FGAddNode(CallNode *call_node, const py::object &callable_info, const std::vector<py::object> &args,
309                        StopTraceReason *stop_reason);
310   void FGAddOutput(bool is_top_graph);
311   StopTraceReason BuildSubGraph(CallNode *call_node, int depth, const py::object &func,
312                                 const GraphBuilderPtr &subgraph) override;
313   py::object ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason) override;
314   bool WhiteListFuncCheckAndInfer(CallNode *, const py::object &f) override;
315 
316   LocationPtr GetLocation(CallNode *call_node) const;
317 
318  protected:
319   bool DoGetItem(const Instr &instr) override;
320   bool DoItemAccess(const Instr &instr) override;
321   bool DoUnary(const Instr &instr) override;
322   bool DoBinary(const Instr &instr) override;
323   bool DoIsOp(const Instr &instr) override;
324   bool DoBinaryMul(const Instr &instr) override;
325   bool DoCompare(const Instr &instr) override;
326   bool DoBuildOp(const Instr &instr) override;
327   ValueNode *HandleGetattr(ValueNode *target_node, const Instr &instr) override;
328   bool HandlePositionParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) override;
329   bool UnpackCallExParams(std::vector<ValueNode *> *params, int extra_local, bool *has_kw,
330                           CallNode *call_node) override;
331   bool HandleKWParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) override;
332   bool UnpackCallExDict(std::vector<ValueNode *> *params, CallNode *call_node) override;
333   bool HandleCallClass(CallNode *call_node) override;
334 
335  private:
336   std::vector<py::object> GetNewArgs(CallNode *call_node, AObject *vobj = nullptr);
337   bool AllConstantArgs(const std::vector<py::object> &args, const py::object &callable_info, CallNode *call_node);
338 
339   py::object HandleGetShapeOfDynamicLengthTensor(const py::object &object);
340 
341   mindspore::FuncGraphBuilderPtr fg_builder_{nullptr};
342   std::string co_name_;
343   AObject *HandleMultiOp(const Instr &instr, const std::vector<ValueNode *> &p, bool is_compare);
344   AObject *HandleBuildOp(const Instr &instr, const std::vector<ValueNode *> &p);
345 };
346 }  // namespace pijit
347 }  // namespace mindspore
348 
349 #endif  // MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_BUILD_H
350