1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_H_ 19 20 #include <vector> 21 #include <algorithm> 22 #include <utility> 23 #include "frontend/optimizer/irpass.h" 24 #include "mindspore/core/ops/sequence_ops.h" 25 #include "mindspore/core/ops/framework_ops.h" 26 #include "frontend/optimizer/optimizer.h" 27 #include "frontend/optimizer/anf_visitor.h" 28 #include "include/common/utils/anfalgo.h" 29 #include "ir/func_graph.h" 30 31 namespace mindspore { 32 namespace opt { 33 namespace irpass { 34 constexpr auto kHandledNotRecomputeNodeFlag = "handled_not_recompute_node"; 35 constexpr auto kPrimalFgCallerUserDataKey = "primal_fg_caller"; 36 bool EnableCellReuse(); 37 38 bool HasBpropGetter(const OptimizerPtr &opt, const AnfNodePtr &k_fg_caller); 39 40 AnfNodePtr GetBpropCaller(const FuncGraphManagerPtr &manager, const AnfNodePtr &bprop_getter); 41 42 bool AddRecomputeNodes(const FuncGraphPtr &root, const opt::OptimizerPtr &opt); 43 44 class RemoveNotRecomputeNode : public AnfVisitor { 45 public: operator()46 AnfNodePtr operator()(const OptimizerPtr &opt, const AnfNodePtr &node) override { 47 if (!EnableCellReuse()) { 48 return nullptr; 49 } 50 Reset(); 51 auto k_fg_caller = node->cast<CNodePtr>(); 52 MS_EXCEPTION_IF_NULL(k_fg_caller); 53 if (!IsMatch(k_fg_caller)) { 54 return nullptr; 55 } 56 57 MS_EXCEPTION_IF_NULL(opt); 58 auto manager = opt->manager(); 59 MS_EXCEPTION_IF_NULL(manager); 60 auto fg = node->func_graph(); 61 MS_EXCEPTION_IF_NULL(fg); 62 63 bool has_bprop_getter = HasBpropGetter(opt, k_fg_caller); 64 // If the k graph has been handled, the call nodes of k and primal graph should be handled. 65 if (k_fg_->has_flag(kHandledNotRecomputeNodeFlag)) { 66 return CreateNewCallerForHandledKGraph(manager, fg, k_fg_caller, has_bprop_getter); 67 } 68 69 // The k graph only contains primal should not be recomputed. 70 if (!has_bprop_getter) { 71 return nullptr; 72 } 73 74 k_fg_->set_flag(kHandledNotRecomputeNodeFlag, true); 75 std::vector<AnfNodePtr> new_primal_fg_outputs{NewValueNode(prim::kPrimMakeTuple), primal_fg_->output()}; 76 std::vector<AnfNodePtr> k_fg_nodes = TopoSort(k_fg_->get_return(), SuccDeeperSimple); 77 int64_t not_recompute_count = 0; 78 for (const auto &node_in_k_fg : k_fg_nodes) { 79 auto [cnode_k_fg, primal_cnode] = GetNotRecomputeKGraphAndPrimalCNode(node_in_k_fg); 80 if (cnode_k_fg == nullptr || primal_cnode == nullptr) { 81 continue; 82 } 83 ++not_recompute_count; 84 // Erase the flag to do inline later. 85 cnode_k_fg->erase_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH); 86 // Replace the primal node in k graph with the node in primal graph. 87 (void)new_primal_fg_outputs.emplace_back(primal_cnode); 88 auto para = k_fg_->add_parameter(); 89 auto cnode_k_fg_output = cnode_k_fg->output(); 90 if (!IsPrimitiveCNode(cnode_k_fg_output, prim::kPrimMakeTuple)) { 91 MS_LOG(INTERNAL_EXCEPTION) << "The output of k graph should be make_tuple, but got " 92 << cnode_k_fg_output->DebugString(); 93 } 94 (void)manager->Replace(cnode_k_fg_output->cast<CNodePtr>()->input(1), para); 95 } 96 if (not_recompute_count == 0) { 97 return nullptr; 98 } 99 100 primal_fg_->set_output(primal_fg_->NewCNode(new_primal_fg_outputs)); 101 auto primal_fg_caller = k_fg_caller->user_data<CNode>(kPrimalFgCallerUserDataKey); 102 UpdateForwardResult(manager, primal_fg_caller); 103 // Add new arguments to k graph caller. 104 return CreateNewKGraphCaller(fg, k_fg_caller, primal_fg_caller, not_recompute_count); 105 } 106 CreateNewCallerForHandledKGraph(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const CNodePtr & k_fg_caller,bool has_bprop_getter)107 AnfNodePtr CreateNewCallerForHandledKGraph(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, 108 const CNodePtr &k_fg_caller, bool has_bprop_getter) { 109 auto not_recompute_count = SizeToLong(k_fg_->parameters().size() - (k_fg_caller->size() - 1)); 110 if (not_recompute_count == 0) { 111 return nullptr; 112 } 113 if (!has_bprop_getter) { 114 std::vector<AnfNodePtr> new_primal_caller_inputs{NewValueNode(primal_fg_)}; 115 (void)new_primal_caller_inputs.insert(new_primal_caller_inputs.cend(), k_fg_caller->inputs().begin() + 1, 116 k_fg_caller->inputs().end()); 117 auto new_primal_caller = fg->NewCNodeInOrder(new_primal_caller_inputs); 118 return new_primal_caller; 119 } 120 121 auto primal_fg_caller = k_fg_caller->user_data<CNode>(kPrimalFgCallerUserDataKey); 122 UpdateForwardResult(manager, primal_fg_caller); 123 return CreateNewKGraphCaller(fg, k_fg_caller, primal_fg_caller, not_recompute_count); 124 } 125 CreateNewKGraphCaller(const FuncGraphPtr & fg,const CNodePtr & k_fg_caller,const CNodePtr & primal_fg_caller,int64_t not_recompute_count)126 static AnfNodePtr CreateNewKGraphCaller(const FuncGraphPtr &fg, const CNodePtr &k_fg_caller, 127 const CNodePtr &primal_fg_caller, int64_t not_recompute_count) { 128 std::vector<AnfNodePtr> new_k_fg_caller_inputs; 129 (void)new_k_fg_caller_inputs.insert(new_k_fg_caller_inputs.cend(), k_fg_caller->inputs().begin(), 130 k_fg_caller->inputs().end()); 131 auto primal_fg_caller_fg = primal_fg_caller->func_graph(); 132 for (int64_t i = 1; i <= not_recompute_count; ++i) { 133 auto extra_forward_result = primal_fg_caller_fg->NewCNodeInOrder( 134 {NewValueNode(prim::kPrimTupleGetItem), primal_fg_caller, NewValueNode(i)}); 135 (void)new_k_fg_caller_inputs.emplace_back(extra_forward_result); 136 } 137 auto new_k_fg_caller = fg->NewCNodeInOrder(new_k_fg_caller_inputs); 138 if (k_fg_caller->HasAttr(kAddedRecomputeDependAttr)) { 139 new_k_fg_caller->AddAttr(kAddedRecomputeDependAttr, MakeValue(true)); 140 } 141 return new_k_fg_caller; 142 } 143 UpdateForwardResult(const FuncGraphManagerPtr & manager,const AnfNodePtr & primal_fg_caller)144 static void UpdateForwardResult(const FuncGraphManagerPtr &manager, const AnfNodePtr &primal_fg_caller) { 145 MS_EXCEPTION_IF_NULL(primal_fg_caller); 146 auto fg = primal_fg_caller->func_graph(); 147 MS_EXCEPTION_IF_NULL(fg); 148 auto forward_result = fg->NewCNodeInOrder( 149 {NewValueNode(prim::kPrimTupleGetItem), primal_fg_caller, NewValueNode(static_cast<int64_t>(0))}); 150 (void)manager->Replace(primal_fg_caller, forward_result); 151 } 152 GetNotRecomputeKGraphAndPrimalCNode(const AnfNodePtr & node)153 static std::pair<FuncGraphPtr, AnfNodePtr> GetNotRecomputeKGraphAndPrimalCNode(const AnfNodePtr &node) { 154 auto cnode = dyn_cast<CNode>(node); 155 if (cnode == nullptr) { 156 return std::make_pair(nullptr, nullptr); 157 } 158 // call (k_fg, ...) 159 auto cnode_k_fg = GetValueNode<FuncGraphPtr>(cnode->input(0)); 160 if (cnode_k_fg == nullptr || !cnode_k_fg->has_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH)) { 161 return std::make_pair(nullptr, nullptr); 162 } 163 // k_fg -> primal 164 auto primal_cnode = GetPrimalCNode(cnode_k_fg); 165 if (primal_cnode == nullptr) { 166 MS_LOG(DEBUG) << "The cnode k_fg " << cnode_k_fg->ToString() << " should have corresponding primal_cnode."; 167 return std::make_pair(nullptr, nullptr); 168 } 169 MS_LOG(DEBUG) << "primal_cnode: " << primal_cnode->DebugString(); 170 return std::make_pair(cnode_k_fg, primal_cnode); 171 } 172 GetPrimalCNode(const FuncGraphPtr & cnode_k_fg)173 static AnfNodePtr GetPrimalCNode(const FuncGraphPtr &cnode_k_fg) { 174 auto primal_cnode_iter = cnode_k_fg->transforms().find("primal_cnode"); 175 if (primal_cnode_iter == cnode_k_fg->transforms().end()) { 176 MS_LOG(DEBUG) << "Not found the primal cnode of k graph " << cnode_k_fg->ToString(); 177 return nullptr; 178 } 179 auto primal_cnode = primal_cnode_iter->second.primal_cnode(); 180 return primal_cnode; 181 } 182 IsMatch(const CNodePtr & k_fg_caller)183 bool IsMatch(const CNodePtr &k_fg_caller) { 184 auto k_fg = GetValueNode<FuncGraphPtr>(k_fg_caller->input(0)); 185 if (k_fg == nullptr || !k_fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) { 186 return false; 187 } 188 189 auto primal_iter = k_fg->transforms().find("primal"); 190 if (primal_iter == k_fg->transforms().end()) { 191 return false; 192 } 193 auto primal_fg = primal_iter->second.func_graph(); 194 if (primal_fg == nullptr || !primal_fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) { 195 return false; 196 } 197 198 k_fg_ = k_fg; 199 primal_fg_ = primal_fg; 200 return true; 201 } 202 Reset()203 void Reset() { 204 k_fg_ = nullptr; 205 primal_fg_ = nullptr; 206 } 207 208 private: 209 FuncGraphPtr k_fg_{nullptr}; 210 FuncGraphPtr primal_fg_{nullptr}; 211 }; 212 } // namespace irpass 213 } // namespace opt 214 } // namespace mindspore 215 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_RECOMPUTE_H_ 216