1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_capture/side_effect.h"
17 #include <algorithm>
18 #include <utility>
19 #include "pipeline/jit/pi/graph_capture/code_generator.h"
20 #include "pipeline/jit/pi/graph_capture/graph.h"
21
22 namespace mindspore {
23 namespace pijit {
24
GetSelfFromListAppendCall(ValueNode * call_node,bool * is_method_descriptor)25 ValueNode *GetSelfFromListAppendCall(ValueNode *call_node, bool *is_method_descriptor) {
26 ValueNode *method_node = call_node->input(0);
27 PyObject *method_object = method_node->GetVobj()->GetPyObject().ptr();
28 ValueNode *self = nullptr;
29 if (Py_IS_TYPE(method_object, &PyMethodDescr_Type)) {
30 self = call_node->input(1);
31 } else if (method_node->GetOpcode() == LOAD_ATTR) {
32 self = method_node->input(0);
33 }
34 if (is_method_descriptor != nullptr) {
35 *is_method_descriptor = Py_IS_TYPE(method_object, &PyMethodDescr_Type);
36 }
37 return self;
38 }
39
RecordModifiedAndReplacedNode(ValueNode * old_node,ValueNode * new_node)40 void SideEffectData::RecordModifiedAndReplacedNode(ValueNode *old_node, ValueNode *new_node) {
41 ValueNode **old_record = &modified_and_replaced_map_[new_node];
42 ValueNode *real_src = old_node;
43 const auto &m = modified_and_replaced_map_;
44 for (auto iter = m.find(real_src); iter != m.end(); iter = m.find(real_src)) {
45 real_src = iter->second;
46 }
47 *old_record = real_src;
48 }
49
AddAttrData(const std::string & name,ValueNode * src,ValueNode * new_attr)50 void SideEffectData::AddAttrData(const std::string &name, ValueNode *src, ValueNode *new_attr) {
51 auto &map = attr_cache_.modified_attrs_[src];
52 map[name] = new_attr;
53 }
54
AddGlobalData(const std::string & module_name,const std::string & name,ValueNode * node)55 void SideEffectData::AddGlobalData(const std::string &module_name, const std::string &name, ValueNode *node) {
56 auto &dict = global_cache_.modified_globals_[module_name];
57 dict[name] = node;
58 }
59
ClearCache()60 void SideEffectData::ClearCache() {
61 attr_cache_.modified_attrs_.clear();
62 global_cache_.modified_globals_.clear();
63 }
64
LoadAttr(ValueNode * src,const std::string & name) const65 SideEffect::CacheResult SideEffect::LoadAttr(ValueNode *src, const std::string &name) const {
66 const auto &cache = data_->attr_cache().modified_attrs_;
67 if (cache.empty()) {
68 return {}; // no attribute modified
69 }
70
71 CacheResult result{};
72 auto Find = [&cache, &name, &result](ValueNode *src_node) {
73 auto map_iter = cache.find(src_node);
74 if (map_iter == cache.end()) {
75 return false; // not find attr map of this node
76 }
77 auto attr_iter = map_iter->second.empty() ? map_iter->second.end() : map_iter->second.find(name);
78 if (attr_iter == map_iter->second.end()) {
79 return false; // not find attr of this node
80 }
81 result = {attr_iter->second, attr_iter->second == nullptr};
82 return true;
83 };
84
85 PyObject *src_object = src->GetVobj() ? src->GetVobj()->GetPyObject().ptr() : nullptr;
86 if (src_object == nullptr) {
87 Find(src);
88 } else {
89 auto iter = data()->id_map().find(src_object);
90 MS_EXCEPTION_IF_CHECK_FAIL(iter != data()->id_map().end(), "can't find the node of object");
91 (void)std::find_if(iter->second.begin(), iter->second.end(), Find);
92 }
93 return result;
94 }
95
LoadGlobal(const std::string & module_name,const std::string & name) const96 SideEffect::CacheResult SideEffect::LoadGlobal(const std::string &module_name, const std::string &name) const {
97 const auto &cache = data_->global_cache().modified_globals_;
98 if (cache.empty()) {
99 return {}; // no global modified
100 }
101 auto m_iter = cache.find(module_name);
102 if (m_iter == cache.end()) {
103 return {}; // this module global not modified
104 }
105 auto value_iter = m_iter->second.find(name);
106 if (value_iter == m_iter->second.end()) {
107 return {}; // this name not modified
108 }
109 return {value_iter->second, value_iter->second == nullptr};
110 }
111
GetRequiredNodes() const112 const std::set<ValueNode *> &SideEffect::GetRequiredNodes() const { return keep_alive_; }
113
Record(ValueNode * node,Type type)114 bool SideEffect::Record(ValueNode *node, Type type) {
115 int opcode = node->GetOpcode();
116 if (opcode == STORE_ATTR || opcode == DELETE_ATTR) {
117 ValueNode *src_node = opcode == DELETE_ATTR ? node->input(0) : node->input(1);
118 ValueNode *attr_node = opcode == DELETE_ATTR ? nullptr : node->input(0);
119 data_->AddAttrData(node->GetName(), src_node, attr_node);
120 type = kSetAttr;
121 } else if (opcode == STORE_GLOBAL || opcode == DELETE_GLOBAL) {
122 MS_EXCEPTION_IF_NULL(node->GetGraph());
123 ValueNode *new_value = opcode == DELETE_GLOBAL ? nullptr : node->input(0);
124 std::string module_name = node->GetGraph()->GetModuleName();
125 if (module_name.empty()) {
126 return false; // empty module name, unknown global source
127 }
128 data_->AddGlobalData(module_name, node->GetName(), new_value);
129 type = kSetGlobal;
130 } else if (opcode == STORE_SUBSCR || opcode == DELETE_SUBSCR) {
131 type = kDefault;
132 } else if (Opcode(opcode).IsCall() && RecordFuncCall(node, type)) {
133 } else {
134 MS_LOG(INFO) << "unimplemented side-effect " << node->ToString();
135 return false;
136 }
137 size_t order_index = nodes_.size();
138 nodes_[node] = {type, order_index};
139 AddKeepAlive(GetKeepAlive(node, type));
140 return true;
141 }
142
RecordFuncCall(ValueNode * node,Type type)143 bool SideEffect::RecordFuncCall(ValueNode *node, Type type) {
144 if (type == kDefault) {
145 return true;
146 }
147 if (type == kSetAttr) { // only builtin-function getattr
148 size_t index = 1;
149 ValueNode *src_node = node->input(index++);
150 py::object name = node->input(index++)->GetVobj()->GetPyObject();
151 ValueNode *attr_node = node->getInputs().size() == index ? nullptr : node->input(index);
152 data_->AddAttrData(PyUnicode_AsUTF8(name.ptr()), src_node, attr_node);
153 return true;
154 }
155 // check list.append, dict.pop, list.__setitem__, dict.__setitem__
156 if (GetSelfFromListAppendCall(node) != nullptr) {
157 return true;
158 }
159 return false;
160 }
161
GetKeepAlive(ValueNode * node,Type type) const162 std::vector<ValueNode *> SideEffect::GetKeepAlive(ValueNode *node, Type type) const {
163 int opcode = node->GetOpcode();
164 std::vector<ValueNode *> alive = node->getInputs();
165 if (Opcode(opcode).IsCall() && type >= kListSetItem) {
166 alive[0] = GetSelfFromListAppendCall(node); // replace function
167 }
168 auto erase_iter = alive.begin();
169 for (auto iter = erase_iter; iter != alive.end(); ++iter) {
170 if (!IsNonLocalValue(*iter)) {
171 *erase_iter = GetSource(*iter);
172 ++erase_iter;
173 }
174 }
175 alive.erase(erase_iter, alive.end());
176 return alive;
177 }
178
ResetRecord(const std::set<ValueNode * > & nodes_set)179 void SideEffect::ResetRecord(const std::set<ValueNode *> &nodes_set) {
180 // remove if record not find in final node set
181 auto size = nodes_.size();
182 for (auto iter = nodes_.begin(), end = nodes_.end(); iter != end;) {
183 iter = nodes_set.find(iter->first) == nodes_set.end() ? nodes_.erase(iter) : (++iter);
184 }
185 if (size == nodes_.size()) {
186 return;
187 }
188 // sort
189 std::map<int, std::pair<ValueNode *, Type>> ordered_nodes;
190 for (const auto &i : nodes_) {
191 ordered_nodes[i.second.order_] = {i.first, i.second.type_};
192 }
193 // rollback
194 keep_alive_.clear();
195 nodes_.clear();
196 data_->ClearCache();
197 for (const auto &i : ordered_nodes) {
198 this->Record(i.second.first, i.second.second);
199 }
200 }
201
Restore(CodeGenerator * cg) const202 void SideEffect::Restore(CodeGenerator *cg) const {
203 if (nodes_.empty()) {
204 return;
205 }
206 std::vector<std::pair<ValueNode *, Type>> ordered_nodes(nodes_.size());
207 for (const auto &i : nodes_) {
208 ordered_nodes[i.second.order_] = {i.first, i.second.type_};
209 }
210 for (const auto &pair : ordered_nodes) {
211 if (pair.second != SideEffect::kSetAttr && pair.second != SideEffect::kSetGlobal) {
212 RestoreEntry(cg, pair.first, pair.second);
213 }
214 }
215 RestoreAttrs(cg);
216 RestoreGlobal(cg);
217 }
218
RestoreEntry(CodeGenerator * cg,ValueNode * node,Type type) const219 void SideEffect::RestoreEntry(CodeGenerator *cg, ValueNode *node, Type type) const {
220 if (type != kDefault) {
221 RestoreSpecializeEntry(cg, node, type);
222 return;
223 }
224 int opcode = node->GetOpcode();
225 int oparg = node->GetOparg();
226 for (const auto &i : node->getInputs()) {
227 cg->LoadValue(GetSource(i));
228 }
229 cg->NewInstr(opcode, oparg);
230 if (Opcode(node->GetOpcode()).IsCall()) {
231 cg->NewInstr(POP_TOP);
232 }
233 }
234
MakeAttrModify(CodeGenerator * cg,const std::string & name,ValueNode * src_node,ValueNode * value)235 static void MakeAttrModify(CodeGenerator *cg, const std::string &name, ValueNode *src_node, ValueNode *value) {
236 auto instr = std::make_unique<Instr>(STORE_ATTR, 0, name);
237 if (value != nullptr) {
238 cg->LoadValue(value);
239 cg->LoadValue(src_node);
240 } else {
241 cg->LoadValue(src_node);
242 instr->set_op(DELETE_ATTR);
243 }
244 cg->AddInstr(std::move(instr));
245 }
246
MakeModuleAttrModify(CodeGenerator * cg,const std::string & name,const py::object & mod,ValueNode * value)247 static void MakeModuleAttrModify(CodeGenerator *cg, const std::string &name, const py::object &mod, ValueNode *value) {
248 auto instr = std::make_unique<Instr>(STORE_ATTR, 0, name);
249 if (value != nullptr) {
250 cg->LoadValue(value);
251 cg->LoadConst(mod);
252 } else {
253 cg->LoadConst(mod);
254 instr->set_op(DELETE_ATTR);
255 }
256 cg->AddInstr(std::move(instr));
257 }
258
RestoreAttrs(CodeGenerator * cg) const259 void SideEffect::RestoreAttrs(CodeGenerator *cg) const {
260 if (data()->attr_cache().modified_attrs_.empty()) {
261 return;
262 }
263 for (const auto &map : data()->attr_cache().modified_attrs_) {
264 const auto &src_node = GetSource(map.first);
265 for (const auto &pair : map.second) {
266 MakeAttrModify(cg, pair.first, src_node, GetSource(pair.second));
267 }
268 }
269 }
270
RestoreGlobal(CodeGenerator * cg) const271 void SideEffect::RestoreGlobal(CodeGenerator *cg) const {
272 if (data()->global_cache().modified_globals_.empty()) {
273 return;
274 }
275 PyObject *tmp = PyDict_GetItemString(cg->GetGlobals().ptr(), "__name__");
276 const char *cur_module_name = tmp == nullptr ? "" : PyUnicode_AsUTF8(tmp);
277
278 for (const auto &map : data()->global_cache().modified_globals_) {
279 const auto &module_name = map.first;
280 if (module_name != cur_module_name) {
281 py::object module_object = py::reinterpret_steal<py::object>(PyImport_ImportModule(module_name.c_str()));
282 for (const auto &pair : map.second) {
283 MakeModuleAttrModify(cg, pair.first, module_object, GetSource(pair.second));
284 }
285 continue;
286 }
287 for (const auto &pair : map.second) {
288 auto instr = std::make_unique<Instr>(STORE_GLOBAL, 0, pair.first);
289 if (pair.second != nullptr) {
290 cg->LoadValue(GetSource(pair.second));
291 } else {
292 instr->set_op(DELETE_GLOBAL);
293 }
294 cg->AddInstr(std::move(instr));
295 }
296 }
297 }
298
GetSource(ValueNode * src_node) const299 ValueNode *SideEffect::GetSource(ValueNode *src_node) const {
300 const auto &map = data()->modified_and_replaced_map();
301 if (map.empty() || src_node == nullptr) {
302 return src_node;
303 }
304 auto iter = map.find(src_node);
305 return iter != map.end() ? iter->second : src_node;
306 }
307
RestoreSpecializeEntry(CodeGenerator * cg,ValueNode * node,Type type) const308 void SideEffect::RestoreSpecializeEntry(CodeGenerator *cg, ValueNode *node, Type type) const {
309 MS_EXCEPTION_IF_CHECK_FAIL(type >= kListSetItem && type <= kDictPop, "not implemented function");
310 constexpr const char *name_map[] = {"__setitem__", "__setitem__", "append", "pop"};
311 static_assert(kDictPop - kListSetItem + 1 == sizeof(name_map) / sizeof(name_map[0]));
312 std::string method_name = name_map[type - kListSetItem];
313
314 bool is_method_descriptor = false;
315 auto self = GetSelfFromListAppendCall(node, &is_method_descriptor);
316 cg->LoadValue(GetSource(self));
317 cg->AddInstr(std::make_unique<Instr>(LOAD_ATTR, 0, method_name));
318 for (size_t i = 1 + is_method_descriptor; i < node->getInputs().size(); ++i) {
319 cg->LoadValue(GetSource(node->input(i)));
320 }
321 cg->NewInstr(node->GetOpcode(), node->getInputs().size() - 1);
322 cg->NewInstr(POP_TOP);
323 }
324
Optimize(const std::vector<ValueNode * > & alive_locals)325 void SideEffect::Optimize(const std::vector<ValueNode *> &alive_locals) {
326 /**
327 * check data_.unique(), validate record is all in final nodes set......
328 */
329 // liveness analysis, remove dead local side-effect
330 // not implement
331 // merge dict, list modify operations
332 // not implement
333 }
334
335 } // namespace pijit
336 } // namespace mindspore
337