• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_compiler/cg/byte_code_generator.h"
17 #include <memory>
18 #include <string>
19 #include <algorithm>
20 #include "utils/log_adapter.h"
21 
22 namespace mindspore {
23 namespace pijit {
24 // the number of bits per byte
25 constexpr unsigned int bits_per_byte = 8;
26 
27 #ifndef MAKE_BYTE_CODE_UNIT
28 #ifdef WORDS_BIGENDIAN
29 #define MAKE_BYTE_CODE_UNIT(op, arg) (((op) << bits_per_byte) | (arg))
30 #else
31 #define MAKE_BYTE_CODE_UNIT(op, arg) ((op) | ((arg) << bits_per_byte))
32 #endif
33 #endif
34 
GenFunction(const ir::FunctionNodePtr & func)35 py::object ByteCodeGenerator::GenFunction(const ir::FunctionNodePtr &func) {
36   ByteCodeGeneratorPtr generator = std::make_shared<ByteCodeGenerator>();
37   auto use_global = func->GetUseGlobal();
38   if (use_global != nullptr) {
39     generator->globals_ = use_global->cast<ir::ValuePtr>()->GetValue().cast<py::dict>();
40   }
41   return generator->Generate(func);
42 }
43 
Generate(const ir::FunctionNodePtr & func)44 py::object ByteCodeGenerator::Generate(const ir::FunctionNodePtr &func) {
45   cell_var_cnt_ = CellVarCounter::GetCount(func);
46   co_consts_.append(py::none());
47   first_line_no_ = func->GetFirstLineNo();
48   func->Sort();
49   Visit(func);
50   auto byte_code = py::reinterpret_steal<py::object>(
51     PyBytes_FromStringAndSize((const char *)co_code_.data(), co_code_.size() * sizeof(co_code_[0])));
52   auto lnotab = py::reinterpret_steal<py::object>(
53     PyBytes_FromStringAndSize(co_lnotab_.data(), co_lnotab_.size() * sizeof(co_lnotab_[0])));
54   auto var_names = py::cast<py::tuple>(co_var_names_);
55   auto consts = py::cast<py::tuple>(co_consts_);
56   auto names = py::cast<py::tuple>(co_names_);
57   auto free_vars = py::cast<py::tuple>(co_free_vars_);
58   auto cell_vars = py::cast<py::tuple>(co_cell_vars_);
59   PyCodeObject *code = PyCode_New(func->GetPosArgsCnt(), func->GetKwOnlyArgsCnt(), var_names.size(),
60                                   func->GetStackSize(), func->GetFlags(), byte_code.ptr(), consts.ptr(), names.ptr(),
61                                   var_names.ptr(), free_vars.ptr(), cell_vars.ptr(), py::str(func->GetFileName()).ptr(),
62                                   py::str(func->GetName()).ptr(), func->GetFirstLineNo(), lnotab.ptr());
63   globals_[py::str("__builtins__")] = builtins_.ptr();
64   auto function = py::reinterpret_steal<py::object>(PyFunction_New(reinterpret_cast<PyObject *>(code), globals_.ptr()));
65   Py_DECREF(code);
66   auto tuple = py::cast<py::tuple>(defaults_);
67   (void)PyFunction_SetDefaults(function.ptr(), tuple.ptr());
68   tuple = py::cast<py::tuple>(clousre_);
69   (void)PyFunction_SetClosure(function.ptr(), tuple.ptr());
70   return function;
71 }
72 
Visit_(const ir::ParameterPtr & node)73 void ByteCodeGenerator::Visit_(const ir::ParameterPtr &node) {
74   const std::string name = node->GetName();
75   MS_EXCEPTION_IF_CHECK_FAIL((co_var_names_map_.find(name) == co_var_names_map_.end()),
76                              "Duplicate parameter name " + name + ".");
77   co_var_names_map_[name] = SizeToInt(co_var_names_.size());
78   co_var_names_.append(py::str(name));
79   ir::NodePtr default_value = node->GetDefaultValue();
80   if (default_value != nullptr) {
81     if (node->GetCategory() == ir::Parameter::KEYWORD_ONLY) {
82       kwdefaults_[co_var_names_[co_var_names_map_[name]]] = default_value;
83     } else {
84       MS_EXCEPTION_IF_CHECK_FAIL(node->GetCategory() == 0, "Error category of parameter.");
85       defaults_.append(default_value->cast<ir::ValuePtr>()->GetValue());
86     }
87   }
88 }
89 
90 #define DEFINE_UN_NODE_VISIT_(OP)                  \
91   void ByteCodeGenerator::Visit_(const OP &node) { \
92     Visit(node->GetArg());                         \
93     CheckInstrOffset(node);                        \
94     GenerateInstr(node->GetOpCode());              \
95     SetStartsLine(node);                           \
96   }
97 
98 DEFINE_UN_NODE_VISIT_(ir::UnaryOperationPtr)
DEFINE_UN_NODE_VISIT_(ir::NegativeNodePtr)99 DEFINE_UN_NODE_VISIT_(ir::NegativeNodePtr)
100 DEFINE_UN_NODE_VISIT_(ir::NotNodePtr)
101 DEFINE_UN_NODE_VISIT_(ir::InvertNodePtr)
102 DEFINE_UN_NODE_VISIT_(ir::ReturnNodePtr)
103 DEFINE_UN_NODE_VISIT_(ir::CastNodePtr)
104 DEFINE_UN_NODE_VISIT_(ir::GetNodePtr)
105 
106 #define DEFINE_BIN_NODE_VISIT_(OP)                 \
107   void ByteCodeGenerator::Visit_(const OP &node) { \
108     Visit(node->GetLeftArg());                     \
109     Visit(node->GetRightArg());                    \
110     CheckInstrOffset(node);                        \
111     GenerateInstr(node->GetOpCode());              \
112     SetStartsLine(node);                           \
113   }
114 
115 DEFINE_BIN_NODE_VISIT_(ir::BinaryOperationPtr)
116 DEFINE_BIN_NODE_VISIT_(ir::AddNodePtr)
117 DEFINE_BIN_NODE_VISIT_(ir::SubNodePtr)
118 DEFINE_BIN_NODE_VISIT_(ir::MulNodePtr)
119 DEFINE_BIN_NODE_VISIT_(ir::DivNodePtr)
120 DEFINE_BIN_NODE_VISIT_(ir::BitwiseNodePtr)
121 
122 void ByteCodeGenerator::Visit_(const ir::ValuePtr &node) { (void)GetValueIndex(node); }
123 
Visit_(const ir::NaryOperationPtr & node)124 void ByteCodeGenerator::Visit_(const ir::NaryOperationPtr &node) {
125   VISIT_NODE_LIST(node->GetArgs())
126   CheckInstrOffset(node);
127   GenerateInstr(node->GetOpCode(), node->GetArgsCnt());
128   SetStartsLine(node);
129 }
130 
Visit_(const ir::DeleteNodePtr & node)131 void ByteCodeGenerator::Visit_(const ir::DeleteNodePtr &node) {
132   Visit(node->GetArg());
133   int arg = 0;
134   if (node->GetOpCode() != DELETE_SUBSCR) {
135     MS_EXCEPTION_IF_CHECK_FAIL(node->GetArg()->isa<ir::Value>(), "Expect delete a value.");
136     arg = GetValueIndex(node->GetArg()->cast<ir::ValuePtr>());
137   }
138   CheckInstrOffset(node);
139   GenerateInstr(node->GetOpCode(), arg);
140   SetStartsLine(node);
141 }
142 
Visit_(const ir::FormatNodePtr & node)143 void ByteCodeGenerator::Visit_(const ir::FormatNodePtr &node) {
144   VISIT_NODE_LIST(node->GetArgs())
145   CheckInstrOffset(node);
146   GenerateInstr(node->GetOpCode(), node->GetFormatType());
147   SetStartsLine(node);
148 }
149 
Visit_(const ir::IsNodePtr & node)150 void ByteCodeGenerator::Visit_(const ir::IsNodePtr &node) {
151   Visit(node->GetLeftArg());
152   Visit(node->GetRightArg());
153   CheckInstrOffset(node);
154   GenerateInstr(node->GetOpCode(), node->IsInvert());
155   SetStartsLine(node);
156 }
157 
Visit_(const ir::ContainsNodePtr & node)158 void ByteCodeGenerator::Visit_(const ir::ContainsNodePtr &node) {
159   Visit(node->GetLeftArg());
160   Visit(node->GetRightArg());
161   CheckInstrOffset(node);
162   GenerateInstr(node->GetOpCode(), node->IsInvert());
163   SetStartsLine(node);
164 }
165 
Visit_(const ir::StoreNodePtr & node)166 void ByteCodeGenerator::Visit_(const ir::StoreNodePtr &node) {
167   Visit(node->GetLeftArg());
168   Visit(node->GetRightArg());
169   ir::NodePtr target = node->GetRightArg();
170   if (target->isa<ir::AttrNode>()) {
171     target = target->cast<ir::AttrNodePtr>()->GetAttr();
172   }
173   int arg = 0;
174   if (!target->isa<ir::SubscrNode>()) {
175     MS_EXCEPTION_IF_CHECK_FAIL(target->isa<ir::Value>(), "Expect store to a var.");
176     arg = GetValueIndex(target->cast<ir::ValuePtr>());
177   }
178   CheckInstrOffset(node);
179   GenerateInstr(node->GetOpCode(), arg);
180   SetStartsLine(node);
181 }
182 
Visit_(const ir::JumpNodePtr & node)183 void ByteCodeGenerator::Visit_(const ir::JumpNodePtr &node) {
184   ir::IRVisitor::Visit_(node);
185   ir::OpCode op = node->GetOpCode();
186   size_t arg = node->GetRightArg()->GetOffset() * 2;
187   if (op == JUMP_FORWARD || op == FOR_ITER) {
188     arg -= (node->GetOffset() + 1) * 2;
189   }
190   CheckInstrOffset(node);
191   GenerateInstr(node->GetOpCode(), SizeToInt(arg));
192   SetStartsLine(node);
193 }
194 
Visit_(const ir::CompareNodePtr & node)195 void ByteCodeGenerator::Visit_(const ir::CompareNodePtr &node) {
196   Visit(node->GetLeftArg());
197   Visit(node->GetRightArg());
198   CheckInstrOffset(node);
199   GenerateInstr(node->GetOpCode(), node->GetInstrArg());
200   SetStartsLine(node);
201 }
202 
Visit_(const ir::UpdateNodePtr & node)203 void ByteCodeGenerator::Visit_(const ir::UpdateNodePtr &node) {
204   Visit(node->GetLeftArg());
205   Visit(node->GetRightArg());
206   CheckInstrOffset(node);
207   GenerateInstr(node->GetOpCode(), node->GetInstrArg());
208   SetStartsLine(node);
209 }
210 
Visit_(const ir::LoadValueNodePtr & node)211 void ByteCodeGenerator::Visit_(const ir::LoadValueNodePtr &node) {
212   VISIT_NODE_LIST(node->GetArgs())
213   int arg = GetValueIndex(node->GetArg()->cast<ir::ValuePtr>());
214   CheckInstrOffset(node);
215   GenerateInstr(node->GetOpCode(), arg);
216   SetStartsLine(node);
217 }
218 
Visit_(const ir::LoadFieldNodePtr & node)219 void ByteCodeGenerator::Visit_(const ir::LoadFieldNodePtr &node) {
220   VISIT_NODE_LIST(node->GetArgs())
221   int arg = GetValueIndex(node->GetArg(1)->cast<ir::ValuePtr>());
222   CheckInstrOffset(node);
223   GenerateInstr(node->GetOpCode(), arg);
224   SetStartsLine(node);
225 }
226 
Visit_(const ir::BuildNodePtr & node)227 void ByteCodeGenerator::Visit_(const ir::BuildNodePtr &node) {
228   VISIT_NODE_LIST(node->GetArgs())
229   CheckInstrOffset(node);
230   size_t arg = node->GetArgsCnt();
231   if (node->GetOpCode() == BUILD_CONST_KEY_MAP) {
232     arg--;
233   }
234   GenerateInstr(node->GetOpCode(), SizeToInt(arg));
235   SetStartsLine(node);
236 }
237 
Visit_(const ir::CallNodePtr & node)238 void ByteCodeGenerator::Visit_(const ir::CallNodePtr &node) {
239   VISIT_NODE_LIST(node->GetArgs())
240   size_t arg = node->GetArgsCnt() - 1;
241   if (node->GetOpCode() == CALL_FUNCTION_KW || node->GetOpCode() == CALL_FUNCTION_EX) {
242     arg--;
243   }
244   CheckInstrOffset(node);
245   GenerateInstr(node->GetOpCode(), SizeToInt(arg));
246   SetStartsLine(node);
247 }
248 
Visit_(const ir::NaryWithFlagNodePtr & node)249 void ByteCodeGenerator::Visit_(const ir::NaryWithFlagNodePtr &node) {
250   VISIT_NODE_LIST(node->GetArgs())
251   CheckInstrOffset(node);
252   GenerateInstr(node->GetOpCode(), node->GetFlag());
253   SetStartsLine(node);
254 }
255 
GetValueIndex(const ir::ValuePtr & node)256 int ByteCodeGenerator::GetValueIndex(const ir::ValuePtr &node) {
257   auto scope = node->GetScope();
258   MS_EXCEPTION_IF_CHECK_FAIL(scope_inquire_map_.find(scope) != scope_inquire_map_.end(),
259                              "Invalid scope in " + node->ToString());
260   auto name_map = scope_inquire_map_.at(scope);
261   auto name = node->GetName();
262   if (name_map->find(name) != name_map->end()) {
263     return name_map->at(name);
264   }
265   auto values = scope_value_list_.at(scope);
266   (*name_map)[name] = SizeToInt(values.first.size());
267   if (values.first.is(values.second)) {
268     values.first.append(node->GetValue());
269   } else {
270     auto obj = py::str(name);
271     values.first.append(obj);
272     if (scope == ir::kScopeClousre) {
273       (py::cast<py::list>(values.second)).append(node->GetValue());
274     } else {
275       (py::cast<py::dict>(values.second))[obj] = node->GetValue();
276     }
277   }
278   return name_map->at(name);
279 }
280 
CheckInstrOffset(const ir::NodePtr & node)281 void ByteCodeGenerator::CheckInstrOffset(const ir::NodePtr &node) {
282   MS_EXCEPTION_IF_CHECK_FAIL(
283     node->GetOffset() - (node->NeedExtInstr() ? 1 : 0) == co_code_.size(),
284     "The offset of " + node->GetNodeName() + "(%" + std::to_string(node->GetNodeId()) + ") is not expected.");
285 }
286 
IsExtendedArg(int arg)287 bool IsExtendedArg(int arg) { return (IntToSize(arg) >> bits_per_byte) > 0; }
288 
GenerateInstr(ir::OpCode op,int arg)289 void ByteCodeGenerator::GenerateInstr(ir::OpCode op, int arg) {
290   if (IsExtendedArg(arg)) {
291     int ext_arg = SizeToInt((IntToSize(arg) >> bits_per_byte));
292     co_code_.push_back(MAKE_BYTE_CODE_UNIT(EXTENDED_ARG, ext_arg));
293     arg = SizeToInt((IntToSize(arg) & 0xff));
294   }
295   co_code_.push_back(MAKE_BYTE_CODE_UNIT(op, arg));
296 }
297 
298 // co_lnotab_ : A string encoding the mapping from bytecode offsets to line numbers.
299 // Elements are value pairs
300 // Value pair : the first one is offset of bytecode
301 //              the second one is the increment of the line number relative to the previous value pair.
302 // For example :
303 // co_firstlineno  (8)
304 // co_lnotab_ = {(0, 1), (6, 1), (9, 2)}
305 // (0, 1) ----> 0 : the first bytecode
306 //              1 : the line no. of first bytecode is (1 + 8) = 9
307 // (6, 1) ----> 6 : the seventh bytecode, means The line number of the second to sixth bytecodes is 9
308 //              1 : the line no. of seventh bytecode is (1 + 9) = 10
309 // (9, 2) ----> 9 : the tenth bytecode, means The line number of the eighth and ninth bytecodes is 10
310 //              1 : the line no. of tenth is bytecode (2 + 10) = 12
SetStartsLine(const ir::NodePtr & node)311 void ByteCodeGenerator::SetStartsLine(const ir::NodePtr &node) {
312   int new_line_no = node->GetDebugInfo()->GetLineNo();
313   if (new_line_no == 0) {
314     return;
315   }
316   size_t dis = 0;
317   int inc = new_line_no;
318   if (co_lnotab_.empty()) {
319     inc -= first_line_no_;
320   } else {
321     MS_EXCEPTION_IF_CHECK_FAIL(last_starts_instr_ != nullptr, "last_starts_instr_ should not be nullptr.");
322     dis = sizeof(_Py_CODEUNIT) * (node->GetOffset() - last_starts_instr_->GetOffset());
323     inc -= last_starts_instr_->GetDebugInfo()->GetLineNo();
324   }
325   last_starts_instr_ = node;
326   co_lnotab_.push_back(SizeToInt(dis));
327   co_lnotab_.push_back(inc);
328 }
329 }  // namespace pijit
330 }  // namespace mindspore
331