• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
17 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
18 #include <vector>
19 #include <memory>
20 #include <string>
21 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "schema/inner/model_generated.h"
24 #include "tools/optimizer/common/pattern_process_pass_extends.h"
25 #include "include/common/utils/utils.h"
26 #include "include/errorcode.h"
27 
28 namespace mindspore {
29 namespace opt {
30 constexpr size_t kWhileUniqInputsLength = 6;
31 // fuse tf 2.x bidirection_gru into MSLITE GRU
32 class TfBidirectionGruFusion : public LitePatternProcessPass {
33  public:
34   explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength,
35                                   const std::string &name = "TfBidirectionGruFusion", bool multi_graph = true)
LitePatternProcessPass(name,multi_graph)36       : LitePatternProcessPass(name, multi_graph), num_fw_vars_(num_fw_vars), num_bw_vars_(num_bw_vars) {}
37 
38   ~TfBidirectionGruFusion() override = default;
39 
40   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
41   const BaseRef DefinePattern() const override;
42 
43  protected:
44   bool Init() const;
45 
46   virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
47 
48   CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv,
49                                     const std::string &base_name, int var_offset) const;
50 
51   static CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
52                                      const std::string &base_name);
53 
54  private:
55   const VectorRef DefineFowardPattern() const;
56 
57   const VectorRef DefinebackwardPattern() const;
58 
59   AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
60 
61   static tensor::TensorPtr GetDefaultTensorInfo(const AnfNodePtr &parameter_anf);
62 
63   static lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf,
64                                             int *input_size, int *hidden_size);
65 
66   static ParameterPtr AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
67                                           const std::vector<int> &shape, TypeId type, void **tensor_data);
68 
69   static lite::STATUS ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, int input_size,
70                                         int hidden_size, float *gate_tensor_data, float *recu_tensor_data);
71 
72   static lite::STATUS ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, int hidden_size,
73                                       float *tensor_data);
74 
75   static void CopyFlattenMatData(const float *mat, int C, int r0, int r1, int c0, int c1, float *data, bool t = false);
76 
77   static CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state,
78                                         const AnfNodePtr &bw_init_state, const std::string &base_name);
79 
80  protected:
81   mutable std::vector<VarPtr> fw_vars_;
82   mutable std::vector<VarPtr> bw_vars_;
83   mutable VarPtr input_;
84   mutable VarPtr input_length_;
85   mutable VarPtr transpose_input_;
86   mutable VarPtr fw_init_state_;
87   mutable VarPtr bw_init_state_;
88 
89  private:
90   int num_fw_vars_{0};
91   int num_bw_vars_{0};
92 };
IsParameterNode(const BaseRef & n)93 inline bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
94 }  // namespace opt
95 }  // namespace mindspore
96 
97 #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
98