• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
19 
20 #include <string>
21 #include <vector>
22 #include <algorithm>
23 #include <memory>
24 
25 #include "utils/hash_map.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "frontend/optimizer/optimizer.h"
28 #include "ir/func_graph_cloner.h"
29 
30 namespace mindspore {
31 namespace opt {
32 bool FuncGraphHasSequenceInput(const FuncGraphPtr &fg);
33 bool FuncGraphHasConstantSequenceInput(const FuncGraphPtr &fg);
34 bool IsSequenceExpandable(const AbstractBasePtr &abs);
35 std::vector<AnfNodePtr> TransformSequenceArgument(const FuncGraphPtr &fg, const AnfNodePtr &node,
36                                                   const abstract::AbstractSequencePtr &abs);
37 bool ContainSparseTensor(const abstract::AbstractBasePtr &abs);
38 bool ParamContainSparseTensor(const AnfNodePtr &param);
39 
40 class GraphSequenceParamTransform {
41  public:
GraphSequenceParamTransform()42   GraphSequenceParamTransform() : cache_() {}
~GraphSequenceParamTransform()43   ~GraphSequenceParamTransform() { cache_.clear(); }
operator()44   FuncGraphPtr operator()(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
45     if (cache_.find(fg) != cache_.end()) {
46       return cache_[fg];
47     }
48     auto new_fg = TransformGraphParam(fg, mng);
49     cache_[fg] = new_fg;
50     return new_fg;
51   }
52 
GenerateSequenceParams(const abstract::AbstractSequencePtr & seq_abs,const FuncGraphPtr & fg,std::vector<AnfNodePtr> * params)53   AnfNodePtr GenerateSequenceParams(const abstract::AbstractSequencePtr &seq_abs, const FuncGraphPtr &fg,
54                                     std::vector<AnfNodePtr> *params) {
55     std::vector<AnfNodePtr> inputs;
56     auto prim_sequence = seq_abs->isa<abstract::AbstractTuple>() ? prim::kPrimMakeTuple : prim::kPrimMakeList;
57     inputs.push_back(NewValueNode(prim_sequence));
58     auto &elements = seq_abs->elements();
59     for (auto &item : elements) {
60       if (item->isa<abstract::AbstractSequence>()) {
61         inputs.push_back(GenerateSequenceParams(item->cast<abstract::AbstractSequencePtr>(), fg, params));
62       } else {
63         auto p = std::make_shared<Parameter>(fg);
64         p->set_abstract(item);
65         params->push_back(p);
66         inputs.push_back(params->back());
67       }
68     }
69     auto node = fg->NewCNode(inputs);
70     node->set_abstract(seq_abs);
71     return node;
72   }
73 
TransformGraphParam(const FuncGraphPtr & fg,const FuncGraphManagerPtr & mng)74   FuncGraphPtr TransformGraphParam(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
75     Cloner cloner({fg}, false, false, false, std::make_shared<TraceCopy>(), std::make_shared<TraceCopy>());
76     auto new_fg = cloner[fg];
77     auto &params = new_fg->parameters();
78     std::vector<AnfNodePtr> new_params;
79     mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl;
80     for (auto &param : params) {
81       auto abs = param->abstract();
82       if (IsSequenceExpandable(abs)) {
83         auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
84         std::vector<AnfNodePtr> sequence_params;
85         (void)repl.emplace(param, GenerateSequenceParams(sequence_abs, new_fg, &sequence_params));
86         std::transform(sequence_params.begin(), sequence_params.end(), std::back_inserter(new_params),
87                        [](AnfNodePtr p) { return p; });
88       } else {
89         new_params.push_back(param);
90       }
91     }
92     auto tmp_mng = mindspore::Manage(new_fg, false);
93     auto tr = tmp_mng->Transact();
94     for (auto &item : repl) {
95       bool ret = tr.Replace(item.first, item.second);
96       if (ret == false) {
97         MS_LOG(ERROR) << "replace failed" << item.first->DebugString() << " with__"
98                       << item.second->DebugString(SizeToInt(kIndex2));
99       }
100     }
101     tr.SetParameters(new_fg, new_params);
102     tr.Commit();
103     mng->AddFuncGraph(new_fg);
104     return new_fg;
105   }
106 
107  private:
108   mindspore::HashMap<FuncGraphPtr, FuncGraphPtr> cache_;
109 };
110 }  // namespace opt
111 }  // namespace mindspore
112 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
113