1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
17 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
18 #include <vector>
19 #include <memory>
20 #include <string>
21 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "schema/inner/model_generated.h"
24 #include "tools/optimizer/common/pattern_process_pass_extends.h"
25 #include "include/common/utils/utils.h"
26 #include "include/errorcode.h"
27
28 namespace mindspore {
29 namespace opt {
30 constexpr size_t kWhileUniqInputsLength = 6;
31 // fuse tf 2.x bidirection_gru into MSLITE GRU
32 class TfBidirectionGruFusion : public LitePatternProcessPass {
33 public:
34 explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength,
35 const std::string &name = "TfBidirectionGruFusion", bool multi_graph = true)
LitePatternProcessPass(name,multi_graph)36 : LitePatternProcessPass(name, multi_graph), num_fw_vars_(num_fw_vars), num_bw_vars_(num_bw_vars) {}
37
38 ~TfBidirectionGruFusion() override = default;
39
40 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
41 const BaseRef DefinePattern() const override;
42
43 protected:
44 bool Init() const;
45
46 virtual AnfNodePtr GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
47
48 CNodePtr CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const EquivPtr &equiv,
49 const std::string &base_name, int var_offset) const;
50
51 static CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
52 const std::string &base_name);
53
54 private:
55 const VectorRef DefineFowardPattern() const;
56
57 const VectorRef DefinebackwardPattern() const;
58
59 AnfNodePtr GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const;
60
61 static tensor::TensorPtr GetDefaultTensorInfo(const AnfNodePtr ¶meter_anf);
62
63 static lite::STATUS GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf, const AnfNodePtr &bw_cand_kernel_anf,
64 int *input_size, int *hidden_size);
65
66 static ParameterPtr AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
67 const std::vector<int> &shape, TypeId type, void **tensor_data);
68
69 static lite::STATUS ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight, int input_size,
70 int hidden_size, float *gate_tensor_data, float *recu_tensor_data);
71
72 static lite::STATUS ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias, int hidden_size,
73 float *tensor_data);
74
75 static void CopyFlattenMatData(const float *mat, int C, int r0, int r1, int c0, int c1, float *data, bool t = false);
76
77 static CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state,
78 const AnfNodePtr &bw_init_state, const std::string &base_name);
79
80 protected:
81 mutable std::vector<VarPtr> fw_vars_;
82 mutable std::vector<VarPtr> bw_vars_;
83 mutable VarPtr input_;
84 mutable VarPtr input_length_;
85 mutable VarPtr transpose_input_;
86 mutable VarPtr fw_init_state_;
87 mutable VarPtr bw_init_state_;
88
89 private:
90 int num_fw_vars_{0};
91 int num_bw_vars_{0};
92 };
IsParameterNode(const BaseRef & n)93 inline bool IsParameterNode(const BaseRef &n) { return utils::isa<ParameterPtr>(n); }
94 } // namespace opt
95 } // namespace mindspore
96
97 #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_BIDIRECTION_GRU_FUSION_H_
98