1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ 19 20 #include <vector> 21 #include <algorithm> 22 #include <utility> 23 #include <memory> 24 25 #include "utils/hash_map.h" 26 #include "mindspore/core/ops/framework_ops.h" 27 #include "frontend/optimizer/irpass.h" 28 #include "frontend/optimizer/optimizer.h" 29 #include "frontend/optimizer/anf_visitor.h" 30 #include "ir/func_graph.h" 31 #include "ir/func_graph_cloner.h" 32 #include "frontend/operator/ops.h" 33 34 namespace mindspore { 35 namespace opt { 36 namespace irpass { 37 namespace internal { 38 class CallOutputTransform { 39 public: CallOutputTransform()40 CallOutputTransform() : cache_() {} 41 ~CallOutputTransform() = default; 42 operator()43 FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs, bool xs_first) { 44 if (cache_.find(fg) == cache_.end()) { 45 cache_[fg] = {}; 46 } 47 48 auto &cache = cache_[fg]; 49 auto key = std::make_pair(nargs, xs_first); 50 if (cache.find(key) == cache.end()) { 51 FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("call")); 52 53 std::vector<AnfNodePtr> new_items; 54 new_items.push_back(new_fg->output()); 55 if (xs_first) { 56 for (size_t i = 0; i < nargs; i++) { 57 new_items.push_back(new_fg->add_parameter()); 58 } 59 } else { 60 for (size_t i = 0; i < nargs; i++) { 61 new_items.push_back(new_fg->InsertFrontParameter()); 62 } 63 } 64 new_fg->set_output(new_fg->NewCNode(new_items)); 65 66 cache[key] = new_fg; 67 } 68 return cache[key]; 69 } 70 71 private: 72 mindspore::HashMap<FuncGraphPtr, mindspore::HashMap<std::pair<size_t, bool>, FuncGraphPtr, PairHasher>> cache_; 73 }; 74 } // namespace internal 75 76 // {{G, Xs}, Ys} 77 class IncorporateCall : public AnfVisitor { 78 public: IncorporateCall()79 IncorporateCall() : call_output_transform_() {} 80 ~IncorporateCall() override = default; 81 operator()82 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 83 Reset(); 84 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 85 return nullptr; 86 } 87 88 auto &inputs = node->cast<CNodePtr>()->inputs(); 89 if (inputs[0] == nullptr || !inputs[0]->isa<CNode>()) { 90 return nullptr; 91 } 92 93 AnfVisitor::Visit(inputs[0]); 94 if (fg_ == nullptr) { 95 return nullptr; 96 } 97 98 auto xs_size = Xs_.size(); 99 auto ys_size = inputs.size() - 1; 100 bool xs_first = true; 101 if ((xs_size > 0) && (Xs_[xs_size - 1]->abstract() != nullptr) && 102 (Xs_[xs_size - 1]->abstract()->isa<abstract::AbstractMonad>())) { 103 xs_first = false; 104 } 105 auto new_fg = call_output_transform_(fg_, ys_size, xs_first); 106 107 std::vector<AnfNodePtr> args; 108 args.push_back(NewValueNode(new_fg)); 109 110 if (xs_first) { 111 if (xs_size > 0) { 112 (void)args.insert(args.cend(), Xs_.cbegin(), Xs_.cend()); 113 } 114 if (ys_size > 0) { 115 (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend()); 116 } 117 } else { 118 if (ys_size > 0) { 119 (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend()); 120 } 121 if (xs_size > 0) { 122 (void)args.insert(args.cend(), Xs_.cbegin(), Xs_.cend()); 123 } 124 } 125 return MakeNewNode(node, args); 126 } 127 MakeNewNode(const AnfNodePtr & node,const std::vector<AnfNodePtr> & args)128 AnfNodePtr MakeNewNode(const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) { 129 auto new_node = node->func_graph()->NewCNode(args); 130 new_node->set_abstract(node->abstract()); 131 // Check if the another only usage of {G, Xs} is UpdateState{s, {G, Xs}}, if yes, replace 132 // UpdateState{s, {G, Xs}} with UpdateState{s, new_node}; 133 const auto &manager = fg_->manager(); 134 MS_EXCEPTION_IF_NULL(manager); 135 auto &node_users_map = manager->node_users(); 136 auto it = node_users_map.find(fg_call_cnode_); 137 if (it != node_users_map.end()) { 138 AnfNodePtr update_state_node = nullptr; 139 auto &node_users = it->second; 140 constexpr size_t users_size = 2; 141 if (node_users.size() == users_size) { 142 for (auto &node_user : node_users) { 143 if (IsPrimitiveCNode(node_user.first, prim::kPrimUpdateState)) { 144 update_state_node = node_user.first; 145 } 146 } 147 } 148 if (update_state_node != nullptr) { 149 auto update_state_cnode = update_state_node->cast<CNodePtr>(); 150 // double check; 151 const size_t attach_index = 2; 152 if (update_state_cnode->input(attach_index) == fg_call_cnode_) { 153 constexpr int recursive_level = 2; 154 MS_LOG(DEBUG) << "Replace UpdateState node: " << update_state_cnode->DebugString(recursive_level) 155 << ", input 2 with: " << new_node->DebugString(); 156 manager->SetEdge(update_state_cnode, attach_index, new_node); 157 } 158 } 159 } 160 return new_node; 161 } 162 Visit(const CNodePtr & cnode)163 void Visit(const CNodePtr &cnode) override { 164 // {G, Xs} 165 if (cnode->size() < 1 || !IsValueNode<FuncGraph>(cnode->input(0))) { 166 return; 167 } 168 169 auto &inputs = cnode->inputs(); 170 fg_ = GetValueNode<FuncGraphPtr>(inputs[0]); 171 fg_call_cnode_ = cnode; 172 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); 173 } 174 Reset()175 void Reset() { 176 Xs_.clear(); 177 fg_ = nullptr; 178 fg_call_cnode_ = nullptr; 179 } 180 181 private: 182 FuncGraphPtr fg_; 183 CNodePtr fg_call_cnode_{nullptr}; 184 std::vector<AnfNodePtr> Xs_{}; 185 internal::CallOutputTransform call_output_transform_; 186 }; 187 188 // {{{prim::kPrimSwitch, X, G1, G2}, Xs}, Ys} 189 class IncorporateCallSwitch : public AnfVisitor { 190 public: IncorporateCallSwitch()191 IncorporateCallSwitch() : call_output_transform_() {} 192 ~IncorporateCallSwitch() override = default; 193 operator()194 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 195 Reset(); 196 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 197 return nullptr; 198 } 199 200 // {{...}, Ys} 201 auto &inputs = node->cast<CNodePtr>()->inputs(); 202 if (inputs[0] == nullptr || !inputs[0]->isa<CNode>()) { 203 return nullptr; 204 } 205 206 // {{...}, Xs} 207 auto &inputs_x = inputs[0]->cast<CNodePtr>()->inputs(); 208 if (inputs_x[0] == nullptr || !inputs_x[0]->isa<CNode>()) { 209 return nullptr; 210 } 211 212 // {prim::kPrimSwitch, X, G1, G2} 213 AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode<FuncGraph>, IsValueNode<FuncGraph>})(inputs_x[0]); 214 if (g2_ == nullptr) { 215 return nullptr; 216 } 217 218 auto fg = node->func_graph(); 219 auto xs_size = inputs_x.size() - 1; 220 auto ys_size = inputs.size() - 1; 221 bool xs_first = true; 222 if ((xs_size > 0) && (inputs_x[xs_size]->abstract() != nullptr) && 223 (inputs_x[xs_size]->abstract()->isa<abstract::AbstractMonad>())) { 224 xs_first = false; 225 } 226 auto new_g1 = call_output_transform_(g1_, ys_size, xs_first); 227 auto new_g2 = call_output_transform_(g2_, ys_size, xs_first); 228 auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); 229 230 std::vector<AnfNodePtr> args{sw_node}; 231 if (xs_first) { 232 if (xs_size > 0) { 233 (void)args.insert(args.cend(), inputs_x.cbegin() + 1, inputs_x.cend()); 234 } 235 if (ys_size > 0) { 236 (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend()); 237 } 238 } else { 239 if (ys_size > 0) { 240 (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend()); 241 } 242 if (xs_size > 0) { 243 (void)args.insert(args.cend(), inputs_x.cbegin() + 1, inputs_x.cend()); 244 } 245 } 246 247 auto new_node = fg->NewCNode(args); 248 new_node->set_abstract(node->abstract()); 249 return new_node; 250 } 251 Visit(const AnfNodePtr & node)252 void Visit(const AnfNodePtr &node) override { 253 if (x_ == nullptr) { 254 x_ = node; 255 return; 256 } 257 AnfVisitor::Visit(node); 258 } 259 Visit(const ValueNodePtr & vnode)260 void Visit(const ValueNodePtr &vnode) override { 261 auto g = GetValueNode<FuncGraphPtr>(vnode); 262 if (g1_ == nullptr) { 263 g1_ = g; 264 } else { 265 g2_ = g; 266 } 267 } 268 Reset()269 void Reset() { 270 x_ = nullptr; 271 g1_ = nullptr; 272 g2_ = nullptr; 273 } 274 275 private: 276 AnfNodePtr x_{nullptr}; 277 FuncGraphPtr g1_{nullptr}, g2_{nullptr}; 278 internal::CallOutputTransform call_output_transform_; 279 }; 280 } // namespace irpass 281 } // namespace opt 282 } // namespace mindspore 283 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ 284