• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
19 #include <memory>
20 #include "mindspore/core/ops/structure_ops.h"
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "ops/lstm.h"
24 #include "src/common/utils.h"
25 #include "tools/common/tensor_util.h"
26 #include "include/common/utils/utils.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
29 #include "tools/optimizer/common/helper.h"
30 #include "nnacl/op_base.h"
31 
32 namespace mindspore {
33 namespace opt {
34 namespace {
35 constexpr int kNumInPlaceHolder = 10;
36 constexpr int kNumGetItem = 4;
37 constexpr size_t kLstmInputsLength = 13;
38 constexpr size_t kLstmInputsVarNum = 11;
39 constexpr size_t kCondNodesNum = 12;
40 constexpr size_t kCondCNodesNum = 4;
41 constexpr size_t kBodyNodesNum = 82;
42 constexpr size_t kBodyCNodesNum = 30;
43 constexpr auto kUnidirectionalGateNum = 4;
44 constexpr auto kBidirectionalGateNum = 8;
45 const auto &p1 = std::placeholders::_1;
IsParameterNode(const BaseRef & n)46 bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
47 
GenerateBodyGraphHiddenPattern(const VarPtr & forget_bias_input,const std::vector<CondVarPtr> & placeholders)48 std::vector<VectorRef> GenerateBodyGraphHiddenPattern(const VarPtr &forget_bias_input,
49                                                       const std::vector<CondVarPtr> &placeholders) {
50   MS_CHECK_TRUE_RET(placeholders.size() >= kNumInPlaceHolder, {});
51   auto is_var_getitem = std::make_shared<Var>("GetItem");
52   MS_CHECK_TRUE_RET(is_var_getitem != nullptr, {});
53   auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
54   MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
55   VectorRef get_item = VectorRef({is_var_getitem, placeholders[7], placeholders[2], is_param3});
56   auto is_var1 = std::make_shared<Var>();
57   MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
58   VectorRef concat_input_h = VectorRef({is_var1, get_item, placeholders[5]});
59 
60   auto is_var2 = std::make_shared<Var>();
61   MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
62   VectorRef matmul = VectorRef({is_var2, concat_input_h, placeholders[8]});
63   auto is_var3 = std::make_shared<Var>();
64   MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
65   VectorRef bias = VectorRef({is_var3, matmul, placeholders[9]});
66   auto is_var4 = std::make_shared<Var>();
67   MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
68   VectorRef split = VectorRef({is_var4, bias});
69 
70   std::vector<VectorRef> get_items;
71   for (int i = 0; i < kNumGetItem; ++i) {
72     auto is_var_loop1 = std::make_shared<Var>();
73     MS_CHECK_TRUE_RET(is_var_loop1 != nullptr, {});
74     auto is_var_loop2 = std::make_shared<Var>();
75     MS_CHECK_TRUE_RET(is_var_loop2 != nullptr, {});
76     VectorRef get_item_loop = VectorRef({is_var_loop1, split, is_var_loop2});
77     get_items.push_back(get_item_loop);
78   }
79 
80   auto is_var_sigmoid1 = std::make_shared<Var>("Sigmoid");
81   MS_CHECK_TRUE_RET(is_var_sigmoid1 != nullptr, {});
82   VectorRef input_gate = VectorRef({is_var_sigmoid1, get_items[0]});
83   auto is_var_tanh1 = std::make_shared<Var>("Tanh");
84   MS_CHECK_TRUE_RET(is_var_tanh1 != nullptr, {});
85   VectorRef input_to_cell = VectorRef({is_var_tanh1, get_items[1]});
86   auto is_var_add1 = std::make_shared<Var>("Add");
87   MS_CHECK_TRUE_RET(is_var_add1 != nullptr, {});
88   VectorRef forget_bias = VectorRef({is_var_add1, get_items[2], forget_bias_input});
89   auto is_var_sigmoid2 = std::make_shared<Var>("Sigmoid");
90   MS_CHECK_TRUE_RET(is_var_sigmoid2 != nullptr, {});
91   VectorRef forget_gate = VectorRef({is_var_sigmoid2, forget_bias});
92   auto is_var_sigmoid3 = std::make_shared<Var>("Sigmoid");
93   MS_CHECK_TRUE_RET(is_var_sigmoid3 != nullptr, {});
94   VectorRef output_gate = VectorRef({is_var_sigmoid3, get_items[3]});
95 
96   auto is_var5 = std::make_shared<Var>();
97   MS_CHECK_TRUE_RET(is_var5 != nullptr, {});
98   VectorRef forgetted_cell = VectorRef({is_var5, forget_gate, placeholders[4]});
99   auto is_var6 = std::make_shared<Var>();
100   MS_CHECK_TRUE_RET(is_var6 != nullptr, {});
101   VectorRef inputted_cell = VectorRef({is_var6, input_gate, input_to_cell});
102   auto is_var_add2 = std::make_shared<Var>("Add");
103   MS_CHECK_TRUE_RET(is_var_add2 != nullptr, {});
104   VectorRef input_forget_cell = VectorRef({is_var_add2, forgetted_cell, inputted_cell});
105   auto is_var_tanh2 = std::make_shared<Var>("Tanh");
106   MS_CHECK_TRUE_RET(is_var_tanh2 != nullptr, {});
107   VectorRef to_new_hidden = VectorRef({is_var_tanh2, input_forget_cell});
108   auto is_var_mul = std::make_shared<Var>("Mul");
109   MS_CHECK_TRUE_RET(is_var_mul != nullptr, {});
110   VectorRef new_hidden = VectorRef({is_var_mul, output_gate, to_new_hidden});
111   return {input_forget_cell, new_hidden};
112 }
113 }  // namespace
114 
TfLstmCellFusion(const std::string & name,bool multigraph)115 TfLstmCellFusion::TfLstmCellFusion(const std::string &name, bool multigraph)
116     : TfliteLstmCellFusion(name, multigraph, kLstmInputsLength, kLstmInputsVarNum, kCondNodesNum, kCondCNodesNum,
117                            kBodyNodesNum, kBodyCNodesNum) {
118   /*
119    * vars for lstm cell input
120    * 0:cond 1:body 2:index 3:limit1 4:output 5:cell 6:hidden 7:limit2 8:input 9:kernel 10:bias
121    */
122 }
123 
GetBodyGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const124 AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
125   std::vector<CondVarPtr> placeholders;
126   for (int i = 0; i < kNumInPlaceHolder; ++i) {
127     auto is_param_holder = std::make_shared<CondVar>(IsParameterNode);
128     MS_CHECK_TRUE_RET(is_param_holder != nullptr, nullptr);
129     placeholders.emplace_back(is_param_holder);
130   }
131   auto is_var1 = std::make_shared<Var>();
132   MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
133   auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
134   MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
135   VectorRef add2 = VectorRef({is_var1, placeholders[2], is_param1});
136   auto is_var2 = std::make_shared<Var>();
137   MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
138   auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
139   MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
140   VectorRef add3 = VectorRef({is_var2, placeholders[0], is_param2});
141 
142   forget_bias_ = std::make_shared<Var>();
143   MS_CHECK_TRUE_RET(forget_bias_ != nullptr, nullptr);
144   auto hidden_cells = GenerateBodyGraphHiddenPattern(forget_bias_, placeholders);
145   MS_CHECK_TRUE_RET(hidden_cells.size() == kInputSizeTwo, {});
146 
147   auto is_var_mul1 = std::make_shared<Var>("Mul");
148   MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
149   auto input_forget_cell = hidden_cells[0];
150   MS_CHECK_TRUE_RET(!input_forget_cell.empty(), {});
151   VectorRef new_to_cell = VectorRef({is_var_mul1, cell_zoneout_new_, input_forget_cell});
152   auto is_var_mul2 = std::make_shared<Var>("Mul");
153   MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
154   VectorRef old_to_cell = VectorRef({is_var_mul2, cell_zoneout_old_, placeholders[4]});
155   auto is_var_add1 = std::make_shared<Var>("Add");
156   MS_CHECK_TRUE_RET(is_var_add1 != nullptr, {});
157   VectorRef output_cell = VectorRef({is_var_add1, new_to_cell, old_to_cell});
158 
159   auto new_hidden = hidden_cells[1];
160   MS_CHECK_TRUE_RET(!new_hidden.empty(), {});
161   auto is_var_mul3 = std::make_shared<Var>("Mul");
162   MS_CHECK_TRUE_RET(is_var_mul3 != nullptr, {});
163   VectorRef new_to_hidden = VectorRef({is_var_mul3, hidden_zoneout_new_, new_hidden});
164   auto is_var_mul4 = std::make_shared<Var>("Mul");
165   MS_CHECK_TRUE_RET(is_var_mul4 != nullptr, {});
166   VectorRef old_to_hidden = VectorRef({is_var_mul4, hidden_zoneout_old_, placeholders[5]});
167   auto is_var_add2 = std::make_shared<Var>("Add");
168   MS_CHECK_TRUE_RET(is_var_add2 != nullptr, {});
169   VectorRef output_hidden = VectorRef({is_var_add2, new_to_hidden, old_to_hidden});
170 
171   auto is_var3 = std::make_shared<Var>();
172   MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
173   VectorRef set_item = VectorRef({is_var3, placeholders[3], placeholders[2], new_hidden});
174 
175   auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
176   MS_CHECK_TRUE_RET(is_make_tuple != nullptr, nullptr);
177   std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden};
178   outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
179   VectorRef make_tuple_node = VectorRef(outputs);
180   auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
181   MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
182   VectorRef return_node = VectorRef({is_return, make_tuple_node});
183 
184   VarPtr is_fg = std::make_shared<Var>("RootG");
185   MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
186   auto pattern = Helper::SexpToNode(return_node, is_fg, primitive_vars.get(), true);
187   return pattern;
188 }
189 
SetWeightAbstractAndDefault(const ParameterPtr & weight,const std::vector<int64_t> & shape,const float * const data_ptr,const int hidden_size)190 STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector<int64_t> &shape,
191                                                      const float *const data_ptr, const int hidden_size) {
192   MS_ASSERT(weight != nullptr);
193   MS_ASSERT(data_ptr != nullptr);
194   if (shape.size() != kInputSizeThree) {
195     MS_LOG(ERROR) << "lstm weight shape must have 3 dims";
196     return RET_ERROR;
197   }
198   const auto param_num = shape[0] * shape[1] * shape[kInputIndexTwo];
199   auto tensor_data = new (std::nothrow) float[static_cast<size_t>(param_num) * sizeof(float)];
200   std::vector<int> data_diff{0, 3, 2, 1};
201   if (tensor_data == nullptr) {
202     MS_LOG(DEBUG) << "new data failed";
203     return RET_ERROR;
204   }
205   for (int i = 0; i < 4; ++i) {
206     for (int j = 0; j < hidden_size; ++j) {
207       for (int t = 0; t < shape[2]; ++t) {
208         tensor_data[(i * hidden_size + j) * shape[2] + t] = data_ptr[t * shape[1] + data_diff[i] * hidden_size + j];
209       }
210     }
211   }
212   auto tensor_info =
213     lite::CreateTensorInfo(tensor_data, static_cast<size_t>(param_num) * sizeof(float), shape, kNumberTypeFloat32);
214   delete[] tensor_data;
215   if (tensor_info == nullptr) {
216     MS_LOG(ERROR) << "create tensor info failed.";
217     return RET_ERROR;
218   }
219   auto status = lite::InitParameterFromTensorInfo(weight, tensor_info);
220   if (status != RET_OK) {
221     MS_LOG(ERROR) << "init parameter from tensor info failed";
222     return RET_ERROR;
223   }
224   return RET_OK;
225 }
226 
SplitWeights(const AnfNodePtr & weight,const ParameterPtr & weight_i,const ParameterPtr & weight_c,int hidden_size)227 STATUS TfLstmCellFusion::SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i,
228                                       const ParameterPtr &weight_c, int hidden_size) {
229   // split input_size and hidden_size at dim 0
230   // transform i,c,f,o to i,o,f,c at dim 1
231   MS_ASSERT(weight != nullptr);
232   MS_ASSERT(weight_i != nullptr);
233   MS_ASSERT(weight_c != nullptr);
234   if (!utils::isa<ParameterPtr>(weight)) {
235     return RET_ERROR;
236   }
237   auto weight_param = utils::cast<ParameterPtr>(weight);
238   if (!weight_param->has_default() || weight_param->default_param() == nullptr) {
239     MS_LOG(DEBUG) << "weight not have default value";
240     return RET_ERROR;
241   }
242   if (!utils::isa<tensor::TensorPtr>(weight_param->default_param())) {
243     MS_LOG(DEBUG) << "default value is not tensor::Tensor";
244     return RET_FAILED;
245   }
246   auto origin_tensor = std::dynamic_pointer_cast<tensor::Tensor>(weight_param->default_param());
247   if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) {
248     MS_LOG(DEBUG) << "origin_tensor is not float32 type";
249     return RET_ERROR;
250   }
251   auto data_ptr = reinterpret_cast<float *>(origin_tensor->data_c());
252   auto data_shape = origin_tensor->shape();
253   if (data_shape.size() != kInputSizeTwo) {
254     MS_LOG(ERROR) << "weight data shape invalid";
255     return RET_ERROR;
256   }
257   if (data_shape[0] <= hidden_size) {
258     MS_LOG(ERROR) << "weight data shape[0] invalid";
259     return RET_ERROR;
260   }
261   if (hidden_size * 4 != data_shape[1]) {
262     MS_LOG(ERROR) << "weight data shape[1] invalid";
263     return RET_ERROR;
264   }
265   const auto input_size = data_shape[0] - hidden_size;
266 
267   std::vector<int64_t> shape_i{1, kUnidirectionalGateNum * hidden_size, input_size};
268   if (SetWeightAbstractAndDefault(weight_i, shape_i, data_ptr, hidden_size) != RET_OK) {
269     MS_LOG(ERROR) << "get weight_i failed";
270     return RET_ERROR;
271   }
272 
273   std::vector<int64_t> shape_c{1, kUnidirectionalGateNum * hidden_size, hidden_size};
274   if (SetWeightAbstractAndDefault(weight_c, shape_c, data_ptr + input_size * data_shape[1], hidden_size) != RET_OK) {
275     MS_LOG(ERROR) << "get weight_i failed";
276     return RET_ERROR;
277   }
278   return RET_OK;
279 }
280 
PopulateBiasNode(const EquivPtr & body_equiv,const ParameterPtr & new_bias,const AnfNodePtr & old_bias,const int hidden_size) const281 STATUS TfLstmCellFusion::PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias,
282                                           const AnfNodePtr &old_bias, const int hidden_size) const {
283   MS_ASSERT(body_equiv != nullptr);
284   MS_ASSERT(new_bias != nullptr);
285   MS_ASSERT(old_bias != nullptr);
286   if (!utils::isa<ParameterPtr>(old_bias)) {
287     MS_LOG(DEBUG) << "old_bias is not parameter";
288     return RET_ERROR;
289   }
290   auto old_bias_param = utils::cast<ParameterPtr>(old_bias);
291   if (!old_bias_param->has_default() || old_bias_param->default_param() == nullptr) {
292     MS_LOG(DEBUG) << "bias not have default value";
293     return RET_ERROR;
294   }
295   if (!utils::isa<tensor::TensorPtr>(old_bias_param->default_param())) {
296     MS_LOG(DEBUG) << "default value is not tensor::Tensor";
297     return RET_FAILED;
298   }
299   auto origin_tensor = std::dynamic_pointer_cast<tensor::Tensor>(old_bias_param->default_param());
300   MS_CHECK_TRUE_RET(origin_tensor != nullptr, RET_ERROR);
301   if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) {
302     MS_LOG(DEBUG) << "origin_tensor is not float32 type";
303     return RET_ERROR;
304   }
305   auto data_ptr = reinterpret_cast<float *>(origin_tensor->data_c());
306   MS_CHECK_TRUE_RET(data_ptr != nullptr, RET_ERROR);
307   auto data_shape = origin_tensor->shape();
308   MS_CHECK_GE(hidden_size, 0, RET_ERROR);
309   if (data_shape.size() != 1 || data_shape[0] != 4 * hidden_size) {
310     MS_LOG(DEBUG) << "bias data shape illegal";
311     return RET_ERROR;
312   }
313 
314   std::vector<int64_t> shape{1, kBidirectionalGateNum * hidden_size};
315   auto tensor_data = std::make_unique<float[]>(static_cast<size_t>(hidden_size) * 8);
316   MS_CHECK_TRUE_RET(tensor_data != nullptr, lite::RET_ERROR);
317   auto forget_bias_node = utils::cast<AnfNodePtr>((*body_equiv)[forget_bias_]);
318   if (forget_bias_node == nullptr) {
319     MS_LOG(ERROR) << "forget bias node is nullptr";
320     return RET_ERROR;
321   }
322   float forget_bias_value = 0.0f;
323   if (GetFloatScalarFromTensorInfo(forget_bias_node, &forget_bias_value) != RET_OK) {
324     return RET_ERROR;
325   }
326 
327   std::vector<int> data_diff{0, 3, 2, 1};
328   for (int i = 0; i < 8; ++i) {
329     for (int j = 0; j < hidden_size; ++j) {
330       if (i < 4) {
331         tensor_data[i * hidden_size + j] = data_ptr[data_diff[i] * hidden_size + j];
332         if (i == 2) {  // forget bias
333           tensor_data[i * hidden_size + j] += forget_bias_value;
334         }
335       } else {
336         tensor_data[i * hidden_size + j] = 0.0f;
337       }
338     }
339   }
340 
341   auto tensor_info =
342     lite::CreateTensorInfo(tensor_data.get(), static_cast<size_t>(hidden_size) * kBidirectionalGateNum * sizeof(float),
343                            shape, kNumberTypeFloat32);
344   if (tensor_info == nullptr) {
345     MS_LOG(ERROR) << "create tensor info failed.";
346     return RET_ERROR;
347   }
348 
349   auto status = lite::InitParameterFromTensorInfo(new_bias, tensor_info);
350   if (status != RET_OK) {
351     MS_LOG(ERROR) << "init parameter from tensor info failed";
352     return RET_ERROR;
353   }
354 
355   return RET_OK;
356 }
357 
CreateLSTMNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const EquivPtr & body_equiv,const std::string & base_name,const float zoneout_cell,const float zoneout_hidden) const358 CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
359                                           const EquivPtr &body_equiv, const std::string &base_name,
360                                           const float zoneout_cell, const float zoneout_hidden) const {
361   MS_ASSERT(func_graph != nullptr);
362   MS_ASSERT(equiv != nullptr);
363   auto lstm_prim = std::make_shared<ops::LSTM>();
364   MS_CHECK_TRUE_RET(lstm_prim != nullptr, nullptr);
365   auto lstm_prim_c = lstm_prim->GetPrim();
366   MS_CHECK_TRUE_RET(lstm_prim_c != nullptr, nullptr);
367   lstm_prim->set_bidirectional(false);
368   lstm_prim->set_zoneout_cell(zoneout_cell);
369   lstm_prim->set_zoneout_hidden(zoneout_hidden);
370   auto value_node = NewValueNode(lstm_prim_c);
371   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
372 
373   auto &vars = while_input_vars_;
374   auto weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
375   MS_ASSERT(weight);
376   auto bias = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
377   MS_ASSERT(bias);
378   auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]);
379   MS_ASSERT(input);
380   auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]);
381   MS_ASSERT(cell);
382   auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]);
383   MS_ASSERT(hidden);
384 
385   if (!utils::isa<ParameterPtr>(hidden)) {
386     MS_LOG(DEBUG) << "hidden is not parameter";
387     return nullptr;
388   }
389   auto hidden_param = utils::cast<ParameterPtr>(hidden);
390   if (!utils::isa<abstract::AbstractTensorPtr>(hidden_param->abstract())) {
391     MS_LOG(DEBUG) << "hidden abstract is not AbstractTensor";
392     return nullptr;
393   }
394   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract());
395   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "Cast to abstract tensor failed!");
396   auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
397   if (hidden_shape.empty()) {
398     MS_LOG(DEBUG) << "can't get hidden shape";
399     return nullptr;
400   }
401 
402   auto i_weight = func_graph->add_parameter();
403   MS_CHECK_TRUE_RET(i_weight != nullptr, nullptr);
404   i_weight->set_name(base_name + "_weight_i");
405   if (weight->abstract() != nullptr) {
406     i_weight->set_abstract(weight->abstract()->Clone());
407   }
408 
409   auto c_weight = func_graph->add_parameter();
410   MS_CHECK_TRUE_RET(c_weight != nullptr, nullptr);
411   c_weight->set_name(base_name + "_weight_c");
412   if (weight->abstract() != nullptr) {
413     c_weight->set_abstract(weight->abstract()->Clone());
414   }
415 
416   if (SplitWeights(weight, i_weight, c_weight, static_cast<int>(hidden_shape.back())) != RET_OK) {
417     MS_LOG(DEBUG) << "split weight to i_weight and c_weight failed";
418     return nullptr;
419   }
420 
421   auto bias_node = func_graph->add_parameter();
422   MS_CHECK_TRUE_RET(bias_node != nullptr, nullptr);
423   MS_CHECK_TRUE_RET(bias_node != nullptr, nullptr);
424   bias_node->set_name(base_name + "_bias");
425   if (bias->abstract() != nullptr) {
426     bias_node->set_abstract(bias->abstract()->Clone());
427   }
428 
429   if (PopulateBiasNode(body_equiv, bias_node, bias, static_cast<int>(hidden_shape.back())) != RET_OK) {
430     MS_LOG(DEBUG) << "reorder bias failed";
431     return nullptr;
432   }
433 
434   if (!utils::isa<CNodePtr>(input) || !CheckPrimitiveType(input, prim::kPrimTensorListFromTensor)) {
435     MS_LOG(DEBUG) << "input is not tensorlistfromtensor op";
436     return nullptr;
437   }
438   auto tensor_list_cnode = utils::cast<CNodePtr>(input);
439   auto input_tensor_node = tensor_list_cnode->input(1);
440 
441   std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias_node, hidden,
442                                              cell};
443   auto new_node = func_graph->NewCNode(new_node_inputs);
444   MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
445   new_node->set_fullname_with_scope(base_name);
446   return new_node;
447 }
448 }  // namespace opt
449 }  // namespace mindspore
450