• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
17 #include <memory>
18 #include <algorithm>
19 #include <functional>
20 #include "ops/lstm.h"
21 #include "ops/squeeze.h"
22 #include "tools/converter/ops/ops_def.h"
23 #include "src/common/utils.h"
24 #include "tools/common/tensor_util.h"
25 #include "utils/utils.h"
26 #include "tools/optimizer/common/gllo_utils.h"
27 #include "securec/include/securec.h"
28 #include "nnacl/op_base.h"
29 
30 namespace mindspore {
31 namespace opt {
32 namespace {
33 constexpr size_t kWhileInputsLength = 23;
34 constexpr size_t kWhileInputsVarNum = 21;
35 constexpr size_t kCondNodesNum = 12;
36 constexpr size_t kCondCNodesNum = 4;
37 constexpr size_t kBodyNodesNum = 95;
38 constexpr size_t kBodyCNodesNum = 34;
39 constexpr size_t kLSTMOutputNum = 3;
40 constexpr auto kUnidirectionalGateNum = 4;
41 const auto &p1 = std::placeholders::_1;
42 constexpr float EPSILON = 1e-5;
IsParameterNode(const BaseRef & n)43 bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
44 
GenerateBodyGraphCellPattern(const std::vector<CondVarPtr> & placeholders)45 std::vector<VectorRef> GenerateBodyGraphCellPattern(const std::vector<CondVarPtr> &placeholders) {
46   MS_CHECK_TRUE_RET(placeholders.size() > 19, {});
47   auto is_var1 = std::make_shared<Var>();
48   MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
49   VectorRef concat_i_w = VectorRef({is_var1, placeholders[8], placeholders[12]});
50   auto is_var2 = std::make_shared<Var>();
51   MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
52   VectorRef concat_f_w = VectorRef({is_var2, placeholders[9], placeholders[13]});
53   auto is_var3 = std::make_shared<Var>();
54   MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
55   VectorRef concat_c_w = VectorRef({is_var3, placeholders[10], placeholders[14]});
56   auto is_var4 = std::make_shared<Var>();
57   MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
58   VectorRef concat_o_w = VectorRef({is_var4, placeholders[11], placeholders[15]});
59 
60   auto is_var_getitem = std::make_shared<Var>("GetItem");
61   MS_CHECK_TRUE_RET(is_var_getitem != nullptr, {});
62   auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
63   MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
64   VectorRef get_item = VectorRef({is_var_getitem, placeholders[7], placeholders[2], is_param3});
65   auto is_var5 = std::make_shared<Var>();
66   MS_CHECK_TRUE_RET(is_var5 != nullptr, {});
67   VectorRef concat_input_h = VectorRef({is_var5, get_item, placeholders[5]});
68 
69   auto is_var6 = std::make_shared<Var>();
70   MS_CHECK_TRUE_RET(is_var6 != nullptr, {});
71   VectorRef matmul_input = VectorRef({is_var6, concat_input_h, concat_i_w});
72   auto is_var7 = std::make_shared<Var>();
73   MS_CHECK_TRUE_RET(is_var7 != nullptr, {});
74   VectorRef matmul_forget = VectorRef({is_var7, concat_input_h, concat_f_w});
75   auto is_var8 = std::make_shared<Var>();
76   MS_CHECK_TRUE_RET(is_var8 != nullptr, {});
77   VectorRef matmul_cell = VectorRef({is_var8, concat_input_h, concat_c_w});
78   auto is_var9 = std::make_shared<Var>();
79   MS_CHECK_TRUE_RET(is_var9 != nullptr, {});
80   VectorRef matmul_output = VectorRef({is_var9, concat_input_h, concat_o_w});
81 
82   auto is_var10 = std::make_shared<Var>();
83   MS_CHECK_TRUE_RET(is_var10 != nullptr, {});
84   VectorRef bias_input = VectorRef({is_var10, matmul_input, placeholders[16]});
85   auto is_var11 = std::make_shared<Var>();
86   MS_CHECK_TRUE_RET(is_var11 != nullptr, {});
87   VectorRef bias_forget = VectorRef({is_var11, matmul_forget, placeholders[17]});
88   auto is_var12 = std::make_shared<Var>();
89   MS_CHECK_TRUE_RET(is_var12 != nullptr, {});
90   VectorRef bias_cell = VectorRef({is_var12, matmul_cell, placeholders[18]});
91   auto is_var13 = std::make_shared<Var>();
92   MS_CHECK_TRUE_RET(is_var13 != nullptr, {});
93   VectorRef bias_output = VectorRef({is_var13, matmul_output, placeholders[19]});
94 
95   auto is_var_tanh = std::make_shared<Var>("Tanh");
96   MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
97   VectorRef cell = VectorRef({is_var_tanh, bias_cell});
98   auto is_var_sigmoid1 = std::make_shared<Var>("Sigmoid");
99   MS_CHECK_TRUE_RET(is_var_sigmoid1 != nullptr, {});
100   VectorRef input_gate = VectorRef({is_var_sigmoid1, bias_input});
101   auto is_var_mul1 = std::make_shared<Var>("Mul");
102   MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
103   VectorRef cell_input = VectorRef({is_var_mul1, input_gate, cell});
104   auto is_var_sigmoid2 = std::make_shared<Var>("Sigmoid");
105   MS_CHECK_TRUE_RET(is_var_sigmoid2 != nullptr, {});
106   VectorRef forget_gate = VectorRef({is_var_sigmoid2, bias_forget});
107   auto is_var_mul2 = std::make_shared<Var>("Mul");
108   MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
109   VectorRef cell_forgeted = VectorRef({is_var_mul2, forget_gate, placeholders[4]});
110   auto is_var_add = std::make_shared<Var>("Add");
111   MS_CHECK_TRUE_RET(is_var_add != nullptr, {});
112   VectorRef cell_new = VectorRef({is_var_add, cell_forgeted, cell_input});
113   return {bias_output, cell_new};
114 }
115 }  // namespace
116 
GetFloatScalarFromTensorInfo(const AnfNodePtr & tensor_info,float * v)117 STATUS TfliteLstmCellFusion::GetFloatScalarFromTensorInfo(const AnfNodePtr &tensor_info, float *v) {
118   if (tensor_info == nullptr || v == nullptr) {
119     MS_LOG(ERROR) << "tensor_info or v is nullptr";
120     return RET_ERROR;
121   }
122   if (!utils::isa<ParameterPtr>(tensor_info)) {
123     MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
124     return RET_ERROR;
125   }
126   auto param_ptr = utils::cast<ParameterPtr>(tensor_info);
127   if (!param_ptr->has_default() || param_ptr->default_param() == nullptr) {
128     MS_LOG(DEBUG) << "param not have default";
129     return RET_ERROR;
130   }
131   auto default_param = param_ptr->default_param();
132   if (!utils::isa<tensor::TensorPtr>(default_param)) {
133     MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
134     return RET_ERROR;
135   }
136   auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
137   auto tensor_shape = default_param_ptr->shape();
138   if (!(tensor_shape.empty() || (tensor_shape.size() == 1 && tensor_shape[0] == 1))) {
139     MS_LOG(DEBUG) << "default param is not scalar";
140     return RET_ERROR;
141   }
142   if (default_param_ptr->data_type() != kNumberTypeFloat32 && default_param_ptr->data_type() != kNumberTypeFloat) {
143     MS_LOG(DEBUG) << "default param is not float";
144     return RET_ERROR;
145   }
146   *v = *(reinterpret_cast<float *>(default_param_ptr->data_c()));
147   return RET_OK;
148 }
149 
Init() const150 bool TfliteLstmCellFusion::Init() const {
151   for (size_t i = 0; i < this->while_input_var_num_; ++i) {
152     auto is_var = std::make_shared<Var>();
153     MS_CHECK_TRUE_RET(is_var != nullptr, false);
154     while_input_vars_.emplace_back(is_var);
155   }
156   cell_zoneout_old_ = std::make_shared<Var>();
157   MS_CHECK_TRUE_RET(cell_zoneout_old_ != nullptr, false);
158   cell_zoneout_new_ = std::make_shared<Var>();
159   MS_CHECK_TRUE_RET(cell_zoneout_new_ != nullptr, false);
160   hidden_zoneout_old_ = std::make_shared<Var>();
161   MS_CHECK_TRUE_RET(hidden_zoneout_old_ != nullptr, false);
162   hidden_zoneout_new_ = std::make_shared<Var>();
163   MS_CHECK_TRUE_RET(hidden_zoneout_new_ != nullptr, false);
164   return true;
165 }
166 
TfliteLstmCellFusion(const std::string & name,bool multigraph,int input_length,int var_num,int cond_nodes_num,int cond_cnodes_num,int body_nodes_num,int body_cnodes_num)167 TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num,
168                                            int cond_nodes_num, int cond_cnodes_num, int body_nodes_num,
169                                            int body_cnodes_num)
170     : PatternProcessPass(name, multigraph) {
171   /*
172    * input vars for lstm while node
173    * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
174    * 9:i2i_  10:i2f_ 11:i2c_ 12:i2o_   13:c2i_   14:c2f_ 15:c2c_   16:c2o_   17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
175    */
176   this->while_inputs_num_ = input_length == 0 ? kWhileInputsLength : input_length;
177   this->while_input_var_num_ = var_num == 0 ? kWhileInputsVarNum : var_num;
178   this->cond_nodes_num_ = cond_nodes_num == 0 ? kCondNodesNum : cond_nodes_num;
179   this->cond_cnodes_num_ = cond_cnodes_num == 0 ? kCondCNodesNum : cond_cnodes_num;
180   this->body_nodes_num_ = body_nodes_num == 0 ? kBodyNodesNum : body_nodes_num;
181   this->body_cnodes_num_ = body_cnodes_num == 0 ? kBodyCNodesNum : body_cnodes_num;
182 }
183 
GetCondGraphPattern(const PrimitiveVarMapPtr & primitive_vars)184 AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) {
185   MS_ASSERT(primitive_vars != nullptr);
186   auto is_parameter1 = std::make_shared<CondVar>(IsParameterNode);
187   MS_CHECK_TRUE_RET(is_parameter1 != nullptr, nullptr);
188   auto is_parameter2 = std::make_shared<CondVar>(IsParameterNode);
189   MS_CHECK_TRUE_RET(is_parameter2 != nullptr, nullptr);
190   auto is_parameter3 = std::make_shared<CondVar>(IsParameterNode);
191   MS_CHECK_TRUE_RET(is_parameter3 != nullptr, nullptr);
192   auto is_parameter4 = std::make_shared<CondVar>(IsParameterNode);
193   MS_CHECK_TRUE_RET(is_parameter4 != nullptr, nullptr);
194   auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
195   MS_CHECK_TRUE_RET(is_less1 != nullptr, nullptr);
196   auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
197   MS_CHECK_TRUE_RET(is_less2 != nullptr, nullptr);
198   auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd));
199   MS_CHECK_TRUE_RET(is_logical_and != nullptr, nullptr);
200   auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
201   MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
202   VectorRef less1_ref = VectorRef({is_less1, is_parameter1, is_parameter2});
203   VectorRef less2_ref = VectorRef({is_less2, is_parameter3, is_parameter4});
204   VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
205   VectorRef return_ref = VectorRef({is_return, logicaland_ref});
206   VarPtr fg = std::make_shared<Var>("RootG");
207   MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
208   auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true);
209   return pattern;
210 }
211 
GetBodyGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const212 AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
213   std::vector<CondVarPtr> placeholders;
214   for (int i = 0; i < 20; ++i) {
215     placeholders.emplace_back(std::make_shared<CondVar>(IsParameterNode));
216   }
217   auto is_var1 = std::make_shared<Var>();
218   MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
219   auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
220   MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
221   VectorRef add2 = VectorRef({is_var1, placeholders[2], is_param1});
222   auto is_var2 = std::make_shared<Var>();
223   MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
224   auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
225   MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
226   VectorRef add3 = VectorRef({is_var2, placeholders[0], is_param2});
227 
228   auto hidden_cells = GenerateBodyGraphCellPattern(placeholders);
229   MS_CHECK_TRUE_RET(hidden_cells.size() == kInputSizeTwo, {});
230   auto is_var_mul1 = std::make_shared<Var>("Mul");
231   MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
232   VectorRef zoneout_cell_old = VectorRef({is_var_mul1, cell_zoneout_old_, placeholders[4]});
233   auto is_var_mul2 = std::make_shared<Var>("Mul");
234   MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
235   auto cell_new = hidden_cells[1];
236   MS_CHECK_TRUE_RET(!cell_new.empty(), {});
237   VectorRef zoneout_cell_new = VectorRef({is_var_mul2, cell_zoneout_new_, cell_new});
238   auto is_var_add1 = std::make_shared<Var>("Add");
239   MS_CHECK_TRUE_RET(is_var_add1 != nullptr, {});
240   VectorRef cell_output = VectorRef({is_var_add1, zoneout_cell_new, zoneout_cell_old});
241 
242   auto is_var_sigmoid = std::make_shared<Var>("Sigmoid");
243   MS_CHECK_TRUE_RET(is_var_sigmoid != nullptr, {});
244   auto bias_output = hidden_cells[0];
245   MS_CHECK_TRUE_RET(!bias_output.empty(), {});
246   VectorRef output_gate = VectorRef({is_var_sigmoid, bias_output});
247   auto is_var_tanh = std::make_shared<Var>("Tanh");
248   MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
249   VectorRef cell_to_output = VectorRef({is_var_tanh, cell_new});
250   auto is_var_mul3 = std::make_shared<Var>("Mul");
251   MS_CHECK_TRUE_RET(is_var_mul3 != nullptr, {});
252   VectorRef output = VectorRef({is_var_mul3, output_gate, cell_to_output});
253 
254   auto is_var_mul4 = std::make_shared<Var>("Mul");
255   MS_CHECK_TRUE_RET(is_var_mul4 != nullptr, {});
256   VectorRef zoneout_hidden_old = VectorRef({is_var_mul4, hidden_zoneout_old_, placeholders[5]});
257   auto is_var_mul5 = std::make_shared<Var>("Mul");
258   MS_CHECK_TRUE_RET(is_var_mul5 != nullptr, {});
259   VectorRef zoneout_hidden_new = VectorRef({is_var_mul5, hidden_zoneout_new_, output});
260   auto is_var_add2 = std::make_shared<Var>("Add");
261   MS_CHECK_TRUE_RET(is_var_add2 != nullptr, {});
262   VectorRef hidden_output = VectorRef({is_var_add2, zoneout_hidden_new, zoneout_hidden_old});
263 
264   auto is_var_setitem = std::make_shared<Var>("SetItem");
265   MS_CHECK_TRUE_RET(is_var_setitem != nullptr, {});
266   VectorRef set_item = VectorRef({is_var_setitem, placeholders[3], placeholders[2], output});
267 
268   auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
269   MS_CHECK_TRUE_RET(is_make_tuple != nullptr, nullptr);
270   std::vector<BaseRef> outputs = {is_make_tuple, add3, placeholders[1], add2, set_item, cell_output, hidden_output};
271   outputs.insert(outputs.end(), placeholders.begin() + 6, placeholders.end());
272   VectorRef make_tuple_node = VectorRef(outputs);
273   auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
274   MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
275   VectorRef return_node = VectorRef({is_return, make_tuple_node});
276 
277   VarPtr fg = std::make_shared<Var>("RootG");
278   MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
279   auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true);
280   return pattern;
281 }
282 
DefinePattern() const283 const BaseRef TfliteLstmCellFusion::DefinePattern() const {
284   if (!Init()) {
285     MS_LOG(ERROR) << "initial member failed.";
286     return {};
287   }
288   auto is_while_node = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimWhile));
289   MS_CHECK_TRUE_RET(is_while_node != nullptr, {});
290   VectorRef while_node = VectorRef({is_while_node});
291   auto while_inputs = while_input_vars_;
292   MS_CHECK_TRUE_RET(while_inputs.size() > kInputSizeThree, {});
293   while_inputs.insert(while_inputs.begin() + 4, while_input_vars_[2]);
294   while_node.insert(while_node.end(), while_inputs.begin(), while_inputs.end());
295 
296   auto is_tuple_get_item = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem));
297   MS_CHECK_TRUE_RET(is_tuple_get_item != nullptr, {});
298   auto is_var = std::make_shared<Var>();
299   MS_CHECK_TRUE_RET(is_var != nullptr, {});
300   VectorRef while_output = VectorRef({is_tuple_get_item, while_node, is_var});
301 
302   auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListStack));
303   MS_CHECK_TRUE_RET(is_tensor_list_stack != nullptr, {});
304   auto is_parameter = std::make_shared<CondVar>(IsParameterNode);
305   MS_CHECK_TRUE_RET(is_parameter != nullptr, {});
306   VectorRef tensor_list_stack_node = VectorRef({is_tensor_list_stack, while_output, is_parameter});
307 
308   return tensor_list_stack_node;
309 }
310 
MatchGraph(const FuncGraphPtr & func_graph,const PrimitiveVarMapPtr & primitive_vars,const AnfNodePtr & pattern)311 EquivPtr TfliteLstmCellFusion::MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars,
312                                           const AnfNodePtr &pattern) {
313   MS_ASSERT(func_graph != nullptr);
314   MS_ASSERT(pattern != nullptr);
315   auto return_node = func_graph->get_return();
316   auto visitor = std::make_shared<Visitor>();
317   MS_CHECK_TRUE_RET(visitor != nullptr, nullptr);
318   PatternEngine pattern_engine(visitor);
319   auto empty_equiv = std::make_shared<Equiv>();
320   MS_CHECK_TRUE_RET(empty_equiv != nullptr, nullptr);
321   EquivPtr equiv = pattern_engine.Match(pattern, return_node, *primitive_vars, empty_equiv);
322   return equiv;
323 }
324 
325 // make sure that only 3,4,5 output of while is referenced
CheckReferencedOutputs(const FuncGraphPtr & func_graph,const CNodePtr & while_cnode)326 bool TfliteLstmCellFusion::CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode) {
327   MS_ASSERT(func_graph != nullptr);
328   MS_ASSERT(while_cnode != nullptr);
329   auto manager = func_graph->manager();
330   if (manager == nullptr) {
331     MS_LOG(ERROR) << "manager is nullptr";
332     return false;
333   }
334   auto while_node_users = manager->node_users()[while_cnode];
335   std::vector<size_t> valid_indexes{3, 4, 5};
336   for (auto &node_user : while_node_users) {
337     if (!utils::isa<CNodePtr>(node_user.first)) {
338       return false;
339     }
340     auto cnode = utils::cast<CNodePtr>(node_user.first);
341     if (IsMarkedTrainOp(cnode)) {
342       return false;
343     }
344     if (!CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
345       return false;
346     }
347     auto index = GetTupleGetItemOutIndex(cnode);
348     if (!lite::IsContain(valid_indexes, index)) {
349       return false;
350     }
351   }
352   return true;
353 }
354 
CheckSubGraph(const AnfNodePtr & pattern,const PrimitiveVarMapPtr & primitive_vars,const AnfNodePtr & anf_sub_graph,const size_t cnode_num,const size_t all_node_num)355 EquivPtr TfliteLstmCellFusion::CheckSubGraph(const AnfNodePtr &pattern, const PrimitiveVarMapPtr &primitive_vars,
356                                              const AnfNodePtr &anf_sub_graph, const size_t cnode_num,
357                                              const size_t all_node_num) {
358   MS_ASSERT(func_graph != nullptr);
359   MS_ASSERT(pattern != nullptr);
360   MS_ASSERT(primitive_vars != nullptr);
361   MS_ASSERT(anf_sub_graph != nullptr);
362   auto sub_graph = GetValueNode<FuncGraphPtr>(anf_sub_graph);
363   MS_CHECK_TRUE_RET(sub_graph != nullptr, nullptr);
364   auto nodes = TopoSort(sub_graph->get_return());
365   auto cnodes = sub_graph->GetOrderedCnodes();
366   if (cnodes.size() != cnode_num || nodes.size() != all_node_num) {
367     MS_LOG(DEBUG) << "sub graph nodes num not match";
368     return nullptr;
369   }
370   return MatchGraph(sub_graph, primitive_vars, pattern);
371 }
372 
CheckBodyGraph(const EquivPtr & equiv,float * zoneout_cell,float * zoneout_hidden) const373 bool TfliteLstmCellFusion::CheckBodyGraph(const EquivPtr &equiv, float *zoneout_cell, float *zoneout_hidden) const {
374   MS_ASSERT(func_graph != nullptr);
375   MS_ASSERT(equiv != nullptr);
376   MS_ASSERT(while_cnode != nullptr);
377   MS_ASSERT(zoneout_cell != nullptr);
378   MS_ASSERT(zoneout_hidden != nullptr);
379 
380   auto cell_zoneout_old_node = utils::cast<AnfNodePtr>((*equiv)[cell_zoneout_old_]);
381   MS_ASSERT(cell_zoneout_old_node != nullptr);
382   auto cell_zoneout_new_node = utils::cast<AnfNodePtr>((*equiv)[cell_zoneout_new_]);
383   MS_ASSERT(cell_zoneout_new_node != nullptr);
384   auto hidden_zoneout_old_node = utils::cast<AnfNodePtr>((*equiv)[hidden_zoneout_old_]);
385   MS_ASSERT(hidden_zoneout_old_node != nullptr);
386   auto hidden_zoneout_new_node = utils::cast<AnfNodePtr>((*equiv)[hidden_zoneout_new_]);
387   MS_ASSERT(hidden_zoneout_new_node != nullptr);
388 
389   float cell_old, cell_new, hidden_old, hidden_new;
390   if (GetFloatScalarFromTensorInfo(cell_zoneout_old_node, &cell_old) != RET_OK) {
391     return false;
392   }
393   if (GetFloatScalarFromTensorInfo(cell_zoneout_new_node, &cell_new) != RET_OK) {
394     return false;
395   }
396   if (GetFloatScalarFromTensorInfo(hidden_zoneout_old_node, &hidden_old) != RET_OK) {
397     return false;
398   }
399   if (GetFloatScalarFromTensorInfo(hidden_zoneout_new_node, &hidden_new) != RET_OK) {
400     return false;
401   }
402   if (cell_old < 0.0f || cell_old > 1.0f || cell_new < 0.0f || cell_new > 1.0f) {
403     MS_LOG(DEBUG) << "cell zoneout value illegal";
404     return false;
405   }
406   if (hidden_old < 0.0f || hidden_old > 1.0f || hidden_new < 0.0f || hidden_new > 1.0f) {
407     MS_LOG(DEBUG) << "hidden zoneout value illegal";
408     return false;
409   }
410   if (std::abs(cell_old + cell_new - 1.0f) > EPSILON || std::abs(hidden_old + hidden_new - 1.0f) > EPSILON ||
411       std::abs(cell_old - hidden_old) > EPSILON) {
412     MS_LOG(DEBUG) << "zoneout value illegal";
413     return false;
414   }
415   *zoneout_cell = cell_old;
416   *zoneout_hidden = hidden_old;
417   return true;
418 }
419 
GetConcatedParam(const std::vector<AnfNodePtr> & params,const ParameterPtr & new_param,bool is_bias)420 STATUS TfliteLstmCellFusion::GetConcatedParam(const std::vector<AnfNodePtr> &params, const ParameterPtr &new_param,
421                                               bool is_bias) {
422   MS_ASSERT(new_param != nullptr);
423   MS_ASSERT(params.size() == 4);
424   std::vector<float *> data_ptrs;
425   std::vector<std::vector<int64_t>> data_shapes;
426   for (auto &param : params) {
427     if (!utils::isa<ParameterPtr>(param)) {
428       MS_LOG(DEBUG) << "param is not Parameter node";
429       return RET_FAILED;
430     }
431     auto param_t = utils::cast<ParameterPtr>(param);
432     if (!param_t->has_default() || param_t->default_param() == nullptr) {
433       MS_LOG(DEBUG) << "param not have default value";
434       return RET_FAILED;
435     }
436     if (!utils::isa<tensor::TensorPtr>(param_t->default_param())) {
437       MS_LOG(DEBUG) << "default value is not tensor::Tensor";
438       return RET_FAILED;
439     }
440     auto origin_tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_t->default_param());
441     if (origin_tensor->data_type() != kNumberTypeFloat32 && origin_tensor->data_type() != kNumberTypeFloat) {
442       MS_LOG(DEBUG) << "origin_tensor is not float32 type";
443       return RET_FAILED;
444     }
445     auto data_ptr = reinterpret_cast<float *>(origin_tensor->data_c());
446     auto data_shape = origin_tensor->shape();
447     data_ptrs.push_back(data_ptr);
448     data_shapes.push_back(data_shape);
449   }
450 
451   for (size_t i = 1; i < data_shapes.size(); ++i) {
452     if (data_shapes[i] != data_shapes[0]) {
453       MS_LOG(DEBUG) << "data shape not same";
454       return RET_FAILED;
455     }
456   }
457   std::vector<int64_t> new_shape;
458   int step = 0;
459   int data_size = 0;
460   MS_ASSERT(!data_shapes.empty());
461   if (is_bias) {
462     if (data_shapes[0].size() != 1) {
463       MS_LOG(ERROR) << "bias data shape error";
464       return RET_ERROR;
465     }
466     step = static_cast<int>(data_shapes[0][0]);
467     MS_CHECK_INT_MUL_NOT_OVERFLOW(C8NUM, step, RET_ERROR);
468     data_size = C8NUM * step;
469     new_shape = std::vector<int64_t>({1, data_size});
470 
471   } else {
472     if (data_shapes[0].size() != 2) {
473       MS_LOG(ERROR) << "weight data shape error";
474       return RET_ERROR;
475     }
476     new_shape = std::vector<int64_t>({1, data_shapes[0][0] * kUnidirectionalGateNum, data_shapes[0][1]});
477     MS_CHECK_INT_MUL_NOT_OVERFLOW(data_shapes[0][0], data_shapes[0][1], RET_ERROR);
478     step = static_cast<int>(data_shapes[0][0] * data_shapes[0][1]);
479     MS_CHECK_INT_MUL_NOT_OVERFLOW(C4NUM, step, RET_ERROR);
480     data_size = C4NUM * step;
481   }
482 
483   auto tensor_info = lite::CreateTensorInfo(nullptr, 0, new_shape, kNumberTypeFloat32);
484   if (tensor_info == nullptr) {
485     MS_LOG(ERROR) << "create tensor info failed.";
486     return RET_ERROR;
487   }
488 
489   auto tensor_data = static_cast<float *>(tensor_info->data_c());
490   for (int i = 0; i < data_size; ++i) {  // bias are stored into first 4*hidden_size buffer, the rest is all 0
491     tensor_data[i] = 0.0f;
492   }
493 
494   for (size_t i = 0; i < data_ptrs.size(); ++i) {
495     auto source_len = std::accumulate(data_shapes[i].begin(), data_shapes[i].end(), 1, std::multiplies<int>());
496     auto ret = memcpy_s(tensor_data + i * step, step * sizeof(float), data_ptrs[i], source_len * sizeof(float));
497     if (ret != EOK) {
498       MS_LOG(ERROR) << "memcpy_s error";
499       return RET_ERROR;
500     }
501   }
502 
503   auto status = lite::InitParameterFromTensorInfo(new_param, tensor_info);
504   if (status != RET_OK) {
505     MS_LOG(ERROR) << "init parameter from tensor info failed";
506     return RET_ERROR;
507   }
508 
509   return RET_OK;
510 }
511 
CreateLSTMNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const EquivPtr & body_equiv,const std::string & base_name,const float zoneout_cell,const float zoneout_hidden) const512 CNodePtr TfliteLstmCellFusion::CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
513                                               const EquivPtr &body_equiv, const std::string &base_name,
514                                               const float zoneout_cell, const float zoneout_hidden) const {
515   MS_ASSERT(func_graph != nullptr);
516   MS_ASSERT(equiv != nullptr);
517   MS_ASSERT(body_equiv != nullptr);
518   /*
519    * input vars for while node
520    * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
521    * 9:i2i_  10:i2f_ 11:i2c_ 12:i2o_   13:c2i_   14:c2f_ 15:c2c_   16:c2o_   17:i_bias_ 18:f_bias_ 19:c_bias_ 20:o_bias_
522    */
523   auto lstm_prim = std::make_shared<ops::LSTM>();
524   MS_CHECK_TRUE_RET(lstm_prim != nullptr, nullptr);
525   lstm_prim->set_bidirectional(false);
526   lstm_prim->set_zoneout_cell(zoneout_cell);
527   lstm_prim->set_zoneout_hidden(zoneout_hidden);
528   auto value_node = NewValueNode(lstm_prim);
529   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
530 
531   auto &vars = while_input_vars_;
532   auto i2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[9]]);
533   MS_ASSERT(i2i_weight);
534   auto i2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[10]]);
535   MS_ASSERT(i2f_weight);
536   auto i2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[11]]);
537   MS_ASSERT(i2c_weight);
538   auto i2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[12]]);
539   MS_ASSERT(i2o_weight);
540 
541   auto c2i_weight = utils::cast<AnfNodePtr>((*equiv)[vars[13]]);
542   MS_ASSERT(c2i_weight);
543   auto c2f_weight = utils::cast<AnfNodePtr>((*equiv)[vars[14]]);
544   MS_ASSERT(c2f_weight);
545   auto c2c_weight = utils::cast<AnfNodePtr>((*equiv)[vars[15]]);
546   MS_ASSERT(c2c_weight);
547   auto c2o_weight = utils::cast<AnfNodePtr>((*equiv)[vars[16]]);
548   MS_ASSERT(c2o_weight);
549 
550   auto i_bias = utils::cast<AnfNodePtr>((*equiv)[vars[17]]);
551   MS_ASSERT(i_bias);
552   auto f_bias = utils::cast<AnfNodePtr>((*equiv)[vars[18]]);
553   MS_ASSERT(f_bias);
554   auto c_bias = utils::cast<AnfNodePtr>((*equiv)[vars[19]]);
555   MS_ASSERT(c_bias);
556   auto o_bias = utils::cast<AnfNodePtr>((*equiv)[vars[20]]);
557   MS_ASSERT(o_bias);
558 
559   auto input = utils::cast<AnfNodePtr>((*equiv)[vars[8]]);
560   MS_ASSERT(input);
561   auto cell = utils::cast<AnfNodePtr>((*equiv)[vars[5]]);
562   MS_ASSERT(cell);
563   auto hidden = utils::cast<AnfNodePtr>((*equiv)[vars[6]]);
564   MS_ASSERT(hidden);
565 
566   std::vector<AnfNodePtr> i_weights{i2i_weight, i2o_weight, i2f_weight, i2c_weight};
567   auto i_weight = func_graph->add_parameter();
568   MS_CHECK_TRUE_RET(i_weight != nullptr, nullptr);
569   auto status = GetConcatedParam(i_weights, i_weight, false);
570   if (status != RET_OK) {
571     return nullptr;
572   }
573   i_weight->set_name(base_name + "_weight_i");
574 
575   std::vector<AnfNodePtr> c_weights{c2i_weight, c2o_weight, c2f_weight, c2c_weight};
576   auto c_weight = func_graph->add_parameter();
577   MS_CHECK_TRUE_RET(c_weight != nullptr, nullptr);
578   status = GetConcatedParam(c_weights, c_weight, false);
579   if (status != RET_OK) {
580     return nullptr;
581   }
582   c_weight->set_name(base_name + "_weight_c");
583 
584   std::vector<AnfNodePtr> biases{i_bias, o_bias, f_bias, c_bias};
585   auto bias = func_graph->add_parameter();
586   MS_CHECK_TRUE_RET(bias != nullptr, nullptr);
587   status = GetConcatedParam(biases, bias, true);
588   if (status != RET_OK) {
589     return nullptr;
590   }
591   bias->set_name(base_name + "_bias");
592 
593   if (!utils::isa<CNodePtr>(input) || !CheckPrimitiveType(input, prim::kPrimTensorListFromTensor)) {
594     MS_LOG(DEBUG) << "input is not tensorlistfromtensor op";
595     return nullptr;
596   }
597   auto tensor_list_cnode = utils::cast<CNodePtr>(input);
598   auto input_tensor_node = tensor_list_cnode->input(1);
599 
600   std::vector<AnfNodePtr> new_node_inputs = {value_node, input_tensor_node, i_weight, c_weight, bias, hidden, cell};
601   auto new_node = func_graph->NewCNode(new_node_inputs);
602   MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
603   new_node->set_fullname_with_scope(base_name);
604   return new_node;
605 }
606 
CreateOutputGetItem(const FuncGraphPtr & func_graph,const CNodePtr & node,const int item_index)607 CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node,
608                                                    const int item_index) {
609   MS_ASSERT(func_graph != nullptr);
610   MS_ASSERT(node != nullptr);
611   auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>();
612   auto get_item_value = NewValueNode(MakeValue<int>(item_index));
613   if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
614     MS_LOG(ERROR) << "NewValueNode is nullptr";
615     return nullptr;
616   }
617   CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value});
618   MS_CHECK_TRUE_RET(get_item_cnode != nullptr, nullptr);
619   auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
620   if (abstract == nullptr) {
621     MS_LOG(ERROR) << "Create tensor abstarct failed";
622     return nullptr;
623   }
624   get_item_cnode->set_abstract(abstract);
625   get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
626                                           std::to_string(item_index));
627   return get_item_cnode;
628 }
629 
AdjustOtherGetItems(const FuncGraphPtr & func_graph,const CNodePtr & while_cnode,const CNodePtr & lstm_cnode,const CNodePtr & output_get_item)630 STATUS TfliteLstmCellFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode,
631                                                  const CNodePtr &lstm_cnode, const CNodePtr &output_get_item) {
632   MS_ASSERT(func_graph != nullptr && while_cnode != nullptr);
633   MS_ASSERT(lstm_cnode != nullptr && output_get_item != nullptr);
634   auto manager = func_graph->manager();
635   if (manager == nullptr) {
636     MS_LOG(ERROR) << "manager is nullptr";
637     return RET_ERROR;
638   }
639   auto while_node_users = manager->node_users()[while_cnode];
640   for (auto &node_user : while_node_users) {
641     if (node_user.first == output_get_item) {
642       continue;
643     }
644     if (!utils::isa<CNodePtr>(node_user.first)) {
645       return RET_ERROR;
646     }
647     auto get_item = utils::cast<CNodePtr>(node_user.first);
648     if (!CheckPrimitiveType(get_item, prim::kPrimTupleGetItem)) {
649       return RET_ERROR;
650     }
651     auto new_inputs = get_item->inputs();
652     if (new_inputs.size() != 3) {
653       return RET_ERROR;
654     }
655     new_inputs[1] = lstm_cnode;
656     auto index_vnode = get_item->input(2);
657     if (!utils::isa<ValueNode>(index_vnode)) {
658       MS_LOG(ERROR) << "TupleGetItem's input 2 is not value node";
659       return RET_ERROR;
660     }
661     auto value_node = utils::cast<ValueNodePtr>(index_vnode);
662     if (value_node == nullptr) {
663       MS_LOG(ERROR) << "cast to ValueNode failed";
664       return RET_ERROR;
665     }
666     auto origin_index = GetValue<int>(value_node->value());
667     int new_index = origin_index == 4 ? 2 : 1;
668     auto new_index_vnode = NewValueNode(MakeValue<int>(new_index));
669     MS_CHECK_TRUE_RET(new_index_vnode != nullptr, RET_ERROR);
670     new_inputs[2] = new_index_vnode;
671     get_item->set_inputs(new_inputs);
672     get_item->set_fullname_with_scope(lstm_cnode->fullname_with_scope() + "_getitem_" + std::to_string(new_index));
673     if (get_item->abstract() == nullptr) {
674       MS_LOG(ERROR) << "get_item's abstract is nullptr";
675       return RET_ERROR;
676     }
677 
678     std::vector<int> squeeze_axis{0};
679     auto squeeze_node = CreateSqueezeNode(func_graph, get_item, squeeze_axis);
680     if (squeeze_node == nullptr) {
681       return RET_ERROR;
682     }
683 
684     auto get_item_users = manager->node_users()[get_item];
685     for (auto &get_item_user : get_item_users) {
686       manager->SetEdge(get_item_user.first, get_item_user.second, squeeze_node);
687     }
688   }
689   return RET_OK;
690 }
691 
SetAbstractTuple(const CNodePtr & cnode,const int output_num)692 STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) {
693   MS_ASSERT(cnode != nullptr);
694   AbstractBasePtrList abstract_list;
695   for (int i = 0; i < output_num; ++i) {
696     auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
697     if (abstract == nullptr) {
698       MS_LOG(ERROR) << "Create tensor abstarct failed";
699       return RET_ERROR;
700     }
701     abstract_list.emplace_back(abstract);
702   }
703   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
704   if (abstract_tuple == nullptr) {
705     MS_LOG(ERROR) << "create abstract_tuple failed";
706     return RET_ERROR;
707   }
708   cnode->set_abstract(abstract_tuple);
709   return RET_OK;
710 }
711 
CreateSqueezeNode(const FuncGraphPtr & func_graph,const CNodePtr & input_node,const std::vector<int> & axis)712 CNodePtr TfliteLstmCellFusion::CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node,
713                                                  const std::vector<int> &axis) {
714   MS_ASSERT(func_graph != nullptr && input_node != nullptr);
715   auto squeeze_prim = std::make_shared<ops::Squeeze>();
716   MS_CHECK_TRUE_RET(squeeze_prim != nullptr, nullptr);
717   std::vector<int64_t> axis_vec;
718   std::transform(axis.begin(), axis.end(), std::back_inserter(axis_vec),
719                  [](int val) { return static_cast<int64_t>(val); });
720   squeeze_prim->set_axis(axis_vec);
721   auto squeeze_cnode = func_graph->NewCNode(squeeze_prim, {input_node});
722   MS_CHECK_TRUE_RET(squeeze_cnode != nullptr, nullptr);
723   if (input_node->abstract() != nullptr) {
724     squeeze_cnode->set_abstract(input_node->abstract()->Clone());
725   }
726   squeeze_cnode->set_fullname_with_scope("squeeze_" + input_node->fullname_with_scope());
727   return squeeze_cnode;
728 }
729 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const730 const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
731                                                const EquivPtr &equiv) const {
732   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
733     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
734     return nullptr;
735   }
736 
737   if (!utils::isa<CNodePtr>(node)) {
738     return nullptr;
739   }
740   auto tensor_list_stack_cnode = utils::cast<CNodePtr>(node);
741   auto tuple_get_item_node = tensor_list_stack_cnode->input(1);
742   if (!utils::isa<CNodePtr>(tuple_get_item_node)) {
743     return nullptr;
744   }
745   auto tuple_get_item_cnode = utils::cast<CNodePtr>(tuple_get_item_node);
746   auto while_node = tuple_get_item_cnode->input(1);
747   if (!utils::isa<CNodePtr>(while_node)) {
748     return nullptr;
749   }
750   auto while_cnode = utils::cast<CNodePtr>(while_node);
751 
752   if (while_cnode == nullptr || while_cnode->size() != while_inputs_num_) {
753     return nullptr;
754   }
755   if (!CheckReferencedOutputs(func_graph, while_cnode)) {
756     return nullptr;
757   }
758   auto primitive_vars_cond = std::make_shared<PrimitiveVarMap>();
759   MS_CHECK_TRUE_RET(primitive_vars_cond != nullptr, nullptr);
760   auto cond_graph_pattern = GetCondGraphPattern(primitive_vars_cond);
761   MS_CHECK_TRUE_RET(cond_graph_pattern != nullptr, nullptr);
762   auto cond_equiv =
763     CheckSubGraph(cond_graph_pattern, primitive_vars_cond, while_cnode->input(1), cond_cnodes_num_, cond_nodes_num_);
764   if (cond_equiv == nullptr || cond_equiv->empty()) {
765     return nullptr;
766   }
767   auto primitive_vars_body = std::make_shared<PrimitiveVarMap>();
768   MS_CHECK_TRUE_RET(primitive_vars_body != nullptr, nullptr);
769   auto body_graph_pattern = GetBodyGraphPattern(primitive_vars_body);
770   MS_CHECK_TRUE_RET(body_graph_pattern != nullptr, nullptr);
771   auto body_equiv =
772     CheckSubGraph(body_graph_pattern, primitive_vars_body, while_cnode->input(2), body_cnodes_num_, body_nodes_num_);
773   if (body_equiv == nullptr || body_equiv->empty()) {
774     return nullptr;
775   }
776   float zoneout_cell = 0.0f;
777   float zoneout_hidden = 0.0f;
778   if (!CheckBodyGraph(body_equiv, &zoneout_cell, &zoneout_hidden)) {
779     return nullptr;
780   }
781   const std::string lstm_name = "lstm_" + while_cnode->fullname_with_scope();
782   auto lstm_node = CreateLSTMNode(func_graph, equiv, body_equiv, lstm_name, zoneout_cell, zoneout_hidden);
783   if (lstm_node == nullptr) {
784     return nullptr;
785   }
786   auto status = SetAbstractTuple(lstm_node, kLSTMOutputNum);
787   if (status != RET_OK) {
788     return nullptr;
789   }
790 
791   auto get_item_node = CreateOutputGetItem(func_graph, lstm_node, 0);
792   if (get_item_node == nullptr) {
793     MS_LOG(DEBUG) << "create lstm output get_item node failed";
794     return nullptr;
795   }
796 
797   status = AdjustOtherGetItems(func_graph, while_cnode, lstm_node, tuple_get_item_cnode);
798   if (status != RET_OK) {
799     return nullptr;
800   }
801 
802   std::vector<int> squeeze_axis{1};  // our lstm output:0 have an extra axis that tflite not have, it must be squeezed
803   auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis);
804   if (squeeze_node == nullptr) {
805     return nullptr;
806   }
807 
808   auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1);
809   MS_CHECK_TRUE_RET(cond_cnode_index_pair != nullptr, nullptr);
810   func_graph->DropFuncGraphCNodeIndex(cond_cnode_index_pair);
811   auto body_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 2);
812   MS_CHECK_TRUE_RET(body_cnode_index_pair != nullptr, nullptr);
813   func_graph->DropFuncGraphCNodeIndex(body_cnode_index_pair);
814   MS_LOG(INFO) << "lstm node:" << lstm_node->fullname_with_scope() << " fusion success";
815   return squeeze_node;
816 }
817 }  // namespace opt
818 }  // namespace mindspore
819