• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_
17 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_
18 
19 #include <vector>
20 #include <memory>
21 #include <string>
22 #include "backend/optimizer/common/optimizer.h"
23 #include "utils/utils.h"
24 #include "include/errorcode.h"
25 
26 namespace mindspore {
27 namespace opt {
28 class TfliteLstmCellFusion : public PatternProcessPass {
29  public:
30   explicit TfliteLstmCellFusion(const std::string &name = "TfliteLstmCellFusion", bool multigraph = true,
31                                 int input_length = 0, int var_num = 0, int cond_nodes_num = 0, int cond_cnodes_num = 0,
32                                 int body_nodes_num = 0, int body_cnodes_num = 0);
33 
34   ~TfliteLstmCellFusion() override = default;
35 
36   static EquivPtr MatchGraph(const FuncGraphPtr &func_graph, const PrimitiveVarMapPtr &primitive_vars,
37                              const AnfNodePtr &pattern);
38 
39   static EquivPtr CheckSubGraph(const AnfNodePtr &pattern, const PrimitiveVarMapPtr &primitive_vars,
40                                 const AnfNodePtr &anf_sub_graph, size_t cnode_num, size_t all_node_num);
41 
42   static lite::STATUS SetAbstractTuple(const CNodePtr &cnode, int output_num);
43 
44   static CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, int item_index);
45 
46  protected:
47   bool Init() const;
48 
49   const BaseRef DefinePattern() const override;
50 
51   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
52 
53   static lite::STATUS GetFloatScalarFromTensorInfo(const AnfNodePtr &tensor_info, float *v);
54 
55   static CNodePtr CreateSqueezeNode(const FuncGraphPtr &func_graph, const CNodePtr &input_node,
56                                     const std::vector<int> &axis);
57 
58   static lite::STATUS AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode,
59                                           const CNodePtr &lstm_cnode, const CNodePtr &output_get_item);
60 
61   static AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars);
62 
63   virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
64 
65   virtual CNodePtr CreateLSTMNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const EquivPtr &body_equiv,
66                                   const std::string &base_name, float zoneout_cell, float zoneout_hidden) const;
67 
68  private:
69   CNodePtr GetWhileCnode(const AnfNodePtr &cnode) const;
70   bool CheckBodyGraph(const EquivPtr &equiv, float *zoneout_cell, float *zoneout_hidden) const;
71 
72   static bool CheckReferencedOutputs(const FuncGraphPtr &func_graph, const CNodePtr &while_cnode);
73 
74   static lite::STATUS GetConcatedParam(const std::vector<AnfNodePtr> &params, const ParameterPtr &new_param,
75                                        bool is_bias);
76 
77  protected:
78   mutable VarPtr cell_zoneout_old_ = nullptr;
79   mutable VarPtr cell_zoneout_new_ = nullptr;
80   mutable VarPtr hidden_zoneout_old_ = nullptr;
81   mutable VarPtr hidden_zoneout_new_ = nullptr;
82   mutable std::vector<VarPtr> while_input_vars_;
83 
84  private:
85   size_t while_input_var_num_ = 0;
86   size_t while_inputs_num_ = 0;
87   size_t cond_nodes_num_ = 0;
88   size_t cond_cnodes_num_ = 0;
89   size_t body_nodes_num_ = 0;
90   size_t body_cnodes_num_ = 0;
91 };
92 }  // namespace opt
93 }  // namespace mindspore
94 #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TFLITE_LSTM_CELL_FUSION_H_
95