1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CORE_IR_MANAGER_H_
20 #define MINDSPORE_CORE_IR_MANAGER_H_
21
22 #include <unordered_set>
23 #include <unordered_map>
24 #include <set>
25 #include <map>
26 #include <list>
27 #include <string>
28 #include <vector>
29 #include <utility>
30 #include <memory>
31 #include <functional>
32
33 #include "utils/any.h"
34 #include "utils/misc.h"
35 #include "utils/signal.h"
36 #include "utils/ordered_set.h"
37 #include "utils/ordered_map.h"
38 #include "ir/anf.h"
39 #include "ir/graph_utils.h"
40 #include "utils/hashing.h"
41 #include "base/base_ref.h"
42 #include "api/ir/func_graph_manager.h"
43
44 namespace mindspore {
45 namespace change {
46 struct ChangeCounter;
47 struct Change {
48 virtual ~Change() = default;
49 virtual void Apply(ChangeCounter *counter) = 0;
50 };
51 using ChangePtr = std::unique_ptr<Change>;
52 } // namespace change
53
54 class FuncGraphTransaction;
55 class FuncGraphManager;
56 using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
57
58 using AnfNodeIndexSet = api::AnfNodeIndexSet;
59 // NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
60 using NodeUsersMap = api::NodeUsersMap;
61 using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
62 using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
63
64 // manage the func graphs.
65 // if no manager exist, just create one and associate it to all func graphs; else reuse simply.
66 // func_graph, be managed graph
67 // manage: if true, created manager will be set in func_graph
68 // FuncGraphManagerPtr: return created manager
69 FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true);
70
71 FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage = true);
72
73 FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true);
74
75 struct Signals {
76 Signal<void()> InvalidateComputer;
77 };
78
79 using CNodeIndexPair = std::pair<AnfNodePtr, int>;
80 using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
81 using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
82
83 // analysis base class, graphs analysis which need dynamic compute by DepCollector in each read
84 class DepComputer {
85 public:
86 explicit DepComputer(const FuncGraphManager *manager);
~DepComputer()87 virtual ~DepComputer() { manager_ = nullptr; }
88
size()89 virtual size_t size() const { return 0; }
90
Reset()91 void Reset() {
92 ExtraReset();
93 validate_ = false;
94 func_graphs_validate_.clear();
95 }
96
OnInvalidateComputer()97 void OnInvalidateComputer() { Reset(); }
98
99 void Recompute();
100
101 void Recompute(const FuncGraphPtr &fg);
102
IsValidate()103 bool IsValidate() const { return validate_; }
104
IsValidate(const FuncGraphPtr & fg)105 bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }
106
107 protected:
108 // subclass can reset their own member;
ExtraReset()109 virtual void ExtraReset() {}
110 // subclass do the real compute
RealRecompute()111 virtual void RealRecompute() {}
RealRecompute(FuncGraphPtr)112 virtual void RealRecompute(FuncGraphPtr) {}
113
114 const FuncGraphManager *manager_;
115 bool validate_;
116 OrderedMap<FuncGraphPtr, bool> func_graphs_validate_;
117
118 private:
119 friend FuncGraphManager;
120 };
121
122 // graph g's all direct or proxy parents
123 class FuncGraphParentsTotalComputer final : public DepComputer {
124 public:
FuncGraphParentsTotalComputer(const FuncGraphManager * m)125 explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
126 ~FuncGraphParentsTotalComputer() override = default;
127
func_graph_parents_total_analysis()128 FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
129
size()130 size_t size() const override { return func_graph_parents_total_analysis_.size(); }
131
132 FuncGraphToFuncGraphSetMap func_graph_parents_total_analysis_;
133
134 protected:
ExtraReset()135 void ExtraReset() override { func_graph_parents_total_analysis_.clear(); }
136
137 void RealRecompute(FuncGraphPtr fg) override;
138
139 private:
140 FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, std::unordered_map<FuncGraphPtr, FuncGraphSetPtr> *seen_fgs);
141 };
142
143 using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
144
145 // graph's nearest parent in parents total
146 class ParentComputer final : public DepComputer {
147 public:
ParentComputer(const FuncGraphManager * m)148 explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {}
149 ~ParentComputer() override = default;
150
parent_analysis()151 FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; }
152
size()153 size_t size() const override { return parent_analysis_.size(); }
154
155 FuncGraphToFuncGraphMap parent_analysis_;
156
157 protected:
ExtraReset()158 void ExtraReset() override { parent_analysis_.clear(); }
159
160 void RealRecompute(FuncGraphPtr fg) override;
161 };
162
163 // graph's children graph except self
164 class ChildrenComputer final : public DepComputer {
165 public:
ChildrenComputer(const FuncGraphManager * m)166 explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {}
167 ~ChildrenComputer() override = default;
168
children_analysis()169 FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; }
170
size()171 size_t size() const override { return children_analysis_.size(); }
172
173 FuncGraphToFuncGraphSetMap children_analysis_;
174
175 protected:
ExtraReset()176 void ExtraReset() override { children_analysis_.clear(); }
177
178 void RealRecompute(FuncGraphPtr fg) override;
179 };
180
181 // graph's children graph include self
182 class ScopeComputer final : public DepComputer {
183 public:
ScopeComputer(const FuncGraphManager * m)184 explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {}
185 ~ScopeComputer() override = default;
186
scope_analysis()187 FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; }
188
size()189 size_t size() const override { return scope_analysis_.size(); }
190
191 FuncGraphToFuncGraphSetMap scope_analysis_;
192
193 protected:
ExtraReset()194 void ExtraReset() override { scope_analysis_.clear(); }
195
196 void RealRecompute(FuncGraphPtr fg) override;
197 };
198
199 using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;
200
201 class FVTotalComputer final : public DepComputer {
202 public:
FVTotalComputer(const FuncGraphManager * m)203 explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
204 ~FVTotalComputer() override = default;
205
fv_total_analysis()206 FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }
207
size()208 size_t size() const override { return fv_total_analysis_.size(); }
209
210 FVTotalMap fv_total_analysis_;
211
212 protected:
ExtraReset()213 void ExtraReset() override { fv_total_analysis_.clear(); }
214
215 void RealRecompute() override;
216 };
217
218 class FuncGraphsUsedTotalComputer final : public DepComputer {
219 public:
FuncGraphsUsedTotalComputer(const FuncGraphManager * m)220 explicit FuncGraphsUsedTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
221 ~FuncGraphsUsedTotalComputer() override = default;
222
func_graph_used_total_analysis()223 FuncGraphToFuncGraphSetMap &func_graph_used_total_analysis() { return func_graph_used_total_analysis_; }
224
size()225 size_t size() const override { return func_graph_used_total_analysis_.size(); }
226
227 FuncGraphToFuncGraphSetMap func_graph_used_total_analysis_;
228
229 protected:
ExtraReset()230 void ExtraReset() override { func_graph_used_total_analysis_.clear(); }
231
232 void RealRecompute(FuncGraphPtr fg) override;
233 };
234
235 using FuncGraphToBoolMap = OrderedMap<FuncGraphPtr, bool>;
236 using RecursiveMap = OrderedMap<FuncGraphPtr, std::shared_ptr<std::list<FuncGraphPtr>>>;
237
238 class RecursiveComputer final : public DepComputer {
239 public:
RecursiveComputer(const FuncGraphManager * m)240 explicit RecursiveComputer(const FuncGraphManager *m) : DepComputer(m) {}
241 ~RecursiveComputer() override = default;
242
recursive_map()243 RecursiveMap &recursive_map() { return recursive_map_; }
recursive_analysis()244 FuncGraphToBoolMap &recursive_analysis() { return recursive_analysis_; }
245
246 void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace);
247
size()248 size_t size() const override { return recursive_analysis_.size(); }
249
250 RecursiveMap recursive_map_;
251 FuncGraphToBoolMap recursive_analysis_;
252
253 protected:
ExtraReset()254 void ExtraReset() override {
255 recursive_analysis_.clear();
256 recursive_map_.clear();
257 }
258
259 void RealRecompute(FuncGraphPtr fg) override;
260 };
261
262 class FuncGraphJTotalComputer final : public DepComputer {
263 public:
FuncGraphJTotalComputer(const FuncGraphManager * m)264 explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
265 ~FuncGraphJTotalComputer() override = default;
266
j_total_analysis()267 FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; }
268
size()269 size_t size() const override { return j_total_analysis_.size(); }
270
271 FuncGraphToBoolMap j_total_analysis_;
272
273 protected:
ExtraReset()274 void ExtraReset() override { j_total_analysis_.clear(); }
275
276 void RealRecompute(FuncGraphPtr fg) override;
277 bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
278 };
279
280 class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>, public api::FuncGraphManager {
281 public:
282 explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
~FuncGraphManager()283 ~FuncGraphManager() {
284 if (is_manage_) {
285 RemoveRoots();
286 }
287 Clear();
288 }
289
290 void Reset();
291 void Init();
292 void Clear();
293 void AddFuncGraph(const FuncGraphPtr &func_graph, bool is_root = false);
294 void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
295 void RemoveRoots();
296 void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters);
297 void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter);
298 void InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter);
299 void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
300 bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) final;
301 void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) final;
302 void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) final;
303 void MoveAllCNodeDropGraph(const FuncGraphPtr &source, const FuncGraphPtr &target, const ScopePtr &scope);
304
305 FuncGraphTransaction Transact();
306 void CommitChanges(std::vector<change::ChangePtr> &&changes);
307
IsManaged()308 bool IsManaged() const { return is_manage_; }
309
roots()310 const FuncGraphSet &roots() const { return roots_; }
311
func_graphs()312 const FuncGraphSet &func_graphs() const { return func_graphs_; }
313
all_nodes()314 AnfNodeSet &all_nodes() { return all_nodes_; }
315
node_users()316 NodeUsersMap &node_users() { return node_users_; }
317
node_users()318 const NodeUsersMap &node_users() const final { return node_users_; }
319
320 FVTotalMap &free_variables_total() const;
321
322 FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
323
324 FuncGraphSet &scopes(const FuncGraphPtr &fg) const;
325
326 FuncGraphPtr parent(const FuncGraphPtr &fg) const;
327
328 FuncGraphSet &children(const FuncGraphPtr &fg) const;
329
330 FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const;
331
332 bool recursive(const FuncGraphPtr &fg) const;
333 std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const;
334
335 bool func_graph_j_total(const FuncGraphPtr &fg) const;
336
signals()337 std::shared_ptr<Signals> signals() const { return signals_; }
338
339 // Static Analysis
340 NodeUsersMap node_users_;
341 AnfNodeSet all_nodes_; // managed nodes
342
343 // Dynamic Analysis
344 std::shared_ptr<ParentComputer> func_graph_parent_;
345
346 private:
347 // Erase OneGraph From Manager
348 void EraseOneGraph(const FuncGraphPtr &fg);
349 void AddIntoManaged(const FuncGraphPtr &fg);
350 void ProcessEdgeAdd(const AnfNodePtr &node, int index, const AnfNodePtr &input);
351 void ProcessEdgeRemove(const AnfNodePtr &node, int index, const AnfNodePtr &input);
352 void ProcessInputsEdgeAdd(const CNodePtr &cnode);
353 void ProcessInputsEdgeRemove(const CNodePtr &cnode);
354 void AcquireNodes(std::vector<AnfNodePtr> &&nodes);
355 FuncGraphSet MaybeDropNodes(std::vector<AnfNodePtr> &&nodes);
356 void OnEdgeAdded(const AnfNodePtr &node, int index, const AnfNodePtr &input);
357 void OnEdgeRemoved(const AnfNodePtr &node, int index, const AnfNodePtr &input);
358 void MoveAllNodes(const FuncGraphPtr &source, const FuncGraphPtr &target);
359
360 FuncGraphSet roots_; // Managed roots.
361 FuncGraphSet func_graphs_; // Managed func graphs.
362
363 std::shared_ptr<Signals> signals_;
364
365 // Dynamic Analysis
366 std::shared_ptr<FuncGraphParentsTotalComputer> func_graph_parents_total_;
367 std::shared_ptr<ChildrenComputer> children_;
368 std::shared_ptr<ScopeComputer> scopes_;
369 std::shared_ptr<FVTotalComputer> free_variables_total_;
370 std::shared_ptr<FuncGraphsUsedTotalComputer> func_graphs_used_total_;
371 std::shared_ptr<RecursiveComputer> recursive_;
372 std::shared_ptr<FuncGraphJTotalComputer> j_total_;
373
374 bool is_manage_;
375 };
376
377 class FuncGraphTransaction {
378 public:
FuncGraphTransaction(FuncGraphManager * manager)379 explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager) {}
FuncGraphTransaction()380 FuncGraphTransaction() : manager_(nullptr) {}
381 ~FuncGraphTransaction() = default;
382
383 FuncGraphTransaction(const FuncGraphTransaction &other) = delete;
384 FuncGraphTransaction &operator=(const FuncGraphTransaction &other) = delete;
385
386 FuncGraphTransaction(FuncGraphTransaction &&other) = default;
387 FuncGraphTransaction &operator=(FuncGraphTransaction &&other) = default;
388
389 // set parameters of a func graph
390 void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms);
391 void AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m);
392 void InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr ¶m);
393
394 // replace old_node with new_node
395 bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
396
397 // set edge, i.e., declare setting node.inputs[key] to value.
398 void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v);
399 // Add edge, i.e., append value to node.inputs.
400 void AddEdge(const AnfNodePtr &src_node, const AnfNodePtr &v);
401
402 // commit all changes
403 void Commit();
404
405 private:
406 FuncGraphManager *manager_;
407 std::vector<change::ChangePtr> changes_;
408 };
409
Transact()410 inline FuncGraphTransaction FuncGraphManager::Transact() { return FuncGraphTransaction(this); }
411
412 } // namespace mindspore
413
414 #endif // MINDSPORE_CORE_IR_MANAGER_H_
415