1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_ 18 #define MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_ 19 20 #include <functional> 21 #include <list> 22 #include <memory> 23 #include <string> 24 #include <unordered_map> 25 #include <unordered_set> 26 #include <utility> 27 #include <vector> 28 29 #include "ir/anf.h" 30 #include "ir/func_graph.h" 31 #include "ir/manager.h" 32 33 namespace mindspore { 34 class Cloner; 35 using ClonerPtr = std::shared_ptr<Cloner>; 36 37 enum CloneType { kBasic = 0, kInline = 1, kLifting = 2, kDropping = 3 }; 38 39 struct CloneInfo { 40 FuncGraphPtr origin; 41 FuncGraphPtr target; 42 AnfNodePtrList params; 43 }; 44 45 class Cloner { 46 public: 47 explicit Cloner(const FuncGraphVector &func_graphs = {}, bool clone_all_valuenodes = false, 48 bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, 49 const TraceInfoPtr &relation = std::make_shared<TraceCopy>(), 50 const TraceInfoPtr &target_relation = nullptr); 51 ~Cloner() = default; 52 void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr, 53 const AnfNodePtrList ¶ms = {}, CloneType type = kBasic); 54 void Run(); 55 56 // Interfaces for specializer 57 AnfNodePtr CloneDisconnected(const AnfNodePtr &root); 58 AnfNodePtr operator[](const AnfNodePtr &node); 59 FuncGraphPtr operator[](const FuncGraphPtr &func_graph); 60 61 // Map of replicate nodes and graphs cloned_node()62 std::unordered_map<AnfNodePtr, AnfNodePtr> *cloned_node() { return &repl_node_; } cloned_func_graph()63 std::unordered_map<FuncGraphPtr, FuncGraphPtr> &cloned_func_graph() { return repl_func_graph_; } 64 65 // Scope of cloned graphs set_scope(const ScopePtr & scope)66 void set_scope(const ScopePtr &scope) { scope_ = scope; } scope()67 const ScopePtr scope() const { return scope_; } 68 69 std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_; 70 std::unordered_map<FuncGraphPtr, FuncGraphPtr> repl_func_graph_; 71 72 private: 73 void CloneNodes(); 74 void LinkEdges(); 75 void SetDefaults(); 76 void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target); 77 void CloneValueNode(const AnfNodePtr &node); 78 void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target); 79 void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target); 80 void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false); 81 void CloneValueNodes(const FuncGraphPtr &func_graph); 82 void AddChildGraphs(const FuncGraphPtr &func_graph); 83 void AddTotalGraphs(const FuncGraphPtr &func_graph); 84 bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline); 85 void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 86 void CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 87 void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 88 void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 89 void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); 90 void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph); 91 void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); 92 void GenParameters(const FuncGraphPtr &func_graph); 93 void CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node); 94 ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true); 95 void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, AnfNodePtrList *const lift_params, 96 AnfNodePtrList *const input_params); 97 void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); 98 void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs, size_t arg_start_index); 99 void SetEdges(const FuncGraphPtr &func_graph, FuncGraphTransaction *tx); 100 void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, 101 const AnfNodePtrList ¶ms); 102 void Lift(const std::vector<FuncGraphPtr> &sorted); 103 void LiftParameters(const FuncGraphPtr &lift_top_func_graph); 104 105 bool clone_all_valuenodes_; 106 bool clone_all_child_graphs_; 107 bool clone_all_used_graphs_; 108 TraceInfoPtr relation_; 109 TraceInfoPtr target_relation_; 110 FuncGraphManagerPtr manager_; 111 FuncGraphSet graph_set_; 112 ScopePtr scope_; 113 CloneType type_; 114 std::list<CloneInfo> todo_; 115 std::list<std::pair<CNodePtr, CNodePtr>> nodes_; 116 std::unordered_map<FuncGraphPtr, bool> status_; 117 std::unordered_map<FuncGraphPtr, std::unordered_map<AnfNodePtr, AnfNodePtr>> repl_map_node_; 118 std::unordered_map<FuncGraphPtr, std::unordered_map<FuncGraphPtr, AnfNodePtr>> repl_map_func_graph_; 119 std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_; 120 }; 121 122 AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, 123 const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr); 124 125 FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph); 126 127 ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation); 128 129 FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, 130 const TraceInfoPtr &relation = std::make_shared<TraceTransform>()); 131 FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes = false); 132 } // namespace mindspore 133 134 #endif // MINDSPORE_CORE_IR_FUNC_GRAPH_CLONER_H_ 135