• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_
18 #define MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_
19 
20 #include <functional>
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 #include <vector>
28 
29 #include "ir/anf.h"
30 #include "ir/func_graph.h"
31 #include "ir/manager.h"
32 
33 namespace mindspore {
34 class Cloner;
35 using ClonerPtr = std::shared_ptr<Cloner>;
36 
37 enum CloneType { kBasic = 0, kInline = 1, kLifting = 2, kDropping = 3 };
38 
39 struct CloneInfo {
40   FuncGraphPtr origin;
41   FuncGraphPtr target;
42   AnfNodePtrList params;
43 };
44 
45 class Cloner {
46  public:
47   explicit Cloner(const FuncGraphVector &func_graphs = {}, bool clone_all_valuenodes = false,
48                   bool clone_all_child_graphs = true, bool clone_all_used_graphs = false,
49                   const TraceInfoPtr &relation = std::make_shared<TraceCopy>(),
50                   const TraceInfoPtr &target_relation = nullptr);
51   ~Cloner() = default;
52   void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr,
53                 const AnfNodePtrList &params = {}, CloneType type = kBasic);
54   void Run();
55 
56   // Interfaces for specializer
57   AnfNodePtr CloneDisconnected(const AnfNodePtr &root);
58   AnfNodePtr operator[](const AnfNodePtr &node);
59   FuncGraphPtr operator[](const FuncGraphPtr &func_graph);
60 
61   // Map of replicate nodes and graphs
cloned_node()62   std::unordered_map<AnfNodePtr, AnfNodePtr> *cloned_node() { return &repl_node_; }
cloned_func_graph()63   std::unordered_map<FuncGraphPtr, FuncGraphPtr> &cloned_func_graph() { return repl_func_graph_; }
64 
65   // Scope of cloned graphs
set_scope(const ScopePtr & scope)66   void set_scope(const ScopePtr &scope) { scope_ = scope; }
scope()67   const ScopePtr scope() const { return scope_; }
68 
69   std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_;
70   std::unordered_map<FuncGraphPtr, FuncGraphPtr> repl_func_graph_;
71 
72  private:
73   void CloneNodes();
74   void LinkEdges();
75   void SetDefaults();
76   void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target);
77   void CloneValueNode(const AnfNodePtr &node);
78   void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target);
79   void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target);
80   void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false);
81   void CloneValueNodes(const FuncGraphPtr &func_graph);
82   void AddChildGraphs(const FuncGraphPtr &func_graph);
83   void AddTotalGraphs(const FuncGraphPtr &func_graph);
84   bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline);
85   void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
86   void CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
87   void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
88   void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
89   void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
90   void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph);
91   void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
92   void GenParameters(const FuncGraphPtr &func_graph);
93   void CloneParameter(const ParameterPtr &param, const AnfNodePtr &node);
94   ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true);
95   void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params, AnfNodePtrList *const lift_params,
96                      AnfNodePtrList *const input_params);
97   void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
98   void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs, size_t arg_start_index);
99   void SetEdges(const FuncGraphPtr &func_graph, FuncGraphTransaction *tx);
100   void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
101                       const AnfNodePtrList &params);
102   void Lift(const std::vector<FuncGraphPtr> &sorted);
103   void LiftParameters(const FuncGraphPtr &lift_top_func_graph);
104 
105   bool clone_all_valuenodes_;
106   bool clone_all_child_graphs_;
107   bool clone_all_used_graphs_;
108   TraceInfoPtr relation_;
109   TraceInfoPtr target_relation_;
110   FuncGraphManagerPtr manager_;
111   FuncGraphSet graph_set_;
112   ScopePtr scope_;
113   CloneType type_;
114   std::list<CloneInfo> todo_;
115   std::list<std::pair<CNodePtr, CNodePtr>> nodes_;
116   std::unordered_map<FuncGraphPtr, bool> status_;
117   std::unordered_map<FuncGraphPtr, std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_map_node_;
118   std::unordered_map<FuncGraphPtr, std::unordered_map<FuncGraphPtr, AnfNodePtr>> repl_map_func_graph_;
119   std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_;
120 };
121 
122 AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
123                        const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr);
124 
125 FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph);
126 
127 ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation);
128 
129 FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph,
130                                 const TraceInfoPtr &relation = std::make_shared<TraceTransform>());
131 FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes = false);
132 }  // namespace mindspore
133 
134 #endif  // MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_
135