1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ 19 20 #include <algorithm> 21 #include <memory> 22 #include <utility> 23 #include <vector> 24 #include <set> 25 26 #include "utils/hash_map.h" 27 #include "mindspore/core/ops/sequence_ops.h" 28 #include "mindspore/core/ops/framework_ops.h" 29 #include "ir/func_graph_cloner.h" 30 #include "frontend/optimizer/irpass.h" 31 #include "frontend/optimizer/optimizer.h" 32 #include "frontend/optimizer/anf_visitor.h" 33 #include "frontend/operator/ops.h" 34 35 namespace mindspore { 36 namespace opt { 37 namespace irpass { 38 const auto kMinInputSizeOfCallWithArgs = 2; 39 // {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} or {X, Ys, Xs} 40 class PartialEliminater : public AnfVisitor { 41 public: operator()42 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 43 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 44 return nullptr; 45 } 46 X_ = nullptr; 47 Xs_.clear(); 48 auto &inputs = node->cast<CNodePtr>()->inputs(); 49 Visit(inputs[0]); 50 51 if (Xs_.size() == 0) { 52 return nullptr; 53 } 54 55 // {X, Xs, Ys} 56 std::vector<AnfNodePtr> args{}; 57 const auto xs_size = Xs_.size(); 58 // Xs_ don't have monad or Ys_ is 0. 59 if (!HasAbstractMonad(Xs_.back()) || inputs.empty()) { 60 args.push_back(X_); 61 (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args)); 62 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); 63 TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info())); 64 auto new_node = node->func_graph()->NewCNode(args); 65 new_node->set_abstract(node->abstract()); 66 return new_node; 67 } 68 // {X, Ys, Xs} if Xs has monad 69 if (!IsValueNode<FuncGraph>(X_)) { 70 constexpr auto recursive_level = 2; 71 MS_LOG(INTERNAL_EXCEPTION) << "Not support yet as X_ is not a funcgraph. node: " 72 << node->DebugString(recursive_level); 73 } 74 auto fg = GetValueNode<FuncGraphPtr>(X_); 75 MS_EXCEPTION_IF_NULL(fg); 76 if (fg->func_graph_cnodes_index().size() != 1) { 77 // If a graph is used by 2 or more partial nodes at the same time, clone the graph. 78 auto new_fg = BasicClone(fg); 79 auto new_fg_node = NewValueNode(new_fg); 80 fg->manager()->Replace(X_, new_fg_node); 81 fg = new_fg; 82 X_ = new_fg_node; 83 } 84 args.push_back(X_); 85 // Ys first; 86 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); 87 (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args)); 88 TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info())); 89 auto new_node = node->func_graph()->NewCNode(args); 90 new_node->set_abstract(node->abstract()); 91 92 // reorder the formal parameter of fg. 93 AnfNodePtrList new_params; 94 (void)std::copy(fg->parameters().cbegin() + SizeToLong(xs_size), fg->parameters().cend(), 95 std::back_inserter(new_params)); 96 (void)std::copy(fg->parameters().cbegin(), fg->parameters().cbegin() + SizeToLong(xs_size), 97 std::back_inserter(new_params)); 98 fg->manager()->SetParameters(fg, new_params); 99 return new_node; 100 } 101 Visit(const AnfNodePtr & node)102 void Visit(const AnfNodePtr &node) override { 103 if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { 104 return; 105 } 106 107 auto &inputs = node->cast<CNodePtr>()->inputs(); 108 // {prim::kPrimPartial, X, Xs} 109 if (inputs.size() <= 1) { 110 return; 111 } 112 113 X_ = inputs[1]; 114 // fill Xs 115 // {Partial, Function, Args....} 116 constexpr auto args_index = 2; 117 (void)std::copy(inputs.begin() + args_index, inputs.end(), std::back_inserter(Xs_)); 118 } 119 120 private: 121 AnfNodePtr X_{nullptr}; 122 std::vector<AnfNodePtr> Xs_{}; 123 }; 124 125 class ChoicePartialEliminater : public AnfVisitor { 126 public: 127 virtual ~ChoicePartialEliminater() = default; 128 Visit(const AnfNodePtr & node)129 void Visit(const AnfNodePtr &node) override { 130 if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { 131 if (IsValueNode<FuncGraph>(node)) { 132 fg_list_.push_back(node); 133 (void)args_list_.emplace_back(AnfNodePtrList{}); 134 } 135 return; 136 } 137 138 auto &inputs = node->cast<CNodePtr>()->inputs(); 139 // {prim::kPrimPartial, G} 140 if (inputs.size() < kPartialMinInputSize) { 141 MS_LOG(INTERNAL_EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString(); 142 } 143 if (IsValueNode<FuncGraph>(inputs[1])) { 144 fg_list_.push_back(inputs[1]); 145 AnfNodePtrList args; 146 // {Partial, Function, Args....} 147 constexpr auto args_index = 2; 148 (void)std::copy(inputs.begin() + args_index, inputs.end(), std::back_inserter(args)); 149 args_list_.push_back(args); 150 } 151 return; 152 } 153 154 protected: 155 AnfNodePtrList fg_list_{}; 156 std::vector<AnfNodePtrList> args_list_{}; 157 158 // return value: true -- continue replace; false -- return nullptr; CheckFuncGraphAndArgs()159 bool CheckFuncGraphAndArgs() { 160 // Either one should be {Partial, G, X} 161 auto has_partial_args = 162 std::any_of(args_list_.cbegin(), args_list_.cend(), [](auto &args) { return args.size() != 0; }); 163 if (!has_partial_args) { 164 return false; 165 } 166 167 // check funcgraph should be used once only. 168 for (size_t i = 0; i < fg_list_.size(); i++) { 169 auto fg_node = fg_list_[i]; 170 auto fg = GetValueNode<FuncGraphPtr>(fg_node); 171 MS_EXCEPTION_IF_NULL(fg); 172 if (fg->func_graph_cnodes_index().size() != 1) { 173 // If a graph is used by 2 or more partial nodes at the same time, clone the graph. 174 // BasicClone should be replaced by TransformableClone to avoid recursive. 175 auto new_fg = TransformableClone(fg); 176 auto manager = fg->manager(); 177 MS_EXCEPTION_IF_NULL(manager); 178 manager->AddFuncGraph(new_fg); 179 fg_list_[i] = NewValueNode(new_fg); 180 } 181 } 182 return true; 183 } 184 185 // Merge partial's args and call's args 186 // branch1: {{primPartial, Xs}, Zs} -> {{primPartial, Xs, Zs}} 187 // branch2: {{primPartial, Ys}, Zs} -> {{primPartial, Ys, Zs}} MergeArgs(const CNodePtr & call_node)188 void MergeArgs(const CNodePtr &call_node) { 189 for (auto &args : args_list_) { 190 (void)args.insert(args.end(), call_node->inputs().begin() + 1, call_node->inputs().end()); 191 } 192 } 193 194 // f(x1, x2, x3, z1, z2 ,monad1) 195 // g(x4, x2, z1, z2, monad2) 196 // h(x5, x2, x7, x8, z1, z2, monad3) 197 // --> union_args = (x1, x2, x3, z1, z2, x4, x5, x7 ,x8, monad1, monad2, monad3) 198 // h(x1, x2, x3, z1, z2, x4, x5, x7 ,x8, monad1, monad2, monad3) 199 // f(x1, x2, x3, z1, z2, x4, x5, x7 ,x8, monad1, monad2, monad3) 200 // g(x1, x2, x3, z1, z2, x4, x5, x7 ,x8, monad1, monad2, monad3) UnifyParameters(const AnfNodePtrList & fg_list,const std::vector<AnfNodePtrList> args_list)201 static AnfNodePtrList UnifyParameters(const AnfNodePtrList &fg_list, const std::vector<AnfNodePtrList> args_list) { 202 if (fg_list.empty()) { 203 return {}; 204 } 205 auto first_func_graph = GetValueNode<FuncGraphPtr>(fg_list[0]); 206 MS_EXCEPTION_IF_NULL(first_func_graph); 207 const auto manager = first_func_graph->manager(); 208 MS_EXCEPTION_IF_NULL(manager); 209 auto txn = manager->Transact(); 210 // Get all new args, new args is the union set of old args. 211 auto new_args = ArgsUnion(args_list); 212 auto old_args_index_map = GenOldArgsIndexes(fg_list, args_list); 213 for (size_t branch_index = 0; branch_index < fg_list.size(); ++branch_index) { 214 auto func_graph = GetValueNode<FuncGraphPtr>(fg_list[branch_index]); 215 MS_EXCEPTION_IF_NULL(func_graph); 216 auto new_parameters = GetFuncGraphNewParameters(func_graph, new_args, old_args_index_map); 217 txn.SetParameters(func_graph, new_parameters); 218 } 219 txn.Commit(); 220 return new_args; 221 } 222 223 private: ArgsUnion(const std::vector<AnfNodePtrList> args_list)224 static std::vector<AnfNodePtr> ArgsUnion(const std::vector<AnfNodePtrList> args_list) { 225 std::vector<AnfNodePtr> no_monad_args; 226 std::vector<AnfNodePtr> monad_args; 227 for (const auto &args : args_list) { 228 for (const auto &arg : args) { 229 if (HasAbstractMonad(arg)) { 230 if (count(monad_args.begin(), monad_args.end(), arg) == 0) { 231 monad_args.push_back(arg); 232 } 233 continue; 234 } 235 if (count(no_monad_args.begin(), no_monad_args.end(), arg) == 0) { 236 no_monad_args.push_back(arg); 237 } 238 } 239 } 240 // Keep monad args after no monad args. 241 (void)no_monad_args.insert(no_monad_args.end(), monad_args.begin(), monad_args.end()); 242 return no_monad_args; 243 } 244 GenOldArgsIndexes(const AnfNodePtrList & fg_list,const std::vector<AnfNodePtrList> & args_list)245 static HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> GenOldArgsIndexes( 246 const AnfNodePtrList &fg_list, const std::vector<AnfNodePtrList> &args_list) { 247 HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> old_args_indexes; 248 for (size_t i = 0; i < fg_list.size(); ++i) { 249 const auto func_graph = GetValueNode<FuncGraphPtr>(fg_list[i]); 250 MS_EXCEPTION_IF_NULL(func_graph); 251 const auto &args = args_list[i]; 252 HashMap<AnfNodePtr, size_t> args_indexes; 253 size_t arg_index = 0; 254 for (const auto &arg : args) { 255 (void)args_indexes.emplace(arg, arg_index++); 256 } 257 old_args_indexes[func_graph] = args_indexes; 258 } 259 return old_args_indexes; 260 } 261 GetParameterByArg(const HashMap<FuncGraphPtr,HashMap<AnfNodePtr,size_t>> & all_old_args_index_map,const AnfNodePtr & arg)262 static AnfNodePtr GetParameterByArg(const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map, 263 const AnfNodePtr &arg) { 264 MS_LOG(DEBUG) << "Get parameter by arg:" << arg->DebugString(); 265 for (const auto &[fg, old_args_index] : all_old_args_index_map) { 266 auto it = old_args_index.find(arg); 267 if (it == old_args_index.end()) { 268 continue; 269 } 270 size_t arg_index = it->second; 271 if (arg_index >= fg->parameters().size()) { 272 MS_LOG(INTERNAL_EXCEPTION) << "Index:" << arg_index << " out of range:" << fg->parameters().size(); 273 } 274 return fg->parameters()[arg_index]; 275 } 276 MS_LOG(INTERNAL_EXCEPTION) << "Can't find parameter of arg:" << arg->DebugString(); 277 } 278 GetFuncGraphNewParameters(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & new_args,const HashMap<FuncGraphPtr,HashMap<AnfNodePtr,size_t>> & all_old_args_index_map)279 static std::vector<AnfNodePtr> GetFuncGraphNewParameters( 280 const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &new_args, 281 const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map) { 282 MS_EXCEPTION_IF_NULL(func_graph); 283 const auto &old_parameters = func_graph->parameters(); 284 std::vector<AnfNodePtr> new_parameters(new_args.size()); 285 const auto &old_args_index_map = all_old_args_index_map.find(func_graph)->second; 286 for (size_t new_arg_index = 0; new_arg_index < new_args.size(); ++new_arg_index) { 287 const auto &new_arg = new_args[new_arg_index]; 288 auto arg_old_index_it = old_args_index_map.find(new_arg); 289 // The new_arg is the arg of current func graph. 290 if (arg_old_index_it != old_args_index_map.end()) { 291 auto arg_old_index = arg_old_index_it->second; 292 new_parameters[new_arg_index] = old_parameters[arg_old_index]; 293 MS_LOG(DEBUG) << "Find exist parameter:" << new_parameters[new_arg_index]->DebugString() 294 << ", arg_old_index:" << arg_old_index; 295 continue; 296 } 297 // The new_arg is the arg of other func graph. 298 const auto other_fg_parameter = GetParameterByArg(all_old_args_index_map, new_arg); 299 MS_LOG(DEBUG) << "Get other fg's parameter:" << other_fg_parameter->DebugString(); 300 TraceGuard guard(std::make_shared<TraceCopy>(other_fg_parameter->debug_info())); 301 ParameterPtr param = std::make_shared<Parameter>(func_graph); 302 param->set_abstract(other_fg_parameter->abstract()); 303 new_parameters[new_arg_index] = param; 304 } 305 return new_parameters; 306 } 307 }; 308 309 // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}, Zs} -> 310 // {{prim::kPrimSwitch, cond, G1, G2}, Xs Union Ys Union Zs} 311 // {{prim::kPrimSwitch, cond, {G1}, {prim::kPrimPartial, G2, Ys}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Ys Union 312 // Zs} 313 // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {G2}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Xs Union 314 // Zs} 315 class SwitchPartialEliminater : public ChoicePartialEliminater { 316 public: operator()317 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 318 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 319 return nullptr; 320 } 321 auto switch_call = node->cast<CNodePtr>(); 322 if (!IsPrimitiveCNode(switch_call->input(0), prim::kPrimSwitch)) { 323 return nullptr; 324 } 325 auto switch_node = switch_call->input(0)->cast<CNodePtr>(); 326 if (switch_node->size() != kSwitchInputSize) { 327 return nullptr; 328 } 329 fg_list_.clear(); 330 args_list_.clear(); 331 const auto maybe_partial_1 = switch_node->input(kSwitchTrueBranchIndex); 332 Visit(maybe_partial_1); 333 const auto maybe_partial_2 = switch_node->input(kSwitchFalseBranchIndex); 334 Visit(maybe_partial_2); 335 336 // Either one should be {Partial, G, X} 337 if (fg_list_.size() != kSwitchBranchesNum && args_list_.size() != kSwitchBranchesNum) { 338 return nullptr; 339 } 340 if (!CheckFuncGraphAndArgs()) { 341 return nullptr; 342 } 343 MergeArgs(switch_call); 344 if (args_list_[0] == args_list_[1]) { 345 return BuildNewSwitchNode(switch_call, args_list_[0]); 346 } else { 347 const auto new_args = UnifyParameters(fg_list_, args_list_); 348 return BuildNewSwitchNode(switch_call, new_args); 349 } 350 } 351 352 private: BuildNewSwitchNode(const CNodePtr & switch_call,const std::vector<AnfNodePtr> & new_args)353 AnfNodePtr BuildNewSwitchNode(const CNodePtr &switch_call, const std::vector<AnfNodePtr> &new_args) { 354 auto fg = switch_call->func_graph(); 355 MS_EXCEPTION_IF_NULL(fg); 356 const auto input0 = switch_call->input(0); 357 MS_EXCEPTION_IF_NULL(input0); 358 const auto switch_node = input0->cast<CNodePtr>(); 359 TraceGuard guard1(std::make_shared<TraceCopy>(switch_node->debug_info())); 360 // {Switch, cond, G1, G2} 361 std::vector<AnfNodePtr> switch_inputs = {switch_node->input(0), switch_node->input(1)}; 362 (void)switch_inputs.insert(switch_inputs.end(), fg_list_.begin(), fg_list_.end()); 363 const auto new_switch_cnode = fg->NewCNode(std::move(switch_inputs)); 364 new_switch_cnode->set_abstract(switch_node->abstract()); 365 // Create switch call. 366 TraceGuard guard2(std::make_shared<TraceCopy>(switch_call->debug_info())); 367 AnfNodePtrList switch_call_inputs{new_switch_cnode}; 368 (void)switch_call_inputs.insert(switch_call_inputs.end(), new_args.begin(), new_args.end()); 369 const auto new_call_node = fg->NewCNode(std::move(switch_call_inputs)); 370 new_call_node->set_abstract(switch_call->abstract()); 371 return new_call_node; 372 } 373 }; 374 375 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}}, Zs} -> 376 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}, Xs Union Ys Union Zs} 377 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{G1}, {prim::kPrimPartial, G2, Ys}}}, Zs} -> 378 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Ys Union Zs} 379 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {G2}}{}, Zs} -> 380 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Xs Union Zs} 381 class SwitchLayerPartialEliminater : public ChoicePartialEliminater { 382 public: operator()383 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 384 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 385 return nullptr; 386 } 387 auto switch_layer_call = node->cast<CNodePtr>(); 388 MS_EXCEPTION_IF_NULL(switch_layer_call); 389 // {SwitchLayer{}, Zs} 390 if (!IsPrimitiveCNode(switch_layer_call->input(0), prim::kPrimSwitchLayer)) { 391 return nullptr; 392 } 393 auto switch_layer_cnode = switch_layer_call->input(0)->cast<CNodePtr>(); 394 // {SwitchLayer, cond, MakeTuple{}} 395 if (switch_layer_cnode->size() != kSwitchLayerInputSize) { 396 return nullptr; 397 } 398 if (!IsPrimitiveCNode(switch_layer_cnode->input(kSwitchLayerBranchesIndex), prim::kPrimMakeTuple)) { 399 return nullptr; 400 } 401 auto make_tuple_cnode = switch_layer_cnode->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>(); 402 if (make_tuple_cnode->size() <= 1) { 403 return nullptr; 404 } 405 406 fg_list_.clear(); 407 args_list_.clear(); 408 // Build funcgraph list and args list; 409 for (size_t i = 1; i < make_tuple_cnode->size(); ++i) { 410 Visit(make_tuple_cnode->input(i)); 411 } 412 413 if (!CheckFuncGraphAndArgs()) { 414 return nullptr; 415 } 416 MergeArgs(switch_layer_call); 417 // All have the same args; 418 auto args_equal = 419 std::all_of(args_list_.cbegin() + 1, args_list_.cend(), [this](auto &args) { return args == args_list_[0]; }); 420 if (args_equal) { 421 return BuildNewSwitchLayerNode(switch_layer_call, args_list_[0]); 422 } else { 423 const auto new_args = UnifyParameters(fg_list_, args_list_); 424 return BuildNewSwitchLayerNode(switch_layer_call, new_args); 425 } 426 } 427 428 private: BuildNewSwitchLayerNode(const CNodePtr & switch_layer_call_node,const AnfNodePtrList & new_args)429 AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &switch_layer_call_node, const AnfNodePtrList &new_args) { 430 const auto switch_layer = switch_layer_call_node->input(0)->cast<CNodePtr>(); 431 MS_EXCEPTION_IF_NULL(switch_layer); 432 auto make_tuple_cnode = switch_layer->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>(); 433 MS_EXCEPTION_IF_NULL(make_tuple_cnode); 434 // {primMakeTuple, G1, G2, ...} 435 AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)}; 436 (void)make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end()); 437 TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info())); 438 auto new_make_tuple_cnode = make_tuple_cnode->func_graph()->NewCNode(std::move(make_tuple_args)); 439 // {primSwitchLayer, cond, MakeTuple{}} 440 TraceGuard guard2(std::make_shared<TraceCopy>(switch_layer->debug_info())); 441 auto new_switch_layer = 442 switch_layer->func_graph()->NewCNode({switch_layer->input(0), switch_layer->input(1), new_make_tuple_cnode}); 443 // Create new switch_layer call node. 444 TraceGuard guard3(std::make_shared<TraceCopy>(switch_layer_call_node->debug_info())); 445 AnfNodePtrList switch_layer_call_inputs{new_switch_layer}; 446 (void)switch_layer_call_inputs.insert(switch_layer_call_inputs.cend(), new_args.cbegin(), new_args.cend()); 447 auto new_node = switch_layer_call_node->func_graph()->NewCNode(std::move(switch_layer_call_inputs)); 448 new_node->set_abstract(switch_layer_call_node->abstract()); 449 return new_node; 450 } 451 }; 452 } // namespace irpass 453 } // namespace opt 454 } // namespace mindspore 455 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ 456