1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PIPELINE_GRAPH_JIT_GRAPH_CAPTURE_SIDE_EFFECT_H_ 18 #define MINDSPORE_CCSRC_PIPELINE_GRAPH_JIT_GRAPH_CAPTURE_SIDE_EFFECT_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <map> 23 #include <set> 24 #include <string> 25 #include "pipeline/jit/pi/graph_capture/node.h" 26 27 namespace mindspore { 28 namespace pijit { 29 30 class CodeGenerator; 31 32 // an unique data in the whole compilation 33 class SideEffectData { 34 public: 35 struct AttrCache { 36 // a map of the modified object and it's modified attrs 37 using AttrMap = std::map<std::string, ValueNode *>; 38 std::map<ValueNode *, AttrMap> modified_attrs_; 39 }; 40 41 struct GlobalCache { 42 // a map of module and modified global dict 43 using NameMap = std::map<std::string, ValueNode *>; 44 std::map<std::string, NameMap> modified_globals_; 45 }; 46 attr_cache()47 const auto &attr_cache() const { return attr_cache_; } global_cache()48 const auto &global_cache() const { return global_cache_; } id_map()49 const auto &id_map() const { return id_map_; } modified_and_replaced_map()50 const auto &modified_and_replaced_map() const { return modified_and_replaced_map_; } 51 52 // track object and nodes Track(PyObject * ptr,ValueNode * node)53 void Track(PyObject *ptr, ValueNode *node) { (ptr ? (void)id_map_[ptr].insert(node) : (void)0); } UnTrack(PyObject * ptr,ValueNode * node)54 void UnTrack(PyObject *ptr, ValueNode *node) { (ptr ? (void)id_map_[ptr].erase(node) : (void)0); } 55 56 // record replaced node 57 void RecordModifiedAndReplacedNode(ValueNode *src_node, ValueNode *new_node); 58 59 // merge attr modify operations 60 void AddAttrData(const std::string &name, ValueNode *src, ValueNode *new_attr); 61 62 // merge global modify operations 63 void AddGlobalData(const std::string &module_name, const std::string &name, ValueNode *value); 64 65 void ClearCache(); 66 67 private: 68 // an unique map that record python object and nodes in the whole compilation 69 // used to resolve object consistency 70 std::map<PyObject *, std::set<ValueNode *>> id_map_; 71 72 // an unique map of new value(key) and old_value(value) 73 std::map<ValueNode *, ValueNode *> modified_and_replaced_map_; 74 75 // optimization cache, record modified object 76 // if record is reset, clean cache 77 AttrCache attr_cache_; 78 GlobalCache global_cache_; 79 }; 80 81 class SideEffect { 82 public: 83 enum Type { 84 kDefault, 85 kSetAttr, 86 kSetGlobal, 87 kListSetItem, 88 kDictSetItem, 89 kListAppend, 90 kDictPop, 91 }; 92 93 struct CacheResult { 94 ValueNode *cache_value_; 95 bool is_deleted_value_; 96 }; 97 98 // find attribute from id_map and attr cache 99 CacheResult LoadAttr(ValueNode *src, const std::string &name) const; 100 101 // find global from global cache 102 CacheResult LoadGlobal(const std::string &module_name, const std::string &name) const; 103 104 public: 105 SideEffect() = default; 106 data()107 const auto &data() const { return data_; } set_data(const std::shared_ptr<SideEffectData> & data)108 void set_data(const std::shared_ptr<SideEffectData> &data) { data_ = data; } 109 110 // check the node is a side-effect record IsRecord(ValueNode * node)111 bool IsRecord(ValueNode *node) const { return nodes_.empty() ? false : nodes_.find(node) != nodes_.end(); } 112 113 // check record is empty IsEmpty()114 bool IsEmpty() const { return nodes_.empty(); } 115 116 // return false if unsupported the side-effect 117 bool Record(ValueNode *side_effect_node, Type type = Type::kDefault); 118 119 // generate the code to restore side-effect 120 void Restore(CodeGenerator *cg) const; 121 122 // reset the record if record not find in final nodes set 123 void ResetRecord(const std::set<ValueNode *> &traced_nodes); 124 125 // return the original node(source) if it's replaced, else return the node 126 ValueNode *GetSource(ValueNode *node) const; 127 128 // optimize the side-effect data, remove modify operations of dead local variable 129 void Optimize(const std::vector<ValueNode *> &alive_locals); 130 131 // return the side-effect handler required nodes 132 const std::set<ValueNode *> &GetRequiredNodes() const; 133 134 private: 135 // add nodes to required AddKeepAlive(const std::vector<ValueNode * > & inputs)136 void AddKeepAlive(const std::vector<ValueNode *> &inputs) { keep_alive_.insert(inputs.begin(), inputs.end()); } 137 138 // get required node of the side-effect node 139 std::vector<ValueNode *> GetKeepAlive(ValueNode *node, Type type) const; 140 141 // if side-effect is function call, check it's supported 142 bool RecordFuncCall(ValueNode *node, Type type); 143 144 // restore a side-effect node 145 void RestoreEntry(CodeGenerator *cg, ValueNode *node, Type type) const; 146 147 // restore attribute 148 void RestoreAttrs(CodeGenerator *cg) const; 149 150 // restore global 151 void RestoreGlobal(CodeGenerator *cg) const; 152 153 // restore list, dict, or other specialized object function call 154 void RestoreSpecializeEntry(CodeGenerator *cg, ValueNode *node, Type type) const; 155 156 struct Entry { 157 Type type_; 158 size_t order_; 159 }; 160 161 // shared from other side-effect recorder 162 std::shared_ptr<SideEffectData> data_; 163 164 // record operations, check side-effect order 165 std::map<ValueNode *, Entry> nodes_; 166 167 // side-effect handler required nodes 168 std::set<ValueNode *> keep_alive_; 169 }; 170 171 // return the self node, if return nullptr, unsupported to handle side-effect 172 ValueNode *GetSelfFromListAppendCall(ValueNode *call_node, bool *is_method_descriptor = nullptr); 173 174 } // namespace pijit 175 } // namespace mindspore 176 177 #endif // MINDSPORE_CCSRC_PIPELINE_GRAPH_JIT_GRAPH_CAPTURE_SIDE_EFFECT_H_ 178