• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_capture/side_effect.h"
17 #include <algorithm>
18 #include <utility>
19 #include "pipeline/jit/pi/graph_capture/code_generator.h"
20 #include "pipeline/jit/pi/graph_capture/graph.h"
21 
22 namespace mindspore {
23 namespace pijit {
24 
GetSelfFromListAppendCall(ValueNode * call_node,bool * is_method_descriptor)25 ValueNode *GetSelfFromListAppendCall(ValueNode *call_node, bool *is_method_descriptor) {
26   ValueNode *method_node = call_node->input(0);
27   PyObject *method_object = method_node->GetVobj()->GetPyObject().ptr();
28   ValueNode *self = nullptr;
29   if (Py_IS_TYPE(method_object, &PyMethodDescr_Type)) {
30     self = call_node->input(1);
31   } else if (method_node->GetOpcode() == LOAD_ATTR) {
32     self = method_node->input(0);
33   }
34   if (is_method_descriptor != nullptr) {
35     *is_method_descriptor = Py_IS_TYPE(method_object, &PyMethodDescr_Type);
36   }
37   return self;
38 }
39 
RecordModifiedAndReplacedNode(ValueNode * old_node,ValueNode * new_node)40 void SideEffectData::RecordModifiedAndReplacedNode(ValueNode *old_node, ValueNode *new_node) {
41   ValueNode **old_record = &modified_and_replaced_map_[new_node];
42   ValueNode *real_src = old_node;
43   const auto &m = modified_and_replaced_map_;
44   for (auto iter = m.find(real_src); iter != m.end(); iter = m.find(real_src)) {
45     real_src = iter->second;
46   }
47   *old_record = real_src;
48 }
49 
AddAttrData(const std::string & name,ValueNode * src,ValueNode * new_attr)50 void SideEffectData::AddAttrData(const std::string &name, ValueNode *src, ValueNode *new_attr) {
51   auto &map = attr_cache_.modified_attrs_[src];
52   map[name] = new_attr;
53 }
54 
AddGlobalData(const std::string & module_name,const std::string & name,ValueNode * node)55 void SideEffectData::AddGlobalData(const std::string &module_name, const std::string &name, ValueNode *node) {
56   auto &dict = global_cache_.modified_globals_[module_name];
57   dict[name] = node;
58 }
59 
ClearCache()60 void SideEffectData::ClearCache() {
61   attr_cache_.modified_attrs_.clear();
62   global_cache_.modified_globals_.clear();
63 }
64 
LoadAttr(ValueNode * src,const std::string & name) const65 SideEffect::CacheResult SideEffect::LoadAttr(ValueNode *src, const std::string &name) const {
66   const auto &cache = data_->attr_cache().modified_attrs_;
67   if (cache.empty()) {
68     return {};  // no attribute modified
69   }
70 
71   CacheResult result{};
72   auto Find = [&cache, &name, &result](ValueNode *src_node) {
73     auto map_iter = cache.find(src_node);
74     if (map_iter == cache.end()) {
75       return false;  // not find attr map of this node
76     }
77     auto attr_iter = map_iter->second.empty() ? map_iter->second.end() : map_iter->second.find(name);
78     if (attr_iter == map_iter->second.end()) {
79       return false;  // not find attr of this node
80     }
81     result = {attr_iter->second, attr_iter->second == nullptr};
82     return true;
83   };
84 
85   PyObject *src_object = src->GetVobj() ? src->GetVobj()->GetPyObject().ptr() : nullptr;
86   if (src_object == nullptr) {
87     Find(src);
88   } else {
89     auto iter = data()->id_map().find(src_object);
90     MS_EXCEPTION_IF_CHECK_FAIL(iter != data()->id_map().end(), "can't find the node of object");
91     (void)std::find_if(iter->second.begin(), iter->second.end(), Find);
92   }
93   return result;
94 }
95 
LoadGlobal(const std::string & module_name,const std::string & name) const96 SideEffect::CacheResult SideEffect::LoadGlobal(const std::string &module_name, const std::string &name) const {
97   const auto &cache = data_->global_cache().modified_globals_;
98   if (cache.empty()) {
99     return {};  // no global modified
100   }
101   auto m_iter = cache.find(module_name);
102   if (m_iter == cache.end()) {
103     return {};  // this module global not modified
104   }
105   auto value_iter = m_iter->second.find(name);
106   if (value_iter == m_iter->second.end()) {
107     return {};  // this name not modified
108   }
109   return {value_iter->second, value_iter->second == nullptr};
110 }
111 
GetRequiredNodes() const112 const std::set<ValueNode *> &SideEffect::GetRequiredNodes() const { return keep_alive_; }
113 
Record(ValueNode * node,Type type)114 bool SideEffect::Record(ValueNode *node, Type type) {
115   int opcode = node->GetOpcode();
116   if (opcode == STORE_ATTR || opcode == DELETE_ATTR) {
117     ValueNode *src_node = opcode == DELETE_ATTR ? node->input(0) : node->input(1);
118     ValueNode *attr_node = opcode == DELETE_ATTR ? nullptr : node->input(0);
119     data_->AddAttrData(node->GetName(), src_node, attr_node);
120     type = kSetAttr;
121   } else if (opcode == STORE_GLOBAL || opcode == DELETE_GLOBAL) {
122     MS_EXCEPTION_IF_NULL(node->GetGraph());
123     ValueNode *new_value = opcode == DELETE_GLOBAL ? nullptr : node->input(0);
124     std::string module_name = node->GetGraph()->GetModuleName();
125     if (module_name.empty()) {
126       return false;  // empty module name, unknown global source
127     }
128     data_->AddGlobalData(module_name, node->GetName(), new_value);
129     type = kSetGlobal;
130   } else if (opcode == STORE_SUBSCR || opcode == DELETE_SUBSCR) {
131     type = kDefault;
132   } else if (Opcode(opcode).IsCall() && RecordFuncCall(node, type)) {
133   } else {
134     MS_LOG(INFO) << "unimplemented side-effect " << node->ToString();
135     return false;
136   }
137   size_t order_index = nodes_.size();
138   nodes_[node] = {type, order_index};
139   AddKeepAlive(GetKeepAlive(node, type));
140   return true;
141 }
142 
RecordFuncCall(ValueNode * node,Type type)143 bool SideEffect::RecordFuncCall(ValueNode *node, Type type) {
144   if (type == kDefault) {
145     return true;
146   }
147   if (type == kSetAttr) {  // only builtin-function getattr
148     size_t index = 1;
149     ValueNode *src_node = node->input(index++);
150     py::object name = node->input(index++)->GetVobj()->GetPyObject();
151     ValueNode *attr_node = node->getInputs().size() == index ? nullptr : node->input(index);
152     data_->AddAttrData(PyUnicode_AsUTF8(name.ptr()), src_node, attr_node);
153     return true;
154   }
155   // check list.append, dict.pop, list.__setitem__, dict.__setitem__
156   if (GetSelfFromListAppendCall(node) != nullptr) {
157     return true;
158   }
159   return false;
160 }
161 
GetKeepAlive(ValueNode * node,Type type) const162 std::vector<ValueNode *> SideEffect::GetKeepAlive(ValueNode *node, Type type) const {
163   int opcode = node->GetOpcode();
164   std::vector<ValueNode *> alive = node->getInputs();
165   if (Opcode(opcode).IsCall() && type >= kListSetItem) {
166     alive[0] = GetSelfFromListAppendCall(node);  // replace function
167   }
168   auto erase_iter = alive.begin();
169   for (auto iter = erase_iter; iter != alive.end(); ++iter) {
170     if (!IsNonLocalValue(*iter)) {
171       *erase_iter = GetSource(*iter);
172       ++erase_iter;
173     }
174   }
175   alive.erase(erase_iter, alive.end());
176   return alive;
177 }
178 
ResetRecord(const std::set<ValueNode * > & nodes_set)179 void SideEffect::ResetRecord(const std::set<ValueNode *> &nodes_set) {
180   // remove if record not find in final node set
181   auto size = nodes_.size();
182   for (auto iter = nodes_.begin(), end = nodes_.end(); iter != end;) {
183     iter = nodes_set.find(iter->first) == nodes_set.end() ? nodes_.erase(iter) : (++iter);
184   }
185   if (size == nodes_.size()) {
186     return;
187   }
188   // sort
189   std::map<int, std::pair<ValueNode *, Type>> ordered_nodes;
190   for (const auto &i : nodes_) {
191     ordered_nodes[i.second.order_] = {i.first, i.second.type_};
192   }
193   // rollback
194   keep_alive_.clear();
195   nodes_.clear();
196   data_->ClearCache();
197   for (const auto &i : ordered_nodes) {
198     this->Record(i.second.first, i.second.second);
199   }
200 }
201 
Restore(CodeGenerator * cg) const202 void SideEffect::Restore(CodeGenerator *cg) const {
203   if (nodes_.empty()) {
204     return;
205   }
206   std::vector<std::pair<ValueNode *, Type>> ordered_nodes(nodes_.size());
207   for (const auto &i : nodes_) {
208     ordered_nodes[i.second.order_] = {i.first, i.second.type_};
209   }
210   for (const auto &pair : ordered_nodes) {
211     if (pair.second != SideEffect::kSetAttr && pair.second != SideEffect::kSetGlobal) {
212       RestoreEntry(cg, pair.first, pair.second);
213     }
214   }
215   RestoreAttrs(cg);
216   RestoreGlobal(cg);
217 }
218 
RestoreEntry(CodeGenerator * cg,ValueNode * node,Type type) const219 void SideEffect::RestoreEntry(CodeGenerator *cg, ValueNode *node, Type type) const {
220   if (type != kDefault) {
221     RestoreSpecializeEntry(cg, node, type);
222     return;
223   }
224   int opcode = node->GetOpcode();
225   int oparg = node->GetOparg();
226   for (const auto &i : node->getInputs()) {
227     cg->LoadValue(GetSource(i));
228   }
229   cg->NewInstr(opcode, oparg);
230   if (Opcode(node->GetOpcode()).IsCall()) {
231     cg->NewInstr(POP_TOP);
232   }
233 }
234 
MakeAttrModify(CodeGenerator * cg,const std::string & name,ValueNode * src_node,ValueNode * value)235 static void MakeAttrModify(CodeGenerator *cg, const std::string &name, ValueNode *src_node, ValueNode *value) {
236   auto instr = std::make_unique<Instr>(STORE_ATTR, 0, name);
237   if (value != nullptr) {
238     cg->LoadValue(value);
239     cg->LoadValue(src_node);
240   } else {
241     cg->LoadValue(src_node);
242     instr->set_op(DELETE_ATTR);
243   }
244   cg->AddInstr(std::move(instr));
245 }
246 
MakeModuleAttrModify(CodeGenerator * cg,const std::string & name,const py::object & mod,ValueNode * value)247 static void MakeModuleAttrModify(CodeGenerator *cg, const std::string &name, const py::object &mod, ValueNode *value) {
248   auto instr = std::make_unique<Instr>(STORE_ATTR, 0, name);
249   if (value != nullptr) {
250     cg->LoadValue(value);
251     cg->LoadConst(mod);
252   } else {
253     cg->LoadConst(mod);
254     instr->set_op(DELETE_ATTR);
255   }
256   cg->AddInstr(std::move(instr));
257 }
258 
RestoreAttrs(CodeGenerator * cg) const259 void SideEffect::RestoreAttrs(CodeGenerator *cg) const {
260   if (data()->attr_cache().modified_attrs_.empty()) {
261     return;
262   }
263   for (const auto &map : data()->attr_cache().modified_attrs_) {
264     const auto &src_node = GetSource(map.first);
265     for (const auto &pair : map.second) {
266       MakeAttrModify(cg, pair.first, src_node, GetSource(pair.second));
267     }
268   }
269 }
270 
RestoreGlobal(CodeGenerator * cg) const271 void SideEffect::RestoreGlobal(CodeGenerator *cg) const {
272   if (data()->global_cache().modified_globals_.empty()) {
273     return;
274   }
275   PyObject *tmp = PyDict_GetItemString(cg->GetGlobals().ptr(), "__name__");
276   const char *cur_module_name = tmp == nullptr ? "" : PyUnicode_AsUTF8(tmp);
277 
278   for (const auto &map : data()->global_cache().modified_globals_) {
279     const auto &module_name = map.first;
280     if (module_name != cur_module_name) {
281       py::object module_object = py::reinterpret_steal<py::object>(PyImport_ImportModule(module_name.c_str()));
282       for (const auto &pair : map.second) {
283         MakeModuleAttrModify(cg, pair.first, module_object, GetSource(pair.second));
284       }
285       continue;
286     }
287     for (const auto &pair : map.second) {
288       auto instr = std::make_unique<Instr>(STORE_GLOBAL, 0, pair.first);
289       if (pair.second != nullptr) {
290         cg->LoadValue(GetSource(pair.second));
291       } else {
292         instr->set_op(DELETE_GLOBAL);
293       }
294       cg->AddInstr(std::move(instr));
295     }
296   }
297 }
298 
GetSource(ValueNode * src_node) const299 ValueNode *SideEffect::GetSource(ValueNode *src_node) const {
300   const auto &map = data()->modified_and_replaced_map();
301   if (map.empty() || src_node == nullptr) {
302     return src_node;
303   }
304   auto iter = map.find(src_node);
305   return iter != map.end() ? iter->second : src_node;
306 }
307 
RestoreSpecializeEntry(CodeGenerator * cg,ValueNode * node,Type type) const308 void SideEffect::RestoreSpecializeEntry(CodeGenerator *cg, ValueNode *node, Type type) const {
309   MS_EXCEPTION_IF_CHECK_FAIL(type >= kListSetItem && type <= kDictPop, "not implemented function");
310   constexpr const char *name_map[] = {"__setitem__", "__setitem__", "append", "pop"};
311   static_assert(kDictPop - kListSetItem + 1 == sizeof(name_map) / sizeof(name_map[0]));
312   std::string method_name = name_map[type - kListSetItem];
313 
314   bool is_method_descriptor = false;
315   auto self = GetSelfFromListAppendCall(node, &is_method_descriptor);
316   cg->LoadValue(GetSource(self));
317   cg->AddInstr(std::make_unique<Instr>(LOAD_ATTR, 0, method_name));
318   for (size_t i = 1 + is_method_descriptor; i < node->getInputs().size(); ++i) {
319     cg->LoadValue(GetSource(node->input(i)));
320   }
321   cg->NewInstr(node->GetOpcode(), node->getInputs().size() - 1);
322   cg->NewInstr(POP_TOP);
323 }
324 
Optimize(const std::vector<ValueNode * > & alive_locals)325 void SideEffect::Optimize(const std::vector<ValueNode *> &alive_locals) {
326   /**
327    * check data_.unique(), validate record is all in final nodes set......
328    */
329   // liveness analysis, remove dead local side-effect
330   // not implement
331   // merge dict, list modify operations
332   // not implement
333 }
334 
335 }  // namespace pijit
336 }  // namespace mindspore
337