• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PIPELINE_GRAPH_JIT_GRAPH_CAPTURE_SIDE_EFFECT_H_
18 #define MINDSPORE_CCSRC_PIPELINE_GRAPH_JIT_GRAPH_CAPTURE_SIDE_EFFECT_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <map>
23 #include <set>
24 #include <string>
25 #include "pipeline/jit/pi/graph_capture/node.h"
26 
27 namespace mindspore {
28 namespace pijit {
29 
30 class CodeGenerator;
31 
32 // an unique data in the whole compilation
33 class SideEffectData {
34  public:
35   struct AttrCache {
36     // a map of the modified object and it's modified attrs
37     using AttrMap = std::map<std::string, ValueNode *>;
38     std::map<ValueNode *, AttrMap> modified_attrs_;
39   };
40 
41   struct GlobalCache {
42     // a map of module and modified global dict
43     using NameMap = std::map<std::string, ValueNode *>;
44     std::map<std::string, NameMap> modified_globals_;
45   };
46 
attr_cache()47   const auto &attr_cache() const { return attr_cache_; }
global_cache()48   const auto &global_cache() const { return global_cache_; }
id_map()49   const auto &id_map() const { return id_map_; }
modified_and_replaced_map()50   const auto &modified_and_replaced_map() const { return modified_and_replaced_map_; }
51 
52   // track object and nodes
Track(PyObject * ptr,ValueNode * node)53   void Track(PyObject *ptr, ValueNode *node) { (ptr ? (void)id_map_[ptr].insert(node) : (void)0); }
UnTrack(PyObject * ptr,ValueNode * node)54   void UnTrack(PyObject *ptr, ValueNode *node) { (ptr ? (void)id_map_[ptr].erase(node) : (void)0); }
55 
56   // record replaced node
57   void RecordModifiedAndReplacedNode(ValueNode *src_node, ValueNode *new_node);
58 
59   // merge attr modify operations
60   void AddAttrData(const std::string &name, ValueNode *src, ValueNode *new_attr);
61 
62   // merge global modify operations
63   void AddGlobalData(const std::string &module_name, const std::string &name, ValueNode *value);
64 
65   void ClearCache();
66 
67  private:
68   // an unique map that record python object and nodes in the whole compilation
69   // used to resolve object consistency
70   std::map<PyObject *, std::set<ValueNode *>> id_map_;
71 
72   // an unique map of new value(key) and old_value(value)
73   std::map<ValueNode *, ValueNode *> modified_and_replaced_map_;
74 
75   // optimization cache, record modified object
76   // if record is reset, clean cache
77   AttrCache attr_cache_;
78   GlobalCache global_cache_;
79 };
80 
81 class SideEffect {
82  public:
83   enum Type {
84     kDefault,
85     kSetAttr,
86     kSetGlobal,
87     kListSetItem,
88     kDictSetItem,
89     kListAppend,
90     kDictPop,
91   };
92 
93   struct CacheResult {
94     ValueNode *cache_value_;
95     bool is_deleted_value_;
96   };
97 
98   // find attribute from id_map and attr cache
99   CacheResult LoadAttr(ValueNode *src, const std::string &name) const;
100 
101   // find global from global cache
102   CacheResult LoadGlobal(const std::string &module_name, const std::string &name) const;
103 
104  public:
105   SideEffect() = default;
106 
data()107   const auto &data() const { return data_; }
set_data(const std::shared_ptr<SideEffectData> & data)108   void set_data(const std::shared_ptr<SideEffectData> &data) { data_ = data; }
109 
110   // check the node is a side-effect record
IsRecord(ValueNode * node)111   bool IsRecord(ValueNode *node) const { return nodes_.empty() ? false : nodes_.find(node) != nodes_.end(); }
112 
113   // check record is empty
IsEmpty()114   bool IsEmpty() const { return nodes_.empty(); }
115 
116   // return false if unsupported the side-effect
117   bool Record(ValueNode *side_effect_node, Type type = Type::kDefault);
118 
119   // generate the code to restore side-effect
120   void Restore(CodeGenerator *cg) const;
121 
122   // reset the record if record not find in final nodes set
123   void ResetRecord(const std::set<ValueNode *> &traced_nodes);
124 
125   // return the original node(source) if it's replaced, else return the node
126   ValueNode *GetSource(ValueNode *node) const;
127 
128   // optimize the side-effect data, remove modify operations of dead local variable
129   void Optimize(const std::vector<ValueNode *> &alive_locals);
130 
131   // return the side-effect handler required nodes
132   const std::set<ValueNode *> &GetRequiredNodes() const;
133 
134  private:
135   // add nodes to required
AddKeepAlive(const std::vector<ValueNode * > & inputs)136   void AddKeepAlive(const std::vector<ValueNode *> &inputs) { keep_alive_.insert(inputs.begin(), inputs.end()); }
137 
138   // get required node of the side-effect node
139   std::vector<ValueNode *> GetKeepAlive(ValueNode *node, Type type) const;
140 
141   // if side-effect is function call, check it's supported
142   bool RecordFuncCall(ValueNode *node, Type type);
143 
144   // restore a side-effect node
145   void RestoreEntry(CodeGenerator *cg, ValueNode *node, Type type) const;
146 
147   // restore attribute
148   void RestoreAttrs(CodeGenerator *cg) const;
149 
150   // restore global
151   void RestoreGlobal(CodeGenerator *cg) const;
152 
153   // restore list, dict, or other specialized object function call
154   void RestoreSpecializeEntry(CodeGenerator *cg, ValueNode *node, Type type) const;
155 
156   struct Entry {
157     Type type_;
158     size_t order_;
159   };
160 
161   // shared from other side-effect recorder
162   std::shared_ptr<SideEffectData> data_;
163 
164   // record operations, check side-effect order
165   std::map<ValueNode *, Entry> nodes_;
166 
167   // side-effect handler required nodes
168   std::set<ValueNode *> keep_alive_;
169 };
170 
171 // return the self node, if return nullptr, unsupported to handle side-effect
172 ValueNode *GetSelfFromListAppendCall(ValueNode *call_node, bool *is_method_descriptor = nullptr);
173 
174 }  // namespace pijit
175 }  // namespace mindspore
176 
177 #endif  // MINDSPORE_CCSRC_PIPELINE_GRAPH_JIT_GRAPH_CAPTURE_SIDE_EFFECT_H_
178