• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_BYTECODE_INLINER_H
17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_BYTECODE_INLINER_H
18 
19 #include <vector>
20 #include <map>
21 #include <memory>
22 #include "pipeline/jit/pi/graph_capture/cfg.h"
23 #include "pipeline/jit/pi/graph_capture/code_generator.h"
24 #include "utils/convert_utils_base.h"
25 
26 namespace mindspore {
27 namespace pijit {
28 
29 class CallNode;
30 
31 /**
32  * used to inline bytecode, rebuild bytecode by nodes.
33  * 1.collect trace nodes for each sub-graph.
34  * 2.if allowed inline the break graph, inline the second half bytecode, and guard the global variable if the globals of
35  *   inlined function is different from top function, eliminate the sideeffect or do not inline a function with
36  *   sideeffect.
37  * 3.if allowed eliminate unused instruction, remove them.
38  * 4.rebuild bytecode, reset frame status, reset break point.
39  */
40 class BytecodeInliner {
41  public:
BytecodeInliner(Graph * graph,const py::dict & global)42   BytecodeInliner(Graph *graph, const py::dict &global)
43       : graph_(graph),
44         traced_nodes_(),
45         extra_globals_(global),
46         cfg_(),
47         last_frame_(),
48         new_frames_(),
49         reconstructed_value_(nullptr),
50         new_break_bci_(-1),
51         inline_partial_(false) {}
52 
53   void Run();
54 
55  private:
56   // prepare and call rebuild bytecodes by nodes
57   void Rebuild();
58 
59   // rebuild bytecodes and frame statue
60   void Rebuild(CodeGenerator *cg);
61 
62   void EraseDeadLocal(const std::vector<ValueNode *> &alive_nodes);
63 
64   // eliminate closure access
65   void EliminateClosureSideEffect();
66 
67   void ResetCFG(CodeGenerator *cg);
68 
69   void ResetGraphStat();
70 
71   // collect traced nodes, collect bytecodes after break
72   void ProcessGraph(Graph *, int local_off = 0);
73 
74   // reconstruct node by bytecode
75   void Reconstruct(ValueNode *node, int local_off);
76 
77   // initialize cfg by instruction list
78   void InitCFG();
79 
80   // reset instruction oparg, guard globals which merge to top func. eliminate sideeffect of inline
81   void FixInstr(Graph *, int local_off, std::vector<std::unique_ptr<Instr>> *);
82 
83   // collect trace nodes for each sub-graph
84   void CollectTracedNodes(Graph *);
85 
86   // top graph
87   Graph *const graph_;
88 
89   // all traced nodes
90   std::vector<ValueNode *> traced_nodes_;
91 
92   // used globals of function and inlined function
93   py::dict extra_globals_;
94 
95   // new cfg
96   std::unique_ptr<CFG> cfg_;
97 
98   // new last frame
99   std::unique_ptr<FrameStates> last_frame_;
100 
101   std::map<int, std::unique_ptr<FrameStates>> new_frames_;
102 
103   ValueNode *reconstructed_value_;
104 
105   int new_break_bci_;
106 
107   bool inline_partial_;
108 };
109 
110 }  // namespace pijit
111 }  // namespace mindspore
112 #endif
113