1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ 19 20 #include <algorithm> 21 #include <memory> 22 #include <unordered_map> 23 #include <utility> 24 #include <vector> 25 26 #include "frontend/optimizer/irpass.h" 27 #include "frontend/optimizer/optimizer.h" 28 #include "frontend/optimizer/anf_visitor.h" 29 #include "frontend/operator/ops.h" 30 31 namespace mindspore { 32 namespace opt { 33 namespace irpass { 34 // {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} or {X, Ys, Xs} 35 class PartialEliminater : public AnfVisitor { 36 public: operator()37 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 38 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 39 return nullptr; 40 } 41 42 X_ = nullptr; 43 Xs_.clear(); 44 auto &inputs = node->cast<CNodePtr>()->inputs(); 45 Visit(inputs[0]); 46 47 if (Xs_.size() == 0) { 48 return nullptr; 49 } 50 51 // {X, Xs, Ys} 52 std::vector<AnfNodePtr> args{}; 53 const auto &xs_size = Xs_.size(); 54 // Xs_ don't have monad or Ys_ is 0. 55 if (!HasAbstractMonad(Xs_.back()) || inputs.empty()) { 56 args.push_back(X_); 57 (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args)); 58 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); 59 TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info())); 60 auto new_node = node->func_graph()->NewCNode(args); 61 return new_node; 62 } 63 // {X, Ys, Xs} if Xs has monad 64 if (!IsValueNode<FuncGraph>(X_)) { 65 MS_LOG(EXCEPTION) << "not support yet as X_ is not a funcgraph. node: " << node->DebugString(2); 66 } 67 auto fg = GetValueNode<FuncGraphPtr>(X_); 68 MS_EXCEPTION_IF_NULL(fg); 69 if (fg->func_graph_cnodes_index().size() != 1) { 70 // If a graph is used by 2 or more partial nodes at the same time, clone the graph. 71 auto new_fg = BasicClone(fg); 72 auto new_fg_node = NewValueNode(new_fg); 73 fg->manager()->Replace(X_, new_fg_node); 74 fg = new_fg; 75 X_ = new_fg_node; 76 } 77 args.push_back(X_); 78 // Ys first; 79 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); 80 (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args)); 81 TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info())); 82 auto new_node = node->func_graph()->NewCNode(args); 83 new_node->set_abstract(node->abstract()); 84 85 // reorder the formal parameter of fg. 86 AnfNodePtrList new_params; 87 std::copy(fg->parameters().cbegin() + xs_size, fg->parameters().cend(), std::back_inserter(new_params)); 88 std::copy(fg->parameters().cbegin(), fg->parameters().cbegin() + xs_size, std::back_inserter(new_params)); 89 fg->manager()->SetParameters(fg, new_params); 90 return new_node; 91 } 92 Visit(const AnfNodePtr & node)93 void Visit(const AnfNodePtr &node) override { 94 if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { 95 return; 96 } 97 98 auto &inputs = node->cast<CNodePtr>()->inputs(); 99 // {prim::kPrimPartial, X, Xs} 100 if (inputs.size() < 2) { 101 return; 102 } 103 104 X_ = inputs[1]; 105 // fill Xs 106 (void)std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(Xs_)); 107 } 108 109 private: 110 AnfNodePtr X_{nullptr}; 111 std::vector<AnfNodePtr> Xs_{}; 112 }; 113 114 class ChoicePartialEliminater : public AnfVisitor { 115 public: 116 virtual ~ChoicePartialEliminater() = default; 117 118 protected: 119 AnfNodePtrList fg_list_{}; 120 std::vector<AnfNodePtrList> args_list_{}; 121 Visit(const AnfNodePtr & node)122 void Visit(const AnfNodePtr &node) override { 123 if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { 124 if (IsValueNode<FuncGraph>(node)) { 125 fg_list_.push_back(node); 126 args_list_.push_back(AnfNodePtrList{}); 127 } 128 return; 129 } 130 131 auto &inputs = node->cast<CNodePtr>()->inputs(); 132 // {prim::kPrimPartial, G, Xs} 133 if (inputs.size() < 3) { 134 MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString(); 135 return; 136 } 137 if (IsValueNode<FuncGraph>(inputs[1])) { 138 fg_list_.push_back(inputs[1]); 139 AnfNodePtrList args; 140 (void)std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(args)); 141 args_list_.push_back(args); 142 } 143 return; 144 } 145 146 // return value: true -- continue replace; false -- return nullptr; CheckFuncGraphAndArgs()147 bool CheckFuncGraphAndArgs() { 148 // Either one should be {Partial, G, X} 149 auto has_partial_args = 150 std::any_of(args_list_.cbegin(), args_list_.cend(), [](auto &args) { return args.size() != 0; }); 151 if (!has_partial_args) { 152 return false; 153 } 154 155 // check funcgraph should be used once only. 156 for (size_t i = 0; i < fg_list_.size(); i++) { 157 auto fg_node = fg_list_[i]; 158 auto fg = GetValueNode<FuncGraphPtr>(fg_node); 159 MS_EXCEPTION_IF_NULL(fg); 160 if (fg->func_graph_cnodes_index().size() != 1) { 161 // If a graph is used by 2 or more partial nodes at the same time, clone the graph. 162 auto new_fg = BasicClone(fg); 163 auto manager = fg->manager(); 164 MS_EXCEPTION_IF_NULL(manager); 165 manager->AddFuncGraph(new_fg); 166 fg_node->cast<ValueNodePtr>()->set_value(new_fg); 167 } 168 } 169 return true; 170 } 171 172 // f(x1, x2, x3, z1, z2) 173 // g(x4, x2, z1, z2) 174 // h(x5, x2, x7, x8, z1, z2) 175 // --> anchor_fg = h 176 // h(x5, x2, x7, x8, x1, x3, x4, z1, z2) 177 // f(x5, x2, x7, x8, x1, x3, x4, z1, z2) 178 // g(x5, x2, x7, x8, x1, x3, x4, z1, z2) 179 // as z1, z2 maybe U or IO monad. UnifyParameters(const size_t & anchor_index,const AnfNodePtrList & fg_list,const std::vector<AnfNodePtrList> args_list)180 AnfNodePtrList UnifyParameters(const size_t &anchor_index, const AnfNodePtrList &fg_list, 181 const std::vector<AnfNodePtrList> args_list) { 182 std::vector<size_t> inputs_index_list[args_list.size()]; 183 size_t extra_input_counter = 0; 184 AnfNodePtrList extra_inputs; 185 const auto &anchor_args = args_list[anchor_index]; 186 size_t anchor_args_size = anchor_args.size(); 187 auto anchor_fg = GetValueNode<FuncGraphPtr>(fg_list[anchor_index]); 188 MS_EXCEPTION_IF_NULL(anchor_fg); 189 // Find the new location of the old_inputs except Zs; 190 for (size_t i = 0; i < args_list.size(); ++i) { 191 if (i == anchor_index) { 192 continue; 193 } 194 const auto &another_args = args_list[i]; 195 auto &curr_inputs_index = inputs_index_list[i]; 196 for (size_t j = 0; j < another_args.size(); ++j) { 197 size_t k; 198 for (k = 0; k < anchor_args_size; ++k) { 199 if (another_args[j] == anchor_args[k]) { 200 curr_inputs_index.push_back(k); 201 break; 202 } 203 } 204 if (k == anchor_args_size) { 205 // check if used by another func_graph; 206 for (k = 0; k < extra_input_counter; ++k) { 207 if (another_args[j] == extra_inputs[k]) { 208 curr_inputs_index.push_back(anchor_args_size + k); 209 break; 210 } 211 } 212 if (k == extra_input_counter) { 213 extra_inputs.push_back(another_args[j]); 214 curr_inputs_index.push_back(anchor_args_size + extra_input_counter); 215 extra_input_counter++; 216 } 217 } 218 } 219 } 220 221 auto manager = anchor_fg->manager(); 222 MS_EXCEPTION_IF_NULL(manager); 223 auto txn = manager->Transact(); 224 225 size_t anchor_params_size = anchor_fg->parameters().size(); 226 const auto &anchor_fg_params = anchor_fg->parameters(); 227 for (size_t i = 0; i < args_list.size(); ++i) { 228 if (i == anchor_index) { 229 continue; 230 } 231 AnfNodePtrList new_params; 232 new_params.resize(anchor_params_size + extra_input_counter); 233 234 const auto &curr_inputs_index = inputs_index_list[i]; 235 auto another_fg = GetValueNode<FuncGraphPtr>(fg_list[i]); 236 MS_EXCEPTION_IF_NULL(another_fg); 237 const auto &old_params = another_fg->parameters(); 238 const auto &old_args = args_list[i]; 239 for (size_t j = 0; j < old_args.size(); j++) { 240 new_params[curr_inputs_index[j]] = old_params[j]; 241 } 242 // Zs_ 243 for (size_t j = old_args.size(), k = 0; j < old_params.size(); ++j, ++k) { 244 new_params[anchor_args_size + extra_input_counter + k] = old_params[j]; 245 } 246 // unused inputs 247 for (size_t j = 0; j < anchor_args_size; ++j) { 248 if (new_params[j] == nullptr) { 249 TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info())); 250 ParameterPtr param = std::make_shared<Parameter>(another_fg); 251 new_params[j] = param; 252 } 253 } 254 // extra inputs used by another func_graph; 255 for (size_t j = 0; j < extra_inputs.size(); ++j) { 256 if (new_params[anchor_args_size + j] == nullptr) { 257 TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info())); 258 ParameterPtr param = std::make_shared<Parameter>(another_fg); 259 new_params[anchor_args_size + j] = param; 260 } 261 } 262 // set the parameter for another_fg and replace it's parameters; 263 txn.SetParameters(another_fg, new_params); 264 } 265 // Reorder Zs_ and add extra parameters for anchor_fg; 266 // add extra parameter for anchor_fg; 267 AnfNodePtrList new_params; 268 new_params.reserve(anchor_params_size + extra_input_counter); 269 // reuse parameters for anchor_args; 270 std::copy(anchor_fg_params.cbegin(), anchor_fg_params.cbegin() + anchor_args_size, std::back_inserter(new_params)); 271 // Extra parameters; 272 for (size_t i = 0; i < extra_inputs.size(); ++i) { 273 TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info())); 274 ParameterPtr param = std::make_shared<Parameter>(anchor_fg); 275 new_params.push_back(param); 276 } 277 // Reorder Zs_ to last; 278 for (size_t i = anchor_args_size; i < anchor_params_size; ++i) { 279 new_params.push_back(anchor_fg_params[i]); 280 } 281 txn.SetParameters(anchor_fg, new_params); 282 txn.Commit(); 283 284 return extra_inputs; 285 } 286 }; 287 288 // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}, Zs} -> 289 // {{prim::kPrimSwitch, cond, G1, G2}, Xs Union Ys Union Zs} 290 // {{prim::kPrimSwitch, cond, {G1}, {prim::kPrimPartial, G2, Ys}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Ys Union 291 // Zs} 292 // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {G2}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Xs Union 293 // Zs} 294 class SwitchPartialEliminater : public ChoicePartialEliminater { 295 public: operator()296 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 297 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 298 return nullptr; 299 } 300 auto cnode = node->cast<CNodePtr>(); 301 if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) { 302 return nullptr; 303 } 304 auto input0_cnode = cnode->input(0)->cast<CNodePtr>(); 305 if (input0_cnode->size() != 4) { 306 return nullptr; 307 } 308 309 fg_list_.clear(); 310 args_list_.clear(); 311 auto &maybe_partial_1 = input0_cnode->input(2); 312 Visit(maybe_partial_1); 313 auto &maybe_partial_2 = input0_cnode->input(3); 314 Visit(maybe_partial_2); 315 316 // Either one should be {Partial, G, X} 317 if (fg_list_.size() != 2 && args_list_.size() != 2) { 318 return nullptr; 319 } 320 // Should not continue; 321 if (!CheckFuncGraphAndArgs()) { 322 return nullptr; 323 } 324 325 if (args_list_[0] == args_list_[1]) { 326 auto new_node = 327 BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[0], AnfNodePtrList{}); 328 return new_node; 329 } else { 330 // find partial funcgraph with the longest args as anchor; 331 size_t max_args_pos = 0; 332 if (args_list_[0].size() > args_list_[1].size()) { 333 max_args_pos = 0; 334 } else { 335 max_args_pos = 1; 336 } 337 338 auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_); 339 auto new_node = 340 BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[max_args_pos], extra_inputs); 341 return new_node; 342 } 343 } 344 345 private: BuildNewSwitchNode(const CNodePtr & old_cnode,const CNodePtr input0_cnode,const AnfNodePtr & G1,const AnfNodePtr & G2,const AnfNodePtrList & partial_args,const AnfNodePtrList & extra_args)346 AnfNodePtr BuildNewSwitchNode(const CNodePtr &old_cnode, const CNodePtr input0_cnode, const AnfNodePtr &G1, 347 const AnfNodePtr &G2, const AnfNodePtrList &partial_args, 348 const AnfNodePtrList &extra_args) { 349 TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info())); 350 // {Switch, cond, G1, G2} 351 auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2}); 352 AnfNodePtrList args{switch_cnode}; 353 (void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args)); 354 (void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args)); 355 // Zs 356 if (old_cnode->size() >= 2) { 357 (void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args)); 358 } 359 TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info())); 360 auto new_node = old_cnode->func_graph()->NewCNode(args); 361 new_node->set_abstract(old_cnode->abstract()); 362 return new_node; 363 } 364 }; 365 366 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}}, Zs} -> 367 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}, Xs Union Ys Union Zs} 368 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{G1}, {prim::kPrimPartial, G2, Ys}}}, Zs} -> 369 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Ys Union Zs} 370 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {G2}}{}, Zs} -> 371 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Xs Union Zs} 372 class SwitchLayerPartialEliminater : public ChoicePartialEliminater { 373 public: operator()374 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 375 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 376 return nullptr; 377 } 378 auto cnode = node->cast<CNodePtr>(); 379 // {SwitchLayer{}, Zs} 380 if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitchLayer)) { 381 return nullptr; 382 } 383 auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>(); 384 // {SwitchLayer, cond, MakeTuple{}} 385 if (switch_layer_cnode->size() != 3) { 386 return nullptr; 387 } 388 if (!IsPrimitiveCNode(switch_layer_cnode->input(2), prim::kPrimMakeTuple)) { 389 return nullptr; 390 } 391 auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>(); 392 if (make_tuple_cnode->size() < 2) { 393 return nullptr; 394 } 395 396 fg_list_.clear(); 397 args_list_.clear(); 398 // Build funcgraph list and args list; 399 for (size_t i = 1; i < make_tuple_cnode->size(); ++i) { 400 Visit(make_tuple_cnode->input(i)); 401 } 402 403 if (!CheckFuncGraphAndArgs()) { 404 return nullptr; 405 } 406 // All have the same args; 407 auto args_equal = 408 std::all_of(args_list_.cbegin() + 1, args_list_.cend(), [this](auto &args) { return args == args_list_[0]; }); 409 if (args_equal) { 410 auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[0], AnfNodePtrList{}); 411 return new_node; 412 } else { 413 // find partial funcgraph with the longest args as anchor; 414 size_t max_args_pos = 0, max_args_len = 0; 415 for (size_t i = 0; i < args_list_.size(); ++i) { 416 if (max_args_len < args_list_[i].size()) { 417 max_args_len = args_list_[i].size(); 418 max_args_pos = i; 419 } 420 } 421 auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_); 422 auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[max_args_pos], extra_inputs); 423 return new_node; 424 } 425 } 426 427 private: BuildNewSwitchLayerNode(const CNodePtr & old_cnode,const CNodePtr switch_layer_cnode,const AnfNodePtrList & anchor_partial_args,const AnfNodePtrList & extra_args)428 AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode, 429 const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) { 430 auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>(); 431 AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)}; 432 make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end()); 433 TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info())); 434 // {MakeTuple, G1, G2, ...} 435 auto new_make_tuple_cnode = old_cnode->func_graph()->NewCNode(make_tuple_args); 436 437 TraceGuard guard2(std::make_shared<TraceCopy>(switch_layer_cnode->debug_info())); 438 // {SwitchLayer, cond, MakeTuple{}} 439 auto new_switch_layer_cnode = old_cnode->func_graph()->NewCNode( 440 {switch_layer_cnode->input(0), switch_layer_cnode->input(1), new_make_tuple_cnode}); 441 AnfNodePtrList args{new_switch_layer_cnode}; 442 (void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args)); 443 (void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args)); 444 // Zs 445 if (old_cnode->size() >= 2) { 446 (void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args)); 447 } 448 TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info())); 449 auto new_node = old_cnode->func_graph()->NewCNode(args); 450 new_node->set_abstract(old_cnode->abstract()); 451 return new_node; 452 } 453 }; 454 } // namespace irpass 455 } // namespace opt 456 } // namespace mindspore 457 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ 458