• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_IR_MANAGER_H_
20 #define MINDSPORE_CORE_IR_MANAGER_H_
21 
22 #include <deque>
23 #include <set>
24 #include <map>
25 #include <list>
26 #include <string>
27 #include <vector>
28 #include <utility>
29 #include <memory>
30 #include <functional>
31 #include "utils/any.h"
32 #include "utils/misc.h"
33 #include "utils/signal.h"
34 #include "utils/hash_map.h"
35 #include "utils/hash_set.h"
36 #include "utils/compact_set.h"
37 #include "utils/ordered_set.h"
38 #include "utils/ordered_map.h"
39 #include "ir/anf.h"
40 #include "ir/graph_utils.h"
41 #include "utils/hashing.h"
42 #include "base/base_ref.h"
43 
44 namespace mindspore {
45 namespace change {
46 struct ChangeCounter;
47 struct Change {
48   virtual ~Change() = default;
49   virtual void Apply(ChangeCounter *counter) = 0;
50 };
51 using ChangePtr = std::unique_ptr<Change>;
52 }  // namespace change
53 
54 class FuncGraphTransaction;
55 class FuncGraphManager;
56 using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
57 
58 using AnfNodeIndexSet = CompactSet<std::pair<AnfNodePtr, int>>;
59 using NodeUsersMap = mindspore::HashMap<AnfNodePtr, AnfNodeIndexSet>;
60 
61 using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
62 using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
63 
64 // manage the func graphs.
65 // if no manager exist, just create one and associate it to all func graphs; else reuse simply.
66 // func_graph, be managed graph
67 // manage: if true, created manager will be set in func_graph
68 // FuncGraphManagerPtr: return created manager
69 MS_CORE_API FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true, bool drop_unused_graph = false);
70 
71 MS_CORE_API FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool manage = true,
72                                        bool drop_unused_graph = false);
73 
74 MS_CORE_API FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true,
75                                             bool drop_unused_graph = false);
76 
77 struct Signals {
78   Signal<void()> InvalidateComputer;
79 };
80 
81 using CNodeIndexPair = std::pair<AnfNodePtr, int>;
82 using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
83 using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
84 
85 // For Fast Pass
86 class FuncGraphPassIndex {
87  public:
88   constexpr static char key[] = "FuncGraphPassIndex";
FuncGraphPassIndex()89   FuncGraphPassIndex() : has_gen_index_(false) {}
set_has_gen_index(bool is_gen_index)90   void set_has_gen_index(bool is_gen_index) { has_gen_index_ = is_gen_index; }
has_gen_index()91   bool has_gen_index() const { return has_gen_index_; }
92   mindspore::HashMap<AnfNodePtr, FuncGraphWeakPtr> node_to_fg_;
93   mindspore::HashMap<std::string, std::set<AnfNodePtr>> name_to_cnode_;
94   mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> subgraph_out_caller_map_;
95   mindspore::HashMap<AnfNodePtr, size_t> node_degree_;
96 
97  private:
98   bool has_gen_index_;
99 };
100 using FuncGraphIndexPtr = std::shared_ptr<FuncGraphPassIndex>;
101 
102 // analysis base class, graphs analysis which need dynamic compute by DepCollector in each read
103 class DepComputer {
104  public:
105   explicit DepComputer(const FuncGraphManager *manager);
~DepComputer()106   virtual ~DepComputer() { manager_ = nullptr; }
107 
size()108   virtual size_t size() const { return 0; }
109 
Reset()110   void Reset() {
111     ExtraReset();
112     validate_ = false;
113     func_graphs_validate_.clear();
114   }
115 
OnInvalidateComputer()116   void OnInvalidateComputer() { Reset(); }
117 
118   void Recompute();
119 
120   void Recompute(const FuncGraphPtr &fg);
121 
IsValidate()122   bool IsValidate() const { return validate_; }
123 
IsValidate(const FuncGraphPtr & fg)124   bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }
125 
126  protected:
127   // subclass can reset their own member;
ExtraReset()128   virtual void ExtraReset() {}
129   // subclass do the real compute
RealRecompute()130   virtual void RealRecompute() {}
RealRecompute(FuncGraphPtr)131   virtual void RealRecompute(FuncGraphPtr) {}
132 
133   const FuncGraphManager *manager_;
134   bool validate_;
135   OrderedMap<FuncGraphPtr, bool> func_graphs_validate_;
136 
137  private:
138   friend FuncGraphManager;
139 };
140 
141 // graph g's all direct or proxy parents
142 class FuncGraphParentsTotalComputer final : public DepComputer {
143  public:
FuncGraphParentsTotalComputer(const FuncGraphManager * m)144   explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
145   ~FuncGraphParentsTotalComputer() override = default;
146 
func_graph_parents_total_analysis()147   FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; }
148 
size()149   size_t size() const override { return func_graph_parents_total_analysis_.size(); }
150 
151   FuncGraphToFuncGraphSetMap func_graph_parents_total_analysis_;
152 
153  protected:
ExtraReset()154   void ExtraReset() override { func_graph_parents_total_analysis_.clear(); }
155 
156   void RealRecompute(FuncGraphPtr fg) override;
157 
158  private:
159   FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg);
160 };
161 
162 using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
163 
164 // graph's nearest parent in parents total
165 class ParentComputer final : public DepComputer {
166  public:
ParentComputer(const FuncGraphManager * m)167   explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {}
168   ~ParentComputer() override = default;
169 
parent_analysis()170   FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; }
171 
size()172   size_t size() const override { return parent_analysis_.size(); }
173 
174   FuncGraphToFuncGraphMap parent_analysis_;
175 
176  protected:
ExtraReset()177   void ExtraReset() override { parent_analysis_.clear(); }
178 
179   void RealRecompute(FuncGraphPtr fg) override;
180 };
181 
182 // graph's children graph except self
183 class ChildrenComputer final : public DepComputer {
184  public:
ChildrenComputer(const FuncGraphManager * m)185   explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {}
186   ~ChildrenComputer() override = default;
187 
children_analysis()188   FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; }
189 
size()190   size_t size() const override { return children_analysis_.size(); }
191 
192   FuncGraphToFuncGraphSetMap children_analysis_;
193 
194  protected:
ExtraReset()195   void ExtraReset() override { children_analysis_.clear(); }
196 
197   void RealRecompute(FuncGraphPtr fg) override;
198 };
199 
200 // graph's children graph include self
201 class ScopeComputer final : public DepComputer {
202  public:
ScopeComputer(const FuncGraphManager * m)203   explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {}
204   ~ScopeComputer() override = default;
205 
scope_analysis()206   FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; }
207 
size()208   size_t size() const override { return scope_analysis_.size(); }
209 
210   FuncGraphToFuncGraphSetMap scope_analysis_;
211 
212  protected:
ExtraReset()213   void ExtraReset() override { scope_analysis_.clear(); }
214 
215   void RealRecompute(FuncGraphPtr fg) override;
216 };
217 
218 using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;
219 
220 class FVTotalComputer final : public DepComputer {
221  public:
FVTotalComputer(const FuncGraphManager * m)222   explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
223   ~FVTotalComputer() override = default;
224 
fv_total_analysis()225   FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }
226 
size()227   size_t size() const override { return fv_total_analysis_.size(); }
228 
229   FVTotalMap fv_total_analysis_;
230 
231  protected:
ExtraReset()232   void ExtraReset() override { fv_total_analysis_.clear(); }
233 
234   void RealRecompute() override;
235 };
236 
237 class FuncGraphsUsedTotalComputer final : public DepComputer {
238  public:
FuncGraphsUsedTotalComputer(const FuncGraphManager * m)239   explicit FuncGraphsUsedTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
240   ~FuncGraphsUsedTotalComputer() override = default;
241 
func_graph_used_total_analysis()242   FuncGraphToFuncGraphSetMap &func_graph_used_total_analysis() { return func_graph_used_total_analysis_; }
243 
size()244   size_t size() const override { return func_graph_used_total_analysis_.size(); }
245 
246   FuncGraphToFuncGraphSetMap func_graph_used_total_analysis_;
247 
248  protected:
ExtraReset()249   void ExtraReset() override { func_graph_used_total_analysis_.clear(); }
250 
251   void RealRecompute(FuncGraphPtr fg) override;
252 };
253 
254 using FuncGraphToBoolMap = OrderedMap<FuncGraphPtr, bool>;
255 using RecursiveMap = OrderedMap<FuncGraphPtr, std::shared_ptr<std::list<FuncGraphPtr>>>;
256 
257 class RecursiveComputer final : public DepComputer {
258  public:
RecursiveComputer(const FuncGraphManager * m)259   explicit RecursiveComputer(const FuncGraphManager *m) : DepComputer(m) {}
260   ~RecursiveComputer() override = default;
261 
recursive_map()262   RecursiveMap &recursive_map() { return recursive_map_; }
recursive_analysis()263   FuncGraphToBoolMap &recursive_analysis() { return recursive_analysis_; }
264 
265   void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<FuncGraphPtr> *trace);
266 
size()267   size_t size() const override { return recursive_analysis_.size(); }
268 
269   RecursiveMap recursive_map_;
270   FuncGraphToBoolMap recursive_analysis_;
271 
272  protected:
ExtraReset()273   void ExtraReset() override {
274     recursive_analysis_.clear();
275     recursive_map_.clear();
276   }
277 
278   void RealRecompute(FuncGraphPtr fg) override;
279 };
280 
281 class FuncGraphMetaFgPrimTotalComputer final : public DepComputer {
282  public:
FuncGraphMetaFgPrimTotalComputer(const FuncGraphManager * m)283   explicit FuncGraphMetaFgPrimTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
284   ~FuncGraphMetaFgPrimTotalComputer() override = default;
285 
meta_fg_prim_total_analysis()286   FuncGraphToBoolMap &meta_fg_prim_total_analysis() { return meta_fg_prim_total_analysis_; }
287 
size()288   size_t size() const override { return meta_fg_prim_total_analysis_.size(); }
289 
290   FuncGraphToBoolMap meta_fg_prim_total_analysis_;
291 
292  protected:
ExtraReset()293   void ExtraReset() override { meta_fg_prim_total_analysis_.clear(); }
294 
295   void RealRecompute(FuncGraphPtr fg) override;
296 
297   bool SeekMetaFgPrim(const FuncGraphPtr &fg, SeenNum seen_num);
298 };
299 
300 class MS_CORE_API FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
301  public:
302   explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true, bool drop_unused_graph = false);
303   virtual ~FuncGraphManager();
304 
305   void Reset();
306   void Init();
307   void Clear() noexcept;
308   void DropFuncGraph(const FuncGraphPtr &fg, bool force = false);
309   void AddFuncGraph(const FuncGraphPtr &func_graph, bool is_root = false);
310   void AddFuncGraphs(const FuncGraphPtr &source_func_graph);
311   void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
312   void RemoveRoots();
313   void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters);
314   void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
315   void InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
316   void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
317   bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
318   bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const AnfNodePtr &mask_node);
319   void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
320   void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value);
321   void MoveAllCNodeDropGraph(const FuncGraphPtr &source, const FuncGraphPtr &target, const AnfNodePtr &call_node,
322                              const ScopePtr &scope, bool update_debug_info = false);
323 
324   FuncGraphTransaction Transact();
325   void CommitChanges(std::vector<change::ChangePtr> &&changes);
326 
roots()327   const FuncGraphSet &roots() const { return roots_; }
328 
func_graphs()329   const FuncGraphSet &func_graphs() const { return func_graphs_; }
330 
all_nodes()331   AnfNodeSet &all_nodes() { return all_nodes_; }
332 
node_users()333   NodeUsersMap &node_users() { return node_users_; }
334 
node_users()335   const NodeUsersMap &node_users() const { return node_users_; }
336 
337   FVTotalMap &free_variables_total() const;
338 
339   FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
340 
341   FuncGraphSet &scopes(const FuncGraphPtr &fg) const;
342 
343   FuncGraphPtr parent(const FuncGraphPtr &fg) const;
344 
345   FuncGraphSet &children(const FuncGraphPtr &fg) const;
346 
347   FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const;
348 
349   bool recursive(const FuncGraphPtr &fg) const;
350   std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const;
351 
352   bool func_graph_meta_fg_prim_total(const FuncGraphPtr &fg) const;
353 
signals()354   std::shared_ptr<Signals> signals() const { return signals_; }
355 
356   // Static Analysis
357   NodeUsersMap node_users_;
358   AnfNodeSet all_nodes_;  // managed nodes
359 
360   // Dynamic Analysis
361   std::shared_ptr<ParentComputer> func_graph_parent_;
362 
363   void ProcessEdgeRemove(const AnfNodePtr &node, int index, const AnfNodePtr &input);
364 
365  private:
366   // Erase OneGraph From Manager
367   void EraseOneGraph(const FuncGraphPtr &fg);
368   void AddIntoManaged(const FuncGraphPtr &fg);
369   void ProcessEdgeAdd(const AnfNodePtr &node, int index, const AnfNodePtr &input);
370   void ProcessInputsEdgeAdd(const CNodePtr &cnode);
371   void ProcessInputsEdgeRemove(const CNodePtr &cnode);
372   void AcquireNodes(std::vector<AnfNodePtr> &&nodes, bool recursive = true);
373   FuncGraphSet MaybeDropNodes(std::vector<AnfNodePtr> &&nodes);
374   void OnEdgeAdded(const AnfNodePtr &node, int index, const AnfNodePtr &input);
375   void OnEdgeRemoved(const AnfNodePtr &node, int index, const AnfNodePtr &input);
376   void MoveAllNodes(const FuncGraphPtr &source, const FuncGraphPtr &target);
377 
378   std::deque<FuncGraphPtr> todo_;
379   FuncGraphSet roots_;        // Managed roots.
380   FuncGraphSet func_graphs_;  // Managed func graphs.
381 
382   std::shared_ptr<Signals> signals_;
383 
384   // Dynamic Analysis
385   std::shared_ptr<FuncGraphParentsTotalComputer> func_graph_parents_total_;
386   std::shared_ptr<ChildrenComputer> children_;
387   std::shared_ptr<ScopeComputer> scopes_;
388   std::shared_ptr<FVTotalComputer> free_variables_total_;
389   std::shared_ptr<FuncGraphsUsedTotalComputer> func_graphs_used_total_;
390   std::shared_ptr<RecursiveComputer> recursive_;
391   std::shared_ptr<FuncGraphMetaFgPrimTotalComputer> meta_fg_prim_total_;
392 
393   bool is_manage_{false};
394   bool drop_unused_graph_{false};
395 };
396 
397 class MS_CORE_API FuncGraphTransaction {
398  public:
FuncGraphTransaction(FuncGraphManager * manager)399   explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager) {}
FuncGraphTransaction()400   FuncGraphTransaction() : manager_(nullptr) {}
401   ~FuncGraphTransaction() = default;
402 
403   FuncGraphTransaction(const FuncGraphTransaction &other) = delete;
404   FuncGraphTransaction &operator=(const FuncGraphTransaction &other) = delete;
405 
406   FuncGraphTransaction(FuncGraphTransaction &&other) = default;
407   FuncGraphTransaction &operator=(FuncGraphTransaction &&other) = default;
408 
409   // set parameters of a func graph
410   void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params);
411   void AddParameter(FuncGraphPtr fg, const AnfNodePtr &param);
412   void InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr &param);
413 
414   // replace old_node with new_node
415   bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
416   bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const AnfNodePtr &mask_node);
417 
418   // set edge, i.e., declare setting node.inputs[key] to value.
419   void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v);
420   // Add edge, i.e., append value to node.inputs.
421   void AddEdge(const AnfNodePtr &src_node, const AnfNodePtr &v);
422 
423   // commit all changes
424   void Commit();
425 
426  private:
427   FuncGraphManager *manager_;
428   std::vector<change::ChangePtr> changes_;
429 };
430 
Transact()431 inline FuncGraphTransaction FuncGraphManager::Transact() { return FuncGraphTransaction(this); }
432 
433 }  // namespace mindspore
434 
435 #endif  // MINDSPORE_CORE_IR_MANAGER_H_
436