1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
19 #include <memory>
20 #include <algorithm>
21 #include <functional>
22 #include "mindspore/core/ops/structure_ops.h"
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/comparison_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "ops/lstm.h"
27 #include "ops/squeeze.h"
28 #include "ops/tuple_get_item.h"
29 #include "src/common/utils.h"
30 #include "tools/common/tensor_util.h"
31 #include "include/common/utils/utils.h"
32 #include "tools/optimizer/common/gllo_utils.h"
33 #include "tools/optimizer/common/helper.h"
34 #include "securec/include/securec.h"
35 #include "nnacl/op_base.h"
36
37 namespace mindspore {
38 namespace opt {
39 namespace {
40 constexpr size_t kWhileInputsLength = 23;
41 constexpr size_t kWhileInputsVarNum = 21;
42 constexpr size_t kCondNodesNum = 12;
43 constexpr size_t kCondCNodesNum = 4;
44 constexpr size_t kBodyNodesNum = 95;
45 constexpr size_t kBodyCNodesNum = 34;
46 constexpr size_t kLSTMOutputNum = 3;
47 constexpr size_t kPlaceholderMinSize = 20;
48 constexpr auto kUnidirectionalGateNum = 4;
49 const auto &p1 = std::placeholders::_1;
50 constexpr float EPSILON = 1e-5;
IsParameterNode(const BaseRef & n)51 bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
52
GenerateBodyGraphCellPattern(const std::vector<CondVarPtr> & placeholders)53 std::vector<VectorRef> GenerateBodyGraphCellPattern(const std::vector<CondVarPtr> &placeholders) {
54 MS_CHECK_TRUE_RET(placeholders.size() >= kPlaceholderMinSize, {});
55 auto is_var1 = std::make_shared<Var>();
56 MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
57 VectorRef concat_i_w = VectorRef({is_var1, placeholders[8], placeholders[12]});
58 auto is_var2 = std::make_shared<Var>();
59 MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
60 VectorRef concat_f_w = VectorRef({is_var2, placeholders[9], placeholders[13]});
61 auto is_var3 = std::make_shared<Var>();
62 MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
63 VectorRef concat_c_w = VectorRef({is_var3, placeholders[10], placeholders[14]});
64 auto is_var4 = std::make_shared<Var>();
65 MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
66 VectorRef concat_o_w = VectorRef({is_var4, placeholders[11], placeholders[15]});
67
68 auto is_var_getitem = std::make_shared<Var>("GetItem");
69 MS_CHECK_TRUE_RET(is_var_getitem != nullptr, {});
70 auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
71 MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
72 VectorRef get_item = VectorRef({is_var_getitem, placeholders[7], placeholders[2], is_param3});
73 auto is_var5 = std::make_shared<Var>();
74 MS_CHECK_TRUE_RET(is_var5 != nullptr, {});
75 VectorRef concat_input_h = VectorRef({is_var5, get_item, placeholders[5]});
76
77 auto is_var6 = std::make_shared<Var>();
78 MS_CHECK_TRUE_RET(is_var6 != nullptr, {});
79 VectorRef matmul_input = VectorRef({is_var6, concat_input_h, concat_i_w});
80 auto is_var7 = std::make_shared<Var>();
81 MS_CHECK_TRUE_RET(is_var7 != nullptr, {});
82 VectorRef matmul_forget = VectorRef({is_var7, concat_input_h, concat_f_w});
83 auto is_var8 = std::make_shared<Var>();
84 MS_CHECK_TRUE_RET(is_var8 != nullptr, {});
85 VectorRef matmul_cell = VectorRef({is_var8, concat_input_h, concat_c_w});
86 auto is_var9 = std::make_shared<Var>();
87 MS_CHECK_TRUE_RET(is_var9 != nullptr, {});
88 VectorRef matmul_output = VectorRef({is_var9, concat_input_h, concat_o_w});
89
90 auto is_var10 = std::make_shared<Var>();
91 MS_CHECK_TRUE_RET(is_var10 != nullptr, {});
92 VectorRef bias_input = VectorRef({is_var10, matmul_input, placeholders[16]});
93 auto is_var11 = std::make_shared<Var>();
94 MS_CHECK_TRUE_RET(is_var11 != nullptr, {});
95 VectorRef bias_forget = VectorRef({is_var11, matmul_forget, placeholders[17]});
96 auto is_var12 = std::make_shared<Var>();
97 MS_CHECK_TRUE_RET(is_var12 != nullptr, {});
98 VectorRef bias_cell = VectorRef({is_var12, matmul_cell, placeholders[18]});
99 auto is_var13 = std::make_shared<Var>();
100 MS_CHECK_TRUE_RET(is_var13 != nullptr, {});
101 VectorRef bias_output = VectorRef({is_var13, matmul_output, placeholders[19]});
102
103 auto is_var_tanh = std::make_shared<Var>("Tanh");
104 MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
105 VectorRef cell = VectorRef({is_var_tanh, bias_cell});
106 auto is_var_sigmoid1 = std::make_shared<Var>("Sigmoid");
107 MS_CHECK_TRUE_RET(is_var_sigmoid1 != nullptr, {});
108 VectorRef input_gate = VectorRef({is_var_sigmoid1, bias_input});
109 auto is_var_mul1 = std::make_shared<Var>("Mul");
110 MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
111 VectorRef cell_input = VectorRef({is_var_mul1, input_gate, cell});
112 auto is_var_sigmoid2 = std::make_shared<Var>("Sigmoid");
113 MS_CHECK_TRUE_RET(is_var_sigmoid2 != nullptr, {});
114 VectorRef forget_gate = VectorRef({is_var_sigmoid2, bias_forget});
115 auto is_var_mul2 = std::make_shared<Var>("Mul");
116 MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
117 VectorRef cell_forgeted = VectorRef({is_var_mul2, forget_gate, placeholders[4]});
118 auto is_var_add = std::make_shared<Var>("Add");
119 MS_CHECK_TRUE_RET(is_var_add != nullptr, {});
120 VectorRef cell_new = VectorRef({is_var_add, cell_forgeted, cell_input});
121 return {bias_output, cell_new};
122 }
123 } // namespace
124
GetFloatScalarFromTensorInfo(const AnfNodePtr & tensor_info,float * v)125 STATUS TfliteLstmCellFusion::GetFloatScalarFromTensorInfo(const AnfNodePtr &tensor_info, float *v) {
126 if (tensor_info == nullptr || v == nullptr) {
127 MS_LOG(ERROR) << "tensor_info or v is nullptr";
128 return RET_ERROR;
129 }
130 if (!utils::isa<ParameterPtr>(tensor_info)) {
131 MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
132 return RET_ERROR;
133 }
134 auto param_ptr = utils::cast<ParameterPtr>(tensor_info);
135 if (!param_ptr->has_default() || param_ptr->default_param() == nullptr) {
136 MS_LOG(DEBUG) << "param not have default";
137 return RET_ERROR;
138 }
139 auto default_param = param_ptr->default_param();
140 if (!utils::isa<tensor::TensorPtr>(default_param)) {
141 MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
142 return RET_ERROR;
143 }
144 auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
145 auto tensor_shape = default_param_ptr->shape();
146 if (!(tensor_shape.empty() || (tensor_shape.size() == 1 && tensor_shape[0] == 1))) {
147 MS_LOG(DEBUG) << "default param is not scalar";
148 return RET_ERROR;
149 }
150 if (default_param_ptr->data_type() != kNumberTypeFloat32 && default_param_ptr->data_type() != kNumberTypeFloat) {
151 MS_LOG(DEBUG) << "default param is not float";
152 return RET_ERROR;
153 }
154 *v = *(reinterpret_cast<float *>(default_param_ptr->data_c()));
155 return RET_OK;
156 }
157
Init() const158 bool TfliteLstmCellFusion::Init() const {
159 for (size_t i = 0; i < this->while_input_var_num_; ++i) {
160 auto is_var = std::make_shared<Var>();
161 MS_CHECK_TRUE_RET(is_var != nullptr, false);
162 while_input_vars_.emplace_back(is_var);
163 }
164 cell_zoneout_old_ = std::make_shared<Var>();
165 MS_CHECK_TRUE_RET(cell_zoneout_old_ != nullptr, false);
166 cell_zoneout_new_ = std::make_shared<Var>();
167 MS_CHECK_TRUE_RET(cell_zoneout_new_ != nullptr, false);
168 hidden_zoneout_old_ = std::make_shared<Var>();
169 MS_CHECK_TRUE_RET(hidden_zoneout_old_ != nullptr, false);
170 hidden_zoneout_new_ = std::make_shared<Var>();
171 MS_CHECK_TRUE_RET(hidden_zoneout_new_ != nullptr, false);
172 return true;
173 }
174
TfliteLstmCellFusion(const std::string & name,bool multigraph,int input_length,int var_num,int cond_nodes_num,int cond_cnodes_num,int body_nodes_num,int body_cnodes_num)175 TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num,
176 int cond_nodes_num, int cond_cnodes_num, int body_nodes_num,
177 int body_cnodes_num)
178 : LitePatternProcessPass(name, multigraph) {
179 /*
180 * input vars for lstm while node
181 * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
182 * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
183 */
184 this->while_inputs_num_ = input_length == 0 ? kWhileInputsLength : input_length;
185 this->while_input_var_num_ = var_num == 0 ? kWhileInputsVarNum : var_num;
186 this->cond_nodes_num_ = cond_nodes_num == 0 ? kCondNodesNum : cond_nodes_num;
187 this->cond_cnodes_num_ = cond_cnodes_num == 0 ? kCondCNodesNum : cond_cnodes_num;
188 this->body_nodes_num_ = body_nodes_num == 0 ? kBodyNodesNum : body_nodes_num;
189 this->body_cnodes_num_ = body_cnodes_num == 0 ? kBodyCNodesNum : body_cnodes_num;
190 }
191
GetCondGraphPattern(const PrimitiveVarMapPtr & primitive_vars)192 AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) {
193 MS_ASSERT(primitive_vars != nullptr);
194 auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode);
195 MS_CHECK_TRUE_RET(is_parameter1 != nullptr, nullptr);
196 auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode);
197 MS_CHECK_TRUE_RET(is_parameter2 != nullptr, nullptr);
198 auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode);
199 MS_CHECK_TRUE_RET(is_parameter3 != nullptr, nullptr);
200 auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode);
201 MS_CHECK_TRUE_RET(is_parameter4 != nullptr, nullptr);
202 auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
203 MS_CHECK_TRUE_RET(is_less1 != nullptr, nullptr);
204 auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
205 MS_CHECK_TRUE_RET(is_less2 != nullptr, nullptr);
206 auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd));
207 MS_CHECK_TRUE_RET(is_logical_and != nullptr, nullptr);
208 auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
209 MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
210 VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2});
211 VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4});
212 VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
213 VectorRef return_ref = VectorRef({is_return, logicaland_ref});
214 VarPtr fg = std::make_shared<Var>("RootG");
215 MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
216 auto pattern = Helper::SexpToNode(return_ref, fg, primitive_vars.get(), true);
217 return pattern;
218 }
219
GetBodyGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const220 AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
221 std::vector<CondVarPtr> placeholders;
222 for (int i = 0; i < 20; ++i) {
223 auto cond_var_node = std::make_shared<CondVar>(IsParameterNode);
224 MS_CHECK_TRUE_RET(cond_var_node != nullptr, {});
225 placeholders.emplace_back(cond_var_node);
226 }
227 auto is_var1 = std::make_shared<Var>();
228 MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
229 auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
230 MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
231 VectorRef add2 = VectorRef({is_var1, placeholders[2], is_param1});
232 auto is_var2 = std::make_shared<Var>();
233 MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
234 auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
235 MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
236 VectorRef add3 = VectorRef({is_var2, placeholders[0], is_param2});
237
238 auto hidden_cells = GenerateBodyGraphCellPattern(placeholders);
239 MS_CHECK_TRUE_RET(hidden_cells.size() == kInputSizeTwo, {});
240 auto is_var_mul1 = std::make_shared<Var>("Mul");
241 MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
242 VectorRef zoneout_cell_old = VectorRef({is_var_mul1, cell_zoneout_old_, placeholders[4]});
243 auto is_var_mul2 = std::make_shared<Var>("Mul");
244 MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
245 auto cell_new = hidden_cells[1];
246 MS_CHECK_TRUE_RET(!cell_new.empty(), {});
247 VectorRef zoneout_cell_new = VectorRef({is_var_mul2, cell_zoneout_new_, cell_new});
248 auto is_var_add1 = std::make_shared<Var>("Add");
249 MS_CHECK_TRUE_RET(is_var_add1 != nullptr, {});
250 VectorRef cell_output = VectorRef({is_var_add1, zoneout_cell_new, zoneout_cell_old});
251
252 auto is_var_sigmoid = std::make_shared<Var>("Sigmoid");
253 MS_CHECK_TRUE_RET(is_var_sigmoid != nullptr, {});
254 auto bias_output = hidden_cells[0];
255 MS_CHECK_TRUE_RET(!bias_output.empty(), {});
256 VectorRef output_gate = VectorRef({is_var_sigmoid, bias_output});
257 auto is_var_tanh = std::make_shared<Var>("Tanh");
258 MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
259 VectorRef cell_to_output = VectorRef({is_var_tanh, cell_new});
260 auto is_var_mul3 = std::make_shared<Var>("Mul");
261 MS_CHECK_TRUE_RET(is_var_mul3 != nullptr, {});
262 VectorRef output = VectorRef({is_var_mul3, output_gate, cell_to_output});
263
264 auto is_var_mul4 = std::make_shared<Var>("Mul");
265 MS_CHECK_TRUE_RET(is_var_mul4 != nullptr, {});
266 VectorRef zoneout_hidden_old = VectorRef({is_var_mul4, hidden_zoneout_old_, placeholders[5]});
267 auto is_var_mul5 = std::make_shared<Var>("Mul");
268 MS_CHECK_TRUE_RET(is_var_mul5 != nullptr, {});
269 VectorRef zoneout_hidden_new = VectorRef({is_var_mul5, hidden_zoneout_new_, output});
270 auto is_var_add2 = std::make_shared<Var>("Add");
271 MS_CHECK_TRUE_RET(is_var_add2 != nullptr, {});
272 VectorRef hidden_output = VectorRef({is_var_add2, zoneout_hidden_new, zoneout_hidden_old});
273
274 auto is_var_setitem = std::make_shared<Var>("SetItem");
275 MS_CHECK_TRUE_RET(is_var_setitem != nullptr, {});
276 VectorRef set_item = VectorRef({is_var_setitem, placeholders[3], placeholders[2], output});
277
278 auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
279 MS_CHECK_TRUE_RET(is_make_tuple != nullptr, nullptr);
280 std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output};
281 outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
282 VectorRef make_tuple_node = VectorRef(outputs);
283 auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
284 MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
285 VectorRef return_node = VectorRef({is_return, make_tuple_node});
286
287 VarPtr fg = std::make_shared<Var>("RootG");
288 MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
289 auto pattern = Helper::SexpToNode(return_node, fg, primitive_vars.get(), true);
290 return pattern;
291 }
292
DefinePattern() const293 const BaseRef TfliteLstmCellFusion::DefinePattern() const {
294 if (!Init()) {
295 MS_LOG(ERROR) << "initial member failed.";
296 return {};
297 }
298 auto is_while_node = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimWhile));
299 MS_CHECK_TRUE_RET(is_while_node != nullptr, {});
300 VectorRef while_node = VectorRef({is_while_node});
301 auto while_inputs = while_input_vars_;
302 MS_CHECK_TRUE_RET(while_inputs.size() > kInputSizeThree, {});
303 while_inputs.insert(while_inputs.begin() + 4, while_input_vars_[2]);
304 while_node.insert(while_node.end(), while_inputs.begin(), while_inputs.end());
305
306 auto is_tuple_get_item = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem));
307 MS_CHECK_TRUE_RET(is_tuple_get_item != nullptr, {});
308 auto is_var = std::make_shared<Var>();
309 MS_CHECK_TRUE_RET(is_var != nullptr, {});
310 VectorRef while_output = VectorRef({is_tuple_get_item, while_node, is_var});
311
312 auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListStack));
313 MS_CHECK_TRUE_RET(is_tensor_list_stack != nullptr, {});
314 auto is_parameter = std::make_shared<CondVar>(IsParameterNode);
315 MS_CHECK_TRUE_RET(is_parameter != nullptr, {});
316 VectorRef tensor_list_stack_node = VectorRef({is_tensor_list_stack, while_output, is_parameter});
317
318 return tensor_list_stack_node;
319 }
320
MatchGraph(const FuncGraphPtr & func_graph,const PrimitiveVarMapPtr & primitive_vars,const AnfNodePtr & pattern)321 EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars,
322 const AnfNodePtr &pattern) {
323 MS_ASSERT(func_graph != nullptr);
324 MS_ASSERT(pattern != nullptr);
325 auto return_node = func_graph->get_return();
326 auto visitor = std::make_shared<Visitor>();
327 MS_CHECK_TRUE_RET(visitor != nullptr, nullptr);
328 PatternEngine pattern_engine(visitor);
329 auto empty_equiv = std::make_shared<Equiv>();
330 MS_CHECK_TRUE_RET(empty_equiv != nullptr, nullptr);
331 EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv);
332 return equiv;
333 }
334
335 // make sure that only 3,4,5 output of while is referenced
CheckReferencedOutputs(const FuncGraphPtr & func_graph,const CNodePtr & while_cnode)336 bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) {
337 MS_ASSERT(func_graph != nullptr);
338 MS_ASSERT(while_cnode != nullptr);
339 auto manager = func_graph->manager();
340 if (manager == nullptr) {
341 MS_LOG(ERROR) << "manager is nullptr";
342 return false;
343 }
344 auto while_node_users = manager->node_users()[while_cnode];
345 std::vector<size_t> valid_indexes{3, 4, 5};
346 for (auto &node_user : while_node_users) {
347 if (!utils::isa<CNodePtr>(node_user.first)) {
348 return false;
349 }
350 auto cnode = utils::cast<CNodePtr>(node_user.first);
351 if (IsMarkedTrainOp(cnode)) {
352 return false;
353 }
354 if (!CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
355 return false;
356 }
357 auto index = GetTupleGetItemOutIndex(cnode);
358 if (!lite::IsContain(valid_indexes, index)) {
359 return false;
360 }
361 }
362 return true;
363 }
364
CheckSubGraph(const AnfNodePtr & pattern,const PrimitiveVarMapPtr & primitive_vars,const AnfNodePtr & anf_sub_graph,const size_t cnode_num,const size_t all_node_num)365 EquivPtr TfliteLstmCellFusion::CheckSubGraph(const AnfNodePtr &pattern, const PrimitiveVarMapPtr &primitive_vars,
366 const AnfNodePtr &anf_sub_graph, const size_t cnode_num,
367 const size_t all_node_num) {
368 MS_ASSERT(func_graph != nullptr);
369 MS_ASSERT(pattern != nullptr);
370 MS_ASSERT(primitive_vars != nullptr);
371 MS_ASSERT(anf_sub_graph != nullptr);
372 auto sub_graph = GetValueNode<FuncGraphPtr>(anf_sub_graph);
373 MS_CHECK_TRUE_RET(sub_graph != nullptr, nullptr);
374 auto nodes = TopoSort(sub_graph->get_return());
375 auto cnodes = sub_graph->GetOrderedCnodes();
376 if (cnodes.size() != cnode_num || nodes.size() != all_node_num) {
377 MS_LOG(DEBUG) << "sub graph nodes num not match";
378 return nullptr;
379 }
380 return MatchGraph(sub_graph, primitive_vars, pattern);
381 }
382
CheckBodyGraph(const EquivPtr & equiv,float * zoneout_cell,float * zoneout_hidden) const383 bool TfliteLstmCellFusion::CheckBodyGraph(const EquivPtr &equiv, float *zoneout_cell, float *zoneout_hidden) const {
384 MS_ASSERT(func_graph != nullptr);
385 MS_ASSERT(equiv != nullptr);
386 MS_ASSERT(while_cnode != nullptr);
387 MS_ASSERT(zoneout_cell != nullptr);
388 MS_ASSERT(zoneout_hidden != nullptr);
389
390 auto cell_zoneout_old_node = utils::cast<AnfNodePtr>((*equiv)[cell_zoneout_old_]);
391 MS_ASSERT(cell_zoneout_old_node != nullptr);
392 auto cell_zoneout_new_node = utils::cast<AnfNodePtr>((*equiv)[cell_zoneout_new_]);
393 MS_ASSERT(cell_zoneout_new_node != nullptr);
394 auto hidden_zoneout_old_node = utils::cast<AnfNodePtr>((*equiv)[hidden_zoneout_old_]);
395 MS_ASSERT(hidden_zoneout_old_node != nullptr);
396 auto hidden_zoneout_new_node = utils::cast<AnfNodePtr>((*equiv)[hidden_zoneout_new_]);
397 MS_ASSERT(hidden_zoneout_new_node != nullptr);
398
399 float cell_old;
400 float cell_new;
401 float hidden_old;
402 float hidden_new;
403 if (GetFloatScalarFromTensorInfo(cell_zoneout_old_node, &cell_old) != RET_OK) {
404 return false;
405 }
406 if (GetFloatScalarFromTensorInfo(cell_zoneout_new_node, &cell_new) != RET_OK) {
407 return false;
408 }
409 if (GetFloatScalarFromTensorInfo(hidden_zoneout_old_node, &hidden_old) != RET_OK) {
410 return false;
411 }
412 if (GetFloatScalarFromTensorInfo(hidden_zoneout_new_node, &hidden_new) != RET_OK) {
413 return false;
414 }
415 if (cell_old < 0.0f || cell_old > 1.0f || cell_new < 0.0f || cell_new > 1.0f) {
416 MS_LOG(DEBUG) << "cell zoneout value illegal";
417 return false;
418 }
419 if (hidden_old < 0.0f || hidden_old > 1.0f || hidden_new < 0.0f || hidden_new > 1.0f) {
420 MS_LOG(DEBUG) << "hidden zoneout value illegal";
421 return false;
422 }
423 if (std::abs(cell_old + cell_new - 1.0f) > EPSILON || std::abs(hidden_old + hidden_new - 1.0f) > EPSILON ||
424 std::abs(cell_old - hidden_old) > EPSILON) {
425 MS_LOG(DEBUG) << "zoneout value illegal";
426 return false;
427 }
428 *zoneout_cell = cell_old;
429 *zoneout_hidden = hidden_old;
430 return true;
431 }
432
GetConcatedParam(const std::vector<AnfNodePtr> & params,const ParameterPtr & new_param,bool is_bias)433 STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> ¶ms, const ParameterPtr &new_param,
434 bool is_bias) {
435 MS_ASSERT(new_param != nullptr);
436 MS_ASSERT(params.size() == 4);
437 std::vector<float *> data_ptrs;
438 std::vector<std::vector<int64_t>> data_shapes;
439 for (auto ¶m : params) {
440 if (!utils::isa<ParameterPtr>(param)) {
441 MS_LOG(DEBUG) << "param is not Parameter node";
442 return RET_FAILED;
443 }
444 auto param_t = utils::cast<ParameterPtr>(param);
445 if (!param_t->has_default() || param_t->default_param() == nullptr) {
446 MS_LOG(DEBUG) << "param not have default value";
447 return RET_FAILED;
448 }
449 if (!utils::isa<tensor::TensorPtr>(param_t->default_param())) {
450 MS_LOG(DEBUG) << "default value is not tensor::Tensor";
451 return RET_FAILED;
452 }
453 auto origin_tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_t->default_param());
454 if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) {
455 MS_LOG(DEBUG) << "origin_tensor is not float32 type";
456 return RET_FAILED;
457 }
458 auto data_ptr = reinterpret_cast<float *>(origin_tensor->data_c());
459 auto data_shape = origin_tensor->shape();
460 data_ptrs.push_back(data_ptr);
461 data_shapes.push_back(data_shape);
462 }
463
464 for (size_t i = 1; i < data_shapes.size(); ++i) {
465 if (data_shapes[i] != data_shapes[0]) {
466 MS_LOG(DEBUG) << "data shape not same";
467 return RET_FAILED;
468 }
469 }
470 std::vector<int64_t> new_shape;
471 int step = 0;
472 int data_size = 0;
473 MS_ASSERT(!data_shapes.empty());
474 if (is_bias) {
475 if (data_shapes[0].size() != 1) {
476 MS_LOG(ERROR) << "bias data shape error";
477 return RET_ERROR;
478 }
479 step = static_cast<int>(data_shapes[0][0]);
480 MS_CHECK_INT_MUL_NOT_OVERFLOW(C8NUM, step, RET_ERROR);
481 data_size = C8NUM * step;
482 new_shape = std::vector<int64_t>({1, data_size});
483
484 } else {
485 if (data_shapes[0].size() != 2) {
486 MS_LOG(ERROR) << "weight data shape error";
487 return RET_ERROR;
488 }
489 new_shape = std::vector<int64_t>({1, data_shapes[0][0] * kUnidirectionalGateNum, data_shapes[0][1]});
490 MS_CHECK_INT_MUL_NOT_OVERFLOW(data_shapes[0][0], data_shapes[0][1], RET_ERROR);
491 step = static_cast<int>(data_shapes[0][0] * data_shapes[0][1]);
492 MS_CHECK_INT_MUL_NOT_OVERFLOW(C4NUM, step, RET_ERROR);
493 data_size = C4NUM * step;
494 }
495
496 auto tensor_info = lite::CreateTensorInfo(nullptr, 0, new_shape, kNumberTypeFloat32);
497 if (tensor_info == nullptr) {
498 MS_LOG(ERROR) << "create tensor info failed.";
499 return RET_ERROR;
500 }
501
502 auto tensor_data = static_cast<float *>(tensor_info->data_c());
503 for (int i = 0; i < data_size; ++i) { // bias are stored into first 4*hidden_size buffer, the rest is all 0
504 tensor_data[i] = 0.0f;
505 }
506
507 for (size_t i = 0; i < data_ptrs.size(); ++i) {
508 auto source_len = std::accumulate(data_shapes[i].begin(), data_shapes[i].end(), 1, std::multiplies<int>());
509 auto ret = memcpy_s(tensor_data + i * step, step * sizeof(float), data_ptrs[i], source_len * sizeof(float));
510 if (ret != EOK) {
511 MS_LOG(ERROR) << "memcpy_s error";
512 return RET_ERROR;
513 }
514 }
515
516 auto status = lite::InitParameterFromTensorInfo(new_param, tensor_info);
517 if (status != RET_OK) {
518 MS_LOG(ERROR) << "init parameter from tensor info failed";
519 return RET_ERROR;
520 }
521
522 return RET_OK;
523 }
524
CreateLSTMNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const EquivPtr & body_equiv,const std::string & base_name,const float zoneout_cell,const float zoneout_hidden) const525 CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
526 const EquivPtr &body_equiv, const std::string &base_name,
527 const float zoneout_cell, const float zoneout_hidden) const {
528 MS_ASSERT(func_graph != nullptr);
529 MS_ASSERT(equiv != nullptr);
530 MS_ASSERT(body_equiv != nullptr);
531 /*
532 * input vars for while node
533 * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
534 * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
535 */
536 auto lstm_prim = std::make_shared<ops::LSTM>();
537 MS_CHECK_TRUE_RET(lstm_prim != nullptr, nullptr);
538 auto lstm_prim_c = lstm_prim->GetPrim();
539 MS_CHECK_TRUE_RET(lstm_prim_c != nullptr, nullptr);
540 lstm_prim->set_bidirectional(false);
541 lstm_prim->set_zoneout_cell(zoneout_cell);
542 lstm_prim->set_zoneout_hidden(zoneout_hidden);
543 auto value_node = NewValueNode(lstm_prim_c);
544 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
545
546 auto &vars = while_input_vars_;
547 auto i2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
548 MS_ASSERT(i2i_weight);
549 auto i2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
550 MS_ASSERT(i2f_weight);
551 auto i2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[11]]);
552 MS_ASSERT(i2c_weight);
553 auto i2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[12]]);
554 MS_ASSERT(i2o_weight);
555
556 auto c2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[13]]);
557 MS_ASSERT(c2i_weight);
558 auto c2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[14]]);
559 MS_ASSERT(c2f_weight);
560 auto c2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[15]]);
561 MS_ASSERT(c2c_weight);
562 auto c2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[16]]);
563 MS_ASSERT(c2o_weight);
564
565 auto i_bias = utils::cast<AnfNodePtr>((*equiv)[vars[17]]);
566 MS_ASSERT(i_bias);
567 auto f_bias = utils::cast<AnfNodePtr>((*equiv)[vars[18]]);
568 MS_ASSERT(f_bias);
569 auto c_bias = utils::cast<AnfNodePtr>((*equiv)[vars[19]]);
570 MS_ASSERT(c_bias);
571 auto o_bias = utils::cast<AnfNodePtr>((*equiv)[vars[20]]);
572 MS_ASSERT(o_bias);
573
574 auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]);
575 MS_ASSERT(input);
576 auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]);
577 MS_ASSERT(cell);
578 auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]);
579 MS_ASSERT(hidden);
580
581 std::vector<AnfNodePtr> i_weights{i2i_weight, i2o_weight, i2f_weight, i2c_weight};
582 auto i_weight = func_graph->add_parameter();
583 MS_CHECK_TRUE_RET(i_weight != nullptr, nullptr);
584 auto status = GetConcatedParam(i_weights, i_weight, false);
585 if (status != RET_OK) {
586 return nullptr;
587 }
588 i_weight->set_name(base_name + "_weight_i");
589
590 std::vector<AnfNodePtr> c_weights{c2i_weight, c2o_weight, c2f_weight, c2c_weight};
591 auto c_weight = func_graph->add_parameter();
592 MS_CHECK_TRUE_RET(c_weight != nullptr, nullptr);
593 status = GetConcatedParam(c_weights, c_weight, false);
594 if (status != RET_OK) {
595 return nullptr;
596 }
597 c_weight->set_name(base_name + "_weight_c");
598
599 std::vector<AnfNodePtr> biases{i_bias, o_bias, f_bias, c_bias};
600 auto bias = func_graph->add_parameter();
601 MS_CHECK_TRUE_RET(bias != nullptr, nullptr);
602 status = GetConcatedParam(biases, bias, true);
603 if (status != RET_OK) {
604 return nullptr;
605 }
606 bias->set_name(base_name + "_bias");
607
608 if (!utils::isa<CNodePtr>(input) || !CheckPrimitiveType(input, prim::kPrimTensorListFromTensor)) {
609 MS_LOG(DEBUG) << "input is not tensorlistfromtensor op";
610 return nullptr;
611 }
612 auto tensor_list_cnode = utils::cast<CNodePtr>(input);
613 auto input_tensor_node = tensor_list_cnode->input(1);
614
615 std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias, hidden, cell};
616 auto new_node = func_graph->NewCNode(new_node_inputs);
617 MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
618 new_node->set_fullname_with_scope(base_name);
619 return new_node;
620 }
621
CreateOutputGetItem(const FuncGraphPtr & func_graph,const CNodePtr & node,const int item_index)622 CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node,
623 const int item_index) {
624 MS_ASSERT(func_graph != nullptr);
625 MS_ASSERT(node != nullptr);
626 auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
627 auto get_item_value = NewValueNode(MakeValue<int64_t>(item_index));
628 if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
629 MS_LOG(ERROR) << "NewValueNode is nullptr";
630 return nullptr;
631 }
632 auto tuple_get_item_prim_c = tuple_get_item_prim->GetPrim();
633 MS_ASSERT(tuple_get_item_prim_c != nullptr);
634 CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim_c, {node, get_item_value});
635 MS_CHECK_TRUE_RET(get_item_cnode != nullptr, nullptr);
636 auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
637 if (abstract == nullptr) {
638 MS_LOG(ERROR) << "Create tensor abstarct failed";
639 return nullptr;
640 }
641 get_item_cnode->set_abstract(abstract);
642 get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
643 std::to_string(item_index));
644 return get_item_cnode;
645 }
646
AdjustOtherGetItems(const FuncGraphPtr & func_graph,const CNodePtr & while_cnode,const CNodePtr & lstm_cnode,const CNodePtr & output_get_item)647 STATUS TfliteLstmCellFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode,
648 const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) {
649 MS_ASSERT(func_graph != nullptr && while_cnode != nullptr);
650 MS_ASSERT(lstm_cnode != nullptr && output_get_item != nullptr);
651 auto manager = func_graph->manager();
652 if (manager == nullptr) {
653 MS_LOG(ERROR) << "manager is nullptr";
654 return RET_ERROR;
655 }
656 auto while_node_users = manager->node_users()[while_cnode];
657 for (auto &node_user : while_node_users) {
658 if (node_user.first == output_get_item) {
659 continue;
660 }
661 if (!utils::isa<CNodePtr>(node_user.first)) {
662 return RET_ERROR;
663 }
664 auto get_item = utils::cast<CNodePtr>(node_user.first);
665 if (!CheckPrimitiveType(get_item, prim::kPrimTupleGetItem)) {
666 return RET_ERROR;
667 }
668 auto new_inputs = get_item->inputs();
669 if (new_inputs.size() != 3) {
670 return RET_ERROR;
671 }
672 new_inputs[1] = lstm_cnode;
673 auto index_vnode = get_item->input(2);
674 if (!utils::isa<ValueNode>(index_vnode)) {
675 MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node";
676 return RET_ERROR;
677 }
678 auto value_node = utils::cast<ValueNodePtr>(index_vnode);
679 if (value_node == nullptr) {
680 MS_LOG(ERROR) << "cast to ValueNode failed";
681 return RET_ERROR;
682 }
683 auto origin_index = value_node->value()->type()->number_type() == kNumberTypeInt64
684 ? GetValue<int64_t>(value_node->value())
685 : GetValue<int>(value_node->value());
686 int64_t new_index = origin_index == 4 ? 2 : 1;
687 auto new_index_vnode = NewValueNode(MakeValue<int64_t>(new_index));
688 MS_CHECK_TRUE_RET(new_index_vnode != nullptr, RET_ERROR);
689 new_inputs[2] = new_index_vnode;
690 get_item->set_inputs(new_inputs);
691 get_item->set_fullname_with_scope(lstm_cnode->fullname_with_scope() + "_getitem_" + std::to_string(new_index));
692 if (get_item->abstract() == nullptr) {
693 MS_LOG(ERROR) << "get_item's abstract is nullptr";
694 return RET_ERROR;
695 }
696
697 std::vector<int> squeeze_axis{0};
698 auto squeeze_node = CreateSqueezeNode(func_graph, get_item, squeeze_axis);
699 if (squeeze_node == nullptr) {
700 return RET_ERROR;
701 }
702
703 auto get_item_users = manager->node_users()[get_item];
704 for (auto &get_item_user : get_item_users) {
705 manager->SetEdge(get_item_user.first, get_item_user.second, squeeze_node);
706 }
707 }
708 return RET_OK;
709 }
710
SetAbstractTuple(const CNodePtr & cnode,const int output_num)711 STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) {
712 MS_ASSERT(cnode != nullptr);
713 AbstractBasePtrList abstract_list;
714 for (int i = 0; i < output_num; ++i) {
715 auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
716 if (abstract == nullptr) {
717 MS_LOG(ERROR) << "Create tensor abstarct failed";
718 return RET_ERROR;
719 }
720 abstract_list.emplace_back(abstract);
721 }
722 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
723 if (abstract_tuple == nullptr) {
724 MS_LOG(ERROR) << "create abstract_tuple failed";
725 return RET_ERROR;
726 }
727 cnode->set_abstract(abstract_tuple);
728 return RET_OK;
729 }
730
CreateSqueezeNode(const FuncGraphPtr & func_graph,const CNodePtr & input_node,const std::vector<int> & axis)731 CNodePtr TfliteLstmCellFusion::CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node,
732 const std::vector<int> &axis) {
733 MS_ASSERT(func_graph != nullptr && input_node != nullptr);
734 auto squeeze_prim = std::make_shared<ops::Squeeze>();
735 MS_CHECK_TRUE_RET(squeeze_prim != nullptr, nullptr);
736 auto squeeze_prim_c = squeeze_prim->GetPrim();
737 MS_CHECK_TRUE_RET(squeeze_prim_c != nullptr, nullptr);
738 std::vector<int64_t> axis_vec;
739 std::transform(axis.begin(), axis.end(), std::back_inserter(axis_vec),
740 [](int val) { return static_cast<int64_t>(val); });
741 squeeze_prim->set_axis(axis_vec);
742 auto squeeze_cnode = func_graph->NewCNode(squeeze_prim_c, {input_node});
743 MS_CHECK_TRUE_RET(squeeze_cnode != nullptr, nullptr);
744 if (input_node->abstract() != nullptr) {
745 squeeze_cnode->set_abstract(input_node->abstract()->Clone());
746 }
747 squeeze_cnode->set_fullname_with_scope("squeeze_" + input_node->fullname_with_scope());
748 return squeeze_cnode;
749 }
750
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const751 const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
752 const EquivPtr &equiv) const {
753 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
754 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
755 return nullptr;
756 }
757
758 if (!utils::isa<CNodePtr>(node)) {
759 return nullptr;
760 }
761 auto tensor_list_stack_cnode = utils::cast<CNodePtr>(node);
762 auto tuple_get_item_node = tensor_list_stack_cnode->input(1);
763 if (!utils::isa<CNodePtr>(tuple_get_item_node)) {
764 return nullptr;
765 }
766 auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item_node);
767 auto while_node = tuple_get_item_cnode->input(1);
768 if (!utils::isa<CNodePtr>(while_node)) {
769 return nullptr;
770 }
771 auto while_cnode = utils::cast<CNodePtr>(while_node);
772
773 if (while_cnode == nullptr || while_cnode->size() != while_inputs_num_) {
774 return nullptr;
775 }
776 if (!CheckReferencedOutputs(func_graph, while_cnode)) {
777 return nullptr;
778 }
779 auto primitive_vars_cond = std::make_shared<PrimitiveVarMap>();
780 MS_CHECK_TRUE_RET(primitive_vars_cond != nullptr, nullptr);
781 auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond);
782 MS_CHECK_TRUE_RET(cond_graph_pattern != nullptr, nullptr);
783 auto cond_equiv =
784 CheckSubGraph(cond_graph_pattern, primitive_vars_cond, while_cnode->input(1), cond_cnodes_num_, cond_nodes_num_);
785 if (cond_equiv == nullptr || cond_equiv->empty()) {
786 return nullptr;
787 }
788 auto primitive_vars_body = std::make_shared<PrimitiveVarMap>();
789 MS_CHECK_TRUE_RET(primitive_vars_body != nullptr, nullptr);
790 auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body);
791 MS_CHECK_TRUE_RET(body_graph_pattern != nullptr, nullptr);
792 auto body_equiv =
793 CheckSubGraph(body_graph_pattern, primitive_vars_body, while_cnode->input(2), body_cnodes_num_, body_nodes_num_);
794 if (body_equiv == nullptr || body_equiv->empty()) {
795 return nullptr;
796 }
797 float zoneout_cell = 0.0f;
798 float zoneout_hidden = 0.0f;
799 if (!CheckBodyGraph(body_equiv, &zoneout_cell, &zoneout_hidden)) {
800 return nullptr;
801 }
802 const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope();
803 auto lstm_node = CreateLSTMNode(func_graph, equiv, body_equiv, lstm_name, zoneout_cell, zoneout_hidden);
804 if (lstm_node == nullptr) {
805 return nullptr;
806 }
807 auto status = SetAbstractTuple(lstm_node, kLSTMOutputNum);
808 if (status != RET_OK) {
809 return nullptr;
810 }
811
812 auto get_item_node = CreateOutputGetItem(func_graph, lstm_node, 0);
813 if (get_item_node == nullptr) {
814 MS_LOG(DEBUG) << "create lstm output get_item node failed";
815 return nullptr;
816 }
817
818 status = AdjustOtherGetItems(func_graph, while_cnode, lstm_node, tuple_get_item_cnode);
819 if (status != RET_OK) {
820 return nullptr;
821 }
822
823 std::vector<int> squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed
824 auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis);
825 MS_CHECK_TRUE_MSG(squeeze_node != nullptr, nullptr, "create a squeeze node failed.");
826
827 auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1);
828 MS_CHECK_TRUE_RET(cond_cnode_index_pair != nullptr, nullptr);
829 func_graph->DropFuncGraphCNodeIndex(cond_cnode_index_pair);
830 auto body_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 2);
831 MS_CHECK_TRUE_RET(body_cnode_index_pair != nullptr, nullptr);
832 func_graph->DropFuncGraphCNodeIndex(body_cnode_index_pair);
833 MS_LOG(INFO) << "lstm node:" << lstm_node->fullname_with_scope() << " fusion success";
834 return squeeze_node;
835 }
836 } // namespace opt
837 } // namespace mindspore
838