1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H 19 20 #include <unordered_map> 21 #include <string> 22 #include <vector> 23 #include <algorithm> 24 #include <memory> 25 26 #include "frontend/optimizer/optimizer.h" 27 28 namespace mindspore { 29 namespace opt { 30 bool CNodeHasTupleInput(const CNodePtr &cnode); 31 bool FuncGraphHasTupleInput(const FuncGraphPtr &fg); 32 std::vector<AnfNodePtr> TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node, 33 const abstract::AbstractTuplePtr &abs); 34 AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode); 35 AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode); 36 AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode); 37 38 class GraphTupleParamTransform { 39 public: GraphTupleParamTransform()40 GraphTupleParamTransform() : cache_() {} ~GraphTupleParamTransform()41 ~GraphTupleParamTransform() { cache_.clear(); } operator()42 FuncGraphPtr operator()(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { 43 if (cache_.find(fg) != cache_.end()) { 44 return cache_[fg]; 45 } 46 auto new_fg = TransformGraphParam(fg, mng); 47 cache_[fg] = new_fg; 48 return new_fg; 49 } 50 GenerateTupleParams(const abstract::AbstractTuplePtr & tuple_abs,const FuncGraphPtr & fg,std::vector<AnfNodePtr> * params)51 AnfNodePtr GenerateTupleParams(const abstract::AbstractTuplePtr &tuple_abs, const FuncGraphPtr &fg, 52 std::vector<AnfNodePtr> *params) { 53 std::vector<AnfNodePtr> inputs; 54 inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); 55 auto &elements = tuple_abs->elements(); 56 for (auto &item : elements) { 57 if (item->isa<abstract::AbstractTuple>()) { 58 inputs.push_back(GenerateTupleParams(item->cast<abstract::AbstractTuplePtr>(), fg, params)); 59 } else { 60 auto p = std::make_shared<Parameter>(fg); 61 p->set_abstract(item); 62 params->push_back(p); 63 inputs.push_back(params->back()); 64 } 65 } 66 auto node = fg->NewCNode(inputs); 67 node->set_abstract(tuple_abs); 68 return node; 69 } 70 TransformGraphParam(const FuncGraphPtr & fg,const FuncGraphManagerPtr & mng)71 FuncGraphPtr TransformGraphParam(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { 72 Cloner cloner({fg}, false, false, false, std::make_shared<TraceCopy>(), std::make_shared<TraceCopy>()); 73 auto new_fg = cloner[fg]; 74 auto ¶ms = new_fg->parameters(); 75 std::vector<AnfNodePtr> new_params; 76 std::unordered_map<AnfNodePtr, AnfNodePtr> repl; 77 for (auto ¶m : params) { 78 auto abs = param->abstract(); 79 if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) { 80 auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>(); 81 std::vector<AnfNodePtr> tuple_params; 82 repl.emplace(param, GenerateTupleParams(tuple_abs, new_fg, &tuple_params)); 83 std::transform(tuple_params.begin(), tuple_params.end(), std::back_inserter(new_params), 84 [](AnfNodePtr p) { return p; }); 85 } else { 86 new_params.push_back(param); 87 } 88 } 89 auto tmp_mng = mindspore::Manage(new_fg, false); 90 auto tr = tmp_mng->Transact(); 91 for (auto &item : repl) { 92 bool ret = tr.Replace(item.first, item.second); 93 if (ret == false) { 94 MS_LOG(ERROR) << "replace failed" << item.first->DebugString() << " with__" << item.second->DebugString(2); 95 } 96 } 97 tr.SetParameters(new_fg, new_params); 98 tr.Commit(); 99 mng->AddFuncGraph(new_fg); 100 return new_fg; 101 } 102 103 private: 104 std::unordered_map<FuncGraphPtr, FuncGraphPtr> cache_; 105 }; 106 } // namespace opt 107 } // namespace mindspore 108 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H 109