• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_CODE_GENERATOR_H
17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_CODE_GENERATOR_H
18 
19 #include <string>
20 #include <vector>
21 #include <unordered_map>
22 #include <algorithm>
23 #include <utility>
24 #include <memory>
25 #include "pipeline/jit/pi/graph_capture/graph_analyzer.h"
26 #include "pipeline/jit/pi/graph_capture/graph_build.h"
27 #include "pipeline/jit/pi/graph_capture/side_effect.h"
28 #include "pipeline/jit/pi/graph_build/func_graph_builder.h"
29 #include "utils/convert_utils_base.h"
30 
31 namespace mindspore {
32 namespace pijit {
33 
34 namespace py = pybind11;
35 
36 class GraphParameterBuilder;
37 
38 struct NodeSet {
39   std::vector<ValueNode *> inputs;  // index is parameters index
40   std::vector<ValueNode *> outputs;
41   std::vector<ValueNode *> operations;
42 };
43 
44 struct GraphInputInfo {
45   std::vector<ValueNode *> args;
46   std::vector<ValueNode *> globals;
47   ValueNode *vargs = nullptr;
48   ValueNode *kwargs = nullptr;
49 };
50 
51 class CodeGenerator {
52  public:
53   struct Code {
54     int co_argcount;
55     int co_kwonlyargcount;
56     int co_nlocals;
57     int co_flags;
58     int co_firstlineno;
59     std::vector<std::unique_ptr<Instr>> co_code;
60     std::vector<std::string> co_varnames;
61     std::vector<std::string> co_cellvars;
62     std::vector<std::string> co_freevars;
63     std::string co_name;
64     py::object co_filename;
65   };
66 
CodeGenerator(const NodeSet * nodes)67   explicit CodeGenerator(const NodeSet *nodes) : nodes_(nodes), globals_(), code_(), nodes_alive_(), locals_map_() {}
68 
SetGlobals(const py::dict & dict)69   void SetGlobals(const py::dict &dict) { globals_ = dict; }
MoveCode()70   std::vector<std::unique_ptr<Instr>> MoveCode() { return std::move(code_.co_code); }
GetGlobals()71   const py::dict &GetGlobals() const { return globals_; }
GetLocalsMap()72   const std::unordered_map<ValueNode *, int> &GetLocalsMap() const { return locals_map_; }
EarseLocal(ValueNode * item)73   bool EarseLocal(ValueNode *item) {
74     auto it = GetLocalsMap().find(item);
75     if (it != GetLocalsMap().end()) {
76       locals_map_.erase(it);
77     } else {
78       return false;
79     }
80     return true;
81   }
GetCode()82   const Code &GetCode() const { return code_; }
SetArgsInfo(int argcount,int kwonlyargcount)83   void SetArgsInfo(int argcount, int kwonlyargcount) {
84     code_.co_argcount = argcount;
85     code_.co_kwonlyargcount = kwonlyargcount;
86   }
SetCodeFlags(unsigned flags)87   void SetCodeFlags(unsigned flags) { code_.co_flags |= flags; }
SetLocalsCount(int nlocals)88   void SetLocalsCount(int nlocals) { code_.co_nlocals = std::max(nlocals, code_.co_nlocals); }
SetFirstLineNumber(int line)89   void SetFirstLineNumber(int line) { code_.co_firstlineno = line; }
SetVariableNames(const std::vector<std::string> & names)90   void SetVariableNames(const std::vector<std::string> &names) { code_.co_varnames = names; }
SetCellVariableNames(const std::vector<std::string> & names)91   void SetCellVariableNames(const std::vector<std::string> &names) { code_.co_cellvars = names; }
SetFreeVariableNames(const std::vector<std::string> & names)92   void SetFreeVariableNames(const std::vector<std::string> &names) { code_.co_freevars = names; }
SetCodeName(const std::string & name)93   void SetCodeName(const std::string &name) { code_.co_name = name; }
SetFileName(const py::object & file)94   void SetFileName(const py::object &file) { code_.co_filename = file; }
95 
MarkAlive(ValueNode * node)96   void MarkAlive(ValueNode *node) { nodes_alive_[node] = INT_MAX; }
97   void MarkAlive();
98   void NewInstr(int op, int arg = 0, int line = -1);
99   void AddInstrs(std::vector<std::unique_ptr<Instr>> &&list);
100   void AddInstr(std::unique_ptr<Instr> &&instr);
101   void EraseUnusedInstr();
102 
103   // initialize local map of parameters
104   void Init();
105 
106   // build bytecode by nodes
107   void Build();
108 
109   // generate return operations of outputs
110   void GenReturn();
111 
112   // build single node
113   void BuildOper(ValueNode *node, int index);
114 
115   // generator local operations of node
116   void LoadValue(ValueNode *node);
117 
118   void LoadConst(const py::object &);
119 
120   // add node to locals map
121   int AllocLocal(ValueNode *node, int index = INT_MAX);
122 
123   std::string PrintAlive() const;
124 
125   /**
126    * Transform code info to PyCodeObject
127    *
128    * \param ccode code info
129    * \return PyCodeObject
130    */
131   static py::object Transform(const Code &ccode);
132 
133   /**
134    * Calculate max stack size
135    *
136    * \param list instruct nodes list
137    * \param sp start of stack depth
138    * \return max depth of stack, or -1 if stack out of bound
139    */
140   static int CalculateStackSize(const std::vector<std::unique_ptr<Instr>> &list, int sp = 0);
141 
142   /**
143    * Convert instruction list to bytes object. generate line table.
144    *
145    * \param list instruct nodes list
146    * \param first_line first line
147    * \return first is co_code, second is co_lnotab
148    */
149   static std::pair<py::bytes, py::bytes> ConvertToCodeBytes(const std::vector<std::unique_ptr<Instr>> &list,
150                                                             int first_line);
151 
152   /**
153    * Copy instruction list at range [start, end).
154    * NOTE: reset opcode:
155    *       LOAD_METHOD -> LOAD_ATTR,
156    *       CALL_METHOD -> CALL_FUNCTION
157    *
158    * \param list instruct nodes list
159    * \param start
160    * \param end
161    * \return instruction list
162    */
163   static std::vector<std::unique_ptr<Instr>> CopyInstr(const std::vector<std::unique_ptr<Instr>> &list, size_t start,
164                                                        size_t end = -1);
165 
166   /**
167    * Erase unused instr
168    *
169    * \param list instruction list
170    */
171   static void EraseUnusedInstr(std::vector<std::unique_ptr<Instr>> *list);
172 
173   /**
174    * generate rot instructions
175    */
176   static std::vector<std::unique_ptr<Instr>> RotStack(int stack);
177 
178  private:
179   const NodeSet *nodes_;
180   py::dict globals_;
181   Code code_;
182   std::unordered_map<ValueNode *, int> nodes_alive_;
183   std::unordered_map<ValueNode *, int> locals_map_;
184 };
185 
186 class CodeBreakGenerator;
187 class MindCodeBreakGenerator;
188 using CodeBreakGeneratorPtr = std::shared_ptr<CodeBreakGenerator>;
189 using MindCodeBreakGeneratorPtr = std::shared_ptr<MindCodeBreakGenerator>;
190 class CodeBreakGenerator {
191  public:
CodeBreakGenerator(PyCodeObject * co)192   explicit CodeBreakGenerator(PyCodeObject *co) : co_(co), cfg_(nullptr), break_bci_(-1), extra_local_(-1) {}
Creator(const GraphBuilderPtr & builder,PyCodeObject * co)193   static CodeBreakGeneratorPtr Creator(const GraphBuilderPtr &builder, PyCodeObject *co) {
194     return builder->trace_flag()
195              ? std::static_pointer_cast<CodeBreakGenerator>(std::make_shared<MindCodeBreakGenerator>(builder, co))
196              : std::make_shared<CodeBreakGenerator>(co);
197   }
198 
SetGlobals(const py::dict & dict)199   void SetGlobals(const py::dict &dict) { globals_ = dict; }
GetGlobals()200   const py::dict &GetGlobals() const { return globals_; }
201 
202   // collect nodes inputs and outputs at graph analyze
203   void Init(const Graph *, const GraphAnalyzer &);
204 
205   // generate a code to call graph, unsupported operations, and untracked operations that will be compiled
206   py::object MakeDispatchCode();
207 
208   // used to replace origin code, extend attribute from origin code.
209   virtual py::object MakeCapturedCode() const;
210 
211   const CFG *GetCFG() const;
212 
213  protected:
214   void ExtendCodeInfo(CodeGenerator *cg, bool merge_kw_only) const;
215 
216   // rebuild parameters of graph, identify parameters that graph only support as constant
217   void BuildGraphParameters(const std::unordered_map<ValueNode *, int> &locals, GraphParameterBuilder *);
218 
219   // rebuild captured nodes to bytecode, build parameters load operations
220   virtual py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&sort, int argc, unsigned flag) const;
221 
222   // make call operations of graph, build parameters load operations
223   void CallCapturedCode(CodeGenerator *code_gen);
224 
225   void FixInterpretOuput(CodeGenerator *code_gen);
226 
227   // make function of untracked bytecode, build restore frame operations of untracked bytecode
228   py::object MakeUntrackedCode(int untracked_bci, int untracked_stack_effect) const;
229 
230   void ReconstructStack(CodeGenerator *code_gen, int untracked_bci, int untracked_stack_effect) const;
231 
232   // make call operations of untracked bytecode
233   void CallUntrackedCode(CodeGenerator *code_gen);
234 
235   void MakeReturn(CodeGenerator *code_gen) const;
236 
237   // build operations of block, build restore frame operations of block
238   void BreakAtBlock(CodeGenerator *code_gen, int untracked_bci, int untracked_stack_effect);
239 
240   // make call operations of untracked bytecode for each branch
241   void BreakAtIf(CodeGenerator *code_gen) const;
242 
243   void RestoreStack(CodeGenerator *code_gen) const;
244 
245   void RestoreLocals(CodeGenerator *code_gen, bool load) const;
246 
247   // return co_cellvars and co_freevars
248   std::vector<std::string> GetClosureNames() const;
249 
250   // root function
251   PyCodeObject *const co_;
252 
253   // instructions for break graph
254   const CFG *cfg_;
255 
256   // function globals
257   py::dict globals_;
258 
259   /**
260    * first execute node,
261    * inputs must be same as the start of function locals(include unbound local)
262    * outputs is alive values
263    **/
264   NodeSet interpret_;
265 
266   // followed interpret execute node
267   NodeSet captured_;
268 
269   GraphInputInfo graph_inputs_info_;
270 
271   // break bci alive locals
272   std::vector<int> alive_locals_;
273 
274   std::shared_ptr<SideEffect> side_effect_handler_;
275 
276   // break bci
277   int break_bci_;
278 
279   // used to store graph outputs
280   int extra_local_;
281 };
282 
283 class MindCodeBreakGenerator : public CodeBreakGenerator {
284  public:
MindCodeBreakGenerator(const GraphBuilderPtr & builder,PyCodeObject * co)285   MindCodeBreakGenerator(const GraphBuilderPtr &builder, PyCodeObject *co)
286       : CodeBreakGenerator(co), builder_(builder) {}
287 
FGBuilder()288   mindspore::FuncGraphBuilderPtr FGBuilder() const {
289     return std::dynamic_pointer_cast<MindGraphBuilder>(builder_)->FGBuilder();
290   }
291 
292   py::object MakeCapturedCode(std::vector<std::unique_ptr<Instr>> &&, int argc, unsigned code_flag) const override;
293 
294   py::object MakeCapturedCode() const override;
295 
296  private:
297   void Compile(const std::string &name, int argc, int kw_only, int flags, const py::object &stub) const;
298 
299   GraphBuilderPtr builder_;
300 };
301 // add a key and value to py::dict, check key conflict or rename the key
302 void MapAdd(const py::dict &dict, const std::string &key, const py::object &value, std::string *rename = nullptr);
303 
304 // make new code by graph and captured information
305 py::object MakeCodeFromCodeGen(const GraphBuilderPtr &builder, const GraphAnalyzerPtr &analyzer, PyObject *globals);
306 }  // namespace pijit
307 }  // namespace mindspore
308 
309 #endif  // MINDSPORE_CCSRC_PIPELINE_GRAPH_JIT_GRAPH_CAPTURE_CODE_GEN_H
310