• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
17 #include <memory>
18 #include <set>
19 #include <functional>
20 #include "src/common/utils.h"
21 #include "utils/utils.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "securec/include/securec.h"
24 #include "tools/converter/ops/ops_def.h"
25 #include "nnacl/op_base.h"
26 
27 namespace mindspore {
28 namespace opt {
29 namespace {
30 constexpr size_t kNumFwVars = 4;
31 constexpr size_t kNumBwVars = 4;
32 const auto &p1 = std::placeholders::_1;
GetPrim(const PrimitivePtr & prim)33 BaseRef GetPrim(const PrimitivePtr &prim) {
34   auto ptr = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim));
35   MS_CHECK_TRUE_MSG(ptr != nullptr, nullptr, "is nullptr.");
36   return ptr;
37 }
38 
GetPrim(const std::string & prim_name)39 BaseRef GetPrim(const std::string &prim_name) {
40   auto prim = std::make_shared<Primitive>(prim_name);
41   MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "is nullptr.");
42   return GetPrim(prim);
43 }
44 }  // namespace
45 
TfBidirectionGruCfFusion(const std::string & name,bool multi_graph)46 TfBidirectionGruCfFusion::TfBidirectionGruCfFusion(const std::string &name, bool multi_graph)
47     : TfBidirectionGruFusion(kNumFwVars, kNumBwVars, name, multi_graph) {
48   /*
49    * vars for fw/bw input
50    * fw:
51    * 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias
52    * bw:
53    * 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias
54    */
55 }
56 
DefineGruCellPattern(const BaseRef & in_ta_read,const BaseRef & switch3_true,const std::vector<VarPtr> & vars) const57 BaseRef TfBidirectionGruCfFusion::DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true,
58                                                        const std::vector<VarPtr> &vars) const {
59   auto concat = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, switch3_true});
60   auto matmul_enter = VectorRef({GetPrim(lite::kNameEnter), vars[0]});  // gate_kernel
61   auto matmul = VectorRef({GetPrim(prim::kPrimMatMul), concat, matmul_enter});
62   auto bias_enter = VectorRef({GetPrim(lite::kNameEnter), vars[1]});  // cand_bias
63   auto bias = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul, bias_enter});
64   auto sigmoid = VectorRef({GetPrim(prim::kPrimActivation), bias});
65   auto split = VectorRef({GetPrim(prim::kPrimSplit), sigmoid});
66   auto rt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()});
67   auto zt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()});
68   auto mul = VectorRef({GetPrim(prim::kPrimMulFusion), rt, switch3_true});
69   auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, mul});
70   auto matmul1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[2]});  // cand_kernel
71   auto matmul1 = VectorRef({GetPrim(prim::kPrimMatMul), concat1, matmul1_enter});
72   auto bias1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[3]});  // cand_bias
73   auto bias1 = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul1, bias1_enter});
74   auto tanh = VectorRef({GetPrim(prim::kPrimActivation), bias1});
75   auto sub = VectorRef({GetPrim(prim::kPrimSubFusion), std::make_shared<CondVar>(IsParameterNode), zt});
76   auto mul2 = VectorRef({GetPrim(prim::kPrimMulFusion), sub, tanh});
77   auto mul1 = VectorRef({GetPrim(prim::kPrimMulFusion), zt, switch3_true});
78   auto add = VectorRef({GetPrim(prim::kPrimAddFusion), mul1, mul2});
79   return add;
80 }
81 
DefineBidirectionRnnPattern(const BaseRef & input,const std::vector<VarPtr> & vars,const VarPtr & init_state) const82 const BaseRef TfBidirectionGruCfFusion::DefineBidirectionRnnPattern(const BaseRef &input,
83                                                                     const std::vector<VarPtr> &vars,
84                                                                     const VarPtr &init_state) const {
85   // in order to match cyclic graph, some node in cycle is represented by SeqVar
86   auto fw_shape1 = VectorRef({GetPrim(prim::kPrimShape), input});
87   auto strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape1, std::make_shared<SeqVar>()});
88   auto fw_max = VectorRef({GetPrim(prim::kPrimReduceFusion), input_length_, std::make_shared<Var>()});
89   auto fw_maximum = VectorRef({GetPrim(prim::kPrimMaximum), std::make_shared<CondVar>(IsParameterNode), fw_max});
90   auto fw_minimum = VectorRef({GetPrim(prim::kPrimMinimum), strided_slice, fw_maximum});
91   auto fw_less1_enter = VectorRef({GetPrim(lite::kNameEnter), fw_minimum});
92   // SeqVar:counter_merge1
93   auto fw_less1 = VectorRef({GetPrim(prim::kPrimLess), std::make_shared<SeqVar>(), fw_less1_enter});
94 
95   // SeqVar:fw_merge,loop_cond
96   auto fw_switch = VectorRef({GetPrim(prim::kPrimSwitch), std::make_shared<SeqVar>()});
97   auto fw_switch_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), fw_switch, std::make_shared<Var>()});  // identity
98   auto fw_add = VectorRef({GetPrim(prim::kPrimAddFusion), fw_switch_true, std::make_shared<CondVar>(IsParameterNode)});
99   auto fw_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), fw_add});
100   auto fw_merge_enter = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)});
101   auto fw_merge = VectorRef({GetPrim(prim::kPrimMerge), fw_merge_enter, fw_next_iter});
102   auto fw_less_enter = VectorRef({GetPrim(lite::kNameEnter), strided_slice});
103   auto fw_less = VectorRef({GetPrim(prim::kPrimLess), fw_merge, fw_less_enter});
104 
105   auto fw_logical_and = VectorRef({GetPrim(prim::kPrimLogicalAnd), fw_less, fw_less1});
106   // SeqVar:fw_logical_and
107   auto loop_cond = VectorRef({GetPrim(lite::kNameLoopCond), fw_logical_and});
108 
109   auto fw_shape = VectorRef({GetPrim(prim::kPrimShape), input});
110   auto fw_unstack_strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape, std::make_shared<SeqVar>()});
111   auto fw_unstack_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<CondVar>(IsParameterNode),
112                                      fw_unstack_strided_slice, std::make_shared<CondVar>(IsParameterNode)});
113 
114   // SeqVar:switch1_true
115   auto counter_add =
116     VectorRef({GetPrim(prim::kPrimAddFusion), std::make_shared<SeqVar>(), std::make_shared<CondVar>(IsParameterNode)});
117   auto counter_zero = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)});
118   auto counter_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), counter_add});
119   auto counter_merge1 = VectorRef({GetPrim(prim::kPrimMerge), counter_zero, counter_next_iter});
120   auto counter_switch1 = VectorRef({GetPrim(prim::kPrimSwitch), counter_merge1, loop_cond});
121   auto switch1_true =
122     VectorRef({GetPrim(prim::kPrimTupleGetItem), counter_switch1, std::make_shared<Var>()});  // identity1
123 
124   auto in_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice});
125   auto in_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()});
126   auto in_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()});
127   auto fw_unstack_ta_scatter =
128     VectorRef({GetPrim(lite::kNameTensorArrayScatterV3), in_ta_handle, fw_unstack_range, input, in_ta_flow});
129   auto in_ta_enter1 = VectorRef({GetPrim(lite::kNameEnter), fw_unstack_ta_scatter});
130   auto in_ta_enter = VectorRef({GetPrim(lite::kNameEnter), in_ta_handle});
131   auto in_ta_read = VectorRef({GetPrim(lite::kNameTensorArrayReadV3), in_ta_enter, switch1_true, in_ta_enter1});
132 
133   auto greater_equal_enter = VectorRef({GetPrim(lite::kNameEnter), input_length_});
134   auto greater_equal = VectorRef({GetPrim(prim::kPrimGreaterEqual), switch1_true, greater_equal_enter});
135   auto select1 = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, std::make_shared<SeqVar>()});  // select h
136 
137   auto next_iteration3 = VectorRef({GetPrim(lite::kNameNextIteration), select1});
138   auto enter3 = VectorRef({GetPrim(lite::kNameEnter), init_state});
139   auto merge3 = VectorRef({GetPrim(prim::kPrimMerge), enter3, next_iteration3});
140   auto switch3 = VectorRef({GetPrim(prim::kPrimSwitch), merge3, loop_cond});
141   auto switch3_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch3, std::make_shared<Var>()});  // identity3
142 
143   auto rnn_cell_out = DefineGruCellPattern(in_ta_read, switch3_true, vars);
144 
145   auto out_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice});
146   auto out_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()});
147   auto out_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()});
148   auto out_ta_enter = VectorRef({GetPrim(lite::kNameEnter), out_ta_handle});
149 
150   auto switch2_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), std::make_shared<SeqVar>()});  // cycle
151 
152   auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<SeqVar>()});
153   auto zeros1 = VectorRef({GetPrim(prim::kPrimFill), std::make_shared<CondVar>(IsParameterNode), concat1});
154   auto select_enter = VectorRef({GetPrim(lite::kNameEnter), zeros1});
155   auto select = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, select_enter, rnn_cell_out});  // select x
156   auto ta_write = VectorRef({GetPrim(lite::kNameTensorArrayWriteV3), out_ta_enter, switch1_true, select, switch2_true});
157 
158   auto enter2 = VectorRef({GetPrim(lite::kNameEnter), out_ta_flow});
159   auto next_iter2 = VectorRef({GetPrim(lite::kNameNextIteration), ta_write});
160   auto merge2 = VectorRef({GetPrim(prim::kPrimMerge), enter2, next_iter2});
161   auto switch2 = VectorRef({GetPrim(prim::kPrimSwitch), merge2, loop_cond});
162   auto switch2_false = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch2, std::make_shared<Var>()});
163 
164   auto exit2 = VectorRef({GetPrim(lite::kNameExit), switch2_false});
165   auto ta_size = VectorRef({GetPrim(lite::kNameTensorArraySizeV3), out_ta_handle, exit2});
166   auto range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<Var>(), ta_size, std::make_shared<Var>()});
167   auto tensor_array_gather = VectorRef({GetPrim(lite::kNameTensorArrayGatherV3), out_ta_handle, range, exit2});
168   auto range1 = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()});
169   auto concat2 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), range1});
170   auto fw_out_trans = VectorRef({GetPrim(prim::kPrimTranspose), tensor_array_gather, concat2});
171   return fw_out_trans;
172 }
173 
DefinePattern() const174 const BaseRef TfBidirectionGruCfFusion::DefinePattern() const {
175   if (!Init()) {
176     MS_LOG(ERROR) << "initial member failed.";
177     return {};
178   }
179 
180   const auto fw_out_trans = DefineBidirectionRnnPattern(transpose_input_, fw_vars_, fw_init_state_);
181 
182   auto bw_reverse_in = VectorRef({GetPrim(prim::kPrimReverseSequence), input_, input_length_});
183   auto bw_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()});
184   auto bw_concat = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), bw_range});
185   auto bw_transpose = VectorRef({GetPrim(prim::kPrimTranspose), bw_reverse_in, bw_concat});
186   auto bw_out_trans = DefineBidirectionRnnPattern(bw_transpose, bw_vars_, bw_init_state_);
187   auto bw_reverse_out = VectorRef({GetPrim(prim::kPrimReverseSequence), bw_out_trans, input_length_});
188   auto concat = VectorRef({GetPrim(prim::kPrimConcat), fw_out_trans, bw_reverse_out});
189   return concat;
190 }
191 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & concat_node,const EquivPtr & equiv) const192 const AnfNodePtr TfBidirectionGruCfFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
193                                                    const EquivPtr &equiv) const {
194   if (func_graph == nullptr || concat_node == nullptr || equiv == nullptr) {
195     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
196     return nullptr;
197   }
198 
199   auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]);
200   MS_ASSERT(transpose_input != nullptr);
201 
202   const std::string gru_name = "gru_" + concat_node->fullname_with_scope();
203   auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 0);
204   MS_CHECK_TRUE_MSG(gru_node != nullptr, nullptr, "gru_node is nullptr.");
205 
206   if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) {
207     return nullptr;
208   }
209 
210   auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0);
211   MS_CHECK_TRUE_MSG(get_item_node != nullptr, nullptr, "get_item_node is nullptr.");
212 
213   auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope());
214   MS_CHECK_TRUE_MSG(output_node != nullptr, nullptr, "output_node is nullptr.");
215   MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success";
216   return output_node;
217 }
218 }  // namespace opt
219 }  // namespace mindspore
220