• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
19 #include <memory>
20 #include <algorithm>
21 #include <functional>
22 #include "mindspore/core/ops/structure_ops.h"
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/comparison_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "ops/lstm.h"
27 #include "ops/squeeze.h"
28 #include "ops/tuple_get_item.h"
29 #include "src/common/utils.h"
30 #include "tools/common/tensor_util.h"
31 #include "include/common/utils/utils.h"
32 #include "tools/optimizer/common/gllo_utils.h"
33 #include "tools/optimizer/common/helper.h"
34 #include "securec/include/securec.h"
35 #include "nnacl/op_base.h"
36 
37 namespace mindspore {
38 namespace opt {
39 namespace {
40 constexpr size_t kWhileInputsLength = 23;
41 constexpr size_t kWhileInputsVarNum = 21;
42 constexpr size_t kCondNodesNum = 12;
43 constexpr size_t kCondCNodesNum = 4;
44 constexpr size_t kBodyNodesNum = 95;
45 constexpr size_t kBodyCNodesNum = 34;
46 constexpr size_t kLSTMOutputNum = 3;
47 constexpr size_t kPlaceholderMinSize = 20;
48 constexpr auto kUnidirectionalGateNum = 4;
49 const auto &p1 = std::placeholders::_1;
50 constexpr float EPSILON = 1e-5;
IsParameterNode(const BaseRef & n)51 bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
52 
GenerateBodyGraphCellPattern(const std::vector<CondVarPtr> & placeholders)53 std::vector<VectorRef> GenerateBodyGraphCellPattern(const std::vector<CondVarPtr> &placeholders) {
54   MS_CHECK_TRUE_RET(placeholders.size() >= kPlaceholderMinSize, {});
55   auto is_var1 = std::make_shared<Var>();
56   MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
57   VectorRef concat_i_w = VectorRef({is_var1, placeholders[8], placeholders[12]});
58   auto is_var2 = std::make_shared<Var>();
59   MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
60   VectorRef concat_f_w = VectorRef({is_var2, placeholders[9], placeholders[13]});
61   auto is_var3 = std::make_shared<Var>();
62   MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
63   VectorRef concat_c_w = VectorRef({is_var3, placeholders[10], placeholders[14]});
64   auto is_var4 = std::make_shared<Var>();
65   MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
66   VectorRef concat_o_w = VectorRef({is_var4, placeholders[11], placeholders[15]});
67 
68   auto is_var_getitem = std::make_shared<Var>("GetItem");
69   MS_CHECK_TRUE_RET(is_var_getitem != nullptr, {});
70   auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
71   MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
72   VectorRef get_item = VectorRef({is_var_getitem, placeholders[7], placeholders[2], is_param3});
73   auto is_var5 = std::make_shared<Var>();
74   MS_CHECK_TRUE_RET(is_var5 != nullptr, {});
75   VectorRef concat_input_h = VectorRef({is_var5, get_item, placeholders[5]});
76 
77   auto is_var6 = std::make_shared<Var>();
78   MS_CHECK_TRUE_RET(is_var6 != nullptr, {});
79   VectorRef matmul_input = VectorRef({is_var6, concat_input_h, concat_i_w});
80   auto is_var7 = std::make_shared<Var>();
81   MS_CHECK_TRUE_RET(is_var7 != nullptr, {});
82   VectorRef matmul_forget = VectorRef({is_var7, concat_input_h, concat_f_w});
83   auto is_var8 = std::make_shared<Var>();
84   MS_CHECK_TRUE_RET(is_var8 != nullptr, {});
85   VectorRef matmul_cell = VectorRef({is_var8, concat_input_h, concat_c_w});
86   auto is_var9 = std::make_shared<Var>();
87   MS_CHECK_TRUE_RET(is_var9 != nullptr, {});
88   VectorRef matmul_output = VectorRef({is_var9, concat_input_h, concat_o_w});
89 
90   auto is_var10 = std::make_shared<Var>();
91   MS_CHECK_TRUE_RET(is_var10 != nullptr, {});
92   VectorRef bias_input = VectorRef({is_var10, matmul_input, placeholders[16]});
93   auto is_var11 = std::make_shared<Var>();
94   MS_CHECK_TRUE_RET(is_var11 != nullptr, {});
95   VectorRef bias_forget = VectorRef({is_var11, matmul_forget, placeholders[17]});
96   auto is_var12 = std::make_shared<Var>();
97   MS_CHECK_TRUE_RET(is_var12 != nullptr, {});
98   VectorRef bias_cell = VectorRef({is_var12, matmul_cell, placeholders[18]});
99   auto is_var13 = std::make_shared<Var>();
100   MS_CHECK_TRUE_RET(is_var13 != nullptr, {});
101   VectorRef bias_output = VectorRef({is_var13, matmul_output, placeholders[19]});
102 
103   auto is_var_tanh = std::make_shared<Var>("Tanh");
104   MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
105   VectorRef cell = VectorRef({is_var_tanh, bias_cell});
106   auto is_var_sigmoid1 = std::make_shared<Var>("Sigmoid");
107   MS_CHECK_TRUE_RET(is_var_sigmoid1 != nullptr, {});
108   VectorRef input_gate = VectorRef({is_var_sigmoid1, bias_input});
109   auto is_var_mul1 = std::make_shared<Var>("Mul");
110   MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
111   VectorRef cell_input = VectorRef({is_var_mul1, input_gate, cell});
112   auto is_var_sigmoid2 = std::make_shared<Var>("Sigmoid");
113   MS_CHECK_TRUE_RET(is_var_sigmoid2 != nullptr, {});
114   VectorRef forget_gate = VectorRef({is_var_sigmoid2, bias_forget});
115   auto is_var_mul2 = std::make_shared<Var>("Mul");
116   MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
117   VectorRef cell_forgeted = VectorRef({is_var_mul2, forget_gate, placeholders[4]});
118   auto is_var_add = std::make_shared<Var>("Add");
119   MS_CHECK_TRUE_RET(is_var_add != nullptr, {});
120   VectorRef cell_new = VectorRef({is_var_add, cell_forgeted, cell_input});
121   return {bias_output, cell_new};
122 }
123 }  // namespace
124 
GetFloatScalarFromTensorInfo(const AnfNodePtr & tensor_info,float * v)125 STATUS TfliteLstmCellFusion::GetFloatScalarFromTensorInfo(const AnfNodePtr &tensor_info, float *v) {
126   if (tensor_info == nullptr || v == nullptr) {
127     MS_LOG(ERROR) << "tensor_info or v is nullptr";
128     return RET_ERROR;
129   }
130   if (!utils::isa<ParameterPtr>(tensor_info)) {
131     MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
132     return RET_ERROR;
133   }
134   auto param_ptr = utils::cast<ParameterPtr>(tensor_info);
135   if (!param_ptr->has_default() || param_ptr->default_param() == nullptr) {
136     MS_LOG(DEBUG) << "param not have default";
137     return RET_ERROR;
138   }
139   auto default_param = param_ptr->default_param();
140   if (!utils::isa<tensor::TensorPtr>(default_param)) {
141     MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
142     return RET_ERROR;
143   }
144   auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
145   auto tensor_shape = default_param_ptr->shape();
146   if (!(tensor_shape.empty() || (tensor_shape.size() == 1 && tensor_shape[0] == 1))) {
147     MS_LOG(DEBUG) << "default param is not scalar";
148     return RET_ERROR;
149   }
150   if (default_param_ptr->data_type() != kNumberTypeFloat32 && default_param_ptr->data_type() != kNumberTypeFloat) {
151     MS_LOG(DEBUG) << "default param is not float";
152     return RET_ERROR;
153   }
154   *v = *(reinterpret_cast<float *>(default_param_ptr->data_c()));
155   return RET_OK;
156 }
157 
Init() const158 bool TfliteLstmCellFusion::Init() const {
159   for (size_t i = 0; i < this->while_input_var_num_; ++i) {
160     auto is_var = std::make_shared<Var>();
161     MS_CHECK_TRUE_RET(is_var != nullptr, false);
162     while_input_vars_.emplace_back(is_var);
163   }
164   cell_zoneout_old_ = std::make_shared<Var>();
165   MS_CHECK_TRUE_RET(cell_zoneout_old_ != nullptr, false);
166   cell_zoneout_new_ = std::make_shared<Var>();
167   MS_CHECK_TRUE_RET(cell_zoneout_new_ != nullptr, false);
168   hidden_zoneout_old_ = std::make_shared<Var>();
169   MS_CHECK_TRUE_RET(hidden_zoneout_old_ != nullptr, false);
170   hidden_zoneout_new_ = std::make_shared<Var>();
171   MS_CHECK_TRUE_RET(hidden_zoneout_new_ != nullptr, false);
172   return true;
173 }
174 
TfliteLstmCellFusion(const std::string & name,bool multigraph,int input_length,int var_num,int cond_nodes_num,int cond_cnodes_num,int body_nodes_num,int body_cnodes_num)175 TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num,
176                                            int cond_nodes_num, int cond_cnodes_num, int body_nodes_num,
177                                            int body_cnodes_num)
178     : LitePatternProcessPass(name, multigraph) {
179   /*
180    * input vars for lstm while node
181    * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
182    * 9:i2i_  10:i2f_ 11:i2c_ 12:i2o_   13:c2i_   14:c2f_ 15:c2c_   16:c2o_   17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
183    */
184   this->while_inputs_num_ = input_length == 0 ? kWhileInputsLength : input_length;
185   this->while_input_var_num_ = var_num == 0 ? kWhileInputsVarNum : var_num;
186   this->cond_nodes_num_ = cond_nodes_num == 0 ? kCondNodesNum : cond_nodes_num;
187   this->cond_cnodes_num_ = cond_cnodes_num == 0 ? kCondCNodesNum : cond_cnodes_num;
188   this->body_nodes_num_ = body_nodes_num == 0 ? kBodyNodesNum : body_nodes_num;
189   this->body_cnodes_num_ = body_cnodes_num == 0 ? kBodyCNodesNum : body_cnodes_num;
190 }
191 
GetCondGraphPattern(const PrimitiveVarMapPtr & primitive_vars)192 AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) {
193   MS_ASSERT(primitive_vars != nullptr);
194   auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode);
195   MS_CHECK_TRUE_RET(is_parameter1 != nullptr, nullptr);
196   auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode);
197   MS_CHECK_TRUE_RET(is_parameter2 != nullptr, nullptr);
198   auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode);
199   MS_CHECK_TRUE_RET(is_parameter3 != nullptr, nullptr);
200   auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode);
201   MS_CHECK_TRUE_RET(is_parameter4 != nullptr, nullptr);
202   auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
203   MS_CHECK_TRUE_RET(is_less1 != nullptr, nullptr);
204   auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
205   MS_CHECK_TRUE_RET(is_less2 != nullptr, nullptr);
206   auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd));
207   MS_CHECK_TRUE_RET(is_logical_and != nullptr, nullptr);
208   auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
209   MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
210   VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2});
211   VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4});
212   VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
213   VectorRef return_ref = VectorRef({is_return, logicaland_ref});
214   VarPtr fg = std::make_shared<Var>("RootG");
215   MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
216   auto pattern = Helper::SexpToNode(return_ref, fg, primitive_vars.get(), true);
217   return pattern;
218 }
219 
GetBodyGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const220 AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
221   std::vector<CondVarPtr> placeholders;
222   for (int i = 0; i < 20; ++i) {
223     auto cond_var_node = std::make_shared<CondVar>(IsParameterNode);
224     MS_CHECK_TRUE_RET(cond_var_node != nullptr, {});
225     placeholders.emplace_back(cond_var_node);
226   }
227   auto is_var1 = std::make_shared<Var>();
228   MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
229   auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
230   MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
231   VectorRef add2 = VectorRef({is_var1, placeholders[2], is_param1});
232   auto is_var2 = std::make_shared<Var>();
233   MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
234   auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
235   MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
236   VectorRef add3 = VectorRef({is_var2, placeholders[0], is_param2});
237 
238   auto hidden_cells = GenerateBodyGraphCellPattern(placeholders);
239   MS_CHECK_TRUE_RET(hidden_cells.size() == kInputSizeTwo, {});
240   auto is_var_mul1 = std::make_shared<Var>("Mul");
241   MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
242   VectorRef zoneout_cell_old = VectorRef({is_var_mul1, cell_zoneout_old_, placeholders[4]});
243   auto is_var_mul2 = std::make_shared<Var>("Mul");
244   MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
245   auto cell_new = hidden_cells[1];
246   MS_CHECK_TRUE_RET(!cell_new.empty(), {});
247   VectorRef zoneout_cell_new = VectorRef({is_var_mul2, cell_zoneout_new_, cell_new});
248   auto is_var_add1 = std::make_shared<Var>("Add");
249   MS_CHECK_TRUE_RET(is_var_add1 != nullptr, {});
250   VectorRef cell_output = VectorRef({is_var_add1, zoneout_cell_new, zoneout_cell_old});
251 
252   auto is_var_sigmoid = std::make_shared<Var>("Sigmoid");
253   MS_CHECK_TRUE_RET(is_var_sigmoid != nullptr, {});
254   auto bias_output = hidden_cells[0];
255   MS_CHECK_TRUE_RET(!bias_output.empty(), {});
256   VectorRef output_gate = VectorRef({is_var_sigmoid, bias_output});
257   auto is_var_tanh = std::make_shared<Var>("Tanh");
258   MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
259   VectorRef cell_to_output = VectorRef({is_var_tanh, cell_new});
260   auto is_var_mul3 = std::make_shared<Var>("Mul");
261   MS_CHECK_TRUE_RET(is_var_mul3 != nullptr, {});
262   VectorRef output = VectorRef({is_var_mul3, output_gate, cell_to_output});
263 
264   auto is_var_mul4 = std::make_shared<Var>("Mul");
265   MS_CHECK_TRUE_RET(is_var_mul4 != nullptr, {});
266   VectorRef zoneout_hidden_old = VectorRef({is_var_mul4, hidden_zoneout_old_, placeholders[5]});
267   auto is_var_mul5 = std::make_shared<Var>("Mul");
268   MS_CHECK_TRUE_RET(is_var_mul5 != nullptr, {});
269   VectorRef zoneout_hidden_new = VectorRef({is_var_mul5, hidden_zoneout_new_, output});
270   auto is_var_add2 = std::make_shared<Var>("Add");
271   MS_CHECK_TRUE_RET(is_var_add2 != nullptr, {});
272   VectorRef hidden_output = VectorRef({is_var_add2, zoneout_hidden_new, zoneout_hidden_old});
273 
274   auto is_var_setitem = std::make_shared<Var>("SetItem");
275   MS_CHECK_TRUE_RET(is_var_setitem != nullptr, {});
276   VectorRef set_item = VectorRef({is_var_setitem, placeholders[3], placeholders[2], output});
277 
278   auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
279   MS_CHECK_TRUE_RET(is_make_tuple != nullptr, nullptr);
280   std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output};
281   outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
282   VectorRef make_tuple_node = VectorRef(outputs);
283   auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
284   MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
285   VectorRef return_node = VectorRef({is_return, make_tuple_node});
286 
287   VarPtr fg = std::make_shared<Var>("RootG");
288   MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
289   auto pattern = Helper::SexpToNode(return_node, fg, primitive_vars.get(), true);
290   return pattern;
291 }
292 
DefinePattern() const293 const BaseRef TfliteLstmCellFusion::DefinePattern() const {
294   if (!Init()) {
295     MS_LOG(ERROR) << "initial member failed.";
296     return {};
297   }
298   auto is_while_node = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimWhile));
299   MS_CHECK_TRUE_RET(is_while_node != nullptr, {});
300   VectorRef while_node = VectorRef({is_while_node});
301   auto while_inputs = while_input_vars_;
302   MS_CHECK_TRUE_RET(while_inputs.size() > kInputSizeThree, {});
303   while_inputs.insert(while_inputs.begin() + 4, while_input_vars_[2]);
304   while_node.insert(while_node.end(), while_inputs.begin(), while_inputs.end());
305 
306   auto is_tuple_get_item = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem));
307   MS_CHECK_TRUE_RET(is_tuple_get_item != nullptr, {});
308   auto is_var = std::make_shared<Var>();
309   MS_CHECK_TRUE_RET(is_var != nullptr, {});
310   VectorRef while_output = VectorRef({is_tuple_get_item, while_node, is_var});
311 
312   auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListStack));
313   MS_CHECK_TRUE_RET(is_tensor_list_stack != nullptr, {});
314   auto is_parameter = std::make_shared<CondVar>(IsParameterNode);
315   MS_CHECK_TRUE_RET(is_parameter != nullptr, {});
316   VectorRef tensor_list_stack_node = VectorRef({is_tensor_list_stack, while_output, is_parameter});
317 
318   return tensor_list_stack_node;
319 }
320 
MatchGraph(const FuncGraphPtr & func_graph,const PrimitiveVarMapPtr & primitive_vars,const AnfNodePtr & pattern)321 EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars,
322                                           const AnfNodePtr &pattern) {
323   MS_ASSERT(func_graph != nullptr);
324   MS_ASSERT(pattern != nullptr);
325   auto return_node = func_graph->get_return();
326   auto visitor = std::make_shared<Visitor>();
327   MS_CHECK_TRUE_RET(visitor != nullptr, nullptr);
328   PatternEngine pattern_engine(visitor);
329   auto empty_equiv = std::make_shared<Equiv>();
330   MS_CHECK_TRUE_RET(empty_equiv != nullptr, nullptr);
331   EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv);
332   return equiv;
333 }
334 
335 // make sure that only 3,4,5 output of while is referenced
CheckReferencedOutputs(const FuncGraphPtr & func_graph,const CNodePtr & while_cnode)336 bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) {
337   MS_ASSERT(func_graph != nullptr);
338   MS_ASSERT(while_cnode != nullptr);
339   auto manager = func_graph->manager();
340   if (manager == nullptr) {
341     MS_LOG(ERROR) << "manager is nullptr";
342     return false;
343   }
344   auto while_node_users = manager->node_users()[while_cnode];
345   std::vector<size_t> valid_indexes{3, 4, 5};
346   for (auto &node_user : while_node_users) {
347     if (!utils::isa<CNodePtr>(node_user.first)) {
348       return false;
349     }
350     auto cnode = utils::cast<CNodePtr>(node_user.first);
351     if (IsMarkedTrainOp(cnode)) {
352       return false;
353     }
354     if (!CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
355       return false;
356     }
357     auto index = GetTupleGetItemOutIndex(cnode);
358     if (!lite::IsContain(valid_indexes, index)) {
359       return false;
360     }
361   }
362   return true;
363 }
364 
CheckSubGraph(const AnfNodePtr & pattern,const PrimitiveVarMapPtr & primitive_vars,const AnfNodePtr & anf_sub_graph,const size_t cnode_num,const size_t all_node_num)365 EquivPtr TfliteLstmCellFusion::CheckSubGraph(const AnfNodePtr &pattern, const PrimitiveVarMapPtr &primitive_vars,
366                                              const AnfNodePtr &anf_sub_graph, const size_t cnode_num,
367                                              const size_t all_node_num) {
368   MS_ASSERT(func_graph != nullptr);
369   MS_ASSERT(pattern != nullptr);
370   MS_ASSERT(primitive_vars != nullptr);
371   MS_ASSERT(anf_sub_graph != nullptr);
372   auto sub_graph = GetValueNode<FuncGraphPtr>(anf_sub_graph);
373   MS_CHECK_TRUE_RET(sub_graph != nullptr, nullptr);
374   auto nodes = TopoSort(sub_graph->get_return());
375   auto cnodes = sub_graph->GetOrderedCnodes();
376   if (cnodes.size() != cnode_num || nodes.size() != all_node_num) {
377     MS_LOG(DEBUG) << "sub graph nodes num not match";
378     return nullptr;
379   }
380   return MatchGraph(sub_graph, primitive_vars, pattern);
381 }
382 
CheckBodyGraph(const EquivPtr & equiv,float * zoneout_cell,float * zoneout_hidden) const383 bool TfliteLstmCellFusion::CheckBodyGraph(const EquivPtr &equiv, float *zoneout_cell, float *zoneout_hidden) const {
384   MS_ASSERT(func_graph != nullptr);
385   MS_ASSERT(equiv != nullptr);
386   MS_ASSERT(while_cnode != nullptr);
387   MS_ASSERT(zoneout_cell != nullptr);
388   MS_ASSERT(zoneout_hidden != nullptr);
389 
390   auto cell_zoneout_old_node = utils::cast<AnfNodePtr>((*equiv)[cell_zoneout_old_]);
391   MS_ASSERT(cell_zoneout_old_node != nullptr);
392   auto cell_zoneout_new_node = utils::cast<AnfNodePtr>((*equiv)[cell_zoneout_new_]);
393   MS_ASSERT(cell_zoneout_new_node != nullptr);
394   auto hidden_zoneout_old_node = utils::cast<AnfNodePtr>((*equiv)[hidden_zoneout_old_]);
395   MS_ASSERT(hidden_zoneout_old_node != nullptr);
396   auto hidden_zoneout_new_node = utils::cast<AnfNodePtr>((*equiv)[hidden_zoneout_new_]);
397   MS_ASSERT(hidden_zoneout_new_node != nullptr);
398 
399   float cell_old;
400   float cell_new;
401   float hidden_old;
402   float hidden_new;
403   if (GetFloatScalarFromTensorInfo(cell_zoneout_old_node, &cell_old) != RET_OK) {
404     return false;
405   }
406   if (GetFloatScalarFromTensorInfo(cell_zoneout_new_node, &cell_new) != RET_OK) {
407     return false;
408   }
409   if (GetFloatScalarFromTensorInfo(hidden_zoneout_old_node, &hidden_old) != RET_OK) {
410     return false;
411   }
412   if (GetFloatScalarFromTensorInfo(hidden_zoneout_new_node, &hidden_new) != RET_OK) {
413     return false;
414   }
415   if (cell_old < 0.0f || cell_old > 1.0f || cell_new < 0.0f || cell_new > 1.0f) {
416     MS_LOG(DEBUG) << "cell zoneout value illegal";
417     return false;
418   }
419   if (hidden_old < 0.0f || hidden_old > 1.0f || hidden_new < 0.0f || hidden_new > 1.0f) {
420     MS_LOG(DEBUG) << "hidden zoneout value illegal";
421     return false;
422   }
423   if (std::abs(cell_old + cell_new - 1.0f) > EPSILON || std::abs(hidden_old + hidden_new - 1.0f) > EPSILON ||
424       std::abs(cell_old - hidden_old) > EPSILON) {
425     MS_LOG(DEBUG) << "zoneout value illegal";
426     return false;
427   }
428   *zoneout_cell = cell_old;
429   *zoneout_hidden = hidden_old;
430   return true;
431 }
432 
GetConcatedParam(const std::vector<AnfNodePtr> & params,const ParameterPtr & new_param,bool is_bias)433 STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> &params, const ParameterPtr &new_param,
434                                               bool is_bias) {
435   MS_ASSERT(new_param != nullptr);
436   MS_ASSERT(params.size() == 4);
437   std::vector<float *> data_ptrs;
438   std::vector<std::vector<int64_t>> data_shapes;
439   for (auto &param : params) {
440     if (!utils::isa<ParameterPtr>(param)) {
441       MS_LOG(DEBUG) << "param is not Parameter node";
442       return RET_FAILED;
443     }
444     auto param_t = utils::cast<ParameterPtr>(param);
445     if (!param_t->has_default() || param_t->default_param() == nullptr) {
446       MS_LOG(DEBUG) << "param not have default value";
447       return RET_FAILED;
448     }
449     if (!utils::isa<tensor::TensorPtr>(param_t->default_param())) {
450       MS_LOG(DEBUG) << "default value is not tensor::Tensor";
451       return RET_FAILED;
452     }
453     auto origin_tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_t->default_param());
454     if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) {
455       MS_LOG(DEBUG) << "origin_tensor is not float32 type";
456       return RET_FAILED;
457     }
458     auto data_ptr = reinterpret_cast<float *>(origin_tensor->data_c());
459     auto data_shape = origin_tensor->shape();
460     data_ptrs.push_back(data_ptr);
461     data_shapes.push_back(data_shape);
462   }
463 
464   for (size_t i = 1; i < data_shapes.size(); ++i) {
465     if (data_shapes[i] != data_shapes[0]) {
466       MS_LOG(DEBUG) << "data shape not same";
467       return RET_FAILED;
468     }
469   }
470   std::vector<int64_t> new_shape;
471   int step = 0;
472   int data_size = 0;
473   MS_ASSERT(!data_shapes.empty());
474   if (is_bias) {
475     if (data_shapes[0].size() != 1) {
476       MS_LOG(ERROR) << "bias data shape error";
477       return RET_ERROR;
478     }
479     step = static_cast<int>(data_shapes[0][0]);
480     MS_CHECK_INT_MUL_NOT_OVERFLOW(C8NUM, step, RET_ERROR);
481     data_size = C8NUM * step;
482     new_shape = std::vector<int64_t>({1, data_size});
483 
484   } else {
485     if (data_shapes[0].size() != 2) {
486       MS_LOG(ERROR) << "weight data shape error";
487       return RET_ERROR;
488     }
489     new_shape = std::vector<int64_t>({1, data_shapes[0][0] * kUnidirectionalGateNum, data_shapes[0][1]});
490     MS_CHECK_INT_MUL_NOT_OVERFLOW(data_shapes[0][0], data_shapes[0][1], RET_ERROR);
491     step = static_cast<int>(data_shapes[0][0] * data_shapes[0][1]);
492     MS_CHECK_INT_MUL_NOT_OVERFLOW(C4NUM, step, RET_ERROR);
493     data_size = C4NUM * step;
494   }
495 
496   auto tensor_info = lite::CreateTensorInfo(nullptr, 0, new_shape, kNumberTypeFloat32);
497   if (tensor_info == nullptr) {
498     MS_LOG(ERROR) << "create tensor info failed.";
499     return RET_ERROR;
500   }
501 
502   auto tensor_data = static_cast<float *>(tensor_info->data_c());
503   for (int i = 0; i < data_size; ++i) {  // bias are stored into first 4*hidden_size buffer, the rest is all 0
504     tensor_data[i] = 0.0f;
505   }
506 
507   for (size_t i = 0; i < data_ptrs.size(); ++i) {
508     auto source_len = std::accumulate(data_shapes[i].begin(), data_shapes[i].end(), 1, std::multiplies<int>());
509     auto ret = memcpy_s(tensor_data + i * step, step * sizeof(float), data_ptrs[i], source_len * sizeof(float));
510     if (ret != EOK) {
511       MS_LOG(ERROR) << "memcpy_s error";
512       return RET_ERROR;
513     }
514   }
515 
516   auto status = lite::InitParameterFromTensorInfo(new_param, tensor_info);
517   if (status != RET_OK) {
518     MS_LOG(ERROR) << "init parameter from tensor info failed";
519     return RET_ERROR;
520   }
521 
522   return RET_OK;
523 }
524 
CreateLSTMNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const EquivPtr & body_equiv,const std::string & base_name,const float zoneout_cell,const float zoneout_hidden) const525 CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
526                                               const EquivPtr &body_equiv, const std::string &base_name,
527                                               const float zoneout_cell, const float zoneout_hidden) const {
528   MS_ASSERT(func_graph != nullptr);
529   MS_ASSERT(equiv != nullptr);
530   MS_ASSERT(body_equiv != nullptr);
531   /*
532    * input vars for while node
533    * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
534    * 9:i2i_  10:i2f_ 11:i2c_ 12:i2o_   13:c2i_   14:c2f_ 15:c2c_   16:c2o_   17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
535    */
536   auto lstm_prim = std::make_shared<ops::LSTM>();
537   MS_CHECK_TRUE_RET(lstm_prim != nullptr, nullptr);
538   auto lstm_prim_c = lstm_prim->GetPrim();
539   MS_CHECK_TRUE_RET(lstm_prim_c != nullptr, nullptr);
540   lstm_prim->set_bidirectional(false);
541   lstm_prim->set_zoneout_cell(zoneout_cell);
542   lstm_prim->set_zoneout_hidden(zoneout_hidden);
543   auto value_node = NewValueNode(lstm_prim_c);
544   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
545 
546   auto &vars = while_input_vars_;
547   auto i2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
548   MS_ASSERT(i2i_weight);
549   auto i2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
550   MS_ASSERT(i2f_weight);
551   auto i2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[11]]);
552   MS_ASSERT(i2c_weight);
553   auto i2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[12]]);
554   MS_ASSERT(i2o_weight);
555 
556   auto c2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[13]]);
557   MS_ASSERT(c2i_weight);
558   auto c2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[14]]);
559   MS_ASSERT(c2f_weight);
560   auto c2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[15]]);
561   MS_ASSERT(c2c_weight);
562   auto c2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[16]]);
563   MS_ASSERT(c2o_weight);
564 
565   auto i_bias = utils::cast<AnfNodePtr>((*equiv)[vars[17]]);
566   MS_ASSERT(i_bias);
567   auto f_bias = utils::cast<AnfNodePtr>((*equiv)[vars[18]]);
568   MS_ASSERT(f_bias);
569   auto c_bias = utils::cast<AnfNodePtr>((*equiv)[vars[19]]);
570   MS_ASSERT(c_bias);
571   auto o_bias = utils::cast<AnfNodePtr>((*equiv)[vars[20]]);
572   MS_ASSERT(o_bias);
573 
574   auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]);
575   MS_ASSERT(input);
576   auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]);
577   MS_ASSERT(cell);
578   auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]);
579   MS_ASSERT(hidden);
580 
581   std::vector<AnfNodePtr> i_weights{i2i_weight, i2o_weight, i2f_weight, i2c_weight};
582   auto i_weight = func_graph->add_parameter();
583   MS_CHECK_TRUE_RET(i_weight != nullptr, nullptr);
584   auto status = GetConcatedParam(i_weights, i_weight, false);
585   if (status != RET_OK) {
586     return nullptr;
587   }
588   i_weight->set_name(base_name + "_weight_i");
589 
590   std::vector<AnfNodePtr> c_weights{c2i_weight, c2o_weight, c2f_weight, c2c_weight};
591   auto c_weight = func_graph->add_parameter();
592   MS_CHECK_TRUE_RET(c_weight != nullptr, nullptr);
593   status = GetConcatedParam(c_weights, c_weight, false);
594   if (status != RET_OK) {
595     return nullptr;
596   }
597   c_weight->set_name(base_name + "_weight_c");
598 
599   std::vector<AnfNodePtr> biases{i_bias, o_bias, f_bias, c_bias};
600   auto bias = func_graph->add_parameter();
601   MS_CHECK_TRUE_RET(bias != nullptr, nullptr);
602   status = GetConcatedParam(biases, bias, true);
603   if (status != RET_OK) {
604     return nullptr;
605   }
606   bias->set_name(base_name + "_bias");
607 
608   if (!utils::isa<CNodePtr>(input) || !CheckPrimitiveType(input, prim::kPrimTensorListFromTensor)) {
609     MS_LOG(DEBUG) << "input is not tensorlistfromtensor op";
610     return nullptr;
611   }
612   auto tensor_list_cnode = utils::cast<CNodePtr>(input);
613   auto input_tensor_node = tensor_list_cnode->input(1);
614 
615   std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias, hidden, cell};
616   auto new_node = func_graph->NewCNode(new_node_inputs);
617   MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
618   new_node->set_fullname_with_scope(base_name);
619   return new_node;
620 }
621 
CreateOutputGetItem(const FuncGraphPtr & func_graph,const CNodePtr & node,const int item_index)622 CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node,
623                                                    const int item_index) {
624   MS_ASSERT(func_graph != nullptr);
625   MS_ASSERT(node != nullptr);
626   auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
627   auto get_item_value = NewValueNode(MakeValue<int64_t>(item_index));
628   if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
629     MS_LOG(ERROR) << "NewValueNode is nullptr";
630     return nullptr;
631   }
632   auto tuple_get_item_prim_c = tuple_get_item_prim->GetPrim();
633   MS_ASSERT(tuple_get_item_prim_c != nullptr);
634   CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim_c, {node, get_item_value});
635   MS_CHECK_TRUE_RET(get_item_cnode != nullptr, nullptr);
636   auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
637   if (abstract == nullptr) {
638     MS_LOG(ERROR) << "Create tensor abstarct failed";
639     return nullptr;
640   }
641   get_item_cnode->set_abstract(abstract);
642   get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
643                                           std::to_string(item_index));
644   return get_item_cnode;
645 }
646 
AdjustOtherGetItems(const FuncGraphPtr & func_graph,const CNodePtr & while_cnode,const CNodePtr & lstm_cnode,const CNodePtr & output_get_item)647 STATUS TfliteLstmCellFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode,
648                                                  const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) {
649   MS_ASSERT(func_graph != nullptr && while_cnode != nullptr);
650   MS_ASSERT(lstm_cnode != nullptr && output_get_item != nullptr);
651   auto manager = func_graph->manager();
652   if (manager == nullptr) {
653     MS_LOG(ERROR) << "manager is nullptr";
654     return RET_ERROR;
655   }
656   auto while_node_users = manager->node_users()[while_cnode];
657   for (auto &node_user : while_node_users) {
658     if (node_user.first == output_get_item) {
659       continue;
660     }
661     if (!utils::isa<CNodePtr>(node_user.first)) {
662       return RET_ERROR;
663     }
664     auto get_item = utils::cast<CNodePtr>(node_user.first);
665     if (!CheckPrimitiveType(get_item, prim::kPrimTupleGetItem)) {
666       return RET_ERROR;
667     }
668     auto new_inputs = get_item->inputs();
669     if (new_inputs.size() != 3) {
670       return RET_ERROR;
671     }
672     new_inputs[1] = lstm_cnode;
673     auto index_vnode = get_item->input(2);
674     if (!utils::isa<ValueNode>(index_vnode)) {
675       MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node";
676       return RET_ERROR;
677     }
678     auto value_node = utils::cast<ValueNodePtr>(index_vnode);
679     if (value_node == nullptr) {
680       MS_LOG(ERROR) << "cast to ValueNode failed";
681       return RET_ERROR;
682     }
683     auto origin_index = value_node->value()->type()->number_type() == kNumberTypeInt64
684                           ? GetValue<int64_t>(value_node->value())
685                           : GetValue<int>(value_node->value());
686     int64_t new_index = origin_index == 4 ? 2 : 1;
687     auto new_index_vnode = NewValueNode(MakeValue<int64_t>(new_index));
688     MS_CHECK_TRUE_RET(new_index_vnode != nullptr, RET_ERROR);
689     new_inputs[2] = new_index_vnode;
690     get_item->set_inputs(new_inputs);
691     get_item->set_fullname_with_scope(lstm_cnode->fullname_with_scope() + "_getitem_" + std::to_string(new_index));
692     if (get_item->abstract() == nullptr) {
693       MS_LOG(ERROR) << "get_item's abstract is nullptr";
694       return RET_ERROR;
695     }
696 
697     std::vector<int> squeeze_axis{0};
698     auto squeeze_node = CreateSqueezeNode(func_graph, get_item, squeeze_axis);
699     if (squeeze_node == nullptr) {
700       return RET_ERROR;
701     }
702 
703     auto get_item_users = manager->node_users()[get_item];
704     for (auto &get_item_user : get_item_users) {
705       manager->SetEdge(get_item_user.first, get_item_user.second, squeeze_node);
706     }
707   }
708   return RET_OK;
709 }
710 
SetAbstractTuple(const CNodePtr & cnode,const int output_num)711 STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) {
712   MS_ASSERT(cnode != nullptr);
713   AbstractBasePtrList abstract_list;
714   for (int i = 0; i < output_num; ++i) {
715     auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
716     if (abstract == nullptr) {
717       MS_LOG(ERROR) << "Create tensor abstarct failed";
718       return RET_ERROR;
719     }
720     abstract_list.emplace_back(abstract);
721   }
722   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
723   if (abstract_tuple == nullptr) {
724     MS_LOG(ERROR) << "create abstract_tuple failed";
725     return RET_ERROR;
726   }
727   cnode->set_abstract(abstract_tuple);
728   return RET_OK;
729 }
730 
CreateSqueezeNode(const FuncGraphPtr & func_graph,const CNodePtr & input_node,const std::vector<int> & axis)731 CNodePtr TfliteLstmCellFusion::CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node,
732                                                  const std::vector<int> &axis) {
733   MS_ASSERT(func_graph != nullptr && input_node != nullptr);
734   auto squeeze_prim = std::make_shared<ops::Squeeze>();
735   MS_CHECK_TRUE_RET(squeeze_prim != nullptr, nullptr);
736   auto squeeze_prim_c = squeeze_prim->GetPrim();
737   MS_CHECK_TRUE_RET(squeeze_prim_c != nullptr, nullptr);
738   std::vector<int64_t> axis_vec;
739   std::transform(axis.begin(), axis.end(), std::back_inserter(axis_vec),
740                  [](int val) { return static_cast<int64_t>(val); });
741   squeeze_prim->set_axis(axis_vec);
742   auto squeeze_cnode = func_graph->NewCNode(squeeze_prim_c, {input_node});
743   MS_CHECK_TRUE_RET(squeeze_cnode != nullptr, nullptr);
744   if (input_node->abstract() != nullptr) {
745     squeeze_cnode->set_abstract(input_node->abstract()->Clone());
746   }
747   squeeze_cnode->set_fullname_with_scope("squeeze_" + input_node->fullname_with_scope());
748   return squeeze_cnode;
749 }
750 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const751 const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
752                                                const EquivPtr &equiv) const {
753   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
754     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
755     return nullptr;
756   }
757 
758   if (!utils::isa<CNodePtr>(node)) {
759     return nullptr;
760   }
761   auto tensor_list_stack_cnode = utils::cast<CNodePtr>(node);
762   auto tuple_get_item_node = tensor_list_stack_cnode->input(1);
763   if (!utils::isa<CNodePtr>(tuple_get_item_node)) {
764     return nullptr;
765   }
766   auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item_node);
767   auto while_node = tuple_get_item_cnode->input(1);
768   if (!utils::isa<CNodePtr>(while_node)) {
769     return nullptr;
770   }
771   auto while_cnode = utils::cast<CNodePtr>(while_node);
772 
773   if (while_cnode == nullptr || while_cnode->size() != while_inputs_num_) {
774     return nullptr;
775   }
776   if (!CheckReferencedOutputs(func_graph, while_cnode)) {
777     return nullptr;
778   }
779   auto primitive_vars_cond = std::make_shared<PrimitiveVarMap>();
780   MS_CHECK_TRUE_RET(primitive_vars_cond != nullptr, nullptr);
781   auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond);
782   MS_CHECK_TRUE_RET(cond_graph_pattern != nullptr, nullptr);
783   auto cond_equiv =
784     CheckSubGraph(cond_graph_pattern, primitive_vars_cond, while_cnode->input(1), cond_cnodes_num_, cond_nodes_num_);
785   if (cond_equiv == nullptr || cond_equiv->empty()) {
786     return nullptr;
787   }
788   auto primitive_vars_body = std::make_shared<PrimitiveVarMap>();
789   MS_CHECK_TRUE_RET(primitive_vars_body != nullptr, nullptr);
790   auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body);
791   MS_CHECK_TRUE_RET(body_graph_pattern != nullptr, nullptr);
792   auto body_equiv =
793     CheckSubGraph(body_graph_pattern, primitive_vars_body, while_cnode->input(2), body_cnodes_num_, body_nodes_num_);
794   if (body_equiv == nullptr || body_equiv->empty()) {
795     return nullptr;
796   }
797   float zoneout_cell = 0.0f;
798   float zoneout_hidden = 0.0f;
799   if (!CheckBodyGraph(body_equiv, &zoneout_cell, &zoneout_hidden)) {
800     return nullptr;
801   }
802   const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope();
803   auto lstm_node = CreateLSTMNode(func_graph, equiv, body_equiv, lstm_name, zoneout_cell, zoneout_hidden);
804   if (lstm_node == nullptr) {
805     return nullptr;
806   }
807   auto status = SetAbstractTuple(lstm_node, kLSTMOutputNum);
808   if (status != RET_OK) {
809     return nullptr;
810   }
811 
812   auto get_item_node = CreateOutputGetItem(func_graph, lstm_node, 0);
813   if (get_item_node == nullptr) {
814     MS_LOG(DEBUG) << "create lstm output get_item node failed";
815     return nullptr;
816   }
817 
818   status = AdjustOtherGetItems(func_graph, while_cnode, lstm_node, tuple_get_item_cnode);
819   if (status != RET_OK) {
820     return nullptr;
821   }
822 
823   std::vector<int> squeeze_axis{1};  // our lstm output:0 have an extra axis that tflite not have, it must be squeezed
824   auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis);
825   MS_CHECK_TRUE_MSG(squeeze_node != nullptr, nullptr, "create a squeeze node failed.");
826 
827   auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1);
828   MS_CHECK_TRUE_RET(cond_cnode_index_pair != nullptr, nullptr);
829   func_graph->DropFuncGraphCNodeIndex(cond_cnode_index_pair);
830   auto body_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 2);
831   MS_CHECK_TRUE_RET(body_cnode_index_pair != nullptr, nullptr);
832   func_graph->DropFuncGraphCNodeIndex(body_cnode_index_pair);
833   MS_LOG(INFO) << "lstm node:" << lstm_node->fullname_with_scope() << " fusion success";
834   return squeeze_node;
835 }
836 }  // namespace opt
837 }  // namespace mindspore
838