1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
17 #include <memory>
18 #include <set>
19 #include <functional>
20 #include "src/common/utils.h"
21 #include "utils/utils.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "securec/include/securec.h"
24 #include "tools/converter/ops/ops_def.h"
25 #include "nnacl/op_base.h"
26
27 namespace mindspore {
28 namespace opt {
29 namespace {
30 constexpr size_t kNumFwVars = 4;
31 constexpr size_t kNumBwVars = 4;
32 const auto &p1 = std::placeholders::_1;
GetPrim(const PrimitivePtr & prim)33 BaseRef GetPrim(const PrimitivePtr &prim) {
34 auto ptr = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim));
35 MS_CHECK_TRUE_MSG(ptr != nullptr, nullptr, "is nullptr.");
36 return ptr;
37 }
38
GetPrim(const std::string & prim_name)39 BaseRef GetPrim(const std::string &prim_name) {
40 auto prim = std::make_shared<Primitive>(prim_name);
41 MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "is nullptr.");
42 return GetPrim(prim);
43 }
44 } // namespace
45
TfBidirectionGruCfFusion(const std::string & name,bool multi_graph)46 TfBidirectionGruCfFusion::TfBidirectionGruCfFusion(const std::string &name, bool multi_graph)
47 : TfBidirectionGruFusion(kNumFwVars, kNumBwVars, name, multi_graph) {
48 /*
49 * vars for fw/bw input
50 * fw:
51 * 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias
52 * bw:
53 * 0:kernel_gate 1:bias_gate 2:cand_kernel 3:cand_bias
54 */
55 }
56
DefineGruCellPattern(const BaseRef & in_ta_read,const BaseRef & switch3_true,const std::vector<VarPtr> & vars) const57 BaseRef TfBidirectionGruCfFusion::DefineGruCellPattern(const BaseRef &in_ta_read, const BaseRef &switch3_true,
58 const std::vector<VarPtr> &vars) const {
59 auto concat = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, switch3_true});
60 auto matmul_enter = VectorRef({GetPrim(lite::kNameEnter), vars[0]}); // gate_kernel
61 auto matmul = VectorRef({GetPrim(prim::kPrimMatMul), concat, matmul_enter});
62 auto bias_enter = VectorRef({GetPrim(lite::kNameEnter), vars[1]}); // cand_bias
63 auto bias = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul, bias_enter});
64 auto sigmoid = VectorRef({GetPrim(prim::kPrimActivation), bias});
65 auto split = VectorRef({GetPrim(prim::kPrimSplit), sigmoid});
66 auto rt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()});
67 auto zt = VectorRef({GetPrim(prim::kPrimTupleGetItem), split, std::make_shared<Var>()});
68 auto mul = VectorRef({GetPrim(prim::kPrimMulFusion), rt, switch3_true});
69 auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), in_ta_read, mul});
70 auto matmul1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[2]}); // cand_kernel
71 auto matmul1 = VectorRef({GetPrim(prim::kPrimMatMul), concat1, matmul1_enter});
72 auto bias1_enter = VectorRef({GetPrim(lite::kNameEnter), vars[3]}); // cand_bias
73 auto bias1 = VectorRef({GetPrim(prim::kPrimBiasAdd), matmul1, bias1_enter});
74 auto tanh = VectorRef({GetPrim(prim::kPrimActivation), bias1});
75 auto sub = VectorRef({GetPrim(prim::kPrimSubFusion), std::make_shared<CondVar>(IsParameterNode), zt});
76 auto mul2 = VectorRef({GetPrim(prim::kPrimMulFusion), sub, tanh});
77 auto mul1 = VectorRef({GetPrim(prim::kPrimMulFusion), zt, switch3_true});
78 auto add = VectorRef({GetPrim(prim::kPrimAddFusion), mul1, mul2});
79 return add;
80 }
81
DefineBidirectionRnnPattern(const BaseRef & input,const std::vector<VarPtr> & vars,const VarPtr & init_state) const82 const BaseRef TfBidirectionGruCfFusion::DefineBidirectionRnnPattern(const BaseRef &input,
83 const std::vector<VarPtr> &vars,
84 const VarPtr &init_state) const {
85 // in order to match cyclic graph, some node in cycle is represented by SeqVar
86 auto fw_shape1 = VectorRef({GetPrim(prim::kPrimShape), input});
87 auto strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape1, std::make_shared<SeqVar>()});
88 auto fw_max = VectorRef({GetPrim(prim::kPrimReduceFusion), input_length_, std::make_shared<Var>()});
89 auto fw_maximum = VectorRef({GetPrim(prim::kPrimMaximum), std::make_shared<CondVar>(IsParameterNode), fw_max});
90 auto fw_minimum = VectorRef({GetPrim(prim::kPrimMinimum), strided_slice, fw_maximum});
91 auto fw_less1_enter = VectorRef({GetPrim(lite::kNameEnter), fw_minimum});
92 // SeqVar:counter_merge1
93 auto fw_less1 = VectorRef({GetPrim(prim::kPrimLess), std::make_shared<SeqVar>(), fw_less1_enter});
94
95 // SeqVar:fw_merge,loop_cond
96 auto fw_switch = VectorRef({GetPrim(prim::kPrimSwitch), std::make_shared<SeqVar>()});
97 auto fw_switch_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), fw_switch, std::make_shared<Var>()}); // identity
98 auto fw_add = VectorRef({GetPrim(prim::kPrimAddFusion), fw_switch_true, std::make_shared<CondVar>(IsParameterNode)});
99 auto fw_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), fw_add});
100 auto fw_merge_enter = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)});
101 auto fw_merge = VectorRef({GetPrim(prim::kPrimMerge), fw_merge_enter, fw_next_iter});
102 auto fw_less_enter = VectorRef({GetPrim(lite::kNameEnter), strided_slice});
103 auto fw_less = VectorRef({GetPrim(prim::kPrimLess), fw_merge, fw_less_enter});
104
105 auto fw_logical_and = VectorRef({GetPrim(prim::kPrimLogicalAnd), fw_less, fw_less1});
106 // SeqVar:fw_logical_and
107 auto loop_cond = VectorRef({GetPrim(lite::kNameLoopCond), fw_logical_and});
108
109 auto fw_shape = VectorRef({GetPrim(prim::kPrimShape), input});
110 auto fw_unstack_strided_slice = VectorRef({GetPrim(prim::kPrimStridedSlice), fw_shape, std::make_shared<SeqVar>()});
111 auto fw_unstack_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<CondVar>(IsParameterNode),
112 fw_unstack_strided_slice, std::make_shared<CondVar>(IsParameterNode)});
113
114 // SeqVar:switch1_true
115 auto counter_add =
116 VectorRef({GetPrim(prim::kPrimAddFusion), std::make_shared<SeqVar>(), std::make_shared<CondVar>(IsParameterNode)});
117 auto counter_zero = VectorRef({GetPrim(lite::kNameEnter), std::make_shared<CondVar>(IsParameterNode)});
118 auto counter_next_iter = VectorRef({GetPrim(lite::kNameNextIteration), counter_add});
119 auto counter_merge1 = VectorRef({GetPrim(prim::kPrimMerge), counter_zero, counter_next_iter});
120 auto counter_switch1 = VectorRef({GetPrim(prim::kPrimSwitch), counter_merge1, loop_cond});
121 auto switch1_true =
122 VectorRef({GetPrim(prim::kPrimTupleGetItem), counter_switch1, std::make_shared<Var>()}); // identity1
123
124 auto in_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice});
125 auto in_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()});
126 auto in_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), in_ta, std::make_shared<Var>()});
127 auto fw_unstack_ta_scatter =
128 VectorRef({GetPrim(lite::kNameTensorArrayScatterV3), in_ta_handle, fw_unstack_range, input, in_ta_flow});
129 auto in_ta_enter1 = VectorRef({GetPrim(lite::kNameEnter), fw_unstack_ta_scatter});
130 auto in_ta_enter = VectorRef({GetPrim(lite::kNameEnter), in_ta_handle});
131 auto in_ta_read = VectorRef({GetPrim(lite::kNameTensorArrayReadV3), in_ta_enter, switch1_true, in_ta_enter1});
132
133 auto greater_equal_enter = VectorRef({GetPrim(lite::kNameEnter), input_length_});
134 auto greater_equal = VectorRef({GetPrim(prim::kPrimGreaterEqual), switch1_true, greater_equal_enter});
135 auto select1 = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, std::make_shared<SeqVar>()}); // select h
136
137 auto next_iteration3 = VectorRef({GetPrim(lite::kNameNextIteration), select1});
138 auto enter3 = VectorRef({GetPrim(lite::kNameEnter), init_state});
139 auto merge3 = VectorRef({GetPrim(prim::kPrimMerge), enter3, next_iteration3});
140 auto switch3 = VectorRef({GetPrim(prim::kPrimSwitch), merge3, loop_cond});
141 auto switch3_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch3, std::make_shared<Var>()}); // identity3
142
143 auto rnn_cell_out = DefineGruCellPattern(in_ta_read, switch3_true, vars);
144
145 auto out_ta = VectorRef({GetPrim(lite::kNameTensorArrayV3), strided_slice});
146 auto out_ta_handle = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()});
147 auto out_ta_flow = VectorRef({GetPrim(prim::kPrimTupleGetItem), out_ta, std::make_shared<Var>()});
148 auto out_ta_enter = VectorRef({GetPrim(lite::kNameEnter), out_ta_handle});
149
150 auto switch2_true = VectorRef({GetPrim(prim::kPrimTupleGetItem), std::make_shared<SeqVar>()}); // cycle
151
152 auto concat1 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<SeqVar>()});
153 auto zeros1 = VectorRef({GetPrim(prim::kPrimFill), std::make_shared<CondVar>(IsParameterNode), concat1});
154 auto select_enter = VectorRef({GetPrim(lite::kNameEnter), zeros1});
155 auto select = VectorRef({GetPrim(prim::kPrimSelect), greater_equal, select_enter, rnn_cell_out}); // select x
156 auto ta_write = VectorRef({GetPrim(lite::kNameTensorArrayWriteV3), out_ta_enter, switch1_true, select, switch2_true});
157
158 auto enter2 = VectorRef({GetPrim(lite::kNameEnter), out_ta_flow});
159 auto next_iter2 = VectorRef({GetPrim(lite::kNameNextIteration), ta_write});
160 auto merge2 = VectorRef({GetPrim(prim::kPrimMerge), enter2, next_iter2});
161 auto switch2 = VectorRef({GetPrim(prim::kPrimSwitch), merge2, loop_cond});
162 auto switch2_false = VectorRef({GetPrim(prim::kPrimTupleGetItem), switch2, std::make_shared<Var>()});
163
164 auto exit2 = VectorRef({GetPrim(lite::kNameExit), switch2_false});
165 auto ta_size = VectorRef({GetPrim(lite::kNameTensorArraySizeV3), out_ta_handle, exit2});
166 auto range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<Var>(), ta_size, std::make_shared<Var>()});
167 auto tensor_array_gather = VectorRef({GetPrim(lite::kNameTensorArrayGatherV3), out_ta_handle, range, exit2});
168 auto range1 = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()});
169 auto concat2 = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), range1});
170 auto fw_out_trans = VectorRef({GetPrim(prim::kPrimTranspose), tensor_array_gather, concat2});
171 return fw_out_trans;
172 }
173
DefinePattern() const174 const BaseRef TfBidirectionGruCfFusion::DefinePattern() const {
175 if (!Init()) {
176 MS_LOG(ERROR) << "initial member failed.";
177 return {};
178 }
179
180 const auto fw_out_trans = DefineBidirectionRnnPattern(transpose_input_, fw_vars_, fw_init_state_);
181
182 auto bw_reverse_in = VectorRef({GetPrim(prim::kPrimReverseSequence), input_, input_length_});
183 auto bw_range = VectorRef({GetPrim(prim::kPrimRange), std::make_shared<SeqVar>()});
184 auto bw_concat = VectorRef({GetPrim(prim::kPrimConcat), std::make_shared<CondVar>(IsParameterNode), bw_range});
185 auto bw_transpose = VectorRef({GetPrim(prim::kPrimTranspose), bw_reverse_in, bw_concat});
186 auto bw_out_trans = DefineBidirectionRnnPattern(bw_transpose, bw_vars_, bw_init_state_);
187 auto bw_reverse_out = VectorRef({GetPrim(prim::kPrimReverseSequence), bw_out_trans, input_length_});
188 auto concat = VectorRef({GetPrim(prim::kPrimConcat), fw_out_trans, bw_reverse_out});
189 return concat;
190 }
191
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & concat_node,const EquivPtr & equiv) const192 const AnfNodePtr TfBidirectionGruCfFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
193 const EquivPtr &equiv) const {
194 if (func_graph == nullptr || concat_node == nullptr || equiv == nullptr) {
195 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
196 return nullptr;
197 }
198
199 auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]);
200 MS_ASSERT(transpose_input != nullptr);
201
202 const std::string gru_name = "gru_" + concat_node->fullname_with_scope();
203 auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 0);
204 MS_CHECK_TRUE_MSG(gru_node != nullptr, nullptr, "gru_node is nullptr.");
205
206 if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) {
207 return nullptr;
208 }
209
210 auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0);
211 MS_CHECK_TRUE_MSG(get_item_node != nullptr, nullptr, "get_item_node is nullptr.");
212
213 auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope());
214 MS_CHECK_TRUE_MSG(output_node != nullptr, nullptr, "output_node is nullptr.");
215 MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success";
216 return output_node;
217 }
218 } // namespace opt
219 } // namespace mindspore
220