1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_compiler/cg/byte_code_generator.h"
17 #include <memory>
18 #include <string>
19 #include <algorithm>
20 #include "utils/log_adapter.h"
21
22 namespace mindspore {
23 namespace pijit {
24 // the number of bits per byte
25 constexpr unsigned int bits_per_byte = 8;
26
27 #ifndef MAKE_BYTE_CODE_UNIT
28 #ifdef WORDS_BIGENDIAN
29 #define MAKE_BYTE_CODE_UNIT(op, arg) (((op) << bits_per_byte) | (arg))
30 #else
31 #define MAKE_BYTE_CODE_UNIT(op, arg) ((op) | ((arg) << bits_per_byte))
32 #endif
33 #endif
34
GenFunction(const ir::FunctionNodePtr & func)35 py::object ByteCodeGenerator::GenFunction(const ir::FunctionNodePtr &func) {
36 ByteCodeGeneratorPtr generator = std::make_shared<ByteCodeGenerator>();
37 auto use_global = func->GetUseGlobal();
38 if (use_global != nullptr) {
39 generator->globals_ = use_global->cast<ir::ValuePtr>()->GetValue().cast<py::dict>();
40 }
41 return generator->Generate(func);
42 }
43
Generate(const ir::FunctionNodePtr & func)44 py::object ByteCodeGenerator::Generate(const ir::FunctionNodePtr &func) {
45 cell_var_cnt_ = CellVarCounter::GetCount(func);
46 co_consts_.append(py::none());
47 first_line_no_ = func->GetFirstLineNo();
48 func->Sort();
49 Visit(func);
50 auto byte_code = py::reinterpret_steal<py::object>(
51 PyBytes_FromStringAndSize((const char *)co_code_.data(), co_code_.size() * sizeof(co_code_[0])));
52 auto lnotab = py::reinterpret_steal<py::object>(
53 PyBytes_FromStringAndSize(co_lnotab_.data(), co_lnotab_.size() * sizeof(co_lnotab_[0])));
54 auto var_names = py::cast<py::tuple>(co_var_names_);
55 auto consts = py::cast<py::tuple>(co_consts_);
56 auto names = py::cast<py::tuple>(co_names_);
57 auto free_vars = py::cast<py::tuple>(co_free_vars_);
58 auto cell_vars = py::cast<py::tuple>(co_cell_vars_);
59 PyCodeObject *code = PyCode_New(func->GetPosArgsCnt(), func->GetKwOnlyArgsCnt(), var_names.size(),
60 func->GetStackSize(), func->GetFlags(), byte_code.ptr(), consts.ptr(), names.ptr(),
61 var_names.ptr(), free_vars.ptr(), cell_vars.ptr(), py::str(func->GetFileName()).ptr(),
62 py::str(func->GetName()).ptr(), func->GetFirstLineNo(), lnotab.ptr());
63 globals_[py::str("__builtins__")] = builtins_.ptr();
64 auto function = py::reinterpret_steal<py::object>(PyFunction_New(reinterpret_cast<PyObject *>(code), globals_.ptr()));
65 Py_DECREF(code);
66 auto tuple = py::cast<py::tuple>(defaults_);
67 (void)PyFunction_SetDefaults(function.ptr(), tuple.ptr());
68 tuple = py::cast<py::tuple>(clousre_);
69 (void)PyFunction_SetClosure(function.ptr(), tuple.ptr());
70 return function;
71 }
72
Visit_(const ir::ParameterPtr & node)73 void ByteCodeGenerator::Visit_(const ir::ParameterPtr &node) {
74 const std::string name = node->GetName();
75 MS_EXCEPTION_IF_CHECK_FAIL((co_var_names_map_.find(name) == co_var_names_map_.end()),
76 "Duplicate parameter name " + name + ".");
77 co_var_names_map_[name] = SizeToInt(co_var_names_.size());
78 co_var_names_.append(py::str(name));
79 ir::NodePtr default_value = node->GetDefaultValue();
80 if (default_value != nullptr) {
81 if (node->GetCategory() == ir::Parameter::KEYWORD_ONLY) {
82 kwdefaults_[co_var_names_[co_var_names_map_[name]]] = default_value;
83 } else {
84 MS_EXCEPTION_IF_CHECK_FAIL(node->GetCategory() == 0, "Error category of parameter.");
85 defaults_.append(default_value->cast<ir::ValuePtr>()->GetValue());
86 }
87 }
88 }
89
90 #define DEFINE_UN_NODE_VISIT_(OP) \
91 void ByteCodeGenerator::Visit_(const OP &node) { \
92 Visit(node->GetArg()); \
93 CheckInstrOffset(node); \
94 GenerateInstr(node->GetOpCode()); \
95 SetStartsLine(node); \
96 }
97
98 DEFINE_UN_NODE_VISIT_(ir::UnaryOperationPtr)
DEFINE_UN_NODE_VISIT_(ir::NegativeNodePtr)99 DEFINE_UN_NODE_VISIT_(ir::NegativeNodePtr)
100 DEFINE_UN_NODE_VISIT_(ir::NotNodePtr)
101 DEFINE_UN_NODE_VISIT_(ir::InvertNodePtr)
102 DEFINE_UN_NODE_VISIT_(ir::ReturnNodePtr)
103 DEFINE_UN_NODE_VISIT_(ir::CastNodePtr)
104 DEFINE_UN_NODE_VISIT_(ir::GetNodePtr)
105
106 #define DEFINE_BIN_NODE_VISIT_(OP) \
107 void ByteCodeGenerator::Visit_(const OP &node) { \
108 Visit(node->GetLeftArg()); \
109 Visit(node->GetRightArg()); \
110 CheckInstrOffset(node); \
111 GenerateInstr(node->GetOpCode()); \
112 SetStartsLine(node); \
113 }
114
115 DEFINE_BIN_NODE_VISIT_(ir::BinaryOperationPtr)
116 DEFINE_BIN_NODE_VISIT_(ir::AddNodePtr)
117 DEFINE_BIN_NODE_VISIT_(ir::SubNodePtr)
118 DEFINE_BIN_NODE_VISIT_(ir::MulNodePtr)
119 DEFINE_BIN_NODE_VISIT_(ir::DivNodePtr)
120 DEFINE_BIN_NODE_VISIT_(ir::BitwiseNodePtr)
121
122 void ByteCodeGenerator::Visit_(const ir::ValuePtr &node) { (void)GetValueIndex(node); }
123
Visit_(const ir::NaryOperationPtr & node)124 void ByteCodeGenerator::Visit_(const ir::NaryOperationPtr &node) {
125 VISIT_NODE_LIST(node->GetArgs())
126 CheckInstrOffset(node);
127 GenerateInstr(node->GetOpCode(), node->GetArgsCnt());
128 SetStartsLine(node);
129 }
130
Visit_(const ir::DeleteNodePtr & node)131 void ByteCodeGenerator::Visit_(const ir::DeleteNodePtr &node) {
132 Visit(node->GetArg());
133 int arg = 0;
134 if (node->GetOpCode() != DELETE_SUBSCR) {
135 MS_EXCEPTION_IF_CHECK_FAIL(node->GetArg()->isa<ir::Value>(), "Expect delete a value.");
136 arg = GetValueIndex(node->GetArg()->cast<ir::ValuePtr>());
137 }
138 CheckInstrOffset(node);
139 GenerateInstr(node->GetOpCode(), arg);
140 SetStartsLine(node);
141 }
142
Visit_(const ir::FormatNodePtr & node)143 void ByteCodeGenerator::Visit_(const ir::FormatNodePtr &node) {
144 VISIT_NODE_LIST(node->GetArgs())
145 CheckInstrOffset(node);
146 GenerateInstr(node->GetOpCode(), node->GetFormatType());
147 SetStartsLine(node);
148 }
149
Visit_(const ir::IsNodePtr & node)150 void ByteCodeGenerator::Visit_(const ir::IsNodePtr &node) {
151 Visit(node->GetLeftArg());
152 Visit(node->GetRightArg());
153 CheckInstrOffset(node);
154 GenerateInstr(node->GetOpCode(), node->IsInvert());
155 SetStartsLine(node);
156 }
157
Visit_(const ir::ContainsNodePtr & node)158 void ByteCodeGenerator::Visit_(const ir::ContainsNodePtr &node) {
159 Visit(node->GetLeftArg());
160 Visit(node->GetRightArg());
161 CheckInstrOffset(node);
162 GenerateInstr(node->GetOpCode(), node->IsInvert());
163 SetStartsLine(node);
164 }
165
Visit_(const ir::StoreNodePtr & node)166 void ByteCodeGenerator::Visit_(const ir::StoreNodePtr &node) {
167 Visit(node->GetLeftArg());
168 Visit(node->GetRightArg());
169 ir::NodePtr target = node->GetRightArg();
170 if (target->isa<ir::AttrNode>()) {
171 target = target->cast<ir::AttrNodePtr>()->GetAttr();
172 }
173 int arg = 0;
174 if (!target->isa<ir::SubscrNode>()) {
175 MS_EXCEPTION_IF_CHECK_FAIL(target->isa<ir::Value>(), "Expect store to a var.");
176 arg = GetValueIndex(target->cast<ir::ValuePtr>());
177 }
178 CheckInstrOffset(node);
179 GenerateInstr(node->GetOpCode(), arg);
180 SetStartsLine(node);
181 }
182
Visit_(const ir::JumpNodePtr & node)183 void ByteCodeGenerator::Visit_(const ir::JumpNodePtr &node) {
184 ir::IRVisitor::Visit_(node);
185 ir::OpCode op = node->GetOpCode();
186 size_t arg = node->GetRightArg()->GetOffset() * 2;
187 if (op == JUMP_FORWARD || op == FOR_ITER) {
188 arg -= (node->GetOffset() + 1) * 2;
189 }
190 CheckInstrOffset(node);
191 GenerateInstr(node->GetOpCode(), SizeToInt(arg));
192 SetStartsLine(node);
193 }
194
Visit_(const ir::CompareNodePtr & node)195 void ByteCodeGenerator::Visit_(const ir::CompareNodePtr &node) {
196 Visit(node->GetLeftArg());
197 Visit(node->GetRightArg());
198 CheckInstrOffset(node);
199 GenerateInstr(node->GetOpCode(), node->GetInstrArg());
200 SetStartsLine(node);
201 }
202
Visit_(const ir::UpdateNodePtr & node)203 void ByteCodeGenerator::Visit_(const ir::UpdateNodePtr &node) {
204 Visit(node->GetLeftArg());
205 Visit(node->GetRightArg());
206 CheckInstrOffset(node);
207 GenerateInstr(node->GetOpCode(), node->GetInstrArg());
208 SetStartsLine(node);
209 }
210
Visit_(const ir::LoadValueNodePtr & node)211 void ByteCodeGenerator::Visit_(const ir::LoadValueNodePtr &node) {
212 VISIT_NODE_LIST(node->GetArgs())
213 int arg = GetValueIndex(node->GetArg()->cast<ir::ValuePtr>());
214 CheckInstrOffset(node);
215 GenerateInstr(node->GetOpCode(), arg);
216 SetStartsLine(node);
217 }
218
Visit_(const ir::LoadFieldNodePtr & node)219 void ByteCodeGenerator::Visit_(const ir::LoadFieldNodePtr &node) {
220 VISIT_NODE_LIST(node->GetArgs())
221 int arg = GetValueIndex(node->GetArg(1)->cast<ir::ValuePtr>());
222 CheckInstrOffset(node);
223 GenerateInstr(node->GetOpCode(), arg);
224 SetStartsLine(node);
225 }
226
Visit_(const ir::BuildNodePtr & node)227 void ByteCodeGenerator::Visit_(const ir::BuildNodePtr &node) {
228 VISIT_NODE_LIST(node->GetArgs())
229 CheckInstrOffset(node);
230 size_t arg = node->GetArgsCnt();
231 if (node->GetOpCode() == BUILD_CONST_KEY_MAP) {
232 arg--;
233 }
234 GenerateInstr(node->GetOpCode(), SizeToInt(arg));
235 SetStartsLine(node);
236 }
237
Visit_(const ir::CallNodePtr & node)238 void ByteCodeGenerator::Visit_(const ir::CallNodePtr &node) {
239 VISIT_NODE_LIST(node->GetArgs())
240 size_t arg = node->GetArgsCnt() - 1;
241 if (node->GetOpCode() == CALL_FUNCTION_KW || node->GetOpCode() == CALL_FUNCTION_EX) {
242 arg--;
243 }
244 CheckInstrOffset(node);
245 GenerateInstr(node->GetOpCode(), SizeToInt(arg));
246 SetStartsLine(node);
247 }
248
Visit_(const ir::NaryWithFlagNodePtr & node)249 void ByteCodeGenerator::Visit_(const ir::NaryWithFlagNodePtr &node) {
250 VISIT_NODE_LIST(node->GetArgs())
251 CheckInstrOffset(node);
252 GenerateInstr(node->GetOpCode(), node->GetFlag());
253 SetStartsLine(node);
254 }
255
GetValueIndex(const ir::ValuePtr & node)256 int ByteCodeGenerator::GetValueIndex(const ir::ValuePtr &node) {
257 auto scope = node->GetScope();
258 MS_EXCEPTION_IF_CHECK_FAIL(scope_inquire_map_.find(scope) != scope_inquire_map_.end(),
259 "Invalid scope in " + node->ToString());
260 auto name_map = scope_inquire_map_.at(scope);
261 auto name = node->GetName();
262 if (name_map->find(name) != name_map->end()) {
263 return name_map->at(name);
264 }
265 auto values = scope_value_list_.at(scope);
266 (*name_map)[name] = SizeToInt(values.first.size());
267 if (values.first.is(values.second)) {
268 values.first.append(node->GetValue());
269 } else {
270 auto obj = py::str(name);
271 values.first.append(obj);
272 if (scope == ir::kScopeClousre) {
273 (py::cast<py::list>(values.second)).append(node->GetValue());
274 } else {
275 (py::cast<py::dict>(values.second))[obj] = node->GetValue();
276 }
277 }
278 return name_map->at(name);
279 }
280
CheckInstrOffset(const ir::NodePtr & node)281 void ByteCodeGenerator::CheckInstrOffset(const ir::NodePtr &node) {
282 MS_EXCEPTION_IF_CHECK_FAIL(
283 node->GetOffset() - (node->NeedExtInstr() ? 1 : 0) == co_code_.size(),
284 "The offset of " + node->GetNodeName() + "(%" + std::to_string(node->GetNodeId()) + ") is not expected.");
285 }
286
IsExtendedArg(int arg)287 bool IsExtendedArg(int arg) { return (IntToSize(arg) >> bits_per_byte) > 0; }
288
GenerateInstr(ir::OpCode op,int arg)289 void ByteCodeGenerator::GenerateInstr(ir::OpCode op, int arg) {
290 if (IsExtendedArg(arg)) {
291 int ext_arg = SizeToInt((IntToSize(arg) >> bits_per_byte));
292 co_code_.push_back(MAKE_BYTE_CODE_UNIT(EXTENDED_ARG, ext_arg));
293 arg = SizeToInt((IntToSize(arg) & 0xff));
294 }
295 co_code_.push_back(MAKE_BYTE_CODE_UNIT(op, arg));
296 }
297
298 // co_lnotab_ : A string encoding the mapping from bytecode offsets to line numbers.
299 // Elements are value pairs
300 // Value pair : the first one is offset of bytecode
301 // the second one is the increment of the line number relative to the previous value pair.
302 // For example :
303 // co_firstlineno (8)
304 // co_lnotab_ = {(0, 1), (6, 1), (9, 2)}
305 // (0, 1) ----> 0 : the first bytecode
306 // 1 : the line no. of first bytecode is (1 + 8) = 9
307 // (6, 1) ----> 6 : the seventh bytecode, means The line number of the second to sixth bytecodes is 9
308 // 1 : the line no. of seventh bytecode is (1 + 9) = 10
309 // (9, 2) ----> 9 : the tenth bytecode, means The line number of the eighth and ninth bytecodes is 10
310 // 1 : the line no. of tenth is bytecode (2 + 10) = 12
SetStartsLine(const ir::NodePtr & node)311 void ByteCodeGenerator::SetStartsLine(const ir::NodePtr &node) {
312 int new_line_no = node->GetDebugInfo()->GetLineNo();
313 if (new_line_no == 0) {
314 return;
315 }
316 size_t dis = 0;
317 int inc = new_line_no;
318 if (co_lnotab_.empty()) {
319 inc -= first_line_no_;
320 } else {
321 MS_EXCEPTION_IF_CHECK_FAIL(last_starts_instr_ != nullptr, "last_starts_instr_ should not be nullptr.");
322 dis = sizeof(_Py_CODEUNIT) * (node->GetOffset() - last_starts_instr_->GetOffset());
323 inc -= last_starts_instr_->GetDebugInfo()->GetLineNo();
324 }
325 last_starts_instr_ = node;
326 co_lnotab_.push_back(SizeToInt(dis));
327 co_lnotab_.push_back(inc);
328 }
329 } // namespace pijit
330 } // namespace mindspore
331