1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_BUILD_H 17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_BUILD_H 18 19 #include <vector> 20 #include <unordered_map> 21 #include <utility> 22 #include <memory> 23 #include <string> 24 #include "pipeline/jit/pi/graph_capture/graph.h" 25 #include "pipeline/jit/pi/graph_build/func_graph_builder.h" 26 #include "utils/convert_utils_base.h" 27 28 namespace mindspore { 29 namespace pijit { 30 class GraphBuilder; 31 class MindGraphBuilder; 32 using GraphBuilderPtr = std::shared_ptr<GraphBuilder>; 33 using MindGraphBuilderPtr = std::shared_ptr<MindGraphBuilder>; 34 35 struct TryBlock { 36 int type; /*what kind of block this is (SETUP_SETUP, SETUP_FINALLY, SETUP_EXCEPT)*/ 37 int bci; /*where to jump to find handler*/ 38 int checkpoint; /*the handler to be rolled back*/ 39 // int level; /* value stack level to pop toe*/ 40 bool IsFinallyBlock; /*record current block is in exception block or finally block*/ 41 }; 42 43 bool CheckSupportCreateInstance(CallNode *call_node); 44 class GraphBuilder { 45 public: 46 static const char *ID___self__; 47 static const char *ID___globals__; 48 static const char *ID___call__; 49 static const char *ID_construct; 50 51 explicit GraphBuilder(const PyFrameObject *f); GraphBuilder(GraphBuilder * r,GraphBuilder * p,PyCodeObject * co,PyObject * globals)52 GraphBuilder(GraphBuilder *r, GraphBuilder *p, PyCodeObject *co, PyObject *globals) 53 : root_(r), parent_(p), graph_(NewGraph(co, globals)), frame_(), current_block_(nullptr) {} GraphBuilder(GraphBuilder * r)54 explicit GraphBuilder(GraphBuilder *r) : root_(r), parent_(nullptr), graph_(nullptr), current_block_(nullptr) {} ~GraphBuilder()55 ~GraphBuilder() { 56 for (auto i : graph_pool_) { 57 delete i; 58 } 59 graph_pool_.clear(); 60 } Creator(const PyFrameObject * f,bool trace_flag)61 static GraphBuilderPtr Creator(const PyFrameObject *f, bool trace_flag) { 62 return trace_flag ? std::static_pointer_cast<GraphBuilder>(std::make_shared<MindGraphBuilder>(f)) 63 : std::make_shared<GraphBuilder>(f); 64 } Creator(GraphBuilder * r,GraphBuilder * p,PyCodeObject * co,PyObject * globals,bool trace_flag)65 static GraphBuilderPtr Creator(GraphBuilder *r, GraphBuilder *p, PyCodeObject *co, PyObject *globals, 66 bool trace_flag) { 67 return trace_flag ? std::static_pointer_cast<GraphBuilder>(std::make_shared<MindGraphBuilder>(r, p, co, globals)) 68 : std::make_shared<GraphBuilder>(r, p, co, globals); 69 } 70 71 StopTraceReason TraceRun(); trace_flag()72 virtual bool trace_flag() { return false; } 73 74 void CollectInlineInfo(CallNode *node, int depth); GetGraph()75 Graph *GetGraph() const { return graph_; } 76 void DumpDFG(); 77 78 // NOTE: nn.Cell will return 'construct' 79 static py::object FindPyFunc(AObject *vobj); 80 static py::object GetFuncInfo(ValueNode *func_node); 81 82 // TryBlockStack operation 83 TryBlock &PeekStack(int p); PushStack(TryBlock tb)84 void PushStack(TryBlock tb) { tryBlockStacks_.push_back(tb); } StackSize()85 int StackSize() { return tryBlockStacks_.size(); } GetTryBlockStacks()86 std::vector<TryBlock> &GetTryBlockStacks() { return tryBlockStacks_; } 87 TryBlock &PopStack(); 88 89 // loop analyze 90 void HandleLoop(); 91 92 /** 93 * Handle call node. Infer call result. Inline call node bytecode 94 * \param depth Current inline depth 95 * \return Ttop trace reason of sub-graph 96 */ 97 StopTraceReason HandleCall(int depth); 98 99 /** 100 * Resolve callable object, if it's unknown object, return infer failed reason. 101 * Check inline white list, infer result and not inline bytecode 102 * If call a class, try to handle class 103 * \param [in] call_node 104 * \param [out] stop_reason 105 * \return The function object of call target 106 */ 107 virtual py::object ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason); 108 109 /** 110 * Resolve closure of function, generate cell free nodes to trace closure 111 * \param func_info The function of call target 112 * \param callable_node The value node of callable object 113 * \param frame FrameStates to place closure node 114 */ 115 void ResolveClosure(const py::object &func_info, ValueNode *callable_node, FrameStates *frame); 116 117 std::pair<PyObject *, ValueNode *> SearchSelfPyObject(PyCodeObject *co); 118 bool HandleSuper(const Instr &instr, AObject *super); 119 AObject *BuildSuperObject(PyCodeObject *co); 120 121 /** 122 * Collect parameters of call stack and set it to frame 123 * \param func_info The function of call target 124 * \param call_node This calling information 125 * \param frame FrameStates to place parameters nodes 126 * \return false if parameters is illegal 127 */ 128 bool HandleCallParameters(const py::object &func_info, CallNode *call_node, FrameStates *frame); 129 130 bool UnpackDynamicLengthTupleByBytecode(std::vector<ValueNode *> *params, ValueNode *args_node, CallNode *call_node); 131 132 /** 133 * Unpack CALL_FUNCTION_EX parameters to stack 134 * \param[in] params the call stack 135 * \param[in] extra_local extra local index 136 * \param[out] extra_oper unpack operations by bytecode 137 * \param[out] has_kw this call has key-word arguments 138 * \return false if can't generate unpack operations 139 */ 140 virtual bool UnpackCallExParams(std::vector<ValueNode *> *params, int extra_local, bool *has_kw, CallNode *call_node); 141 142 virtual bool UnpackCallExDict(std::vector<ValueNode *> *params, CallNode *call_node); 143 144 bool UnpackDynamicLengthDictByBytecode(std::vector<ValueNode *> *params, CallNode *call_node, ValueNode *dict_node); 145 // generate the general unpack operations of dict, return operations 146 std::vector<AbstractNode *> GenerateDictUnpack(ValueNode *kwargs_node); 147 148 /** 149 * Pack key-word parameters, generate kwvargs value node, check kw-defaults arguments 150 * \param[in] func The function of call target 151 * \param[in] params This calling stack 152 * \param[in] frame FrameStates to place parameters nodes 153 * \param[out] extra_oper the move operations to move parameters to locals 154 * \return false if parameters is illegal 155 */ 156 virtual bool HandleKWParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame); 157 158 /** 159 * Pack key-word parameters to dict, unpack the position arguments by key from the dict. 160 * Set parameters to frame 161 * \param[in] func The function of call target 162 * \param[in] params This calling stack 163 * \param[in] frame FrameStates to place parameters nodes 164 * \param[out] dict_gen the move operations to move parameters to locals 165 * \param[out] dict_op the opcode of dict generation 166 * \return false if parameters is illegal 167 */ 168 bool PackKwParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame, 169 std::vector<ValueNode *> *kwvargs); 170 171 bool CheckAndSetDefaultParams(const py::object &func, FrameStates *frame, int pargc); 172 173 /** 174 * Use the call stack without key-word arguments to fill the frame locals 175 */ 176 virtual bool HandlePositionParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame); 177 178 // build subgraph, return stop trace reason 179 virtual StopTraceReason BuildSubGraph(CallNode *call_node, int depth, const py::object &func, 180 const GraphBuilderPtr &subgraph); 181 182 bool ReplaceCall(CallNode *call_node, const py::object &func); 183 184 // build abstract instance of python class 185 virtual bool HandleCallClass(CallNode *call_node); 186 187 // return false if has unsupported bytecode 188 bool DoByteCode(const Instr &instr); 189 190 // unpack elements 191 bool UnpackElements(ValueNode *); 192 193 // unpack elements 194 bool UnpackSequenceElements(ValueNode *); 195 196 // unpack object elements as LOAD_CONST 197 std::vector<ValueNode *> UnpackConstObject(const py::object &); 198 199 // return true if not inline 200 virtual bool WhiteListFuncCheckAndInfer(CallNode *, const py::object &f); 201 202 bool DoSetItem(ValueNode *map, ValueNode *key, ValueNode *val); 203 204 // transform dict set item to make a new dict 205 ValueNode *TransformDictSetItem(ValueNode *map, ValueNode *key, ValueNode *val, bool ignore_key_error); 206 207 // transform list set item to make a new list 208 ValueNode *TransformListSetItem(ValueNode *list, ValueNode *key, ValueNode *val); 209 210 ValueNode *ReplaceMergeOp(int opcode, const std::vector<ValueNode *> &inputs); 211 212 bool ClassInstantiationFold(CallNode *, AObject::Type); 213 214 // frame operation seek(int p)215 ValueNode *&seek(int p) { return frame_.Peek(p); } push(ValueNode * v)216 void push(ValueNode *v) { frame_.Push(v); } pop()217 ValueNode *pop() { return frame_.Pop(); } popn(int n)218 void popn(int n) { frame_.Popn(n); } getLocal(int i)219 ValueNode *getLocal(int i) { return frame_.Local(i); } setLocal(int i,ValueNode * n)220 void setLocal(int i, ValueNode *n) { frame_.SetLocal(i, n); } 221 222 // pointers 223 std::vector<Graph *> graph_pool_; 224 ValueNode *NewValueNode(AObject *o, int op, int arg, const std::vector<ValueNode *> &p = {}, 225 const std::string &name = ""); 226 ValueNode *NewValueNode(AObject *o, const Instr &, const std::vector<ValueNode *> &p = {}); 227 Graph *NewGraph(PyCodeObject *co, PyObject *f_globals); 228 229 bool ReplaceAll(ValueNode *old_node, ValueNode *new_node, bool *referenced = nullptr); 230 231 bool TraceRunForIterSequence(int jump_bci, bool is_range_type); 232 bool TraceRunForIterEnumerate(int jump_bci); 233 bool TraceRunForIterZip(int jump_bci); 234 235 // bytecode operations 236 bool TraceRunControl(const Instr &instr); 237 bool TraceRunForIter(const Instr &instr); 238 bool DoUnpack(const Instr &instr); 239 bool DoCall(const Instr &instr); 240 bool DoNop(const Instr &instr); 241 bool DoReturn(const Instr &instr); 242 bool DoLocalAccess(const Instr &instr); 243 bool DoCellAccess(const Instr &instr); 244 bool DoGlobalAccess(const Instr &instr); 245 bool DoAttrAccess(const Instr &instr); 246 virtual ValueNode *HandleGetattr(ValueNode *target_node, const Instr &instr); 247 virtual bool DoGetItem(const Instr &instr); 248 virtual bool DoItemAccess(const Instr &instr); 249 bool DoStackOp(const Instr &instr); 250 bool DoLoadConst(const Instr &instr); 251 bool DoListToTuple(const Instr &instr); 252 bool DoGetIter(const Instr &instr); 253 bool DoMakeFunction(const Instr &instr); 254 AObject *InferUnary(ValueNode *, const Instr &instr); 255 virtual bool DoUnary(const Instr &instr); 256 AObject *InferBinary(ValueNode *, ValueNode *, const Instr &instr); 257 virtual bool DoBinary(const Instr &instr); 258 virtual bool DoIsOp(const Instr &instr); 259 virtual bool DoBinaryMul(const Instr &instr); 260 bool DoBinaryAdd(const Instr &instr); 261 bool DoInplaceAdd(const Instr &instr); 262 virtual bool DoCompare(const Instr &instr); 263 virtual bool DoBuildOp(const Instr &instr); 264 bool DoMergeOp(const Instr &instr); 265 bool DoFormatValue(const Instr &instr); 266 bool DoImport(const Instr &instr); 267 bool DoYieldValue(const Instr &instr); 268 bool DoException(const Instr &instr); 269 bool DoWith(const Instr &instr); 270 bool DoOtherBytecode(const Instr &instr); 271 bool NotImplementBytecode(const Instr &instr); 272 root()273 const auto &root() const { return root_; } frame()274 const auto &frame() const { return frame_; } 275 276 protected: 277 GraphBuilder *root_; 278 GraphBuilder *parent_; 279 Graph *graph_; 280 FrameStates frame_; 281 Block *current_block_; 282 int cur_bci_; 283 std::vector<TryBlock> tryBlockStacks_{}; 284 285 static const std::unordered_map<int, bool (GraphBuilder::*)(const Instr &)> bytecode_meth_map_; 286 287 ValueNode *GetCallFunctionNode(ValueNode *node, PyObject *dst_dtype); 288 bool DoMixedPrecisionLocalAccess(const Instr &instr, ValueNode *node); 289 ValueNode *DoMixedPrecisionAttrAccess(const Instr &instr, ValueNode *node, ValueNode *attr); 290 }; 291 292 class MindGraphBuilder : public GraphBuilder { 293 public: 294 explicit MindGraphBuilder(const PyFrameObject *f); MindGraphBuilder(GraphBuilder * r,GraphBuilder * p,PyCodeObject * co,PyObject * globals)295 MindGraphBuilder(GraphBuilder *r, GraphBuilder *p, PyCodeObject *co, PyObject *globals) 296 : GraphBuilder(r, p, co, globals) { 297 std::vector<std::string> comments; 298 auto location = co ? std::make_shared<Location>(py::cast<std::string>(co->co_filename), co->co_firstlineno, 0, 299 co->co_firstlineno, 0, "", std::move(comments)) 300 : std::make_shared<Location>("anonymous", 0, 0, 0, 0, "", std::move(comments)); 301 MS_EXCEPTION_IF_NULL(location); 302 TraceGuard trace_guard(location); 303 fg_builder_ = std::make_shared<FuncGraphBuilder>(); 304 } trace_flag()305 bool trace_flag() { return true; } FGBuilder()306 mindspore::FuncGraphBuilderPtr FGBuilder() const { return fg_builder_; } 307 bool FGAddInputs(const std::vector<py::object> &args); 308 py::object FGAddNode(CallNode *call_node, const py::object &callable_info, const std::vector<py::object> &args, 309 StopTraceReason *stop_reason); 310 void FGAddOutput(bool is_top_graph); 311 StopTraceReason BuildSubGraph(CallNode *call_node, int depth, const py::object &func, 312 const GraphBuilderPtr &subgraph) override; 313 py::object ResolveCallable(CallNode *call_node, StopTraceReason *stop_reason) override; 314 bool WhiteListFuncCheckAndInfer(CallNode *, const py::object &f) override; 315 316 LocationPtr GetLocation(CallNode *call_node) const; 317 318 protected: 319 bool DoGetItem(const Instr &instr) override; 320 bool DoItemAccess(const Instr &instr) override; 321 bool DoUnary(const Instr &instr) override; 322 bool DoBinary(const Instr &instr) override; 323 bool DoIsOp(const Instr &instr) override; 324 bool DoBinaryMul(const Instr &instr) override; 325 bool DoCompare(const Instr &instr) override; 326 bool DoBuildOp(const Instr &instr) override; 327 ValueNode *HandleGetattr(ValueNode *target_node, const Instr &instr) override; 328 bool HandlePositionParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) override; 329 bool UnpackCallExParams(std::vector<ValueNode *> *params, int extra_local, bool *has_kw, 330 CallNode *call_node) override; 331 bool HandleKWParams(const py::object &func, std::vector<ValueNode *> *params, FrameStates *frame) override; 332 bool UnpackCallExDict(std::vector<ValueNode *> *params, CallNode *call_node) override; 333 bool HandleCallClass(CallNode *call_node) override; 334 335 private: 336 std::vector<py::object> GetNewArgs(CallNode *call_node, AObject *vobj = nullptr); 337 bool AllConstantArgs(const std::vector<py::object> &args, const py::object &callable_info, CallNode *call_node); 338 339 py::object HandleGetShapeOfDynamicLengthTensor(const py::object &object); 340 341 mindspore::FuncGraphBuilderPtr fg_builder_{nullptr}; 342 std::string co_name_; 343 AObject *HandleMultiOp(const Instr &instr, const std::vector<ValueNode *> &p, bool is_compare); 344 AObject *HandleBuildOp(const Instr &instr, const std::vector<ValueNode *> &p); 345 }; 346 } // namespace pijit 347 } // namespace mindspore 348 349 #endif // MINDSPORE_PI_JIT_GRAPH_CAPTURE_GRAPH_BUILD_H 350