• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_IR_MANAGER_H_
20 #define MINDSPORE_CORE_IR_MANAGER_H_
21 
22 #include <unordered_set>
23 #include <unordered_map>
24 #include <set>
25 #include <map>
26 #include <list>
27 #include <string>
28 #include <vector>
29 #include <utility>
30 #include <memory>
31 #include <functional>
32 
33 #include "utils/any.h"
34 #include "utils/misc.h"
35 #include "utils/signal.h"
36 #include "utils/ordered_set.h"
37 #include "utils/ordered_map.h"
38 #include "ir/anf.h"
39 #include "ir/graph_utils.h"
40 #include "utils/hashing.h"
41 #include "base/base_ref.h"
42 #include "api/ir/func_graph_manager.h"
43 
44 namespace mindspore {
45 namespace change {
46 struct ChangeCounter;
47 struct Change {
48   virtual ~Change() = default;
49   virtual void Apply(ChangeCounter *counter) = 0;
50 };
51 using ChangePtr = std::unique_ptr<Change>;
52 }  // namespace change
53 
54 class FuncGraphTransaction;
55 class FuncGraphManager;
56 using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
57 
58 using AnfNodeIndexSet = api::AnfNodeIndexSet;
59 // NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
60 using NodeUsersMap = api::NodeUsersMap;
61 using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
62 using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
63 
64 // manage the func graphs.
65 // if no manager exist, just create one and associate it to all func graphs; else reuse simply.
66 // func_graph, be managed graph
67 // manage: if true, created manager will be set in func_graph
68 // FuncGraphManagerPtr: return created manager
69 FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true);
70 
71 FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage = true);
72 
73 FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true);
74 
75 struct Signals {
76   Signal<void()> InvalidateComputer;
77 };
78 
79 using CNodeIndexPair = std::pair<AnfNodePtr, int>;
80 using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
81 using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
82 
83 // analysis base class, graphs analysis which need dynamic compute by DepCollector in each read
84 class DepComputer {
85  public:
86   explicit DepComputer(const FuncGraphManager *manager);
~DepComputer()87   virtual ~DepComputer() { manager_ = nullptr; }
88 
size()89   virtual size_t size() const { return 0; }
90 
Reset()91   void Reset() {
92     ExtraReset();
93     validate_ = false;
94     func_graphs_validate_.clear();
95   }
96 
OnInvalidateComputer()97   void OnInvalidateComputer() { Reset(); }
98 
99   void Recompute();
100 
101   void Recompute(const FuncGraphPtr &fg);
102 
IsValidate()103   bool IsValidate() const { return validate_; }
104 
IsValidate(const FuncGraphPtr & fg)105   bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }
106 
107  protected:
108   // subclass can reset their own member;
ExtraReset()109   virtual void ExtraReset() {}
110   // subclass do the real compute
RealRecompute()111   virtual void RealRecompute() {}
RealRecompute(FuncGraphPtr)112   virtual void RealRecompute(FuncGraphPtr) {}
113 
114   const FuncGraphManager *manager_;
115   bool validate_;
116   OrderedMap<FuncGraphPtr, bool> func_graphs_validate_;
117 
118  private:
119   friend FuncGraphManager;
120 };
121 
122 // graph g's all direct or proxy parents
123 class FuncGraphParentsTotalComputer final : public DepComputer {
124  public:
FuncGraphParentsTotalComputer(const FuncGraphManager * m)125   explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
126   ~FuncGraphParentsTotalComputer() override = default;
127 
func_graph_parents_total_analysis()128   FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
129 
size()130   size_t size() const override { return func_graph_parents_total_analysis_.size(); }
131 
132   FuncGraphToFuncGraphSetMap func_graph_parents_total_analysis_;
133 
134  protected:
ExtraReset()135   void ExtraReset() override { func_graph_parents_total_analysis_.clear(); }
136 
137   void RealRecompute(FuncGraphPtr fg) override;
138 
139  private:
140   FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, std::unordered_map<FuncGraphPtr, FuncGraphSetPtr> *seen_fgs);
141 };
142 
143 using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
144 
145 // graph's nearest parent in parents total
146 class ParentComputer final : public DepComputer {
147  public:
ParentComputer(const FuncGraphManager * m)148   explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {}
149   ~ParentComputer() override = default;
150 
parent_analysis()151   FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; }
152 
size()153   size_t size() const override { return parent_analysis_.size(); }
154 
155   FuncGraphToFuncGraphMap parent_analysis_;
156 
157  protected:
ExtraReset()158   void ExtraReset() override { parent_analysis_.clear(); }
159 
160   void RealRecompute(FuncGraphPtr fg) override;
161 };
162 
163 // graph's children graph except self
164 class ChildrenComputer final : public DepComputer {
165  public:
ChildrenComputer(const FuncGraphManager * m)166   explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {}
167   ~ChildrenComputer() override = default;
168 
children_analysis()169   FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; }
170 
size()171   size_t size() const override { return children_analysis_.size(); }
172 
173   FuncGraphToFuncGraphSetMap children_analysis_;
174 
175  protected:
ExtraReset()176   void ExtraReset() override { children_analysis_.clear(); }
177 
178   void RealRecompute(FuncGraphPtr fg) override;
179 };
180 
181 // graph's children graph include self
182 class ScopeComputer final : public DepComputer {
183  public:
ScopeComputer(const FuncGraphManager * m)184   explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {}
185   ~ScopeComputer() override = default;
186 
scope_analysis()187   FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; }
188 
size()189   size_t size() const override { return scope_analysis_.size(); }
190 
191   FuncGraphToFuncGraphSetMap scope_analysis_;
192 
193  protected:
ExtraReset()194   void ExtraReset() override { scope_analysis_.clear(); }
195 
196   void RealRecompute(FuncGraphPtr fg) override;
197 };
198 
199 using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;
200 
201 class FVTotalComputer final : public DepComputer {
202  public:
FVTotalComputer(const FuncGraphManager * m)203   explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
204   ~FVTotalComputer() override = default;
205 
fv_total_analysis()206   FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }
207 
size()208   size_t size() const override { return fv_total_analysis_.size(); }
209 
210   FVTotalMap fv_total_analysis_;
211 
212  protected:
ExtraReset()213   void ExtraReset() override { fv_total_analysis_.clear(); }
214 
215   void RealRecompute() override;
216 };
217 
218 class FuncGraphsUsedTotalComputer final : public DepComputer {
219  public:
FuncGraphsUsedTotalComputer(const FuncGraphManager * m)220   explicit FuncGraphsUsedTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
221   ~FuncGraphsUsedTotalComputer() override = default;
222 
func_graph_used_total_analysis()223   FuncGraphToFuncGraphSetMap &func_graph_used_total_analysis() { return func_graph_used_total_analysis_; }
224 
size()225   size_t size() const override { return func_graph_used_total_analysis_.size(); }
226 
227   FuncGraphToFuncGraphSetMap func_graph_used_total_analysis_;
228 
229  protected:
ExtraReset()230   void ExtraReset() override { func_graph_used_total_analysis_.clear(); }
231 
232   void RealRecompute(FuncGraphPtr fg) override;
233 };
234 
235 using FuncGraphToBoolMap = OrderedMap<FuncGraphPtr, bool>;
236 using RecursiveMap = OrderedMap<FuncGraphPtr, std::shared_ptr<std::list<FuncGraphPtr>>>;
237 
238 class RecursiveComputer final : public DepComputer {
239  public:
RecursiveComputer(const FuncGraphManager * m)240   explicit RecursiveComputer(const FuncGraphManager *m) : DepComputer(m) {}
241   ~RecursiveComputer() override = default;
242 
recursive_map()243   RecursiveMap &recursive_map() { return recursive_map_; }
recursive_analysis()244   FuncGraphToBoolMap &recursive_analysis() { return recursive_analysis_; }
245 
246   void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace);
247 
size()248   size_t size() const override { return recursive_analysis_.size(); }
249 
250   RecursiveMap recursive_map_;
251   FuncGraphToBoolMap recursive_analysis_;
252 
253  protected:
ExtraReset()254   void ExtraReset() override {
255     recursive_analysis_.clear();
256     recursive_map_.clear();
257   }
258 
259   void RealRecompute(FuncGraphPtr fg) override;
260 };
261 
262 class FuncGraphJTotalComputer final : public DepComputer {
263  public:
FuncGraphJTotalComputer(const FuncGraphManager * m)264   explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
265   ~FuncGraphJTotalComputer() override = default;
266 
j_total_analysis()267   FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; }
268 
size()269   size_t size() const override { return j_total_analysis_.size(); }
270 
271   FuncGraphToBoolMap j_total_analysis_;
272 
273  protected:
ExtraReset()274   void ExtraReset() override { j_total_analysis_.clear(); }
275 
276   void RealRecompute(FuncGraphPtr fg) override;
277   bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
278 };
279 
280 class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>, public api::FuncGraphManager {
281  public:
282   explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
~FuncGraphManager()283   ~FuncGraphManager() {
284     if (is_manage_) {
285       RemoveRoots();
286     }
287     Clear();
288   }
289 
290   void Reset();
291   void Init();
292   void Clear();
293   void AddFuncGraph(const FuncGraphPtr &func_graph, bool is_root = false);
294   void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
295   void RemoveRoots();
296   void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters);
297   void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
298   void InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
299   void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
300   bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) final;
301   void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) final;
302   void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) final;
303   void MoveAllCNodeDropGraph(const FuncGraphPtr &source, const FuncGraphPtr &target, const ScopePtr &scope);
304 
305   FuncGraphTransaction Transact();
306   void CommitChanges(std::vector<change::ChangePtr> &&changes);
307 
IsManaged()308   bool IsManaged() const { return is_manage_; }
309 
roots()310   const FuncGraphSet &roots() const { return roots_; }
311 
func_graphs()312   const FuncGraphSet &func_graphs() const { return func_graphs_; }
313 
all_nodes()314   AnfNodeSet &all_nodes() { return all_nodes_; }
315 
node_users()316   NodeUsersMap &node_users() { return node_users_; }
317 
node_users()318   const NodeUsersMap &node_users() const final { return node_users_; }
319 
320   FVTotalMap &free_variables_total() const;
321 
322   FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
323 
324   FuncGraphSet &scopes(const FuncGraphPtr &fg) const;
325 
326   FuncGraphPtr parent(const FuncGraphPtr &fg) const;
327 
328   FuncGraphSet &children(const FuncGraphPtr &fg) const;
329 
330   FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const;
331 
332   bool recursive(const FuncGraphPtr &fg) const;
333   std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const;
334 
335   bool func_graph_j_total(const FuncGraphPtr &fg) const;
336 
signals()337   std::shared_ptr<Signals> signals() const { return signals_; }
338 
339   // Static Analysis
340   NodeUsersMap node_users_;
341   AnfNodeSet all_nodes_;  // managed nodes
342 
343   // Dynamic Analysis
344   std::shared_ptr<ParentComputer> func_graph_parent_;
345 
346  private:
347   // Erase OneGraph From Manager
348   void EraseOneGraph(const FuncGraphPtr &fg);
349   void AddIntoManaged(const FuncGraphPtr &fg);
350   void ProcessEdgeAdd(const AnfNodePtr &node, int index, const AnfNodePtr &input);
351   void ProcessEdgeRemove(const AnfNodePtr &node, int index, const AnfNodePtr &input);
352   void ProcessInputsEdgeAdd(const CNodePtr &cnode);
353   void ProcessInputsEdgeRemove(const CNodePtr &cnode);
354   void AcquireNodes(std::vector<AnfNodePtr> &&nodes);
355   FuncGraphSet MaybeDropNodes(std::vector<AnfNodePtr> &&nodes);
356   void OnEdgeAdded(const AnfNodePtr &node, int index, const AnfNodePtr &input);
357   void OnEdgeRemoved(const AnfNodePtr &node, int index, const AnfNodePtr &input);
358   void MoveAllNodes(const FuncGraphPtr &source, const FuncGraphPtr &target);
359 
360   FuncGraphSet roots_;        // Managed roots.
361   FuncGraphSet func_graphs_;  // Managed func graphs.
362 
363   std::shared_ptr<Signals> signals_;
364 
365   // Dynamic Analysis
366   std::shared_ptr<FuncGraphParentsTotalComputer> func_graph_parents_total_;
367   std::shared_ptr<ChildrenComputer> children_;
368   std::shared_ptr<ScopeComputer> scopes_;
369   std::shared_ptr<FVTotalComputer> free_variables_total_;
370   std::shared_ptr<FuncGraphsUsedTotalComputer> func_graphs_used_total_;
371   std::shared_ptr<RecursiveComputer> recursive_;
372   std::shared_ptr<FuncGraphJTotalComputer> j_total_;
373 
374   bool is_manage_;
375 };
376 
377 class FuncGraphTransaction {
378  public:
FuncGraphTransaction(FuncGraphManager * manager)379   explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager) {}
FuncGraphTransaction()380   FuncGraphTransaction() : manager_(nullptr) {}
381   ~FuncGraphTransaction() = default;
382 
383   FuncGraphTransaction(const FuncGraphTransaction &other) = delete;
384   FuncGraphTransaction &operator=(const FuncGraphTransaction &other) = delete;
385 
386   FuncGraphTransaction(FuncGraphTransaction &&other) = default;
387   FuncGraphTransaction &operator=(FuncGraphTransaction &&other) = default;
388 
389   // set parameters of a func graph
390   void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params);
391   void AddParameter(FuncGraphPtr fg, const AnfNodePtr &param);
392   void InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr &param);
393 
394   // replace old_node with new_node
395   bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
396 
397   // set edge, i.e., declare setting node.inputs[key] to value.
398   void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v);
399   // Add edge, i.e., append value to node.inputs.
400   void AddEdge(const AnfNodePtr &src_node, const AnfNodePtr &v);
401 
402   // commit all changes
403   void Commit();
404 
405  private:
406   FuncGraphManager *manager_;
407   std::vector<change::ChangePtr> changes_;
408 };
409 
Transact()410 inline FuncGraphTransaction FuncGraphManager::Transact() { return FuncGraphTransaction(this); }
411 
412 }  // namespace mindspore
413 
414 #endif  // MINDSPORE_CORE_IR_MANAGER_H_
415