1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_CODE_GENERATOR_H 17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_CODE_GENERATOR_H 18 19 #include <string> 20 #include <vector> 21 #include <unordered_map> 22 #include <algorithm> 23 #include <utility> 24 #include <memory> 25 #include "pipeline/jit/pi/graph_capture/graph_analyzer.h" 26 #include "pipeline/jit/pi/graph_capture/graph_build.h" 27 #include "pipeline/jit/pi/graph_capture/side_effect.h" 28 #include "pipeline/jit/pi/graph_build/func_graph_builder.h" 29 #include "utils/convert_utils_base.h" 30 31 namespace mindspore { 32 namespace pijit { 33 34 namespace py = pybind11; 35 36 class GraphParameterBuilder; 37 38 struct NodeSet { 39 std::vector<ValueNode *> inputs; // index is parameters index 40 std::vector<ValueNode *> outputs; 41 std::vector<ValueNode *> operations; 42 }; 43 44 struct GraphInputInfo { 45 std::vector<ValueNode *> args; 46 std::vector<ValueNode *> globals; 47 ValueNode *vargs = nullptr; 48 ValueNode *kwargs = nullptr; 49 }; 50 51 class CodeGenerator { 52 public: 53 struct Code { 54 int co_argcount; 55 int co_kwonlyargcount; 56 int co_nlocals; 57 int co_flags; 58 int co_firstlineno; 59 std::vector<std::unique_ptr<Instr>> co_code; 60 std::vector<std::string> co_varnames; 61 std::vector<std::string> co_cellvars; 62 std::vector<std::string> co_freevars; 63 std::string co_name; 64 py::object co_filename; 65 }; 66 CodeGenerator(const NodeSet * nodes)67 explicit CodeGenerator(const NodeSet *nodes) : nodes_(nodes), globals_(), code_(), nodes_alive_(), locals_map_() {} 68 SetGlobals(const py::dict & dict)69 void SetGlobals(const py::dict &dict) { globals_ = dict; } MoveCode()70 std::vector<std::unique_ptr<Instr>> MoveCode() { return std::move(code_.co_code); } GetGlobals()71 const py::dict &GetGlobals() const { return globals_; } GetLocalsMap()72 const std::unordered_map<ValueNode *, int> &GetLocalsMap() const { return locals_map_; } EarseLocal(ValueNode * item)73 bool EarseLocal(ValueNode *item) { 74 auto it = GetLocalsMap().find(item); 75 if (it != GetLocalsMap().end()) { 76 locals_map_.erase(it); 77 } else { 78 return false; 79 } 80 return true; 81 } GetCode()82 const Code &GetCode() const { return code_; } SetArgsInfo(int argcount,int kwonlyargcount)83 void SetArgsInfo(int argcount, int kwonlyargcount) { 84 code_.co_argcount = argcount; 85 code_.co_kwonlyargcount = kwonlyargcount; 86 } SetCodeFlags(unsigned flags)87 void SetCodeFlags(unsigned flags) { code_.co_flags |= flags; } SetLocalsCount(int nlocals)88 void SetLocalsCount(int nlocals) { code_.co_nlocals = std::max(nlocals, code_.co_nlocals); } SetFirstLineNumber(int line)89 void SetFirstLineNumber(int line) { code_.co_firstlineno = line; } SetVariableNames(const std::vector<std::string> & names)90 void SetVariableNames(const std::vector<std::string> &names) { code_.co_varnames = names; } SetCellVariableNames(const std::vector<std::string> & names)91 void SetCellVariableNames(const std::vector<std::string> &names) { code_.co_cellvars = names; } SetFreeVariableNames(const std::vector<std::string> & names)92 void SetFreeVariableNames(const std::vector<std::string> &names) { code_.co_freevars = names; } SetCodeName(const std::string & name)93 void SetCodeName(const std::string &name) { code_.co_name = name; } SetFileName(const py::object & file)94 void SetFileName(const py::object &file) { code_.co_filename = file; } 95 MarkAlive(ValueNode * node)96 void MarkAlive(ValueNode *node) { nodes_alive_[node] = INT_MAX; } 97 void MarkAlive(); 98 void NewInstr(int op, int arg = 0, int line = -1); 99 void AddInstrs(std::vector<std::unique_ptr<Instr>> &&list); 100 void AddInstr(std::unique_ptr<Instr> &&instr); 101 void EraseUnusedInstr(); 102 103 // initialize local map of parameters 104 void Init(); 105 106 // build bytecode by nodes 107 void Build(); 108 109 // generate return operations of outputs 110 void GenReturn(); 111 112 // build single node 113 void BuildOper(ValueNode *node, int index); 114 115 // generator local operations of node 116 void LoadValue(ValueNode *node); 117 118 void LoadConst(const py::object &); 119 120 // add node to locals map 121 int AllocLocal(ValueNode *node, int index = INT_MAX); 122 123 std::string PrintAlive() const; 124 125 /** 126 * Transform code info to PyCodeObject 127 * 128 * \param ccode code info 129 * \return PyCodeObject 130 */ 131 static py::object Transform(const Code &ccode); 132 133 /** 134 * Calculate max stack size 135 * 136 * \param list instruct nodes list 137 * \param sp start of stack depth 138 * \return max depth of stack, or -1 if stack out of bound 139 */ 140 static int CalculateStackSize(const std::vector<std::unique_ptr<Instr>> &list, int sp = 0); 141 142 /** 143 * Convert instruction list to bytes object. generate line table. 144 * 145 * \param list instruct nodes list 146 * \param first_line first line 147 * \return first is co_code, second is co_lnotab 148 */ 149 static std::pair<py::bytes, py::bytes> ConvertToCodeBytes(const std::vector<std::unique_ptr<Instr>> &list, 150 int first_line); 151 152 /** 153 * Copy instruction list at range [start, end). 154 * NOTE: reset opcode: 155 * LOAD_METHOD -> LOAD_ATTR, 156 * CALL_METHOD -> CALL_FUNCTION 157 * 158 * \param list instruct nodes list 159 * \param start 160 * \param end 161 * \return instruction list 162 */ 163 static std::vector<std::unique_ptr<Instr>> CopyInstr(const std::vector<std::unique_ptr<Instr>> &list, size_t start, 164 size_t end = -1); 165 166 /** 167 * Erase unused instr 168 * 169 * \param list instruction list 170 */ 171 static void EraseUnusedInstr(std::vector<std::unique_ptr<Instr>> *list); 172 173 /** 174 * generate rot instructions 175 */ 176 static std::vector<std::unique_ptr<Instr>> RotStack(int stack); 177 178 private: 179 const NodeSet *nodes_; 180 py::dict globals_; 181 Code code_; 182 std::unordered_map<ValueNode *, int> nodes_alive_; 183 std::unordered_map<ValueNode *, int> locals_map_; 184 }; 185 186 class CodeBreakGenerator; 187 class MindCodeBreakGenerator; 188 using CodeBreakGeneratorPtr = std::shared_ptr<CodeBreakGenerator>; 189 using MindCodeBreakGeneratorPtr = std::shared_ptr<MindCodeBreakGenerator>; 190 class CodeBreakGenerator { 191 public: CodeBreakGenerator(PyCodeObject * co)192 explicit CodeBreakGenerator(PyCodeObject *co) : co_(co), cfg_(nullptr), break_bci_(-1), extra_local_(-1) {} Creator(const GraphBuilderPtr & builder,PyCodeObject * co)193 static CodeBreakGeneratorPtr Creator(const GraphBuilderPtr &builder, PyCodeObject *co) { 194 return builder->trace_flag() 195 ? std::static_pointer_cast<CodeBreakGenerator>(std::make_shared<MindCodeBreakGenerator>(builder, co)) 196 : std::make_shared<CodeBreakGenerator>(co); 197 } 198 SetGlobals(const py::dict & dict)199 void SetGlobals(const py::dict &dict) { globals_ = dict; } GetGlobals()200 const py::dict &GetGlobals() const { return globals_; } 201 202 // collect nodes inputs and outputs at graph analyze 203 void Init(const Graph *, const GraphAnalyzer &); 204 205 // generate a code to call graph, unsupported operations, and untracked operations that will be compiled 206 py::object MakeDispatchCode(); 207 208 // used to replace origin code, extend attribute from origin code. 209 virtual py::object MakeCapturedCode() const; 210 211 const CFG *GetCFG() const; 212 213 protected: 214 void ExtendCodeInfo(CodeGenerator *cg, bool merge_kw_only) const; 215 216 // rebuild parameters of graph, identify parameters that graph only support as constant 217 void BuildGraphParameters(const std::unordered_map<ValueNode *, int> &locals, GraphParameterBuilder *); 218 219 // rebuild captured nodes to bytecode, build parameters load operations 220 virtual py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&sort, int argc, unsigned flag) const; 221 222 // make call operations of graph, build parameters load operations 223 void CallCapturedCode(CodeGenerator *code_gen); 224 225 void FixInterpretOuput(CodeGenerator *code_gen); 226 227 // make function of untracked bytecode, build restore frame operations of untracked bytecode 228 py::object MakeUntrackedCode(int untracked_bci, int untracked_stack_effect) const; 229 230 void ReconstructStack(CodeGenerator *code_gen, int untracked_bci, int untracked_stack_effect) const; 231 232 // make call operations of untracked bytecode 233 void CallUntrackedCode(CodeGenerator *code_gen); 234 235 void MakeReturn(CodeGenerator *code_gen) const; 236 237 // build operations of block, build restore frame operations of block 238 void BreakAtBlock(CodeGenerator *code_gen, int untracked_bci, int untracked_stack_effect); 239 240 // make call operations of untracked bytecode for each branch 241 void BreakAtIf(CodeGenerator *code_gen) const; 242 243 void RestoreStack(CodeGenerator *code_gen) const; 244 245 void RestoreLocals(CodeGenerator *code_gen, bool load) const; 246 247 // return co_cellvars and co_freevars 248 std::vector<std::string> GetClosureNames() const; 249 250 // root function 251 PyCodeObject *const co_; 252 253 // instructions for break graph 254 const CFG *cfg_; 255 256 // function globals 257 py::dict globals_; 258 259 /** 260 * first execute node, 261 * inputs must be same as the start of function locals(include unbound local) 262 * outputs is alive values 263 **/ 264 NodeSet interpret_; 265 266 // followed interpret execute node 267 NodeSet captured_; 268 269 GraphInputInfo graph_inputs_info_; 270 271 // break bci alive locals 272 std::vector<int> alive_locals_; 273 274 std::shared_ptr<SideEffect> side_effect_handler_; 275 276 // break bci 277 int break_bci_; 278 279 // used to store graph outputs 280 int extra_local_; 281 }; 282 283 class MindCodeBreakGenerator : public CodeBreakGenerator { 284 public: MindCodeBreakGenerator(const GraphBuilderPtr & builder,PyCodeObject * co)285 MindCodeBreakGenerator(const GraphBuilderPtr &builder, PyCodeObject *co) 286 : CodeBreakGenerator(co), builder_(builder) {} 287 FGBuilder()288 mindspore::FuncGraphBuilderPtr FGBuilder() const { 289 return std::dynamic_pointer_cast<MindGraphBuilder>(builder_)->FGBuilder(); 290 } 291 292 py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&, int argc, unsigned code_flag) const override; 293 294 py::object MakeCapturedCode() const override; 295 296 private: 297 void Compile(const std::string &name, int argc, int kw_only, int flags, const py::object &stub) const; 298 299 GraphBuilderPtr builder_; 300 }; 301 // add a key and value to py::dict, check key conflict or rename the key 302 void MapAdd(const py::dict &dict, const std::string &key, const py::object &value, std::string *rename = nullptr); 303 304 // make new code by graph and captured information 305 py::object MakeCodeFromCodeGen(const GraphBuilderPtr &builder, const GraphAnalyzerPtr &analyzer, PyObject *globals); 306 } // namespace pijit 307 } // namespace mindspore 308 309 #endif // MINDSPORE_CCSRC_PIPELINE_GRAPH_JIT_GRAPH_CAPTURE_CODE_GEN_H 310