1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CORE_IR_MANAGER_H_
20 #define MINDSPORE_CORE_IR_MANAGER_H_
21
22 #include <deque>
23 #include <set>
24 #include <map>
25 #include <list>
26 #include <string>
27 #include <vector>
28 #include <utility>
29 #include <memory>
30 #include <functional>
31 #include "utils/any.h"
32 #include "utils/misc.h"
33 #include "utils/signal.h"
34 #include "utils/hash_map.h"
35 #include "utils/hash_set.h"
36 #include "utils/compact_set.h"
37 #include "utils/ordered_set.h"
38 #include "utils/ordered_map.h"
39 #include "ir/anf.h"
40 #include "ir/graph_utils.h"
41 #include "utils/hashing.h"
42 #include "base/base_ref.h"
43
44 namespace mindspore {
45 namespace change {
46 struct ChangeCounter;
47 struct Change {
48 virtual ~Change() = default;
49 virtual void Apply(ChangeCounter *counter) = 0;
50 };
51 using ChangePtr = std::unique_ptr<Change>;
52 } // namespace change
53
54 class FuncGraphTransaction;
55 class FuncGraphManager;
56 using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
57
58 using AnfNodeIndexSet = CompactSet<std::pair<AnfNodePtr, int>>;
59 using NodeUsersMap = mindspore::HashMap<AnfNodePtr, AnfNodeIndexSet>;
60
61 using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
62 using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
63
64 // manage the func graphs.
65 // if no manager exist, just create one and associate it to all func graphs; else reuse simply.
66 // func_graph, be managed graph
67 // manage: if true, created manager will be set in func_graph
68 // FuncGraphManagerPtr: return created manager
69 MS_CORE_API FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true, bool drop_unused_graph = false);
70
71 MS_CORE_API FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage = true,
72 bool drop_unused_graph = false);
73
74 MS_CORE_API FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true,
75 bool drop_unused_graph = false);
76
77 struct Signals {
78 Signal<void()> InvalidateComputer;
79 };
80
81 using CNodeIndexPair = std::pair<AnfNodePtr, int>;
82 using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
83 using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
84
85 // For Fast Pass
86 class FuncGraphPassIndex {
87 public:
88 constexpr static char key[] = "FuncGraphPassIndex";
FuncGraphPassIndex()89 FuncGraphPassIndex() : has_gen_index_(false) {}
set_has_gen_index(bool is_gen_index)90 void set_has_gen_index(bool is_gen_index) { has_gen_index_ = is_gen_index; }
has_gen_index()91 bool has_gen_index() const { return has_gen_index_; }
92 mindspore::HashMap<AnfNodePtr, FuncGraphWeakPtr> node_to_fg_;
93 mindspore::HashMap<std::string, std::set<AnfNodePtr>> name_to_cnode_;
94 mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> subgraph_out_caller_map_;
95 mindspore::HashMap<AnfNodePtr, size_t> node_degree_;
96
97 private:
98 bool has_gen_index_;
99 };
100 using FuncGraphIndexPtr = std::shared_ptr<FuncGraphPassIndex>;
101
102 // analysis base class, graphs analysis which need dynamic compute by DepCollector in each read
103 class DepComputer {
104 public:
105 explicit DepComputer(const FuncGraphManager *manager);
~DepComputer()106 virtual ~DepComputer() { manager_ = nullptr; }
107
size()108 virtual size_t size() const { return 0; }
109
Reset()110 void Reset() {
111 ExtraReset();
112 validate_ = false;
113 func_graphs_validate_.clear();
114 }
115
OnInvalidateComputer()116 void OnInvalidateComputer() { Reset(); }
117
118 void Recompute();
119
120 void Recompute(const FuncGraphPtr &fg);
121
IsValidate()122 bool IsValidate() const { return validate_; }
123
IsValidate(const FuncGraphPtr & fg)124 bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }
125
126 protected:
127 // subclass can reset their own member;
ExtraReset()128 virtual void ExtraReset() {}
129 // subclass do the real compute
RealRecompute()130 virtual void RealRecompute() {}
RealRecompute(FuncGraphPtr)131 virtual void RealRecompute(FuncGraphPtr) {}
132
133 const FuncGraphManager *manager_;
134 bool validate_;
135 OrderedMap<FuncGraphPtr, bool> func_graphs_validate_;
136
137 private:
138 friend FuncGraphManager;
139 };
140
141 // graph g's all direct or proxy parents
142 class FuncGraphParentsTotalComputer final : public DepComputer {
143 public:
FuncGraphParentsTotalComputer(const FuncGraphManager * m)144 explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
145 ~FuncGraphParentsTotalComputer() override = default;
146
func_graph_parents_total_analysis()147 FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
148
size()149 size_t size() const override { return func_graph_parents_total_analysis_.size(); }
150
151 FuncGraphToFuncGraphSetMap func_graph_parents_total_analysis_;
152
153 protected:
ExtraReset()154 void ExtraReset() override { func_graph_parents_total_analysis_.clear(); }
155
156 void RealRecompute(FuncGraphPtr fg) override;
157
158 private:
159 FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg);
160 };
161
162 using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
163
164 // graph's nearest parent in parents total
165 class ParentComputer final : public DepComputer {
166 public:
ParentComputer(const FuncGraphManager * m)167 explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {}
168 ~ParentComputer() override = default;
169
parent_analysis()170 FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; }
171
size()172 size_t size() const override { return parent_analysis_.size(); }
173
174 FuncGraphToFuncGraphMap parent_analysis_;
175
176 protected:
ExtraReset()177 void ExtraReset() override { parent_analysis_.clear(); }
178
179 void RealRecompute(FuncGraphPtr fg) override;
180 };
181
182 // graph's children graph except self
183 class ChildrenComputer final : public DepComputer {
184 public:
ChildrenComputer(const FuncGraphManager * m)185 explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {}
186 ~ChildrenComputer() override = default;
187
children_analysis()188 FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; }
189
size()190 size_t size() const override { return children_analysis_.size(); }
191
192 FuncGraphToFuncGraphSetMap children_analysis_;
193
194 protected:
ExtraReset()195 void ExtraReset() override { children_analysis_.clear(); }
196
197 void RealRecompute(FuncGraphPtr fg) override;
198 };
199
200 // graph's children graph include self
201 class ScopeComputer final : public DepComputer {
202 public:
ScopeComputer(const FuncGraphManager * m)203 explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {}
204 ~ScopeComputer() override = default;
205
scope_analysis()206 FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; }
207
size()208 size_t size() const override { return scope_analysis_.size(); }
209
210 FuncGraphToFuncGraphSetMap scope_analysis_;
211
212 protected:
ExtraReset()213 void ExtraReset() override { scope_analysis_.clear(); }
214
215 void RealRecompute(FuncGraphPtr fg) override;
216 };
217
218 using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;
219
220 class FVTotalComputer final : public DepComputer {
221 public:
FVTotalComputer(const FuncGraphManager * m)222 explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
223 ~FVTotalComputer() override = default;
224
fv_total_analysis()225 FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }
226
size()227 size_t size() const override { return fv_total_analysis_.size(); }
228
229 FVTotalMap fv_total_analysis_;
230
231 protected:
ExtraReset()232 void ExtraReset() override { fv_total_analysis_.clear(); }
233
234 void RealRecompute() override;
235 };
236
237 class FuncGraphsUsedTotalComputer final : public DepComputer {
238 public:
FuncGraphsUsedTotalComputer(const FuncGraphManager * m)239 explicit FuncGraphsUsedTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
240 ~FuncGraphsUsedTotalComputer() override = default;
241
func_graph_used_total_analysis()242 FuncGraphToFuncGraphSetMap &func_graph_used_total_analysis() { return func_graph_used_total_analysis_; }
243
size()244 size_t size() const override { return func_graph_used_total_analysis_.size(); }
245
246 FuncGraphToFuncGraphSetMap func_graph_used_total_analysis_;
247
248 protected:
ExtraReset()249 void ExtraReset() override { func_graph_used_total_analysis_.clear(); }
250
251 void RealRecompute(FuncGraphPtr fg) override;
252 };
253
254 using FuncGraphToBoolMap = OrderedMap<FuncGraphPtr, bool>;
255 using RecursiveMap = OrderedMap<FuncGraphPtr, std::shared_ptr<std::list<FuncGraphPtr>>>;
256
257 class RecursiveComputer final : public DepComputer {
258 public:
RecursiveComputer(const FuncGraphManager * m)259 explicit RecursiveComputer(const FuncGraphManager *m) : DepComputer(m) {}
260 ~RecursiveComputer() override = default;
261
recursive_map()262 RecursiveMap &recursive_map() { return recursive_map_; }
recursive_analysis()263 FuncGraphToBoolMap &recursive_analysis() { return recursive_analysis_; }
264
265 void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace);
266
size()267 size_t size() const override { return recursive_analysis_.size(); }
268
269 RecursiveMap recursive_map_;
270 FuncGraphToBoolMap recursive_analysis_;
271
272 protected:
ExtraReset()273 void ExtraReset() override {
274 recursive_analysis_.clear();
275 recursive_map_.clear();
276 }
277
278 void RealRecompute(FuncGraphPtr fg) override;
279 };
280
281 class FuncGraphMetaFgPrimTotalComputer final : public DepComputer {
282 public:
FuncGraphMetaFgPrimTotalComputer(const FuncGraphManager * m)283 explicit FuncGraphMetaFgPrimTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
284 ~FuncGraphMetaFgPrimTotalComputer() override = default;
285
meta_fg_prim_total_analysis()286 FuncGraphToBoolMap &meta_fg_prim_total_analysis() { return meta_fg_prim_total_analysis_; }
287
size()288 size_t size() const override { return meta_fg_prim_total_analysis_.size(); }
289
290 FuncGraphToBoolMap meta_fg_prim_total_analysis_;
291
292 protected:
ExtraReset()293 void ExtraReset() override { meta_fg_prim_total_analysis_.clear(); }
294
295 void RealRecompute(FuncGraphPtr fg) override;
296
297 bool SeekMetaFgPrim(const FuncGraphPtr &fg, SeenNum seen_num);
298 };
299
300 class MS_CORE_API FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
301 public:
302 explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true, bool drop_unused_graph = false);
303 virtual ~FuncGraphManager();
304
305 void Reset();
306 void Init();
307 void Clear() noexcept;
308 void DropFuncGraph(const FuncGraphPtr &fg, bool force = false);
309 void AddFuncGraph(const FuncGraphPtr &func_graph, bool is_root = false);
310 void AddFuncGraphs(const FuncGraphPtr &source_func_graph);
311 void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
312 void RemoveRoots();
313 void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters);
314 void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter);
315 void InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter);
316 void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
317 bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
318 bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const AnfNodePtr &mask_node);
319 void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
320 void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value);
321 void MoveAllCNodeDropGraph(const FuncGraphPtr &source, const FuncGraphPtr &target, const AnfNodePtr &call_node,
322 const ScopePtr &scope, bool update_debug_info = false);
323
324 FuncGraphTransaction Transact();
325 void CommitChanges(std::vector<change::ChangePtr> &&changes);
326
roots()327 const FuncGraphSet &roots() const { return roots_; }
328
func_graphs()329 const FuncGraphSet &func_graphs() const { return func_graphs_; }
330
all_nodes()331 AnfNodeSet &all_nodes() { return all_nodes_; }
332
node_users()333 NodeUsersMap &node_users() { return node_users_; }
334
node_users()335 const NodeUsersMap &node_users() const { return node_users_; }
336
337 FVTotalMap &free_variables_total() const;
338
339 FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
340
341 FuncGraphSet &scopes(const FuncGraphPtr &fg) const;
342
343 FuncGraphPtr parent(const FuncGraphPtr &fg) const;
344
345 FuncGraphSet &children(const FuncGraphPtr &fg) const;
346
347 FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const;
348
349 bool recursive(const FuncGraphPtr &fg) const;
350 std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const;
351
352 bool func_graph_meta_fg_prim_total(const FuncGraphPtr &fg) const;
353
signals()354 std::shared_ptr<Signals> signals() const { return signals_; }
355
356 // Static Analysis
357 NodeUsersMap node_users_;
358 AnfNodeSet all_nodes_; // managed nodes
359
360 // Dynamic Analysis
361 std::shared_ptr<ParentComputer> func_graph_parent_;
362
363 void ProcessEdgeRemove(const AnfNodePtr &node, int index, const AnfNodePtr &input);
364
365 private:
366 // Erase OneGraph From Manager
367 void EraseOneGraph(const FuncGraphPtr &fg);
368 void AddIntoManaged(const FuncGraphPtr &fg);
369 void ProcessEdgeAdd(const AnfNodePtr &node, int index, const AnfNodePtr &input);
370 void ProcessInputsEdgeAdd(const CNodePtr &cnode);
371 void ProcessInputsEdgeRemove(const CNodePtr &cnode);
372 void AcquireNodes(std::vector<AnfNodePtr> &&nodes, bool recursive = true);
373 FuncGraphSet MaybeDropNodes(std::vector<AnfNodePtr> &&nodes);
374 void OnEdgeAdded(const AnfNodePtr &node, int index, const AnfNodePtr &input);
375 void OnEdgeRemoved(const AnfNodePtr &node, int index, const AnfNodePtr &input);
376 void MoveAllNodes(const FuncGraphPtr &source, const FuncGraphPtr &target);
377
378 std::deque<FuncGraphPtr> todo_;
379 FuncGraphSet roots_; // Managed roots.
380 FuncGraphSet func_graphs_; // Managed func graphs.
381
382 std::shared_ptr<Signals> signals_;
383
384 // Dynamic Analysis
385 std::shared_ptr<FuncGraphParentsTotalComputer> func_graph_parents_total_;
386 std::shared_ptr<ChildrenComputer> children_;
387 std::shared_ptr<ScopeComputer> scopes_;
388 std::shared_ptr<FVTotalComputer> free_variables_total_;
389 std::shared_ptr<FuncGraphsUsedTotalComputer> func_graphs_used_total_;
390 std::shared_ptr<RecursiveComputer> recursive_;
391 std::shared_ptr<FuncGraphMetaFgPrimTotalComputer> meta_fg_prim_total_;
392
393 bool is_manage_{false};
394 bool drop_unused_graph_{false};
395 };
396
397 class MS_CORE_API FuncGraphTransaction {
398 public:
FuncGraphTransaction(FuncGraphManager * manager)399 explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager) {}
FuncGraphTransaction()400 FuncGraphTransaction() : manager_(nullptr) {}
401 ~FuncGraphTransaction() = default;
402
403 FuncGraphTransaction(const FuncGraphTransaction &other) = delete;
404 FuncGraphTransaction &operator=(const FuncGraphTransaction &other) = delete;
405
406 FuncGraphTransaction(FuncGraphTransaction &&other) = default;
407 FuncGraphTransaction &operator=(FuncGraphTransaction &&other) = default;
408
409 // set parameters of a func graph
410 void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms);
411 void AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m);
412 void InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr ¶m);
413
414 // replace old_node with new_node
415 bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
416 bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const AnfNodePtr &mask_node);
417
418 // set edge, i.e., declare setting node.inputs[key] to value.
419 void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v);
420 // Add edge, i.e., append value to node.inputs.
421 void AddEdge(const AnfNodePtr &src_node, const AnfNodePtr &v);
422
423 // commit all changes
424 void Commit();
425
426 private:
427 FuncGraphManager *manager_;
428 std::vector<change::ChangePtr> changes_;
429 };
430
Transact()431 inline FuncGraphTransaction FuncGraphManager::Transact() { return FuncGraphTransaction(this); }
432
433 } // namespace mindspore
434
435 #endif // MINDSPORE_CORE_IR_MANAGER_H_
436