1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
19 #include <memory>
20 #include "mindspore/core/ops/structure_ops.h"
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "ops/lstm.h"
24 #include "src/common/utils.h"
25 #include "tools/common/tensor_util.h"
26 #include "include/common/utils/utils.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
29 #include "tools/optimizer/common/helper.h"
30 #include "nnacl/op_base.h"
31
32 namespace mindspore {
33 namespace opt {
34 namespace {
35 constexpr int kNumInPlaceHolder = 10;
36 constexpr int kNumGetItem = 4;
37 constexpr size_t kLstmInputsLength = 13;
38 constexpr size_t kLstmInputsVarNum = 11;
39 constexpr size_t kCondNodesNum = 12;
40 constexpr size_t kCondCNodesNum = 4;
41 constexpr size_t kBodyNodesNum = 82;
42 constexpr size_t kBodyCNodesNum = 30;
43 constexpr auto kUnidirectionalGateNum = 4;
44 constexpr auto kBidirectionalGateNum = 8;
45 const auto &p1 = std::placeholders::_1;
IsParameterNode(const BaseRef & n)46 bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
47
GenerateBodyGraphHiddenPattern(const VarPtr & forget_bias_input,const std::vector<CondVarPtr> & placeholders)48 std::vector<VectorRef> GenerateBodyGraphHiddenPattern(const VarPtr &forget_bias_input,
49 const std::vector<CondVarPtr> &placeholders) {
50 MS_CHECK_TRUE_RET(placeholders.size() >= kNumInPlaceHolder, {});
51 auto is_var_getitem = std::make_shared<Var>("GetItem");
52 MS_CHECK_TRUE_RET(is_var_getitem != nullptr, {});
53 auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
54 MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
55 VectorRef get_item = VectorRef({is_var_getitem, placeholders[7], placeholders[2], is_param3});
56 auto is_var1 = std::make_shared<Var>();
57 MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
58 VectorRef concat_input_h = VectorRef({is_var1, get_item, placeholders[5]});
59
60 auto is_var2 = std::make_shared<Var>();
61 MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
62 VectorRef matmul = VectorRef({is_var2, concat_input_h, placeholders[8]});
63 auto is_var3 = std::make_shared<Var>();
64 MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
65 VectorRef bias = VectorRef({is_var3, matmul, placeholders[9]});
66 auto is_var4 = std::make_shared<Var>();
67 MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
68 VectorRef split = VectorRef({is_var4, bias});
69
70 std::vector<VectorRef> get_items;
71 for (int i = 0; i < kNumGetItem; ++i) {
72 auto is_var_loop1 = std::make_shared<Var>();
73 MS_CHECK_TRUE_RET(is_var_loop1 != nullptr, {});
74 auto is_var_loop2 = std::make_shared<Var>();
75 MS_CHECK_TRUE_RET(is_var_loop2 != nullptr, {});
76 VectorRef get_item_loop = VectorRef({is_var_loop1, split, is_var_loop2});
77 get_items.push_back(get_item_loop);
78 }
79
80 auto is_var_sigmoid1 = std::make_shared<Var>("Sigmoid");
81 MS_CHECK_TRUE_RET(is_var_sigmoid1 != nullptr, {});
82 VectorRef input_gate = VectorRef({is_var_sigmoid1, get_items[0]});
83 auto is_var_tanh1 = std::make_shared<Var>("Tanh");
84 MS_CHECK_TRUE_RET(is_var_tanh1 != nullptr, {});
85 VectorRef input_to_cell = VectorRef({is_var_tanh1, get_items[1]});
86 auto is_var_add1 = std::make_shared<Var>("Add");
87 MS_CHECK_TRUE_RET(is_var_add1 != nullptr, {});
88 VectorRef forget_bias = VectorRef({is_var_add1, get_items[2], forget_bias_input});
89 auto is_var_sigmoid2 = std::make_shared<Var>("Sigmoid");
90 MS_CHECK_TRUE_RET(is_var_sigmoid2 != nullptr, {});
91 VectorRef forget_gate = VectorRef({is_var_sigmoid2, forget_bias});
92 auto is_var_sigmoid3 = std::make_shared<Var>("Sigmoid");
93 MS_CHECK_TRUE_RET(is_var_sigmoid3 != nullptr, {});
94 VectorRef output_gate = VectorRef({is_var_sigmoid3, get_items[3]});
95
96 auto is_var5 = std::make_shared<Var>();
97 MS_CHECK_TRUE_RET(is_var5 != nullptr, {});
98 VectorRef forgetted_cell = VectorRef({is_var5, forget_gate, placeholders[4]});
99 auto is_var6 = std::make_shared<Var>();
100 MS_CHECK_TRUE_RET(is_var6 != nullptr, {});
101 VectorRef inputted_cell = VectorRef({is_var6, input_gate, input_to_cell});
102 auto is_var_add2 = std::make_shared<Var>("Add");
103 MS_CHECK_TRUE_RET(is_var_add2 != nullptr, {});
104 VectorRef input_forget_cell = VectorRef({is_var_add2, forgetted_cell, inputted_cell});
105 auto is_var_tanh2 = std::make_shared<Var>("Tanh");
106 MS_CHECK_TRUE_RET(is_var_tanh2 != nullptr, {});
107 VectorRef to_new_hidden = VectorRef({is_var_tanh2, input_forget_cell});
108 auto is_var_mul = std::make_shared<Var>("Mul");
109 MS_CHECK_TRUE_RET(is_var_mul != nullptr, {});
110 VectorRef new_hidden = VectorRef({is_var_mul, output_gate, to_new_hidden});
111 return {input_forget_cell, new_hidden};
112 }
113 } // namespace
114
TfLstmCellFusion(const std::string & name,bool multigraph)115 TfLstmCellFusion::TfLstmCellFusion(const std::string &name, bool multigraph)
116 : TfliteLstmCellFusion(name, multigraph, kLstmInputsLength, kLstmInputsVarNum, kCondNodesNum, kCondCNodesNum,
117 kBodyNodesNum, kBodyCNodesNum) {
118 /*
119 * vars for lstm cell input
120 * 0:cond 1:body 2:index 3:limit1 4:output 5:cell 6:hidden 7:limit2 8:input 9:kernel 10:bias
121 */
122 }
123
GetBodyGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const124 AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
125 std::vector<CondVarPtr> placeholders;
126 for (int i = 0; i < kNumInPlaceHolder; ++i) {
127 auto is_param_holder = std::make_shared<CondVar>(IsParameterNode);
128 MS_CHECK_TRUE_RET(is_param_holder != nullptr, nullptr);
129 placeholders.emplace_back(is_param_holder);
130 }
131 auto is_var1 = std::make_shared<Var>();
132 MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
133 auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
134 MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
135 VectorRef add2 = VectorRef({is_var1, placeholders[2], is_param1});
136 auto is_var2 = std::make_shared<Var>();
137 MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
138 auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
139 MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
140 VectorRef add3 = VectorRef({is_var2, placeholders[0], is_param2});
141
142 forget_bias_ = std::make_shared<Var>();
143 MS_CHECK_TRUE_RET(forget_bias_ != nullptr, nullptr);
144 auto hidden_cells = GenerateBodyGraphHiddenPattern(forget_bias_, placeholders);
145 MS_CHECK_TRUE_RET(hidden_cells.size() == kInputSizeTwo, {});
146
147 auto is_var_mul1 = std::make_shared<Var>("Mul");
148 MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
149 auto input_forget_cell = hidden_cells[0];
150 MS_CHECK_TRUE_RET(!input_forget_cell.empty(), {});
151 VectorRef new_to_cell = VectorRef({is_var_mul1, cell_zoneout_new_, input_forget_cell});
152 auto is_var_mul2 = std::make_shared<Var>("Mul");
153 MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
154 VectorRef old_to_cell = VectorRef({is_var_mul2, cell_zoneout_old_, placeholders[4]});
155 auto is_var_add1 = std::make_shared<Var>("Add");
156 MS_CHECK_TRUE_RET(is_var_add1 != nullptr, {});
157 VectorRef output_cell = VectorRef({is_var_add1, new_to_cell, old_to_cell});
158
159 auto new_hidden = hidden_cells[1];
160 MS_CHECK_TRUE_RET(!new_hidden.empty(), {});
161 auto is_var_mul3 = std::make_shared<Var>("Mul");
162 MS_CHECK_TRUE_RET(is_var_mul3 != nullptr, {});
163 VectorRef new_to_hidden = VectorRef({is_var_mul3, hidden_zoneout_new_, new_hidden});
164 auto is_var_mul4 = std::make_shared<Var>("Mul");
165 MS_CHECK_TRUE_RET(is_var_mul4 != nullptr, {});
166 VectorRef old_to_hidden = VectorRef({is_var_mul4, hidden_zoneout_old_, placeholders[5]});
167 auto is_var_add2 = std::make_shared<Var>("Add");
168 MS_CHECK_TRUE_RET(is_var_add2 != nullptr, {});
169 VectorRef output_hidden = VectorRef({is_var_add2, new_to_hidden, old_to_hidden});
170
171 auto is_var3 = std::make_shared<Var>();
172 MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
173 VectorRef set_item = VectorRef({is_var3, placeholders[3], placeholders[2], new_hidden});
174
175 auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
176 MS_CHECK_TRUE_RET(is_make_tuple != nullptr, nullptr);
177 std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, output_cell, output_hidden};
178 outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
179 VectorRef make_tuple_node = VectorRef(outputs);
180 auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
181 MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
182 VectorRef return_node = VectorRef({is_return, make_tuple_node});
183
184 VarPtr is_fg = std::make_shared<Var>("RootG");
185 MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
186 auto pattern = Helper::SexpToNode(return_node, is_fg, primitive_vars.get(), true);
187 return pattern;
188 }
189
SetWeightAbstractAndDefault(const ParameterPtr & weight,const std::vector<int64_t> & shape,const float * const data_ptr,const int hidden_size)190 STATUS TfLstmCellFusion::SetWeightAbstractAndDefault(const ParameterPtr &weight, const std::vector<int64_t> &shape,
191 const float *const data_ptr, const int hidden_size) {
192 MS_ASSERT(weight != nullptr);
193 MS_ASSERT(data_ptr != nullptr);
194 if (shape.size() != kInputSizeThree) {
195 MS_LOG(ERROR) << "lstm weight shape must have 3 dims";
196 return RET_ERROR;
197 }
198 const auto param_num = shape[0] * shape[1] * shape[kInputIndexTwo];
199 auto tensor_data = new (std::nothrow) float[static_cast<size_t>(param_num) * sizeof(float)];
200 std::vector<int> data_diff{0, 3, 2, 1};
201 if (tensor_data == nullptr) {
202 MS_LOG(DEBUG) << "new data failed";
203 return RET_ERROR;
204 }
205 for (int i = 0; i < 4; ++i) {
206 for (int j = 0; j < hidden_size; ++j) {
207 for (int t = 0; t < shape[2]; ++t) {
208 tensor_data[(i * hidden_size + j) * shape[2] + t] = data_ptr[t * shape[1] + data_diff[i] * hidden_size + j];
209 }
210 }
211 }
212 auto tensor_info =
213 lite::CreateTensorInfo(tensor_data, static_cast<size_t>(param_num) * sizeof(float), shape, kNumberTypeFloat32);
214 delete[] tensor_data;
215 if (tensor_info == nullptr) {
216 MS_LOG(ERROR) << "create tensor info failed.";
217 return RET_ERROR;
218 }
219 auto status = lite::InitParameterFromTensorInfo(weight, tensor_info);
220 if (status != RET_OK) {
221 MS_LOG(ERROR) << "init parameter from tensor info failed";
222 return RET_ERROR;
223 }
224 return RET_OK;
225 }
226
SplitWeights(const AnfNodePtr & weight,const ParameterPtr & weight_i,const ParameterPtr & weight_c,int hidden_size)227 STATUS TfLstmCellFusion::SplitWeights(const AnfNodePtr &weight, const ParameterPtr &weight_i,
228 const ParameterPtr &weight_c, int hidden_size) {
229 // split input_size and hidden_size at dim 0
230 // transform i,c,f,o to i,o,f,c at dim 1
231 MS_ASSERT(weight != nullptr);
232 MS_ASSERT(weight_i != nullptr);
233 MS_ASSERT(weight_c != nullptr);
234 if (!utils::isa<ParameterPtr>(weight)) {
235 return RET_ERROR;
236 }
237 auto weight_param = utils::cast<ParameterPtr>(weight);
238 if (!weight_param->has_default() || weight_param->default_param() == nullptr) {
239 MS_LOG(DEBUG) << "weight not have default value";
240 return RET_ERROR;
241 }
242 if (!utils::isa<tensor::TensorPtr>(weight_param->default_param())) {
243 MS_LOG(DEBUG) << "default value is not tensor::Tensor";
244 return RET_FAILED;
245 }
246 auto origin_tensor = std::dynamic_pointer_cast<tensor::Tensor>(weight_param->default_param());
247 if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) {
248 MS_LOG(DEBUG) << "origin_tensor is not float32 type";
249 return RET_ERROR;
250 }
251 auto data_ptr = reinterpret_cast<float *>(origin_tensor->data_c());
252 auto data_shape = origin_tensor->shape();
253 if (data_shape.size() != kInputSizeTwo) {
254 MS_LOG(ERROR) << "weight data shape invalid";
255 return RET_ERROR;
256 }
257 if (data_shape[0] <= hidden_size) {
258 MS_LOG(ERROR) << "weight data shape[0] invalid";
259 return RET_ERROR;
260 }
261 if (hidden_size * 4 != data_shape[1]) {
262 MS_LOG(ERROR) << "weight data shape[1] invalid";
263 return RET_ERROR;
264 }
265 const auto input_size = data_shape[0] - hidden_size;
266
267 std::vector<int64_t> shape_i{1, kUnidirectionalGateNum * hidden_size, input_size};
268 if (SetWeightAbstractAndDefault(weight_i, shape_i, data_ptr, hidden_size) != RET_OK) {
269 MS_LOG(ERROR) << "get weight_i failed";
270 return RET_ERROR;
271 }
272
273 std::vector<int64_t> shape_c{1, kUnidirectionalGateNum * hidden_size, hidden_size};
274 if (SetWeightAbstractAndDefault(weight_c, shape_c, data_ptr + input_size * data_shape[1], hidden_size) != RET_OK) {
275 MS_LOG(ERROR) << "get weight_i failed";
276 return RET_ERROR;
277 }
278 return RET_OK;
279 }
280
PopulateBiasNode(const EquivPtr & body_equiv,const ParameterPtr & new_bias,const AnfNodePtr & old_bias,const int hidden_size) const281 STATUS TfLstmCellFusion::PopulateBiasNode(const EquivPtr &body_equiv, const ParameterPtr &new_bias,
282 const AnfNodePtr &old_bias, const int hidden_size) const {
283 MS_ASSERT(body_equiv != nullptr);
284 MS_ASSERT(new_bias != nullptr);
285 MS_ASSERT(old_bias != nullptr);
286 if (!utils::isa<ParameterPtr>(old_bias)) {
287 MS_LOG(DEBUG) << "old_bias is not parameter";
288 return RET_ERROR;
289 }
290 auto old_bias_param = utils::cast<ParameterPtr>(old_bias);
291 if (!old_bias_param->has_default() || old_bias_param->default_param() == nullptr) {
292 MS_LOG(DEBUG) << "bias not have default value";
293 return RET_ERROR;
294 }
295 if (!utils::isa<tensor::TensorPtr>(old_bias_param->default_param())) {
296 MS_LOG(DEBUG) << "default value is not tensor::Tensor";
297 return RET_FAILED;
298 }
299 auto origin_tensor = std::dynamic_pointer_cast<tensor::Tensor>(old_bias_param->default_param());
300 MS_CHECK_TRUE_RET(origin_tensor != nullptr, RET_ERROR);
301 if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) {
302 MS_LOG(DEBUG) << "origin_tensor is not float32 type";
303 return RET_ERROR;
304 }
305 auto data_ptr = reinterpret_cast<float *>(origin_tensor->data_c());
306 MS_CHECK_TRUE_RET(data_ptr != nullptr, RET_ERROR);
307 auto data_shape = origin_tensor->shape();
308 MS_CHECK_GE(hidden_size, 0, RET_ERROR);
309 if (data_shape.size() != 1 || data_shape[0] != 4 * hidden_size) {
310 MS_LOG(DEBUG) << "bias data shape illegal";
311 return RET_ERROR;
312 }
313
314 std::vector<int64_t> shape{1, kBidirectionalGateNum * hidden_size};
315 auto tensor_data = std::make_unique<float[]>(static_cast<size_t>(hidden_size) * 8);
316 MS_CHECK_TRUE_RET(tensor_data != nullptr, lite::RET_ERROR);
317 auto forget_bias_node = utils::cast<AnfNodePtr>((*body_equiv)[forget_bias_]);
318 if (forget_bias_node == nullptr) {
319 MS_LOG(ERROR) << "forget bias node is nullptr";
320 return RET_ERROR;
321 }
322 float forget_bias_value = 0.0f;
323 if (GetFloatScalarFromTensorInfo(forget_bias_node, &forget_bias_value) != RET_OK) {
324 return RET_ERROR;
325 }
326
327 std::vector<int> data_diff{0, 3, 2, 1};
328 for (int i = 0; i < 8; ++i) {
329 for (int j = 0; j < hidden_size; ++j) {
330 if (i < 4) {
331 tensor_data[i * hidden_size + j] = data_ptr[data_diff[i] * hidden_size + j];
332 if (i == 2) { // forget bias
333 tensor_data[i * hidden_size + j] += forget_bias_value;
334 }
335 } else {
336 tensor_data[i * hidden_size + j] = 0.0f;
337 }
338 }
339 }
340
341 auto tensor_info =
342 lite::CreateTensorInfo(tensor_data.get(), static_cast<size_t>(hidden_size) * kBidirectionalGateNum * sizeof(float),
343 shape, kNumberTypeFloat32);
344 if (tensor_info == nullptr) {
345 MS_LOG(ERROR) << "create tensor info failed.";
346 return RET_ERROR;
347 }
348
349 auto status = lite::InitParameterFromTensorInfo(new_bias, tensor_info);
350 if (status != RET_OK) {
351 MS_LOG(ERROR) << "init parameter from tensor info failed";
352 return RET_ERROR;
353 }
354
355 return RET_OK;
356 }
357
CreateLSTMNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const EquivPtr & body_equiv,const std::string & base_name,const float zoneout_cell,const float zoneout_hidden) const358 CNodePtr TfLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
359 const EquivPtr &body_equiv, const std::string &base_name,
360 const float zoneout_cell, const float zoneout_hidden) const {
361 MS_ASSERT(func_graph != nullptr);
362 MS_ASSERT(equiv != nullptr);
363 auto lstm_prim = std::make_shared<ops::LSTM>();
364 MS_CHECK_TRUE_RET(lstm_prim != nullptr, nullptr);
365 auto lstm_prim_c = lstm_prim->GetPrim();
366 MS_CHECK_TRUE_RET(lstm_prim_c != nullptr, nullptr);
367 lstm_prim->set_bidirectional(false);
368 lstm_prim->set_zoneout_cell(zoneout_cell);
369 lstm_prim->set_zoneout_hidden(zoneout_hidden);
370 auto value_node = NewValueNode(lstm_prim_c);
371 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
372
373 auto &vars = while_input_vars_;
374 auto weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
375 MS_ASSERT(weight);
376 auto bias = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
377 MS_ASSERT(bias);
378 auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]);
379 MS_ASSERT(input);
380 auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]);
381 MS_ASSERT(cell);
382 auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]);
383 MS_ASSERT(hidden);
384
385 if (!utils::isa<ParameterPtr>(hidden)) {
386 MS_LOG(DEBUG) << "hidden is not parameter";
387 return nullptr;
388 }
389 auto hidden_param = utils::cast<ParameterPtr>(hidden);
390 if (!utils::isa<abstract::AbstractTensorPtr>(hidden_param->abstract())) {
391 MS_LOG(DEBUG) << "hidden abstract is not AbstractTensor";
392 return nullptr;
393 }
394 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(hidden_param->abstract());
395 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "Cast to abstract tensor failed!");
396 auto hidden_shape = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
397 if (hidden_shape.empty()) {
398 MS_LOG(DEBUG) << "can't get hidden shape";
399 return nullptr;
400 }
401
402 auto i_weight = func_graph->add_parameter();
403 MS_CHECK_TRUE_RET(i_weight != nullptr, nullptr);
404 i_weight->set_name(base_name + "_weight_i");
405 if (weight->abstract() != nullptr) {
406 i_weight->set_abstract(weight->abstract()->Clone());
407 }
408
409 auto c_weight = func_graph->add_parameter();
410 MS_CHECK_TRUE_RET(c_weight != nullptr, nullptr);
411 c_weight->set_name(base_name + "_weight_c");
412 if (weight->abstract() != nullptr) {
413 c_weight->set_abstract(weight->abstract()->Clone());
414 }
415
416 if (SplitWeights(weight, i_weight, c_weight, static_cast<int>(hidden_shape.back())) != RET_OK) {
417 MS_LOG(DEBUG) << "split weight to i_weight and c_weight failed";
418 return nullptr;
419 }
420
421 auto bias_node = func_graph->add_parameter();
422 MS_CHECK_TRUE_RET(bias_node != nullptr, nullptr);
423 MS_CHECK_TRUE_RET(bias_node != nullptr, nullptr);
424 bias_node->set_name(base_name + "_bias");
425 if (bias->abstract() != nullptr) {
426 bias_node->set_abstract(bias->abstract()->Clone());
427 }
428
429 if (PopulateBiasNode(body_equiv, bias_node, bias, static_cast<int>(hidden_shape.back())) != RET_OK) {
430 MS_LOG(DEBUG) << "reorder bias failed";
431 return nullptr;
432 }
433
434 if (!utils::isa<CNodePtr>(input) || !CheckPrimitiveType(input, prim::kPrimTensorListFromTensor)) {
435 MS_LOG(DEBUG) << "input is not tensorlistfromtensor op";
436 return nullptr;
437 }
438 auto tensor_list_cnode = utils::cast<CNodePtr>(input);
439 auto input_tensor_node = tensor_list_cnode->input(1);
440
441 std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias_node, hidden,
442 cell};
443 auto new_node = func_graph->NewCNode(new_node_inputs);
444 MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
445 new_node->set_fullname_with_scope(base_name);
446 return new_node;
447 }
448 } // namespace opt
449 } // namespace mindspore
450