1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
17 #include <memory>
18 #include <algorithm>
19 #include <functional>
20 #include "ops/lstm.h"
21 #include "ops/squeeze.h"
22 #include "tools/converter/ops/ops_def.h"
23 #include "src/common/utils.h"
24 #include "tools/common/tensor_util.h"
25 #include "utils/utils.h"
26 #include "tools/optimizer/common/gllo_utils.h"
27 #include "securec/include/securec.h"
28 #include "nnacl/op_base.h"
29
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 constexpr size_t kWhileInputsLength = 23;
34 constexpr size_t kWhileInputsVarNum = 21;
35 constexpr size_t kCondNodesNum = 12;
36 constexpr size_t kCondCNodesNum = 4;
37 constexpr size_t kBodyNodesNum = 95;
38 constexpr size_t kBodyCNodesNum = 34;
39 constexpr size_t kLSTMOutputNum = 3;
40 constexpr auto kUnidirectionalGateNum = 4;
41 const auto &p1 = std::placeholders::_1;
42 constexpr float EPSILON = 1e-5;
IsParameterNode(const BaseRef & n)43 bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
44
GenerateBodyGraphCellPattern(const std::vector<CondVarPtr> & placeholders)45 std::vector<VectorRef> GenerateBodyGraphCellPattern(const std::vector<CondVarPtr> &placeholders) {
46 MS_CHECK_TRUE_RET(placeholders.size() > 19, {});
47 auto is_var1 = std::make_shared<Var>();
48 MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
49 VectorRef concat_i_w = VectorRef({is_var1, placeholders[8], placeholders[12]});
50 auto is_var2 = std::make_shared<Var>();
51 MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
52 VectorRef concat_f_w = VectorRef({is_var2, placeholders[9], placeholders[13]});
53 auto is_var3 = std::make_shared<Var>();
54 MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
55 VectorRef concat_c_w = VectorRef({is_var3, placeholders[10], placeholders[14]});
56 auto is_var4 = std::make_shared<Var>();
57 MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
58 VectorRef concat_o_w = VectorRef({is_var4, placeholders[11], placeholders[15]});
59
60 auto is_var_getitem = std::make_shared<Var>("GetItem");
61 MS_CHECK_TRUE_RET(is_var_getitem != nullptr, {});
62 auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
63 MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
64 VectorRef get_item = VectorRef({is_var_getitem, placeholders[7], placeholders[2], is_param3});
65 auto is_var5 = std::make_shared<Var>();
66 MS_CHECK_TRUE_RET(is_var5 != nullptr, {});
67 VectorRef concat_input_h = VectorRef({is_var5, get_item, placeholders[5]});
68
69 auto is_var6 = std::make_shared<Var>();
70 MS_CHECK_TRUE_RET(is_var6 != nullptr, {});
71 VectorRef matmul_input = VectorRef({is_var6, concat_input_h, concat_i_w});
72 auto is_var7 = std::make_shared<Var>();
73 MS_CHECK_TRUE_RET(is_var7 != nullptr, {});
74 VectorRef matmul_forget = VectorRef({is_var7, concat_input_h, concat_f_w});
75 auto is_var8 = std::make_shared<Var>();
76 MS_CHECK_TRUE_RET(is_var8 != nullptr, {});
77 VectorRef matmul_cell = VectorRef({is_var8, concat_input_h, concat_c_w});
78 auto is_var9 = std::make_shared<Var>();
79 MS_CHECK_TRUE_RET(is_var9 != nullptr, {});
80 VectorRef matmul_output = VectorRef({is_var9, concat_input_h, concat_o_w});
81
82 auto is_var10 = std::make_shared<Var>();
83 MS_CHECK_TRUE_RET(is_var10 != nullptr, {});
84 VectorRef bias_input = VectorRef({is_var10, matmul_input, placeholders[16]});
85 auto is_var11 = std::make_shared<Var>();
86 MS_CHECK_TRUE_RET(is_var11 != nullptr, {});
87 VectorRef bias_forget = VectorRef({is_var11, matmul_forget, placeholders[17]});
88 auto is_var12 = std::make_shared<Var>();
89 MS_CHECK_TRUE_RET(is_var12 != nullptr, {});
90 VectorRef bias_cell = VectorRef({is_var12, matmul_cell, placeholders[18]});
91 auto is_var13 = std::make_shared<Var>();
92 MS_CHECK_TRUE_RET(is_var13 != nullptr, {});
93 VectorRef bias_output = VectorRef({is_var13, matmul_output, placeholders[19]});
94
95 auto is_var_tanh = std::make_shared<Var>("Tanh");
96 MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
97 VectorRef cell = VectorRef({is_var_tanh, bias_cell});
98 auto is_var_sigmoid1 = std::make_shared<Var>("Sigmoid");
99 MS_CHECK_TRUE_RET(is_var_sigmoid1 != nullptr, {});
100 VectorRef input_gate = VectorRef({is_var_sigmoid1, bias_input});
101 auto is_var_mul1 = std::make_shared<Var>("Mul");
102 MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
103 VectorRef cell_input = VectorRef({is_var_mul1, input_gate, cell});
104 auto is_var_sigmoid2 = std::make_shared<Var>("Sigmoid");
105 MS_CHECK_TRUE_RET(is_var_sigmoid2 != nullptr, {});
106 VectorRef forget_gate = VectorRef({is_var_sigmoid2, bias_forget});
107 auto is_var_mul2 = std::make_shared<Var>("Mul");
108 MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
109 VectorRef cell_forgeted = VectorRef({is_var_mul2, forget_gate, placeholders[4]});
110 auto is_var_add = std::make_shared<Var>("Add");
111 MS_CHECK_TRUE_RET(is_var_add != nullptr, {});
112 VectorRef cell_new = VectorRef({is_var_add, cell_forgeted, cell_input});
113 return {bias_output, cell_new};
114 }
115 } // namespace
116
GetFloatScalarFromTensorInfo(const AnfNodePtr & tensor_info,float * v)117 STATUS TfliteLstmCellFusion::GetFloatScalarFromTensorInfo(const AnfNodePtr &tensor_info, float *v) {
118 if (tensor_info == nullptr || v == nullptr) {
119 MS_LOG(ERROR) << "tensor_info or v is nullptr";
120 return RET_ERROR;
121 }
122 if (!utils::isa<ParameterPtr>(tensor_info)) {
123 MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
124 return RET_ERROR;
125 }
126 auto param_ptr = utils::cast<ParameterPtr>(tensor_info);
127 if (!param_ptr->has_default() || param_ptr->default_param() == nullptr) {
128 MS_LOG(DEBUG) << "param not have default";
129 return RET_ERROR;
130 }
131 auto default_param = param_ptr->default_param();
132 if (!utils::isa<tensor::TensorPtr>(default_param)) {
133 MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
134 return RET_ERROR;
135 }
136 auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
137 auto tensor_shape = default_param_ptr->shape();
138 if (!(tensor_shape.empty() || (tensor_shape.size() == 1 && tensor_shape[0] == 1))) {
139 MS_LOG(DEBUG) << "default param is not scalar";
140 return RET_ERROR;
141 }
142 if (default_param_ptr->data_type() != kNumberTypeFloat32 && default_param_ptr->data_type() != kNumberTypeFloat) {
143 MS_LOG(DEBUG) << "default param is not float";
144 return RET_ERROR;
145 }
146 *v = *(reinterpret_cast<float *>(default_param_ptr->data_c()));
147 return RET_OK;
148 }
149
Init() const150 bool TfliteLstmCellFusion::Init() const {
151 for (size_t i = 0; i < this->while_input_var_num_; ++i) {
152 auto is_var = std::make_shared<Var>();
153 MS_CHECK_TRUE_RET(is_var != nullptr, false);
154 while_input_vars_.emplace_back(is_var);
155 }
156 cell_zoneout_old_ = std::make_shared<Var>();
157 MS_CHECK_TRUE_RET(cell_zoneout_old_ != nullptr, false);
158 cell_zoneout_new_ = std::make_shared<Var>();
159 MS_CHECK_TRUE_RET(cell_zoneout_new_ != nullptr, false);
160 hidden_zoneout_old_ = std::make_shared<Var>();
161 MS_CHECK_TRUE_RET(hidden_zoneout_old_ != nullptr, false);
162 hidden_zoneout_new_ = std::make_shared<Var>();
163 MS_CHECK_TRUE_RET(hidden_zoneout_new_ != nullptr, false);
164 return true;
165 }
166
TfliteLstmCellFusion(const std::string & name,bool multigraph,int input_length,int var_num,int cond_nodes_num,int cond_cnodes_num,int body_nodes_num,int body_cnodes_num)167 TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num,
168 int cond_nodes_num, int cond_cnodes_num, int body_nodes_num,
169 int body_cnodes_num)
170 : PatternProcessPass(name, multigraph) {
171 /*
172 * input vars for lstm while node
173 * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
174 * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
175 */
176 this->while_inputs_num_ = input_length == 0 ? kWhileInputsLength : input_length;
177 this->while_input_var_num_ = var_num == 0 ? kWhileInputsVarNum : var_num;
178 this->cond_nodes_num_ = cond_nodes_num == 0 ? kCondNodesNum : cond_nodes_num;
179 this->cond_cnodes_num_ = cond_cnodes_num == 0 ? kCondCNodesNum : cond_cnodes_num;
180 this->body_nodes_num_ = body_nodes_num == 0 ? kBodyNodesNum : body_nodes_num;
181 this->body_cnodes_num_ = body_cnodes_num == 0 ? kBodyCNodesNum : body_cnodes_num;
182 }
183
GetCondGraphPattern(const PrimitiveVarMapPtr & primitive_vars)184 AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) {
185 MS_ASSERT(primitive_vars != nullptr);
186 auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode);
187 MS_CHECK_TRUE_RET(is_parameter1 != nullptr, nullptr);
188 auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode);
189 MS_CHECK_TRUE_RET(is_parameter2 != nullptr, nullptr);
190 auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode);
191 MS_CHECK_TRUE_RET(is_parameter3 != nullptr, nullptr);
192 auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode);
193 MS_CHECK_TRUE_RET(is_parameter4 != nullptr, nullptr);
194 auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
195 MS_CHECK_TRUE_RET(is_less1 != nullptr, nullptr);
196 auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
197 MS_CHECK_TRUE_RET(is_less2 != nullptr, nullptr);
198 auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd));
199 MS_CHECK_TRUE_RET(is_logical_and != nullptr, nullptr);
200 auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
201 MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
202 VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2});
203 VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4});
204 VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
205 VectorRef return_ref = VectorRef({is_return, logicaland_ref});
206 VarPtr fg = std::make_shared<Var>("RootG");
207 MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
208 auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true);
209 return pattern;
210 }
211
GetBodyGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const212 AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
213 std::vector<CondVarPtr> placeholders;
214 for (int i = 0; i < 20; ++i) {
215 placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode));
216 }
217 auto is_var1 = std::make_shared<Var>();
218 MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
219 auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
220 MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
221 VectorRef add2 = VectorRef({is_var1, placeholders[2], is_param1});
222 auto is_var2 = std::make_shared<Var>();
223 MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
224 auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
225 MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
226 VectorRef add3 = VectorRef({is_var2, placeholders[0], is_param2});
227
228 auto hidden_cells = GenerateBodyGraphCellPattern(placeholders);
229 MS_CHECK_TRUE_RET(hidden_cells.size() == kInputSizeTwo, {});
230 auto is_var_mul1 = std::make_shared<Var>("Mul");
231 MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
232 VectorRef zoneout_cell_old = VectorRef({is_var_mul1, cell_zoneout_old_, placeholders[4]});
233 auto is_var_mul2 = std::make_shared<Var>("Mul");
234 MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
235 auto cell_new = hidden_cells[1];
236 MS_CHECK_TRUE_RET(!cell_new.empty(), {});
237 VectorRef zoneout_cell_new = VectorRef({is_var_mul2, cell_zoneout_new_, cell_new});
238 auto is_var_add1 = std::make_shared<Var>("Add");
239 MS_CHECK_TRUE_RET(is_var_add1 != nullptr, {});
240 VectorRef cell_output = VectorRef({is_var_add1, zoneout_cell_new, zoneout_cell_old});
241
242 auto is_var_sigmoid = std::make_shared<Var>("Sigmoid");
243 MS_CHECK_TRUE_RET(is_var_sigmoid != nullptr, {});
244 auto bias_output = hidden_cells[0];
245 MS_CHECK_TRUE_RET(!bias_output.empty(), {});
246 VectorRef output_gate = VectorRef({is_var_sigmoid, bias_output});
247 auto is_var_tanh = std::make_shared<Var>("Tanh");
248 MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
249 VectorRef cell_to_output = VectorRef({is_var_tanh, cell_new});
250 auto is_var_mul3 = std::make_shared<Var>("Mul");
251 MS_CHECK_TRUE_RET(is_var_mul3 != nullptr, {});
252 VectorRef output = VectorRef({is_var_mul3, output_gate, cell_to_output});
253
254 auto is_var_mul4 = std::make_shared<Var>("Mul");
255 MS_CHECK_TRUE_RET(is_var_mul4 != nullptr, {});
256 VectorRef zoneout_hidden_old = VectorRef({is_var_mul4, hidden_zoneout_old_, placeholders[5]});
257 auto is_var_mul5 = std::make_shared<Var>("Mul");
258 MS_CHECK_TRUE_RET(is_var_mul5 != nullptr, {});
259 VectorRef zoneout_hidden_new = VectorRef({is_var_mul5, hidden_zoneout_new_, output});
260 auto is_var_add2 = std::make_shared<Var>("Add");
261 MS_CHECK_TRUE_RET(is_var_add2 != nullptr, {});
262 VectorRef hidden_output = VectorRef({is_var_add2, zoneout_hidden_new, zoneout_hidden_old});
263
264 auto is_var_setitem = std::make_shared<Var>("SetItem");
265 MS_CHECK_TRUE_RET(is_var_setitem != nullptr, {});
266 VectorRef set_item = VectorRef({is_var_setitem, placeholders[3], placeholders[2], output});
267
268 auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
269 MS_CHECK_TRUE_RET(is_make_tuple != nullptr, nullptr);
270 std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output};
271 outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
272 VectorRef make_tuple_node = VectorRef(outputs);
273 auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
274 MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
275 VectorRef return_node = VectorRef({is_return, make_tuple_node});
276
277 VarPtr fg = std::make_shared<Var>("RootG");
278 MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
279 auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true);
280 return pattern;
281 }
282
DefinePattern() const283 const BaseRef TfliteLstmCellFusion::DefinePattern() const {
284 if (!Init()) {
285 MS_LOG(ERROR) << "initial member failed.";
286 return {};
287 }
288 auto is_while_node = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimWhile));
289 MS_CHECK_TRUE_RET(is_while_node != nullptr, {});
290 VectorRef while_node = VectorRef({is_while_node});
291 auto while_inputs = while_input_vars_;
292 MS_CHECK_TRUE_RET(while_inputs.size() > kInputSizeThree, {});
293 while_inputs.insert(while_inputs.begin() + 4, while_input_vars_[2]);
294 while_node.insert(while_node.end(), while_inputs.begin(), while_inputs.end());
295
296 auto is_tuple_get_item = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem));
297 MS_CHECK_TRUE_RET(is_tuple_get_item != nullptr, {});
298 auto is_var = std::make_shared<Var>();
299 MS_CHECK_TRUE_RET(is_var != nullptr, {});
300 VectorRef while_output = VectorRef({is_tuple_get_item, while_node, is_var});
301
302 auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListStack));
303 MS_CHECK_TRUE_RET(is_tensor_list_stack != nullptr, {});
304 auto is_parameter = std::make_shared<CondVar>(IsParameterNode);
305 MS_CHECK_TRUE_RET(is_parameter != nullptr, {});
306 VectorRef tensor_list_stack_node = VectorRef({is_tensor_list_stack, while_output, is_parameter});
307
308 return tensor_list_stack_node;
309 }
310
MatchGraph(const FuncGraphPtr & func_graph,const PrimitiveVarMapPtr & primitive_vars,const AnfNodePtr & pattern)311 EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars,
312 const AnfNodePtr &pattern) {
313 MS_ASSERT(func_graph != nullptr);
314 MS_ASSERT(pattern != nullptr);
315 auto return_node = func_graph->get_return();
316 auto visitor = std::make_shared<Visitor>();
317 MS_CHECK_TRUE_RET(visitor != nullptr, nullptr);
318 PatternEngine pattern_engine(visitor);
319 auto empty_equiv = std::make_shared<Equiv>();
320 MS_CHECK_TRUE_RET(empty_equiv != nullptr, nullptr);
321 EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv);
322 return equiv;
323 }
324
325 // make sure that only 3,4,5 output of while is referenced
CheckReferencedOutputs(const FuncGraphPtr & func_graph,const CNodePtr & while_cnode)326 bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) {
327 MS_ASSERT(func_graph != nullptr);
328 MS_ASSERT(while_cnode != nullptr);
329 auto manager = func_graph->manager();
330 if (manager == nullptr) {
331 MS_LOG(ERROR) << "manager is nullptr";
332 return false;
333 }
334 auto while_node_users = manager->node_users()[while_cnode];
335 std::vector<size_t> valid_indexes{3, 4, 5};
336 for (auto &node_user : while_node_users) {
337 if (!utils::isa<CNodePtr>(node_user.first)) {
338 return false;
339 }
340 auto cnode = utils::cast<CNodePtr>(node_user.first);
341 if (IsMarkedTrainOp(cnode)) {
342 return false;
343 }
344 if (!CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
345 return false;
346 }
347 auto index = GetTupleGetItemOutIndex(cnode);
348 if (!lite::IsContain(valid_indexes, index)) {
349 return false;
350 }
351 }
352 return true;
353 }
354
CheckSubGraph(const AnfNodePtr & pattern,const PrimitiveVarMapPtr & primitive_vars,const AnfNodePtr & anf_sub_graph,const size_t cnode_num,const size_t all_node_num)355 EquivPtr TfliteLstmCellFusion::CheckSubGraph(const AnfNodePtr &pattern, const PrimitiveVarMapPtr &primitive_vars,
356 const AnfNodePtr &anf_sub_graph, const size_t cnode_num,
357 const size_t all_node_num) {
358 MS_ASSERT(func_graph != nullptr);
359 MS_ASSERT(pattern != nullptr);
360 MS_ASSERT(primitive_vars != nullptr);
361 MS_ASSERT(anf_sub_graph != nullptr);
362 auto sub_graph = GetValueNode<FuncGraphPtr>(anf_sub_graph);
363 MS_CHECK_TRUE_RET(sub_graph != nullptr, nullptr);
364 auto nodes = TopoSort(sub_graph->get_return());
365 auto cnodes = sub_graph->GetOrderedCnodes();
366 if (cnodes.size() != cnode_num || nodes.size() != all_node_num) {
367 MS_LOG(DEBUG) << "sub graph nodes num not match";
368 return nullptr;
369 }
370 return MatchGraph(sub_graph, primitive_vars, pattern);
371 }
372
CheckBodyGraph(const EquivPtr & equiv,float * zoneout_cell,float * zoneout_hidden) const373 bool TfliteLstmCellFusion::CheckBodyGraph(const EquivPtr &equiv, float *zoneout_cell, float *zoneout_hidden) const {
374 MS_ASSERT(func_graph != nullptr);
375 MS_ASSERT(equiv != nullptr);
376 MS_ASSERT(while_cnode != nullptr);
377 MS_ASSERT(zoneout_cell != nullptr);
378 MS_ASSERT(zoneout_hidden != nullptr);
379
380 auto cell_zoneout_old_node = utils::cast<AnfNodePtr>((*equiv)[cell_zoneout_old_]);
381 MS_ASSERT(cell_zoneout_old_node != nullptr);
382 auto cell_zoneout_new_node = utils::cast<AnfNodePtr>((*equiv)[cell_zoneout_new_]);
383 MS_ASSERT(cell_zoneout_new_node != nullptr);
384 auto hidden_zoneout_old_node = utils::cast<AnfNodePtr>((*equiv)[hidden_zoneout_old_]);
385 MS_ASSERT(hidden_zoneout_old_node != nullptr);
386 auto hidden_zoneout_new_node = utils::cast<AnfNodePtr>((*equiv)[hidden_zoneout_new_]);
387 MS_ASSERT(hidden_zoneout_new_node != nullptr);
388
389 float cell_old, cell_new, hidden_old, hidden_new;
390 if (GetFloatScalarFromTensorInfo(cell_zoneout_old_node, &cell_old) != RET_OK) {
391 return false;
392 }
393 if (GetFloatScalarFromTensorInfo(cell_zoneout_new_node, &cell_new) != RET_OK) {
394 return false;
395 }
396 if (GetFloatScalarFromTensorInfo(hidden_zoneout_old_node, &hidden_old) != RET_OK) {
397 return false;
398 }
399 if (GetFloatScalarFromTensorInfo(hidden_zoneout_new_node, &hidden_new) != RET_OK) {
400 return false;
401 }
402 if (cell_old < 0.0f || cell_old > 1.0f || cell_new < 0.0f || cell_new > 1.0f) {
403 MS_LOG(DEBUG) << "cell zoneout value illegal";
404 return false;
405 }
406 if (hidden_old < 0.0f || hidden_old > 1.0f || hidden_new < 0.0f || hidden_new > 1.0f) {
407 MS_LOG(DEBUG) << "hidden zoneout value illegal";
408 return false;
409 }
410 if (std::abs(cell_old + cell_new - 1.0f) > EPSILON || std::abs(hidden_old + hidden_new - 1.0f) > EPSILON ||
411 std::abs(cell_old - hidden_old) > EPSILON) {
412 MS_LOG(DEBUG) << "zoneout value illegal";
413 return false;
414 }
415 *zoneout_cell = cell_old;
416 *zoneout_hidden = hidden_old;
417 return true;
418 }
419
GetConcatedParam(const std::vector<AnfNodePtr> & params,const ParameterPtr & new_param,bool is_bias)420 STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> ¶ms, const ParameterPtr &new_param,
421 bool is_bias) {
422 MS_ASSERT(new_param != nullptr);
423 MS_ASSERT(params.size() == 4);
424 std::vector<float *> data_ptrs;
425 std::vector<std::vector<int64_t>> data_shapes;
426 for (auto ¶m : params) {
427 if (!utils::isa<ParameterPtr>(param)) {
428 MS_LOG(DEBUG) << "param is not Parameter node";
429 return RET_FAILED;
430 }
431 auto param_t = utils::cast<ParameterPtr>(param);
432 if (!param_t->has_default() || param_t->default_param() == nullptr) {
433 MS_LOG(DEBUG) << "param not have default value";
434 return RET_FAILED;
435 }
436 if (!utils::isa<tensor::TensorPtr>(param_t->default_param())) {
437 MS_LOG(DEBUG) << "default value is not tensor::Tensor";
438 return RET_FAILED;
439 }
440 auto origin_tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_t->default_param());
441 if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) {
442 MS_LOG(DEBUG) << "origin_tensor is not float32 type";
443 return RET_FAILED;
444 }
445 auto data_ptr = reinterpret_cast<float *>(origin_tensor->data_c());
446 auto data_shape = origin_tensor->shape();
447 data_ptrs.push_back(data_ptr);
448 data_shapes.push_back(data_shape);
449 }
450
451 for (size_t i = 1; i < data_shapes.size(); ++i) {
452 if (data_shapes[i] != data_shapes[0]) {
453 MS_LOG(DEBUG) << "data shape not same";
454 return RET_FAILED;
455 }
456 }
457 std::vector<int64_t> new_shape;
458 int step = 0;
459 int data_size = 0;
460 MS_ASSERT(!data_shapes.empty());
461 if (is_bias) {
462 if (data_shapes[0].size() != 1) {
463 MS_LOG(ERROR) << "bias data shape error";
464 return RET_ERROR;
465 }
466 step = static_cast<int>(data_shapes[0][0]);
467 MS_CHECK_INT_MUL_NOT_OVERFLOW(C8NUM, step, RET_ERROR);
468 data_size = C8NUM * step;
469 new_shape = std::vector<int64_t>({1, data_size});
470
471 } else {
472 if (data_shapes[0].size() != 2) {
473 MS_LOG(ERROR) << "weight data shape error";
474 return RET_ERROR;
475 }
476 new_shape = std::vector<int64_t>({1, data_shapes[0][0] * kUnidirectionalGateNum, data_shapes[0][1]});
477 MS_CHECK_INT_MUL_NOT_OVERFLOW(data_shapes[0][0], data_shapes[0][1], RET_ERROR);
478 step = static_cast<int>(data_shapes[0][0] * data_shapes[0][1]);
479 MS_CHECK_INT_MUL_NOT_OVERFLOW(C4NUM, step, RET_ERROR);
480 data_size = C4NUM * step;
481 }
482
483 auto tensor_info = lite::CreateTensorInfo(nullptr, 0, new_shape, kNumberTypeFloat32);
484 if (tensor_info == nullptr) {
485 MS_LOG(ERROR) << "create tensor info failed.";
486 return RET_ERROR;
487 }
488
489 auto tensor_data = static_cast<float *>(tensor_info->data_c());
490 for (int i = 0; i < data_size; ++i) { // bias are stored into first 4*hidden_size buffer, the rest is all 0
491 tensor_data[i] = 0.0f;
492 }
493
494 for (size_t i = 0; i < data_ptrs.size(); ++i) {
495 auto source_len = std::accumulate(data_shapes[i].begin(), data_shapes[i].end(), 1, std::multiplies<int>());
496 auto ret = memcpy_s(tensor_data + i * step, step * sizeof(float), data_ptrs[i], source_len * sizeof(float));
497 if (ret != EOK) {
498 MS_LOG(ERROR) << "memcpy_s error";
499 return RET_ERROR;
500 }
501 }
502
503 auto status = lite::InitParameterFromTensorInfo(new_param, tensor_info);
504 if (status != RET_OK) {
505 MS_LOG(ERROR) << "init parameter from tensor info failed";
506 return RET_ERROR;
507 }
508
509 return RET_OK;
510 }
511
CreateLSTMNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const EquivPtr & body_equiv,const std::string & base_name,const float zoneout_cell,const float zoneout_hidden) const512 CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
513 const EquivPtr &body_equiv, const std::string &base_name,
514 const float zoneout_cell, const float zoneout_hidden) const {
515 MS_ASSERT(func_graph != nullptr);
516 MS_ASSERT(equiv != nullptr);
517 MS_ASSERT(body_equiv != nullptr);
518 /*
519 * input vars for while node
520 * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
521 * 9:i2i_ 10:i2f_ 11:i2c_ 12:i2o_ 13:c2i_ 14:c2f_ 15:c2c_ 16:c2o_ 17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
522 */
523 auto lstm_prim = std::make_shared<ops::LSTM>();
524 MS_CHECK_TRUE_RET(lstm_prim != nullptr, nullptr);
525 lstm_prim->set_bidirectional(false);
526 lstm_prim->set_zoneout_cell(zoneout_cell);
527 lstm_prim->set_zoneout_hidden(zoneout_hidden);
528 auto value_node = NewValueNode(lstm_prim);
529 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
530
531 auto &vars = while_input_vars_;
532 auto i2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
533 MS_ASSERT(i2i_weight);
534 auto i2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
535 MS_ASSERT(i2f_weight);
536 auto i2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[11]]);
537 MS_ASSERT(i2c_weight);
538 auto i2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[12]]);
539 MS_ASSERT(i2o_weight);
540
541 auto c2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[13]]);
542 MS_ASSERT(c2i_weight);
543 auto c2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[14]]);
544 MS_ASSERT(c2f_weight);
545 auto c2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[15]]);
546 MS_ASSERT(c2c_weight);
547 auto c2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[16]]);
548 MS_ASSERT(c2o_weight);
549
550 auto i_bias = utils::cast<AnfNodePtr>((*equiv)[vars[17]]);
551 MS_ASSERT(i_bias);
552 auto f_bias = utils::cast<AnfNodePtr>((*equiv)[vars[18]]);
553 MS_ASSERT(f_bias);
554 auto c_bias = utils::cast<AnfNodePtr>((*equiv)[vars[19]]);
555 MS_ASSERT(c_bias);
556 auto o_bias = utils::cast<AnfNodePtr>((*equiv)[vars[20]]);
557 MS_ASSERT(o_bias);
558
559 auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]);
560 MS_ASSERT(input);
561 auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]);
562 MS_ASSERT(cell);
563 auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]);
564 MS_ASSERT(hidden);
565
566 std::vector<AnfNodePtr> i_weights{i2i_weight, i2o_weight, i2f_weight, i2c_weight};
567 auto i_weight = func_graph->add_parameter();
568 MS_CHECK_TRUE_RET(i_weight != nullptr, nullptr);
569 auto status = GetConcatedParam(i_weights, i_weight, false);
570 if (status != RET_OK) {
571 return nullptr;
572 }
573 i_weight->set_name(base_name + "_weight_i");
574
575 std::vector<AnfNodePtr> c_weights{c2i_weight, c2o_weight, c2f_weight, c2c_weight};
576 auto c_weight = func_graph->add_parameter();
577 MS_CHECK_TRUE_RET(c_weight != nullptr, nullptr);
578 status = GetConcatedParam(c_weights, c_weight, false);
579 if (status != RET_OK) {
580 return nullptr;
581 }
582 c_weight->set_name(base_name + "_weight_c");
583
584 std::vector<AnfNodePtr> biases{i_bias, o_bias, f_bias, c_bias};
585 auto bias = func_graph->add_parameter();
586 MS_CHECK_TRUE_RET(bias != nullptr, nullptr);
587 status = GetConcatedParam(biases, bias, true);
588 if (status != RET_OK) {
589 return nullptr;
590 }
591 bias->set_name(base_name + "_bias");
592
593 if (!utils::isa<CNodePtr>(input) || !CheckPrimitiveType(input, prim::kPrimTensorListFromTensor)) {
594 MS_LOG(DEBUG) << "input is not tensorlistfromtensor op";
595 return nullptr;
596 }
597 auto tensor_list_cnode = utils::cast<CNodePtr>(input);
598 auto input_tensor_node = tensor_list_cnode->input(1);
599
600 std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias, hidden, cell};
601 auto new_node = func_graph->NewCNode(new_node_inputs);
602 MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
603 new_node->set_fullname_with_scope(base_name);
604 return new_node;
605 }
606
CreateOutputGetItem(const FuncGraphPtr & func_graph,const CNodePtr & node,const int item_index)607 CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node,
608 const int item_index) {
609 MS_ASSERT(func_graph != nullptr);
610 MS_ASSERT(node != nullptr);
611 auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>();
612 auto get_item_value = NewValueNode(MakeValue<int>(item_index));
613 if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
614 MS_LOG(ERROR) << "NewValueNode is nullptr";
615 return nullptr;
616 }
617 CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value});
618 MS_CHECK_TRUE_RET(get_item_cnode != nullptr, nullptr);
619 auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
620 if (abstract == nullptr) {
621 MS_LOG(ERROR) << "Create tensor abstarct failed";
622 return nullptr;
623 }
624 get_item_cnode->set_abstract(abstract);
625 get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
626 std::to_string(item_index));
627 return get_item_cnode;
628 }
629
AdjustOtherGetItems(const FuncGraphPtr & func_graph,const CNodePtr & while_cnode,const CNodePtr & lstm_cnode,const CNodePtr & output_get_item)630 STATUS TfliteLstmCellFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode,
631 const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) {
632 MS_ASSERT(func_graph != nullptr && while_cnode != nullptr);
633 MS_ASSERT(lstm_cnode != nullptr && output_get_item != nullptr);
634 auto manager = func_graph->manager();
635 if (manager == nullptr) {
636 MS_LOG(ERROR) << "manager is nullptr";
637 return RET_ERROR;
638 }
639 auto while_node_users = manager->node_users()[while_cnode];
640 for (auto &node_user : while_node_users) {
641 if (node_user.first == output_get_item) {
642 continue;
643 }
644 if (!utils::isa<CNodePtr>(node_user.first)) {
645 return RET_ERROR;
646 }
647 auto get_item = utils::cast<CNodePtr>(node_user.first);
648 if (!CheckPrimitiveType(get_item, prim::kPrimTupleGetItem)) {
649 return RET_ERROR;
650 }
651 auto new_inputs = get_item->inputs();
652 if (new_inputs.size() != 3) {
653 return RET_ERROR;
654 }
655 new_inputs[1] = lstm_cnode;
656 auto index_vnode = get_item->input(2);
657 if (!utils::isa<ValueNode>(index_vnode)) {
658 MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node";
659 return RET_ERROR;
660 }
661 auto value_node = utils::cast<ValueNodePtr>(index_vnode);
662 if (value_node == nullptr) {
663 MS_LOG(ERROR) << "cast to ValueNode failed";
664 return RET_ERROR;
665 }
666 auto origin_index = GetValue<int>(value_node->value());
667 int new_index = origin_index == 4 ? 2 : 1;
668 auto new_index_vnode = NewValueNode(MakeValue<int>(new_index));
669 MS_CHECK_TRUE_RET(new_index_vnode != nullptr, RET_ERROR);
670 new_inputs[2] = new_index_vnode;
671 get_item->set_inputs(new_inputs);
672 get_item->set_fullname_with_scope(lstm_cnode->fullname_with_scope() + "_getitem_" + std::to_string(new_index));
673 if (get_item->abstract() == nullptr) {
674 MS_LOG(ERROR) << "get_item's abstract is nullptr";
675 return RET_ERROR;
676 }
677
678 std::vector<int> squeeze_axis{0};
679 auto squeeze_node = CreateSqueezeNode(func_graph, get_item, squeeze_axis);
680 if (squeeze_node == nullptr) {
681 return RET_ERROR;
682 }
683
684 auto get_item_users = manager->node_users()[get_item];
685 for (auto &get_item_user : get_item_users) {
686 manager->SetEdge(get_item_user.first, get_item_user.second, squeeze_node);
687 }
688 }
689 return RET_OK;
690 }
691
SetAbstractTuple(const CNodePtr & cnode,const int output_num)692 STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) {
693 MS_ASSERT(cnode != nullptr);
694 AbstractBasePtrList abstract_list;
695 for (int i = 0; i < output_num; ++i) {
696 auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
697 if (abstract == nullptr) {
698 MS_LOG(ERROR) << "Create tensor abstarct failed";
699 return RET_ERROR;
700 }
701 abstract_list.emplace_back(abstract);
702 }
703 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
704 if (abstract_tuple == nullptr) {
705 MS_LOG(ERROR) << "create abstract_tuple failed";
706 return RET_ERROR;
707 }
708 cnode->set_abstract(abstract_tuple);
709 return RET_OK;
710 }
711
CreateSqueezeNode(const FuncGraphPtr & func_graph,const CNodePtr & input_node,const std::vector<int> & axis)712 CNodePtr TfliteLstmCellFusion::CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node,
713 const std::vector<int> &axis) {
714 MS_ASSERT(func_graph != nullptr && input_node != nullptr);
715 auto squeeze_prim = std::make_shared<ops::Squeeze>();
716 MS_CHECK_TRUE_RET(squeeze_prim != nullptr, nullptr);
717 std::vector<int64_t> axis_vec;
718 std::transform(axis.begin(), axis.end(), std::back_inserter(axis_vec),
719 [](int val) { return static_cast<int64_t>(val); });
720 squeeze_prim->set_axis(axis_vec);
721 auto squeeze_cnode = func_graph->NewCNode(squeeze_prim, {input_node});
722 MS_CHECK_TRUE_RET(squeeze_cnode != nullptr, nullptr);
723 if (input_node->abstract() != nullptr) {
724 squeeze_cnode->set_abstract(input_node->abstract()->Clone());
725 }
726 squeeze_cnode->set_fullname_with_scope("squeeze_" + input_node->fullname_with_scope());
727 return squeeze_cnode;
728 }
729
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const730 const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
731 const EquivPtr &equiv) const {
732 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
733 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
734 return nullptr;
735 }
736
737 if (!utils::isa<CNodePtr>(node)) {
738 return nullptr;
739 }
740 auto tensor_list_stack_cnode = utils::cast<CNodePtr>(node);
741 auto tuple_get_item_node = tensor_list_stack_cnode->input(1);
742 if (!utils::isa<CNodePtr>(tuple_get_item_node)) {
743 return nullptr;
744 }
745 auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item_node);
746 auto while_node = tuple_get_item_cnode->input(1);
747 if (!utils::isa<CNodePtr>(while_node)) {
748 return nullptr;
749 }
750 auto while_cnode = utils::cast<CNodePtr>(while_node);
751
752 if (while_cnode == nullptr || while_cnode->size() != while_inputs_num_) {
753 return nullptr;
754 }
755 if (!CheckReferencedOutputs(func_graph, while_cnode)) {
756 return nullptr;
757 }
758 auto primitive_vars_cond = std::make_shared<PrimitiveVarMap>();
759 MS_CHECK_TRUE_RET(primitive_vars_cond != nullptr, nullptr);
760 auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond);
761 MS_CHECK_TRUE_RET(cond_graph_pattern != nullptr, nullptr);
762 auto cond_equiv =
763 CheckSubGraph(cond_graph_pattern, primitive_vars_cond, while_cnode->input(1), cond_cnodes_num_, cond_nodes_num_);
764 if (cond_equiv == nullptr || cond_equiv->empty()) {
765 return nullptr;
766 }
767 auto primitive_vars_body = std::make_shared<PrimitiveVarMap>();
768 MS_CHECK_TRUE_RET(primitive_vars_body != nullptr, nullptr);
769 auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body);
770 MS_CHECK_TRUE_RET(body_graph_pattern != nullptr, nullptr);
771 auto body_equiv =
772 CheckSubGraph(body_graph_pattern, primitive_vars_body, while_cnode->input(2), body_cnodes_num_, body_nodes_num_);
773 if (body_equiv == nullptr || body_equiv->empty()) {
774 return nullptr;
775 }
776 float zoneout_cell = 0.0f;
777 float zoneout_hidden = 0.0f;
778 if (!CheckBodyGraph(body_equiv, &zoneout_cell, &zoneout_hidden)) {
779 return nullptr;
780 }
781 const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope();
782 auto lstm_node = CreateLSTMNode(func_graph, equiv, body_equiv, lstm_name, zoneout_cell, zoneout_hidden);
783 if (lstm_node == nullptr) {
784 return nullptr;
785 }
786 auto status = SetAbstractTuple(lstm_node, kLSTMOutputNum);
787 if (status != RET_OK) {
788 return nullptr;
789 }
790
791 auto get_item_node = CreateOutputGetItem(func_graph, lstm_node, 0);
792 if (get_item_node == nullptr) {
793 MS_LOG(DEBUG) << "create lstm output get_item node failed";
794 return nullptr;
795 }
796
797 status = AdjustOtherGetItems(func_graph, while_cnode, lstm_node, tuple_get_item_cnode);
798 if (status != RET_OK) {
799 return nullptr;
800 }
801
802 std::vector<int> squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed
803 auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis);
804 if (squeeze_node == nullptr) {
805 return nullptr;
806 }
807
808 auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1);
809 MS_CHECK_TRUE_RET(cond_cnode_index_pair != nullptr, nullptr);
810 func_graph->DropFuncGraphCNodeIndex(cond_cnode_index_pair);
811 auto body_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 2);
812 MS_CHECK_TRUE_RET(body_cnode_index_pair != nullptr, nullptr);
813 func_graph->DropFuncGraphCNodeIndex(body_cnode_index_pair);
814 MS_LOG(INFO) << "lstm node:" << lstm_node->fullname_with_scope() << " fusion success";
815 return squeeze_node;
816 }
817 } // namespace opt
818 } // namespace mindspore
819