1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H 19 20 #include <string> 21 #include <vector> 22 #include <algorithm> 23 #include <memory> 24 25 #include "utils/hash_map.h" 26 #include "mindspore/core/ops/sequence_ops.h" 27 #include "frontend/optimizer/optimizer.h" 28 #include "ir/func_graph_cloner.h" 29 30 namespace mindspore { 31 namespace opt { 32 bool FuncGraphHasSequenceInput(const FuncGraphPtr &fg); 33 bool FuncGraphHasConstantSequenceInput(const FuncGraphPtr &fg); 34 bool IsSequenceExpandable(const AbstractBasePtr &abs); 35 std::vector<AnfNodePtr> TransformSequenceArgument(const FuncGraphPtr &fg, const AnfNodePtr &node, 36 const abstract::AbstractSequencePtr &abs); 37 bool ContainSparseTensor(const abstract::AbstractBasePtr &abs); 38 bool ParamContainSparseTensor(const AnfNodePtr ¶m); 39 40 class GraphSequenceParamTransform { 41 public: GraphSequenceParamTransform()42 GraphSequenceParamTransform() : cache_() {} ~GraphSequenceParamTransform()43 ~GraphSequenceParamTransform() { cache_.clear(); } operator()44 FuncGraphPtr operator()(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { 45 if (cache_.find(fg) != cache_.end()) { 46 return cache_[fg]; 47 } 48 auto new_fg = TransformGraphParam(fg, mng); 49 cache_[fg] = new_fg; 50 return new_fg; 51 } 52 GenerateSequenceParams(const abstract::AbstractSequencePtr & seq_abs,const FuncGraphPtr & fg,std::vector<AnfNodePtr> * params)53 AnfNodePtr GenerateSequenceParams(const abstract::AbstractSequencePtr &seq_abs, const FuncGraphPtr &fg, 54 std::vector<AnfNodePtr> *params) { 55 std::vector<AnfNodePtr> inputs; 56 auto prim_sequence = seq_abs->isa<abstract::AbstractTuple>() ? prim::kPrimMakeTuple : prim::kPrimMakeList; 57 inputs.push_back(NewValueNode(prim_sequence)); 58 auto &elements = seq_abs->elements(); 59 for (auto &item : elements) { 60 if (item->isa<abstract::AbstractSequence>()) { 61 inputs.push_back(GenerateSequenceParams(item->cast<abstract::AbstractSequencePtr>(), fg, params)); 62 } else { 63 auto p = std::make_shared<Parameter>(fg); 64 p->set_abstract(item); 65 params->push_back(p); 66 inputs.push_back(params->back()); 67 } 68 } 69 auto node = fg->NewCNode(inputs); 70 node->set_abstract(seq_abs); 71 return node; 72 } 73 TransformGraphParam(const FuncGraphPtr & fg,const FuncGraphManagerPtr & mng)74 FuncGraphPtr TransformGraphParam(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) { 75 Cloner cloner({fg}, false, false, false, std::make_shared<TraceCopy>(), std::make_shared<TraceCopy>()); 76 auto new_fg = cloner[fg]; 77 auto ¶ms = new_fg->parameters(); 78 std::vector<AnfNodePtr> new_params; 79 mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl; 80 for (auto ¶m : params) { 81 auto abs = param->abstract(); 82 if (IsSequenceExpandable(abs)) { 83 auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>(); 84 std::vector<AnfNodePtr> sequence_params; 85 (void)repl.emplace(param, GenerateSequenceParams(sequence_abs, new_fg, &sequence_params)); 86 std::transform(sequence_params.begin(), sequence_params.end(), std::back_inserter(new_params), 87 [](AnfNodePtr p) { return p; }); 88 } else { 89 new_params.push_back(param); 90 } 91 } 92 auto tmp_mng = mindspore::Manage(new_fg, false); 93 auto tr = tmp_mng->Transact(); 94 for (auto &item : repl) { 95 bool ret = tr.Replace(item.first, item.second); 96 if (ret == false) { 97 MS_LOG(ERROR) << "replace failed" << item.first->DebugString() << " with__" 98 << item.second->DebugString(SizeToInt(kIndex2)); 99 } 100 } 101 tr.SetParameters(new_fg, new_params); 102 tr.Commit(); 103 mng->AddFuncGraph(new_fg); 104 return new_fg; 105 } 106 107 private: 108 mindspore::HashMap<FuncGraphPtr, FuncGraphPtr> cache_; 109 }; 110 } // namespace opt 111 } // namespace mindspore 112 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H 113