1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_ 18 #define MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_ 19 20 #include <functional> 21 #include <list> 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 27 #include "utils/hash_map.h" 28 #include "utils/hash_set.h" 29 #include "ir/anf.h" 30 #include "ir/func_graph.h" 31 #include "ir/manager.h" 32 #include "utils/hashing.h" 33 #include "mindapi/base/macros.h" 34 35 namespace mindspore { 36 class Cloner; 37 using ClonerPtr = std::shared_ptr<Cloner>; 38 using NodeToNodeMap = mindspore::HashMap<AnfNodePtr, AnfNodePtr>; 39 40 enum CloneType { kBasic = 0, kInline = 1, kLifting = 2, kDropping = 3 }; 41 42 struct CloneInfo { 43 FuncGraphPtr origin; 44 FuncGraphPtr target; 45 AnfNodePtrList params; 46 }; 47 48 struct UpdateInfo { UpdateInfoUpdateInfo49 UpdateInfo(const ScopePtr &scope, const NodeDebugInfoPtr &debug_info) : scope_(scope), debug_info_(debug_info) {} 50 ~UpdateInfo() = default; 51 52 ScopePtr scope_; 53 NodeDebugInfoPtr debug_info_; 54 }; 55 56 using UpdateInfoPtr = std::shared_ptr<UpdateInfo>; 57 58 class MS_CORE_API Cloner { 59 public: 60 explicit Cloner(const FuncGraphVector &func_graphs = {}, bool clone_all_valuenodes = false, 61 bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, 62 const TraceInfoPtr &relation = std::make_shared<TraceCopy>(), 63 const TraceInfoPtr &target_relation = nullptr); 64 ~Cloner() = default; 65 void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr, 66 const AnfNodePtrList ¶ms = {}, CloneType type = kBasic); 67 void Run(); 68 69 // Interfaces for specializer 70 AnfNodePtr CloneDisconnected(const AnfNodePtr &root); 71 AnfNodePtr operator[](const AnfNodePtr &node); 72 FuncGraphPtr operator[](const FuncGraphPtr &func_graph); 73 74 // Map of replicate nodes and graphs cloned_nodes()75 const NodeToNodeMap &cloned_nodes() const { return replicated_node_; } cloned_func_graphs()76 const mindspore::HashMap<FuncGraphPtr, FuncGraphPtr> &cloned_func_graphs() const { return replicated_func_graph_; } 77 78 // Scope of cloned graphs set_scope(const ScopePtr & scope)79 void set_scope(const ScopePtr &scope) { scope_ = scope; } scope()80 const ScopePtr scope() const { return scope_; } 81 82 // When clone nodes, the same debug info and scope. set_update_info(const UpdateInfoPtr & update_info)83 void set_update_info(const UpdateInfoPtr &update_info) { update_info_ = update_info; } update_info()84 const UpdateInfoPtr update_info() const { return update_info_; } 85 86 // set call node debug info of InlineClone. set_inline_call_node(const AnfNodePtr & call_node)87 void set_inline_call_node(const AnfNodePtr &call_node) { inline_call_node_ = call_node; } 88 preset_abstract()89 bool preset_abstract() const { return preset_abstract_; } set_preset_abstract(bool preset_abstract)90 void set_preset_abstract(bool preset_abstract) { preset_abstract_ = preset_abstract; } 91 lifting_func_graph_filter()92 GraphFilterFunc lifting_func_graph_filter() const { return lifting_func_graph_filter_; } set_lifting_func_graph_filter(GraphFilterFunc filter)93 void set_lifting_func_graph_filter(GraphFilterFunc filter) { lifting_func_graph_filter_ = filter; } 94 95 private: 96 void CloneNodes(); 97 void LinkCNodeEdges(); 98 void SetDefaults(); 99 void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target); 100 void CloneValueNode(const AnfNodePtr &node); 101 void CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr &target); 102 void CloneCNodeWithoutInputs(const AnfNodePtr &node, const FuncGraphPtr &target); 103 void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false); 104 void CloneValueNodes(const FuncGraphPtr &func_graph); 105 void AddChildGraphs(const FuncGraphPtr &func_graph); 106 void AddTotalGraphs(const FuncGraphPtr &func_graph); 107 bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline); 108 void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 109 void CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 110 void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 111 void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 112 void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); 113 void SetFuncGraphInfo(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) const; 114 void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 115 void GenParameters(const FuncGraphPtr &func_graph); 116 void CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) const; 117 ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true); 118 void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList &inputs, size_t arg_start_index); 119 CNodePtr SetPartialEdges(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FuncGraphTransaction *tx); 120 void SetEdges(const FuncGraphPtr &func_graph, FuncGraphTransaction *tx); 121 void SetEdgesBfs(const FuncGraphPtr &root_fg, FuncGraphTransaction *tx); 122 void AddParameters(const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList ¶ms, 123 AnfNodeWeakPtrList *const lift_params, AnfNodeWeakPtrList *const input_params); 124 void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList ¶ms); 125 void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, 126 const AnfNodeWeakPtrList ¶ms); 127 void Lift(const std::vector<FuncGraphPtr> &sorted); 128 void LiftParameters(const FuncGraphVector &todo_func_graphs); 129 bool IsLiftTopFuncGraph(const FuncGraphPtr &func_graph); 130 131 bool clone_all_valuenodes_; 132 bool clone_all_child_graphs_; 133 bool clone_all_used_graphs_; 134 bool preset_abstract_{true}; 135 GraphFilterFunc lifting_func_graph_filter_; 136 TraceInfoPtr relation_; 137 TraceInfoPtr target_relation_; 138 NodeToNodeMap replicated_node_; 139 mindspore::HashMap<FuncGraphPtr, FuncGraphPtr> replicated_func_graph_; 140 FuncGraphManagerPtr manager_; 141 FuncGraphSet graph_set_; 142 ScopePtr scope_; 143 UpdateInfoPtr update_info_; 144 AnfNodePtr inline_call_node_{nullptr}; 145 CloneType type_; 146 std::vector<CloneInfo> todo_; 147 mindspore::HashMap<FuncGraphPtr, bool> status_; 148 mindspore::HashMap<FuncGraphPtr, NodeToNodeMap> replicated_map_node_; 149 mindspore::HashMap<FuncGraphPtr, mindspore::HashMap<FuncGraphPtr, AnfNodePtr>> replicated_map_func_graph_; 150 mindspore::HashMap<FuncGraphPtr, AnfNodeWeakPtrList> replicated_func_graph_params_; 151 }; 152 153 MS_CORE_API AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, 154 const AnfNodePtrList &func_graph_args, const AnfNodePtr &call_node = nullptr); 155 156 MS_CORE_API FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph, bool preset_abstract = true, 157 const GraphFilterFunc &lifting_func_graph_filter = GraphFilterFunc()); 158 MS_CORE_API FuncGraphVector LiftingCloneMulti(const FuncGraphVector &func_graphs); 159 160 MS_CORE_API ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation); 161 162 MS_CORE_API FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, 163 const TraceInfoPtr &relation = std::make_shared<TraceTransform>()); 164 MS_CORE_API FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes = false, 165 const UpdateInfoPtr update_info = nullptr); 166 } // namespace mindspore 167 168 #endif // MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_ 169