• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_
18 #define MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_
19 
20 #include <functional>
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "utils/hash_map.h"
28 #include "utils/hash_set.h"
29 #include "ir/anf.h"
30 #include "ir/func_graph.h"
31 #include "ir/manager.h"
32 #include "utils/hashing.h"
33 #include "mindapi/base/macros.h"
34 
35 namespace mindspore {
36 class Cloner;
37 using ClonerPtr = std::shared_ptr<Cloner>;
38 using NodeToNodeMap = mindspore::HashMap<AnfNodePtr, AnfNodePtr>;
39 
40 enum CloneType { kBasic = 0, kInline = 1, kLifting = 2, kDropping = 3 };
41 
42 struct CloneInfo {
43   FuncGraphPtr origin;
44   FuncGraphPtr target;
45   AnfNodePtrList params;
46 };
47 
48 struct UpdateInfo {
UpdateInfoUpdateInfo49   UpdateInfo(const ScopePtr &scope, const NodeDebugInfoPtr &debug_info) : scope_(scope), debug_info_(debug_info) {}
50   ~UpdateInfo() = default;
51 
52   ScopePtr scope_;
53   NodeDebugInfoPtr debug_info_;
54 };
55 
56 using UpdateInfoPtr = std::shared_ptr<UpdateInfo>;
57 
58 class MS_CORE_API Cloner {
59  public:
60   explicit Cloner(const FuncGraphVector &func_graphs = {}, bool clone_all_valuenodes = false,
61                   bool clone_all_child_graphs = true, bool clone_all_used_graphs = false,
62                   const TraceInfoPtr &relation = std::make_shared<TraceCopy>(),
63                   const TraceInfoPtr &target_relation = nullptr);
64   ~Cloner() = default;
65   void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr,
66                 const AnfNodePtrList &params = {}, CloneType type = kBasic);
67   void Run();
68 
69   // Interfaces for specializer
70   AnfNodePtr CloneDisconnected(const AnfNodePtr &root);
71   AnfNodePtr operator[](const AnfNodePtr &node);
72   FuncGraphPtr operator[](const FuncGraphPtr &func_graph);
73 
74   // Map of replicate nodes and graphs
cloned_nodes()75   const NodeToNodeMap &cloned_nodes() const { return replicated_node_; }
cloned_func_graphs()76   const mindspore::HashMap<FuncGraphPtr, FuncGraphPtr> &cloned_func_graphs() const { return replicated_func_graph_; }
77 
78   // Scope of cloned graphs
set_scope(const ScopePtr & scope)79   void set_scope(const ScopePtr &scope) { scope_ = scope; }
scope()80   const ScopePtr scope() const { return scope_; }
81 
82   // When clone nodes, the same debug info and scope.
set_update_info(const UpdateInfoPtr & update_info)83   void set_update_info(const UpdateInfoPtr &update_info) { update_info_ = update_info; }
update_info()84   const UpdateInfoPtr update_info() const { return update_info_; }
85 
86   // set call node debug info of InlineClone.
set_inline_call_node(const AnfNodePtr & call_node)87   void set_inline_call_node(const AnfNodePtr &call_node) { inline_call_node_ = call_node; }
88 
preset_abstract()89   bool preset_abstract() const { return preset_abstract_; }
set_preset_abstract(bool preset_abstract)90   void set_preset_abstract(bool preset_abstract) { preset_abstract_ = preset_abstract; }
91 
lifting_func_graph_filter()92   GraphFilterFunc lifting_func_graph_filter() const { return lifting_func_graph_filter_; }
set_lifting_func_graph_filter(GraphFilterFunc filter)93   void set_lifting_func_graph_filter(GraphFilterFunc filter) { lifting_func_graph_filter_ = filter; }
94 
95  private:
96   void CloneNodes();
97   void LinkCNodeEdges();
98   void SetDefaults();
99   void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target);
100   void CloneValueNode(const AnfNodePtr &node);
101   void CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr &target);
102   void CloneCNodeWithoutInputs(const AnfNodePtr &node, const FuncGraphPtr &target);
103   void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false);
104   void CloneValueNodes(const FuncGraphPtr &func_graph);
105   void AddChildGraphs(const FuncGraphPtr &func_graph);
106   void AddTotalGraphs(const FuncGraphPtr &func_graph);
107   bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline);
108   void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
109   void CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
110   void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
111   void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
112   void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
113   void SetFuncGraphInfo(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) const;
114   void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
115   void GenParameters(const FuncGraphPtr &func_graph);
116   void CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) const;
117   ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true);
118   void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList &inputs, size_t arg_start_index);
119   CNodePtr SetPartialEdges(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FuncGraphTransaction *tx);
120   void SetEdges(const FuncGraphPtr &func_graph, FuncGraphTransaction *tx);
121   void SetEdgesBfs(const FuncGraphPtr &root_fg, FuncGraphTransaction *tx);
122   void AddParameters(const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList &params,
123                      AnfNodeWeakPtrList *const lift_params, AnfNodeWeakPtrList *const input_params);
124   void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList &params);
125   void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
126                       const AnfNodeWeakPtrList &params);
127   void Lift(const std::vector<FuncGraphPtr> &sorted);
128   void LiftParameters(const FuncGraphVector &todo_func_graphs);
129   bool IsLiftTopFuncGraph(const FuncGraphPtr &func_graph);
130 
131   bool clone_all_valuenodes_;
132   bool clone_all_child_graphs_;
133   bool clone_all_used_graphs_;
134   bool preset_abstract_{true};
135   GraphFilterFunc lifting_func_graph_filter_;
136   TraceInfoPtr relation_;
137   TraceInfoPtr target_relation_;
138   NodeToNodeMap replicated_node_;
139   mindspore::HashMap<FuncGraphPtr, FuncGraphPtr> replicated_func_graph_;
140   FuncGraphManagerPtr manager_;
141   FuncGraphSet graph_set_;
142   ScopePtr scope_;
143   UpdateInfoPtr update_info_;
144   AnfNodePtr inline_call_node_{nullptr};
145   CloneType type_;
146   std::vector<CloneInfo> todo_;
147   mindspore::HashMap<FuncGraphPtr, bool> status_;
148   mindspore::HashMap<FuncGraphPtr, NodeToNodeMap> replicated_map_node_;
149   mindspore::HashMap<FuncGraphPtr, mindspore::HashMap<FuncGraphPtr, AnfNodePtr>> replicated_map_func_graph_;
150   mindspore::HashMap<FuncGraphPtr, AnfNodeWeakPtrList> replicated_func_graph_params_;
151 };
152 
153 MS_CORE_API AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
154                                    const AnfNodePtrList &func_graph_args, const AnfNodePtr &call_node = nullptr);
155 
156 MS_CORE_API FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph, bool preset_abstract = true,
157                                       const GraphFilterFunc &lifting_func_graph_filter = GraphFilterFunc());
158 MS_CORE_API FuncGraphVector LiftingCloneMulti(const FuncGraphVector &func_graphs);
159 
160 MS_CORE_API ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation);
161 
162 MS_CORE_API FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph,
163                                             const TraceInfoPtr &relation = std::make_shared<TraceTransform>());
164 MS_CORE_API FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes = false,
165                                     const UpdateInfoPtr update_info = nullptr);
166 }  // namespace mindspore
167 
168 #endif  // MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_
169