• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
19 #include <memory>
20 #include <functional>
21 #include "mindspore/core/ops/structure_ops.h"
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/math_ops.h"
24 #include "mindspore/core/ops/lite_ops.h"
25 #include "mindspore/core/ops/comparison_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "tools/optimizer/common/helper.h"
29 #include "ops/gru.h"
30 #include "ops/squeeze.h"
31 #include "ops/stack.h"
32 #include "ops/auto_generate/gen_lite_ops.h"
33 #include "src/common/utils.h"
34 #include "tools/common/tensor_util.h"
35 #include "include/common/utils/utils.h"
36 #include "nnacl/op_base.h"
37 #include "ops/op_utils.h"
38 
39 namespace mindspore {
40 namespace opt {
41 namespace {
42 constexpr int kOffsetTwo = 2;
43 constexpr int kReservedParamNodesNum = 13;
44 constexpr size_t kCondNodesNum = 12;
45 constexpr size_t kCondCNodesNum = 4;
46 constexpr size_t kBodyNodesNum = 69;
47 constexpr size_t kBodyCNodesNum = 25;
48 constexpr auto kGateNum = 2;
49 const auto &p1 = std::placeholders::_1;
GenerateBodyGraphHiddenPattern(const BaseRef & sigmoid1,const BaseRef & get_item,const std::vector<CondVarPtr> & placeholders)50 VectorRef GenerateBodyGraphHiddenPattern(const BaseRef &sigmoid1, const BaseRef &get_item,
51                                          const std::vector<CondVarPtr> &placeholders) {
52   MS_CHECK_TRUE_RET(placeholders.size() >= kCondCNodesNum, {});
53   auto is_var_split = std::make_shared<Var>("Split");
54   MS_CHECK_TRUE_RET(is_var_split != nullptr, {});
55   VectorRef split = VectorRef({is_var_split, sigmoid1});
56   auto is_var_tuple_getitem1 = std::make_shared<Var>("TupleGetItem");
57   MS_CHECK_TRUE_RET(is_var_tuple_getitem1 != nullptr, {});
58   auto is_var4 = std::make_shared<Var>();
59   MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
60   VectorRef get_item1 = VectorRef({is_var_tuple_getitem1, split, is_var4});
61   auto is_var_tuple_getitem2 = std::make_shared<Var>("TupleGetItem");
62   MS_CHECK_TRUE_RET(is_var_tuple_getitem2 != nullptr, {});
63   auto is_var5 = std::make_shared<Var>();
64   MS_CHECK_TRUE_RET(is_var5 != nullptr, {});
65   VectorRef get_item2 = VectorRef({is_var_tuple_getitem2, split, is_var5});
66 
67   auto is_var_mul1 = std::make_shared<Var>("Mul");
68   MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
69   VectorRef pre_reset = VectorRef({is_var_mul1, get_item1, placeholders[4]});
70   auto is_var_concat = std::make_shared<Var>("Concat");
71   MS_CHECK_TRUE_RET(is_var_concat != nullptr, {});
72   VectorRef concat2 = VectorRef({is_var_concat, get_item, pre_reset});
73   auto is_var_matmul2 = std::make_shared<Var>("Matmul");
74   MS_CHECK_TRUE_RET(is_var_matmul2 != nullptr, {});
75   VectorRef matmul2 = VectorRef({is_var_matmul2, concat2, placeholders[10]});
76   auto is_var_biasadd2 = std::make_shared<Var>("BiasAdd");
77   MS_CHECK_TRUE_RET(is_var_biasadd2 != nullptr, {});
78   VectorRef biasadd2 = VectorRef({is_var_biasadd2, matmul2, placeholders[11]});
79   auto is_var_tanh = std::make_shared<Var>("Tanh");
80   MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
81   VectorRef tanh = VectorRef({is_var_tanh, biasadd2});
82 
83   auto is_var_mul2 = std::make_shared<Var>("Mul");
84   MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
85   VectorRef update_hidden = VectorRef({is_var_mul2, get_item2, placeholders[4]});
86   auto is_var_sub = std::make_shared<Var>("Sub");
87   MS_CHECK_TRUE_RET(is_var_sub != nullptr, {});
88   auto is_param = std::make_shared<CondVar>(IsParameterNode);
89   MS_CHECK_TRUE_RET(is_param != nullptr, {});
90   VectorRef minus_update = VectorRef({is_var_sub, is_param, get_item2});
91   auto is_var_mul3 = std::make_shared<Var>("Mul");
92   MS_CHECK_TRUE_RET(is_var_mul3 != nullptr, {});
93   VectorRef updated = VectorRef({is_var_mul3, minus_update, tanh});
94 
95   auto is_var_add = std::make_shared<Var>("Add");
96   MS_CHECK_TRUE_RET(is_var_add != nullptr, {});
97   VectorRef new_hidden = VectorRef({is_var_add, update_hidden, updated});
98 
99   return new_hidden;
100 }
101 }  // namespace
102 
Init() const103 bool TfBidirectionGruFusion::Init() const {
104   /*
105    * vars for while input
106    * fw_while_inputs:
107    * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
108    * bw_while_inputs:
109    * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
110    */
111   for (int i = 0; i < num_fw_vars_; ++i) {
112     auto is_var = std::make_shared<Var>();
113     MS_CHECK_TRUE_RET(is_var != nullptr, false);
114     fw_vars_.emplace_back(is_var);
115   }
116   for (int i = 0; i < num_bw_vars_; ++i) {
117     auto is_var = std::make_shared<Var>();
118     MS_CHECK_TRUE_RET(is_var != nullptr, false);
119     bw_vars_.emplace_back(is_var);
120   }
121   input_ = std::make_shared<Var>();
122   MS_CHECK_TRUE_RET(input_ != nullptr, false);
123   input_length_ = std::make_shared<Var>();
124   MS_CHECK_TRUE_RET(input_length_ != nullptr, false);
125   transpose_input_ = std::make_shared<Var>();
126   MS_CHECK_TRUE_RET(transpose_input_ != nullptr, false);
127   fw_init_state_ = std::make_shared<Var>();
128   MS_CHECK_TRUE_RET(fw_init_state_ != nullptr, false);
129   bw_init_state_ = std::make_shared<Var>();
130   MS_CHECK_TRUE_RET(bw_init_state_ != nullptr, false);
131   return true;
132 }
133 
DefineFowardPattern() const134 const VectorRef TfBidirectionGruFusion::DefineFowardPattern() const {
135   auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion));
136   MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
137   auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
138   MS_CHECK_TRUE_RET(is_param1 != nullptr, {});
139   auto fw_reduce = VectorRef({is_reduce, input_length_, is_param1});
140   auto is_maximum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum));
141   MS_CHECK_TRUE_RET(is_maximum != nullptr, {});
142   auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
143   MS_CHECK_TRUE_RET(is_param2 != nullptr, {});
144   auto fw_max = VectorRef({is_maximum, is_param2, fw_reduce});
145 
146   auto is_shape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape));
147   MS_CHECK_TRUE_RET(is_shape != nullptr, {});
148   auto fw_shape = VectorRef({is_shape, transpose_input_});
149   auto is_strided_slice = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice));
150   MS_CHECK_TRUE_RET(is_strided_slice != nullptr, {});
151   auto is_seq_var = std::make_shared<SeqVar>();
152   MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
153   auto fw_stride = VectorRef({is_strided_slice, fw_shape, is_seq_var});
154   auto is_minimum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum));
155   MS_CHECK_TRUE_RET(is_minimum != nullptr, {});
156   auto fw_min = VectorRef({is_minimum, fw_stride, fw_max});
157   auto is_tensor_list_reserve = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve));
158   MS_CHECK_TRUE_RET(is_tensor_list_reserve != nullptr, {});
159   auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
160   MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
161   auto fw_reserve = VectorRef({is_tensor_list_reserve, is_param3, fw_stride});
162   auto is_tensor_list_from_tensor = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListFromTensor));
163   MS_CHECK_TRUE_RET(is_tensor_list_from_tensor != nullptr, {});
164   auto is_param4 = std::make_shared<CondVar>(IsParameterNode);
165   MS_CHECK_TRUE_RET(is_param4 != nullptr, {});
166   auto fw_from_tensor = VectorRef({is_tensor_list_from_tensor, transpose_input_, is_param4});
167   auto is_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimWhile));
168   MS_CHECK_TRUE_RET(is_while != nullptr, {});
169   auto is_param5 = std::make_shared<CondVar>(IsParameterNode);
170   MS_CHECK_TRUE_RET(is_param5 != nullptr, {});
171   auto is_param6 = std::make_shared<CondVar>(IsParameterNode);
172   MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
173   auto fw_while = VectorRef({is_while, fw_vars_[0], fw_vars_[1], is_param5, fw_stride, is_param6, fw_reserve,
174                              fw_init_state_, fw_min, fw_from_tensor, input_length_});
175   fw_while.insert(fw_while.end(), fw_vars_.begin() + kOffsetTwo, fw_vars_.end());
176   auto is_var1 = std::make_shared<Var>();
177   MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
178   fw_while.emplace_back(is_var1);
179   auto is_tuple_getitem = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem));
180   MS_CHECK_TRUE_RET(is_tuple_getitem != nullptr, {});
181   auto is_var2 = std::make_shared<Var>();
182   MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
183   auto fw_get_item = VectorRef({is_tuple_getitem, fw_while, is_var2});
184   auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListStack));
185   MS_CHECK_TRUE_RET(is_tensor_list_stack != nullptr, {});
186   auto is_param7 = std::make_shared<CondVar>(IsParameterNode);
187   MS_CHECK_TRUE_RET(is_param7 != nullptr, {});
188   auto fw_stack = VectorRef({is_tensor_list_stack, fw_get_item, is_param7});
189   auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
190   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
191   auto is_var3 = std::make_shared<Var>();
192   MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
193   auto fw_out_trans = VectorRef({is_transpose, fw_stack, is_var3});
194   return fw_out_trans;
195 }
196 
DefinebackwardPattern() const197 const VectorRef TfBidirectionGruFusion::DefinebackwardPattern() const {
198   auto is_reverse_sequence = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReverseSequence));
199   MS_CHECK_TRUE_RET(is_reverse_sequence != nullptr, {});
200   auto bw_reverse_seq = VectorRef({is_reverse_sequence, input_, input_length_});
201   auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion));
202   MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
203   auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
204   MS_CHECK_TRUE_RET(is_param1 != nullptr, {});
205   auto bw_max1 = VectorRef({is_reduce, input_length_, is_param1});
206   auto is_maximum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum));
207   MS_CHECK_TRUE_RET(is_maximum != nullptr, {});
208   auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
209   MS_CHECK_TRUE_RET(is_param2 != nullptr, {});
210   auto bw_max2 = VectorRef({is_maximum, is_param2, bw_max1});
211   auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
212   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
213   auto is_var1 = std::make_shared<Var>();
214   MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
215   auto bw_trans = VectorRef({is_transpose, bw_reverse_seq, is_var1});
216   auto is_shape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape));
217   MS_CHECK_TRUE_RET(is_shape != nullptr, {});
218   auto bw_shape = VectorRef({is_shape, bw_trans});
219   auto is_strided_slice = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice));
220   MS_CHECK_TRUE_RET(is_strided_slice != nullptr, {});
221   auto is_seq_var = std::make_shared<SeqVar>();
222   MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
223   auto bw_stride = VectorRef({is_strided_slice, bw_shape, is_seq_var});
224   auto is_minimum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum));
225   MS_CHECK_TRUE_RET(is_minimum != nullptr, {});
226   auto bw_min = VectorRef({is_minimum, bw_stride, bw_max2});
227   auto is_tensor_list_reserve = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve));
228   MS_CHECK_TRUE_RET(is_tensor_list_reserve != nullptr, {});
229   auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
230   MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
231   auto bw_reserve = VectorRef({is_tensor_list_reserve, is_param3, bw_stride});
232   auto is_tensor_list_from_tensor = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListFromTensor));
233   MS_CHECK_TRUE_RET(is_tensor_list_from_tensor != nullptr, {});
234   auto is_param4 = std::make_shared<CondVar>(IsParameterNode);
235   MS_CHECK_TRUE_RET(is_param4 != nullptr, {});
236   auto bw_from_tensor = VectorRef({is_tensor_list_from_tensor, bw_trans, is_param4});
237   auto is_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimWhile));
238   MS_CHECK_TRUE_RET(is_while != nullptr, {});
239   auto is_param5 = std::make_shared<CondVar>(IsParameterNode);
240   MS_CHECK_TRUE_RET(is_param5 != nullptr, {});
241   auto is_param6 = std::make_shared<CondVar>(IsParameterNode);
242   MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
243   auto bw_while = VectorRef({is_while, bw_vars_[0], bw_vars_[1], is_param5, bw_stride, is_param6, bw_reserve,
244                              bw_init_state_, bw_min, bw_from_tensor, input_length_});
245   bw_while.insert(bw_while.end(), bw_vars_.begin() + kOffsetTwo, bw_vars_.end());
246   auto is_var2 = std::make_shared<Var>();
247   MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
248   bw_while.emplace_back(is_var2);
249   auto is_tuple_getitem = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem));
250   MS_CHECK_TRUE_RET(is_tuple_getitem != nullptr, {});
251   auto is_var3 = std::make_shared<Var>();
252   MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
253   auto bw_get_item = VectorRef({is_tuple_getitem, bw_while, is_var3});
254   auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListStack));
255   MS_CHECK_TRUE_RET(is_tensor_list_stack != nullptr, {});
256   auto is_param7 = std::make_shared<CondVar>(IsParameterNode);
257   MS_CHECK_TRUE_RET(is_param7 != nullptr, {});
258   auto bw_stack = VectorRef({is_tensor_list_stack, bw_get_item, is_param7});
259   auto is_var4 = std::make_shared<Var>();
260   MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
261   auto bw_out_trans = VectorRef({is_transpose, bw_stack, is_var4});
262   auto bw_reverse1 = VectorRef({is_reverse_sequence, bw_out_trans, input_length_});
263   return bw_reverse1;
264 }
265 
DefinePattern() const266 const BaseRef TfBidirectionGruFusion::DefinePattern() const {
267   if (!Init()) {
268     MS_LOG(ERROR) << "initial member failed.";
269     return {};
270   }
271 
272   // forward
273   auto fw_out_trans = DefineFowardPattern();
274   MS_CHECK_TRUE_RET(!fw_out_trans.empty(), {});
275 
276   // backward
277   auto bw_reverse1 = DefinebackwardPattern();
278   MS_CHECK_TRUE_RET(!bw_reverse1.empty(), {});
279 
280   auto is_concat = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimConcat));
281   MS_CHECK_TRUE_RET(is_concat != nullptr, {});
282   auto concat = VectorRef({is_concat, fw_out_trans, bw_reverse1});
283   return concat;
284 }
285 
GetCondGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const286 AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
287   MS_ASSERT(primitive_vars != nullptr);
288   auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
289   MS_CHECK_TRUE_RET(is_less1 != nullptr, nullptr);
290   auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
291   MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
292   auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
293   MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
294   VectorRef less1_ref = VectorRef({is_less1, is_param1, is_param2});
295   auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
296   MS_CHECK_TRUE_RET(is_less2 != nullptr, nullptr);
297   auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
298   MS_CHECK_TRUE_RET(is_param3 != nullptr, nullptr);
299   auto is_param4 = std::make_shared<CondVar>(IsParameterNode);
300   MS_CHECK_TRUE_RET(is_param4 != nullptr, nullptr);
301   VectorRef less2_ref = VectorRef({is_less2, is_param3, is_param4});
302   auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd));
303   MS_CHECK_TRUE_RET(is_logical_and != nullptr, nullptr);
304   VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
305   auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
306   MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
307   VectorRef return_ref = VectorRef({is_return, logicaland_ref});
308   VarPtr is_fg = std::make_shared<Var>("RootG");
309   MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
310   auto pattern = Helper::SexpToNode(return_ref, is_fg, primitive_vars.get(), true);
311   return pattern;
312 }
313 
GetBodyGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const314 AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
315   MS_ASSERT(primitive_vars != nullptr);
316   std::vector<CondVarPtr> placeholders;
317   for (int i = 0; i < kReservedParamNodesNum; ++i) {
318     auto is_param_placeholder = std::make_shared<CondVar>(IsParameterNode);
319     MS_CHECK_TRUE_RET(is_param_placeholder != nullptr, nullptr);
320     placeholders.emplace_back(is_param_placeholder);
321   }
322   auto is_var1 = std::make_shared<Var>();
323   MS_CHECK_TRUE_RET(is_var1 != nullptr, nullptr);
324   auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
325   MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
326   VectorRef add = VectorRef({is_var1, placeholders[2], is_param1});
327   auto is_var2 = std::make_shared<Var>();
328   MS_CHECK_TRUE_RET(is_var2 != nullptr, nullptr);
329   auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
330   MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
331   VectorRef add1 = VectorRef({is_var2, placeholders[0], is_param2});
332 
333   auto is_getitem = std::make_shared<Var>("GetItem");
334   MS_CHECK_TRUE_RET(is_getitem != nullptr, nullptr);
335   auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
336   MS_CHECK_TRUE_RET(is_param3 != nullptr, nullptr);
337   VectorRef get_item = VectorRef({is_getitem, placeholders[6], placeholders[2], is_param3});
338   auto is_var3 = std::make_shared<Var>();
339   MS_CHECK_TRUE_RET(is_var3 != nullptr, nullptr);
340   VectorRef concat_input_h = VectorRef({is_var3, get_item, placeholders[4]});
341 
342   auto is_var_matmul1 = std::make_shared<Var>("Matmul");
343   MS_CHECK_TRUE_RET(is_var_matmul1 != nullptr, nullptr);
344   VectorRef matmul1 = VectorRef({is_var_matmul1, concat_input_h, placeholders[8]});
345   auto is_var_biasadd1 = std::make_shared<Var>("BiasAdd");
346   MS_CHECK_TRUE_RET(is_var_biasadd1 != nullptr, nullptr);
347   VectorRef biasadd1 = VectorRef({is_var_biasadd1, matmul1, placeholders[9]});
348   auto is_var_sigmoid = std::make_shared<Var>("Sigmoid");
349   MS_CHECK_TRUE_RET(is_var_sigmoid != nullptr, nullptr);
350   VectorRef sigmoid1 = VectorRef({is_var_sigmoid, biasadd1});
351 
352   auto new_hidden = GenerateBodyGraphHiddenPattern(sigmoid1, get_item, placeholders);
353   MS_CHECK_TRUE_RET(!new_hidden.empty(), nullptr);
354 
355   auto is_var_ge = std::make_shared<Var>("GreaterEqual");
356   MS_CHECK_TRUE_RET(is_var_ge != nullptr, nullptr);
357   VectorRef greater_equal = VectorRef({is_var_ge, placeholders[2], placeholders[7]});
358 
359   auto is_var_switch1 = std::make_shared<Var>("Switch");
360   MS_CHECK_TRUE_RET(is_var_switch1 != nullptr, {});
361   VectorRef select_output = VectorRef({is_var_switch1, greater_equal, placeholders[12], new_hidden});
362   auto is_var_setitem = std::make_shared<Var>("SetItem");
363   MS_CHECK_TRUE_RET(is_var_setitem != nullptr, {});
364   VectorRef output = VectorRef({is_var_setitem, placeholders[3], placeholders[2], select_output});
365 
366   auto is_var_switch2 = std::make_shared<Var>("Switch");
367   MS_CHECK_TRUE_RET(is_var_switch2 != nullptr, {});
368   VectorRef select_hidden = VectorRef({is_var_switch2, greater_equal, placeholders[4], new_hidden});
369 
370   auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
371   MS_CHECK_TRUE_RET(is_make_tuple != nullptr, nullptr);
372   std::vector<BaseRef> outputs = {is_make_tuple,  add1,          placeholders[1], add,
373                                   output,         select_hidden, placeholders[5], placeholders[6],
374                                   placeholders[7]};
375   outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end());
376   VectorRef make_tuple_node = VectorRef(outputs);
377   auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
378   MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
379   VectorRef return_node = VectorRef({is_return, make_tuple_node});
380 
381   VarPtr is_fg = std::make_shared<Var>("RootG");
382   MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
383   auto pattern = Helper::SexpToNode(return_node, is_fg, primitive_vars.get(), true);
384   return pattern;
385 }
386 
GetDefaultTensorInfo(const AnfNodePtr & parameter_anf)387 tensor::TensorPtr TfBidirectionGruFusion::GetDefaultTensorInfo(const AnfNodePtr &parameter_anf) {
388   MS_ASSERT(parameter_anf != nullptr);
389   if (!utils::isa<ParameterPtr>(parameter_anf)) {
390     MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr";
391     return nullptr;
392   }
393   auto parameter = utils::cast<ParameterPtr>(parameter_anf);
394   if (!parameter->has_default()) {
395     MS_LOG(DEBUG) << "parameter not have default value";
396     return nullptr;
397   }
398   auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param());
399   return tensor_info;
400 }
401 
GetInputAndHiddenSize(const AnfNodePtr & fw_cand_kernel_anf,const AnfNodePtr & bw_cand_kernel_anf,int * input_size,int * hidden_size)402 STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf,
403                                                      const AnfNodePtr &bw_cand_kernel_anf, int *input_size,
404                                                      int *hidden_size) {
405   MS_ASSERT(fw_cand_kernel_anf != nullptr);
406   MS_ASSERT(bw_cand_kernel_anf != nullptr);
407   MS_ASSERT(input_size != nullptr);
408   MS_ASSERT(hidden_size != nullptr);
409   auto fw_cand_kernel_value = GetDefaultTensorInfo(fw_cand_kernel_anf);
410   if (fw_cand_kernel_value == nullptr) {
411     return RET_ERROR;
412   }
413   auto fw_cand_kernel_shape = fw_cand_kernel_value->shape();
414   if (fw_cand_kernel_shape.size() != kInputSizeTwo) {
415     return RET_ERROR;
416   }
417   auto bw_cand_kernel_value = GetDefaultTensorInfo(bw_cand_kernel_anf);
418   if (bw_cand_kernel_value == nullptr) {
419     return RET_ERROR;
420   }
421   auto bw_cand_kernel_shape = bw_cand_kernel_value->shape();
422   if (bw_cand_kernel_shape.size() != kInputSizeTwo) {
423     return RET_ERROR;
424   }
425   if (fw_cand_kernel_shape != bw_cand_kernel_shape) {
426     return RET_ERROR;
427   }
428   if (fw_cand_kernel_shape[1] <= 0 || fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1] <= 0) {
429     MS_LOG(DEBUG) << "gru input size or hidden size illegal";
430     return RET_ERROR;
431   }
432   *hidden_size = fw_cand_kernel_shape[1];
433   *input_size = fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1];
434   return RET_OK;
435 }
436 
AddDefaultParameter(const FuncGraphPtr & func_graph,const std::string & name,const std::vector<int> & shape,const TypeId type,void ** tensor_data)437 ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
438                                                          const std::vector<int> &shape, const TypeId type,
439                                                          void **tensor_data) {
440   MS_ASSERT(func_graph != nullptr);
441   MS_ASSERT(tensor_data != nullptr);
442   auto parameter = func_graph->add_parameter();
443   MS_CHECK_TRUE_RET(parameter != nullptr, nullptr);
444   parameter->set_name(name);
445   std::vector<int64_t> shape_vector(shape.begin(), shape.end());
446   auto abstract = lite::CreateTensorAbstract(shape_vector, type);
447   if (abstract == nullptr) {
448     MS_LOG(ERROR) << "Create tensor abstarct failed";
449     return nullptr;
450   }
451   parameter->set_abstract(abstract);
452 
453   auto gate_weight_default = std::make_shared<tensor::Tensor>(type, shape_vector);
454   if (gate_weight_default == nullptr) {
455     MS_LOG(ERROR) << "gate_weight_default is nullptr";
456     return nullptr;
457   }
458 
459   *tensor_data = gate_weight_default->data_c();
460   parameter->set_default_param(gate_weight_default);
461   return parameter;
462 }
463 
CopyFlattenMatData(const float * mat,const int C,const int r0,const int r1,const int c0,const int c1,float * data,bool t)464 void TfBidirectionGruFusion::CopyFlattenMatData(const float *mat, const int C, const int r0, const int r1, const int c0,
465                                                 const int c1, float *data, bool t) {
466   MS_ASSERT(mat != nullptr);
467   MS_ASSERT(data != nullptr);
468   MS_ASSERT(r0 >= 0 && r0 < r1);
469   MS_ASSERT(c0 >= 0 && c0 < c1 && c1 <= C);
470   const int RT = r1 - r0;
471   const int CT = c1 - c0;
472   for (int i = r0; i < r1; ++i) {
473     for (int j = c0; j < c1; ++j) {
474       if (t) {
475         data[(j - c0) * RT + (i - r0)] = mat[i * C + j];
476       } else {
477         data[(i - r0) * CT + (j - c0)] = mat[i * C + j];
478       }
479     }
480   }
481 }
482 
ConvertWeightData(const AnfNodePtr & gate_weight,const AnfNodePtr & cand_weight,const int input_size,const int hidden_size,float * gate_tensor_data,float * recu_tensor_data)483 STATUS TfBidirectionGruFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight,
484                                                  const int input_size, const int hidden_size, float *gate_tensor_data,
485                                                  float *recu_tensor_data) {
486   MS_ASSERT(gate_weight != nullptr);
487   MS_ASSERT(cand_weight != nullptr);
488   MS_ASSERT(gate_tensor_data != nullptr);
489   MS_ASSERT(recu_tensor_data != nullptr);
490   const std::vector<int64_t> gate_shape{input_size + hidden_size, hidden_size * kGateNum};
491   const std::vector<int64_t> cand_shape{hidden_size * kGateNum, hidden_size};
492   auto gate_weight_value = GetDefaultTensorInfo(gate_weight);
493   if (gate_weight_value == nullptr) {
494     return RET_ERROR;
495   }
496   auto gate_weight_data = reinterpret_cast<float *>(gate_weight_value->data_c());
497   if (gate_weight_data == nullptr) {
498     return RET_ERROR;
499   }
500   auto gate_weight_shape = gate_weight_value->shape();
501 
502   auto cand_weight_value = GetDefaultTensorInfo(cand_weight);
503   if (cand_weight_value == nullptr) {
504     return RET_ERROR;
505   }
506   auto cand_weight_data = reinterpret_cast<float *>(cand_weight_value->data_c());
507   if (cand_weight_data == nullptr) {
508     return RET_ERROR;
509   }
510   auto cand_weight_shape = cand_weight_value->shape();
511 
512   if (gate_weight_shape != gate_shape || cand_weight_shape != cand_shape) {
513     return RET_ERROR;
514   }
515 
516   // input_update_weight
517   CopyFlattenMatData(gate_weight_data, hidden_size * kGateNum, 0, input_size, hidden_size, hidden_size * kGateNum,
518                      gate_tensor_data, true);
519   // input_reset_weight
520   CopyFlattenMatData(gate_weight_data, hidden_size * kGateNum, 0, input_size, 0, hidden_size,
521                      gate_tensor_data + input_size * hidden_size, true);
522   // input_hidden_weight
523   CopyFlattenMatData(cand_weight_data, hidden_size, 0, input_size, 0, hidden_size,
524                      gate_tensor_data + input_size * hidden_size * kGateNum, true);
525 
526   // state_update_weight
527   CopyFlattenMatData(gate_weight_data, hidden_size * kGateNum, input_size, input_size + hidden_size, hidden_size,
528                      hidden_size * kGateNum, recu_tensor_data, true);
529   // state_reset_weight
530   CopyFlattenMatData(gate_weight_data, hidden_size * kGateNum, input_size, input_size + hidden_size, 0, hidden_size,
531                      recu_tensor_data + hidden_size * hidden_size, true);
532   // state_hidden_weight
533   CopyFlattenMatData(cand_weight_data, hidden_size, input_size, input_size + hidden_size, 0, hidden_size,
534                      recu_tensor_data + hidden_size * hidden_size * kGateNum, true);
535   return RET_OK;
536 }
537 
ConvertBiasData(const AnfNodePtr & gate_bias,const AnfNodePtr & cand_bias,const int hidden_size,float * tensor_data)538 STATUS TfBidirectionGruFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias,
539                                                const int hidden_size, float *tensor_data) {
540   MS_ASSERT(gate_bias != nullptr && cand_bias != nullptr);
541   MS_ASSERT(tensor_data != nullptr);
542   std::vector<int64_t> gate_shape{hidden_size * kGateNum};
543   std::vector<int64_t> cand_shape{hidden_size};
544   auto gate_bias_value = GetDefaultTensorInfo(gate_bias);
545   if (gate_bias_value == nullptr) {
546     return RET_ERROR;
547   }
548   auto gate_bias_data = reinterpret_cast<float *>(gate_bias_value->data_c());
549   auto gate_bias_shape = gate_bias_value->shape();
550   auto cand_bias_value = GetDefaultTensorInfo(cand_bias);
551   if (cand_bias_value == nullptr) {
552     return RET_ERROR;
553   }
554   auto cand_bias_data = reinterpret_cast<float *>(cand_bias_value->data_c());
555   auto cand_bias_shape = cand_bias_value->shape();
556   if (gate_bias_shape != gate_shape || cand_bias_shape != cand_shape) {
557     return RET_ERROR;
558   }
559 
560   // update_gate bias
561   CopyFlattenMatData(gate_bias_data, hidden_size * kGateNum, 0, 1, hidden_size, hidden_size * kGateNum, tensor_data,
562                      false);
563   // reset_gate bias
564   CopyFlattenMatData(gate_bias_data, hidden_size * kGateNum, 0, 1, 0, hidden_size, tensor_data + hidden_size, false);
565   // hidden_gate bias
566   CopyFlattenMatData(cand_bias_data, hidden_size, 0, 1, 0, hidden_size, tensor_data + hidden_size * kGateNum, false);
567 
568   return RET_OK;
569 }
570 
GetStackedHiddenState(const FuncGraphPtr & func_graph,const AnfNodePtr & fw_init_state,const AnfNodePtr & bw_init_state,const std::string & base_name)571 CNodePtr TfBidirectionGruFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state,
572                                                        const AnfNodePtr &bw_init_state, const std::string &base_name) {
573   MS_ASSERT(func_graph != nullptr);
574   MS_ASSERT(fw_init_state != nullptr);
575   MS_ASSERT(bw_init_state != nullptr);
576   auto stack_prim = std::make_shared<ops::Stack>();
577   MS_CHECK_TRUE_RET(stack_prim != nullptr, nullptr);
578   auto stack_prim_c = stack_prim->GetPrim();
579   MS_CHECK_TRUE_RET(stack_prim_c != nullptr, nullptr);
580   stack_prim->set_axis(0);
581   auto value_node = NewValueNode(stack_prim_c);
582   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
583   std::vector<AnfNodePtr> new_node_inputs = {value_node, fw_init_state, bw_init_state};
584   auto new_node = func_graph->NewCNode(new_node_inputs);
585   MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
586   if (fw_init_state->abstract() != nullptr) {
587     new_node->set_abstract(fw_init_state->abstract()->Clone());
588   }
589   new_node->set_fullname_with_scope("stack_hidden_" + base_name);
590   return new_node;
591 }
592 
CreateBiDirectionGruNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const EquivPtr & equiv,const std::string & base_name,int var_offset) const593 CNodePtr TfBidirectionGruFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
594                                                           const EquivPtr &equiv, const std::string &base_name,
595                                                           int var_offset) const {
596   MS_ASSERT(func_graph != nullptr);
597   MS_ASSERT(input != nullptr);
598   MS_ASSERT(equiv != nullptr);
599   auto gru_prim = std::make_shared<ops::GRU>();
600   MS_CHECK_TRUE_RET(gru_prim != nullptr, nullptr);
601   auto gru_prim_c = gru_prim->GetPrim();
602   MS_CHECK_TRUE_RET(gru_prim_c != nullptr, nullptr);
603   gru_prim->set_bidirectional(true);
604   auto value_node = NewValueNode(gru_prim_c);
605   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
606 
607   auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset]]);
608   MS_ASSERT(fw_gate_kernel != nullptr);
609   auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 1]]);
610   MS_ASSERT(fw_gate_bias != nullptr);
611   auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 2]]);
612   MS_ASSERT(fw_cand_kernel != nullptr);
613   auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 3]]);
614   MS_ASSERT(fw_cand_bias != nullptr);
615 
616   auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset]]);
617   MS_ASSERT(bw_gate_kernel != nullptr);
618   auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 1]]);
619   MS_ASSERT(bw_gate_bias != nullptr);
620   auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 2]]);
621   MS_ASSERT(bw_cand_kernel != nullptr);
622   auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 3]]);
623   MS_ASSERT(bw_cand_bias != nullptr);
624 
625   auto fw_init_state = utils::cast<AnfNodePtr>((*equiv)[fw_init_state_]);
626   MS_ASSERT(fw_init_state != nullptr);
627   auto bw_init_state = utils::cast<AnfNodePtr>((*equiv)[bw_init_state_]);
628   MS_ASSERT(bw_init_state != nullptr);
629   auto stacked_hidden = GetStackedHiddenState(func_graph, fw_init_state, bw_init_state, base_name);
630   if (stacked_hidden == nullptr) {
631     return nullptr;
632   }
633   auto input_length = utils::cast<AnfNodePtr>((*equiv)[input_length_]);
634   MS_ASSERT(input_length != nullptr);
635 
636   int input_size = 0;
637   int hidden_size = 0;
638   auto status = GetInputAndHiddenSize(fw_cand_kernel, bw_cand_kernel, &input_size, &hidden_size);
639   if (status != RET_OK) {
640     return nullptr;
641   }
642   std::vector<int> gate_weight_shape{2, hidden_size * 3, input_size};
643   float *gate_tensor_data = nullptr;
644   auto gate_weight = AddDefaultParameter(func_graph, base_name + "_gate_weight", gate_weight_shape, kNumberTypeFloat32,
645                                          reinterpret_cast<void **>(&gate_tensor_data));
646   if (gate_weight == nullptr) {
647     return nullptr;
648   }
649   std::vector<int> recu_weight_shape{2, hidden_size * 3, hidden_size};
650   float *recu_tensor_data = nullptr;
651   auto recu_weight = AddDefaultParameter(func_graph, base_name + "_cand_weight", recu_weight_shape, kNumberTypeFloat32,
652                                          reinterpret_cast<void **>(&recu_tensor_data));
653   if (recu_weight == nullptr || recu_tensor_data == nullptr) {
654     return nullptr;
655   }
656   std::vector<int> bias_shape{2, hidden_size * 6};
657   float *bias_tensor_data = nullptr;
658   auto bias = AddDefaultParameter(func_graph, base_name + "_bias", bias_shape, kNumberTypeFloat32,
659                                   reinterpret_cast<void **>(&bias_tensor_data));
660   if (bias == nullptr || bias_tensor_data == nullptr) {
661     return nullptr;
662   }
663   for (int i = 0; i < 2 * hidden_size * 6; ++i) {
664     bias_tensor_data[i] = 0.0f;
665   }
666 
667   if (ConvertWeightData(fw_gate_kernel, fw_cand_kernel, input_size, hidden_size, gate_tensor_data, recu_tensor_data) !=
668       RET_OK) {
669     return nullptr;
670   }
671   auto gate_data_diff = hidden_size * input_size * 3;
672   auto recu_data_diff = hidden_size * hidden_size * 3;
673   if (ConvertWeightData(bw_gate_kernel, bw_cand_kernel, input_size, hidden_size, gate_tensor_data + gate_data_diff,
674                         recu_tensor_data + recu_data_diff) != RET_OK) {
675     return nullptr;
676   }
677 
678   if (ConvertBiasData(fw_gate_bias, fw_cand_bias, hidden_size, bias_tensor_data) != RET_OK) {
679     return nullptr;
680   }
681   auto bias_data_diff = hidden_size * 6;
682   if (ConvertBiasData(bw_gate_bias, bw_cand_bias, hidden_size, bias_tensor_data + bias_data_diff) != RET_OK) {
683     return nullptr;
684   }
685   std::vector<AnfNodePtr> new_node_inputs = {value_node, input,          gate_weight, recu_weight,
686                                              bias,       stacked_hidden, input_length};
687   auto new_node = func_graph->NewCNode(new_node_inputs);
688   MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
689   auto prim = GetValueNode<PrimitivePtr>(new_node->input(0));
690   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
691   prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format::NHWC));
692   new_node->set_fullname_with_scope(base_name);
693   return new_node;
694 }
695 
GetPostProcessNode(const FuncGraphPtr & func_graph,const CNodePtr & gru_output,const std::string & base_name)696 CNodePtr TfBidirectionGruFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
697                                                     const std::string &base_name) {
698   MS_ASSERT(func_graph != nullptr);
699   MS_ASSERT(gru_output != nullptr);
700   auto split_prim = std::make_shared<ops::Split>();
701   MS_CHECK_TRUE_RET(split_prim != nullptr, nullptr);
702   auto split_prim_c = split_prim->GetPrim();
703   MS_CHECK_TRUE_RET(split_prim_c != nullptr, nullptr);
704   split_prim->set_output_num(2);
705   split_prim->set_axis(1);
706   auto split_value_node = NewValueNode(split_prim_c);
707   MS_CHECK_TRUE_RET(split_value_node != nullptr, nullptr);
708   std::vector<AnfNodePtr> new_node_inputs = {split_value_node, gru_output};
709   auto split_new_node = func_graph->NewCNode(new_node_inputs);
710   MS_CHECK_TRUE_RET(split_new_node != nullptr, nullptr);
711   split_new_node->set_fullname_with_scope("split_" + base_name);
712   if (TfliteLstmCellFusion::SetAbstractTuple(split_new_node, 2) != RET_OK) {
713     return nullptr;
714   }
715 
716   auto split_out1 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 0);
717   if (split_out1 == nullptr) {
718     return nullptr;
719   }
720   auto split_out2 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 1);
721   if (split_out2 == nullptr) {
722     return nullptr;
723   }
724 
725   auto concat_prim = std::make_shared<ops::Concat>();
726   MS_CHECK_TRUE_RET(concat_prim != nullptr, nullptr);
727   auto concat_prim_c = concat_prim->GetPrim();
728   MS_CHECK_TRUE_RET(concat_prim_c != nullptr, nullptr);
729   concat_prim->set_axis(3);
730   auto concat_value_node = NewValueNode(concat_prim_c);
731   MS_CHECK_TRUE_RET(concat_value_node != nullptr, nullptr);
732   std::vector<AnfNodePtr> concat_new_node_inputs = {concat_value_node, split_out1, split_out2};
733   auto concat_new_node = func_graph->NewCNode(concat_new_node_inputs);
734   MS_CHECK_TRUE_RET(concat_new_node != nullptr, nullptr);
735   concat_new_node->set_fullname_with_scope("concat_" + base_name);
736   if (gru_output->abstract() != nullptr) {
737     concat_new_node->set_abstract(gru_output->abstract()->Clone());
738   }
739 
740   auto squeeze_prim = std::make_shared<ops::Squeeze>();
741   MS_CHECK_TRUE_RET(squeeze_prim != nullptr, nullptr);
742   auto squeeze_prim_c = squeeze_prim->GetPrim();
743   MS_CHECK_TRUE_RET(squeeze_prim_c != nullptr, nullptr);
744   squeeze_prim->set_axis(std::vector<int64_t>{1});
745   auto squeeze_value_node = NewValueNode(squeeze_prim_c);
746   MS_CHECK_TRUE_RET(squeeze_value_node != nullptr, nullptr);
747   std::vector<AnfNodePtr> squeeze_new_node_inputs = {squeeze_value_node, concat_new_node};
748   auto squeeze_new_node = func_graph->NewCNode(squeeze_new_node_inputs);
749   MS_CHECK_TRUE_RET(squeeze_new_node != nullptr, nullptr);
750   squeeze_new_node->set_fullname_with_scope("squeeze_" + base_name);
751   if (gru_output->abstract() != nullptr) {
752     squeeze_new_node->set_abstract(gru_output->abstract()->Clone());
753   }
754 
755   auto transpose_prim = std::make_shared<ops::Transpose>();
756   MS_CHECK_TRUE_RET(transpose_prim != nullptr, nullptr);
757   auto transpose_perm = BuildIntVecParameterNode(func_graph, {1, 0, 2}, "transpose_" + base_name + "_perm");
758   MS_CHECK_TRUE_RET(transpose_perm != nullptr, nullptr);
759   auto transpose_prim_c = transpose_prim->GetPrim();
760   MS_CHECK_TRUE_RET(transpose_prim_c != nullptr, nullptr);
761   auto transpose_new_node = func_graph->NewCNode(transpose_prim_c, {squeeze_new_node, transpose_perm});
762   MS_CHECK_TRUE_RET(transpose_new_node != nullptr, nullptr);
763   transpose_new_node->set_fullname_with_scope("transpose_" + base_name);
764   if (gru_output->abstract() != nullptr) {
765     transpose_new_node->set_abstract(gru_output->abstract()->Clone());
766   }
767 
768   return transpose_new_node;
769 }
770 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & concat_node,const EquivPtr & equiv) const771 const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
772                                                  const EquivPtr &equiv) const {
773   if (func_graph == nullptr || concat_node == nullptr) {
774     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
775     return nullptr;
776   }
777 
778   auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]);
779   MS_ASSERT(transpose_input != nullptr);
780   if (!utils::isa<CNodePtr>(transpose_input) || !CheckPrimitiveType(transpose_input, prim::kPrimTranspose)) {
781     return nullptr;
782   }
783 
784   auto fw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>();
785   MS_CHECK_TRUE_RET(fw_cond_primitive_vars != nullptr, nullptr);
786   auto fw_cond_graph_pattern = GetCondGraphPattern(fw_cond_primitive_vars);
787   MS_CHECK_TRUE_RET(fw_cond_graph_pattern != nullptr, nullptr);
788   auto fw_cond = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[0]]);
789   MS_ASSERT(fw_cond != nullptr);
790   auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(fw_cond_graph_pattern, fw_cond_primitive_vars, fw_cond,
791                                                            kCondCNodesNum, kCondNodesNum);
792   if (fw_cond_equiv == nullptr || fw_cond_equiv->empty()) {
793     return nullptr;
794   }
795 
796   auto bw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>();
797   MS_CHECK_TRUE_RET(bw_cond_primitive_vars != nullptr, nullptr);
798   auto bw_cond_graph_pattern = GetCondGraphPattern(bw_cond_primitive_vars);
799   MS_CHECK_TRUE_RET(bw_cond_graph_pattern != nullptr, nullptr);
800   auto bw_cond = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[0]]);
801   MS_ASSERT(bw_cond != nullptr);
802   auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(bw_cond_graph_pattern, bw_cond_primitive_vars, bw_cond,
803                                                            kCondCNodesNum, kCondNodesNum);
804   if (bw_cond_equiv == nullptr || bw_cond_equiv->empty()) {
805     return nullptr;
806   }
807 
808   auto fw_primitive_vars_body = std::make_shared<PrimitiveVarMap>();
809   MS_CHECK_TRUE_RET(fw_primitive_vars_body != nullptr, nullptr);
810   auto fw_body_graph_pattern = GetBodyGraphPattern(fw_primitive_vars_body);
811   MS_CHECK_TRUE_RET(fw_body_graph_pattern != nullptr, nullptr);
812   auto fw_body = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[1]]);
813   MS_ASSERT(fw_body != nullptr);
814   auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(fw_body_graph_pattern, fw_primitive_vars_body, fw_body,
815                                                            kBodyCNodesNum, kBodyNodesNum);
816   if (fw_body_equiv == nullptr || fw_body_equiv->empty()) {
817     return nullptr;
818   }
819 
820   auto bw_primitive_vars_body = std::make_shared<PrimitiveVarMap>();
821   MS_CHECK_TRUE_RET(bw_primitive_vars_body != nullptr, nullptr);
822   auto bw_body_graph_pattern = GetBodyGraphPattern(bw_primitive_vars_body);
823   MS_CHECK_TRUE_RET(bw_body_graph_pattern != nullptr, nullptr);
824   auto bw_body = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[1]]);
825   MS_ASSERT(bw_body != nullptr);
826   auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(bw_body_graph_pattern, bw_primitive_vars_body, bw_body,
827                                                            kBodyCNodesNum, kBodyNodesNum);
828   if (bw_body_equiv == nullptr || bw_body_equiv->empty()) {
829     return nullptr;
830   }
831 
832   const std::string gru_name = "gru_" + concat_node->fullname_with_scope();
833   auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 2);
834   if (gru_node == nullptr) {
835     return nullptr;
836   }
837   if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) {
838     return nullptr;
839   }
840 
841   auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0);
842   if (get_item_node == nullptr) {
843     return nullptr;
844   }
845 
846   auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope());
847   MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success";
848   return output_node;
849 }
850 }  // namespace opt
851 }  // namespace mindspore
852