1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
19 #include <memory>
20 #include <functional>
21 #include "mindspore/core/ops/structure_ops.h"
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/math_ops.h"
24 #include "mindspore/core/ops/lite_ops.h"
25 #include "mindspore/core/ops/comparison_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "tools/optimizer/common/helper.h"
29 #include "ops/gru.h"
30 #include "ops/squeeze.h"
31 #include "ops/stack.h"
32 #include "ops/auto_generate/gen_lite_ops.h"
33 #include "src/common/utils.h"
34 #include "tools/common/tensor_util.h"
35 #include "include/common/utils/utils.h"
36 #include "nnacl/op_base.h"
37 #include "ops/op_utils.h"
38
39 namespace mindspore {
40 namespace opt {
41 namespace {
42 constexpr int kOffsetTwo = 2;
43 constexpr int kReservedParamNodesNum = 13;
44 constexpr size_t kCondNodesNum = 12;
45 constexpr size_t kCondCNodesNum = 4;
46 constexpr size_t kBodyNodesNum = 69;
47 constexpr size_t kBodyCNodesNum = 25;
48 constexpr auto kGateNum = 2;
49 const auto &p1 = std::placeholders::_1;
GenerateBodyGraphHiddenPattern(const BaseRef & sigmoid1,const BaseRef & get_item,const std::vector<CondVarPtr> & placeholders)50 VectorRef GenerateBodyGraphHiddenPattern(const BaseRef &sigmoid1, const BaseRef &get_item,
51 const std::vector<CondVarPtr> &placeholders) {
52 MS_CHECK_TRUE_RET(placeholders.size() >= kCondCNodesNum, {});
53 auto is_var_split = std::make_shared<Var>("Split");
54 MS_CHECK_TRUE_RET(is_var_split != nullptr, {});
55 VectorRef split = VectorRef({is_var_split, sigmoid1});
56 auto is_var_tuple_getitem1 = std::make_shared<Var>("TupleGetItem");
57 MS_CHECK_TRUE_RET(is_var_tuple_getitem1 != nullptr, {});
58 auto is_var4 = std::make_shared<Var>();
59 MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
60 VectorRef get_item1 = VectorRef({is_var_tuple_getitem1, split, is_var4});
61 auto is_var_tuple_getitem2 = std::make_shared<Var>("TupleGetItem");
62 MS_CHECK_TRUE_RET(is_var_tuple_getitem2 != nullptr, {});
63 auto is_var5 = std::make_shared<Var>();
64 MS_CHECK_TRUE_RET(is_var5 != nullptr, {});
65 VectorRef get_item2 = VectorRef({is_var_tuple_getitem2, split, is_var5});
66
67 auto is_var_mul1 = std::make_shared<Var>("Mul");
68 MS_CHECK_TRUE_RET(is_var_mul1 != nullptr, {});
69 VectorRef pre_reset = VectorRef({is_var_mul1, get_item1, placeholders[4]});
70 auto is_var_concat = std::make_shared<Var>("Concat");
71 MS_CHECK_TRUE_RET(is_var_concat != nullptr, {});
72 VectorRef concat2 = VectorRef({is_var_concat, get_item, pre_reset});
73 auto is_var_matmul2 = std::make_shared<Var>("Matmul");
74 MS_CHECK_TRUE_RET(is_var_matmul2 != nullptr, {});
75 VectorRef matmul2 = VectorRef({is_var_matmul2, concat2, placeholders[10]});
76 auto is_var_biasadd2 = std::make_shared<Var>("BiasAdd");
77 MS_CHECK_TRUE_RET(is_var_biasadd2 != nullptr, {});
78 VectorRef biasadd2 = VectorRef({is_var_biasadd2, matmul2, placeholders[11]});
79 auto is_var_tanh = std::make_shared<Var>("Tanh");
80 MS_CHECK_TRUE_RET(is_var_tanh != nullptr, {});
81 VectorRef tanh = VectorRef({is_var_tanh, biasadd2});
82
83 auto is_var_mul2 = std::make_shared<Var>("Mul");
84 MS_CHECK_TRUE_RET(is_var_mul2 != nullptr, {});
85 VectorRef update_hidden = VectorRef({is_var_mul2, get_item2, placeholders[4]});
86 auto is_var_sub = std::make_shared<Var>("Sub");
87 MS_CHECK_TRUE_RET(is_var_sub != nullptr, {});
88 auto is_param = std::make_shared<CondVar>(IsParameterNode);
89 MS_CHECK_TRUE_RET(is_param != nullptr, {});
90 VectorRef minus_update = VectorRef({is_var_sub, is_param, get_item2});
91 auto is_var_mul3 = std::make_shared<Var>("Mul");
92 MS_CHECK_TRUE_RET(is_var_mul3 != nullptr, {});
93 VectorRef updated = VectorRef({is_var_mul3, minus_update, tanh});
94
95 auto is_var_add = std::make_shared<Var>("Add");
96 MS_CHECK_TRUE_RET(is_var_add != nullptr, {});
97 VectorRef new_hidden = VectorRef({is_var_add, update_hidden, updated});
98
99 return new_hidden;
100 }
101 } // namespace
102
Init() const103 bool TfBidirectionGruFusion::Init() const {
104 /*
105 * vars for while input
106 * fw_while_inputs:
107 * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
108 * bw_while_inputs:
109 * 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
110 */
111 for (int i = 0; i < num_fw_vars_; ++i) {
112 auto is_var = std::make_shared<Var>();
113 MS_CHECK_TRUE_RET(is_var != nullptr, false);
114 fw_vars_.emplace_back(is_var);
115 }
116 for (int i = 0; i < num_bw_vars_; ++i) {
117 auto is_var = std::make_shared<Var>();
118 MS_CHECK_TRUE_RET(is_var != nullptr, false);
119 bw_vars_.emplace_back(is_var);
120 }
121 input_ = std::make_shared<Var>();
122 MS_CHECK_TRUE_RET(input_ != nullptr, false);
123 input_length_ = std::make_shared<Var>();
124 MS_CHECK_TRUE_RET(input_length_ != nullptr, false);
125 transpose_input_ = std::make_shared<Var>();
126 MS_CHECK_TRUE_RET(transpose_input_ != nullptr, false);
127 fw_init_state_ = std::make_shared<Var>();
128 MS_CHECK_TRUE_RET(fw_init_state_ != nullptr, false);
129 bw_init_state_ = std::make_shared<Var>();
130 MS_CHECK_TRUE_RET(bw_init_state_ != nullptr, false);
131 return true;
132 }
133
DefineFowardPattern() const134 const VectorRef TfBidirectionGruFusion::DefineFowardPattern() const {
135 auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion));
136 MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
137 auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
138 MS_CHECK_TRUE_RET(is_param1 != nullptr, {});
139 auto fw_reduce = VectorRef({is_reduce, input_length_, is_param1});
140 auto is_maximum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum));
141 MS_CHECK_TRUE_RET(is_maximum != nullptr, {});
142 auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
143 MS_CHECK_TRUE_RET(is_param2 != nullptr, {});
144 auto fw_max = VectorRef({is_maximum, is_param2, fw_reduce});
145
146 auto is_shape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape));
147 MS_CHECK_TRUE_RET(is_shape != nullptr, {});
148 auto fw_shape = VectorRef({is_shape, transpose_input_});
149 auto is_strided_slice = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice));
150 MS_CHECK_TRUE_RET(is_strided_slice != nullptr, {});
151 auto is_seq_var = std::make_shared<SeqVar>();
152 MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
153 auto fw_stride = VectorRef({is_strided_slice, fw_shape, is_seq_var});
154 auto is_minimum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum));
155 MS_CHECK_TRUE_RET(is_minimum != nullptr, {});
156 auto fw_min = VectorRef({is_minimum, fw_stride, fw_max});
157 auto is_tensor_list_reserve = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve));
158 MS_CHECK_TRUE_RET(is_tensor_list_reserve != nullptr, {});
159 auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
160 MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
161 auto fw_reserve = VectorRef({is_tensor_list_reserve, is_param3, fw_stride});
162 auto is_tensor_list_from_tensor = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListFromTensor));
163 MS_CHECK_TRUE_RET(is_tensor_list_from_tensor != nullptr, {});
164 auto is_param4 = std::make_shared<CondVar>(IsParameterNode);
165 MS_CHECK_TRUE_RET(is_param4 != nullptr, {});
166 auto fw_from_tensor = VectorRef({is_tensor_list_from_tensor, transpose_input_, is_param4});
167 auto is_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimWhile));
168 MS_CHECK_TRUE_RET(is_while != nullptr, {});
169 auto is_param5 = std::make_shared<CondVar>(IsParameterNode);
170 MS_CHECK_TRUE_RET(is_param5 != nullptr, {});
171 auto is_param6 = std::make_shared<CondVar>(IsParameterNode);
172 MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
173 auto fw_while = VectorRef({is_while, fw_vars_[0], fw_vars_[1], is_param5, fw_stride, is_param6, fw_reserve,
174 fw_init_state_, fw_min, fw_from_tensor, input_length_});
175 fw_while.insert(fw_while.end(), fw_vars_.begin() + kOffsetTwo, fw_vars_.end());
176 auto is_var1 = std::make_shared<Var>();
177 MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
178 fw_while.emplace_back(is_var1);
179 auto is_tuple_getitem = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem));
180 MS_CHECK_TRUE_RET(is_tuple_getitem != nullptr, {});
181 auto is_var2 = std::make_shared<Var>();
182 MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
183 auto fw_get_item = VectorRef({is_tuple_getitem, fw_while, is_var2});
184 auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListStack));
185 MS_CHECK_TRUE_RET(is_tensor_list_stack != nullptr, {});
186 auto is_param7 = std::make_shared<CondVar>(IsParameterNode);
187 MS_CHECK_TRUE_RET(is_param7 != nullptr, {});
188 auto fw_stack = VectorRef({is_tensor_list_stack, fw_get_item, is_param7});
189 auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
190 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
191 auto is_var3 = std::make_shared<Var>();
192 MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
193 auto fw_out_trans = VectorRef({is_transpose, fw_stack, is_var3});
194 return fw_out_trans;
195 }
196
DefinebackwardPattern() const197 const VectorRef TfBidirectionGruFusion::DefinebackwardPattern() const {
198 auto is_reverse_sequence = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReverseSequence));
199 MS_CHECK_TRUE_RET(is_reverse_sequence != nullptr, {});
200 auto bw_reverse_seq = VectorRef({is_reverse_sequence, input_, input_length_});
201 auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion));
202 MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
203 auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
204 MS_CHECK_TRUE_RET(is_param1 != nullptr, {});
205 auto bw_max1 = VectorRef({is_reduce, input_length_, is_param1});
206 auto is_maximum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum));
207 MS_CHECK_TRUE_RET(is_maximum != nullptr, {});
208 auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
209 MS_CHECK_TRUE_RET(is_param2 != nullptr, {});
210 auto bw_max2 = VectorRef({is_maximum, is_param2, bw_max1});
211 auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
212 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
213 auto is_var1 = std::make_shared<Var>();
214 MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
215 auto bw_trans = VectorRef({is_transpose, bw_reverse_seq, is_var1});
216 auto is_shape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape));
217 MS_CHECK_TRUE_RET(is_shape != nullptr, {});
218 auto bw_shape = VectorRef({is_shape, bw_trans});
219 auto is_strided_slice = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice));
220 MS_CHECK_TRUE_RET(is_strided_slice != nullptr, {});
221 auto is_seq_var = std::make_shared<SeqVar>();
222 MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
223 auto bw_stride = VectorRef({is_strided_slice, bw_shape, is_seq_var});
224 auto is_minimum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum));
225 MS_CHECK_TRUE_RET(is_minimum != nullptr, {});
226 auto bw_min = VectorRef({is_minimum, bw_stride, bw_max2});
227 auto is_tensor_list_reserve = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve));
228 MS_CHECK_TRUE_RET(is_tensor_list_reserve != nullptr, {});
229 auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
230 MS_CHECK_TRUE_RET(is_param3 != nullptr, {});
231 auto bw_reserve = VectorRef({is_tensor_list_reserve, is_param3, bw_stride});
232 auto is_tensor_list_from_tensor = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListFromTensor));
233 MS_CHECK_TRUE_RET(is_tensor_list_from_tensor != nullptr, {});
234 auto is_param4 = std::make_shared<CondVar>(IsParameterNode);
235 MS_CHECK_TRUE_RET(is_param4 != nullptr, {});
236 auto bw_from_tensor = VectorRef({is_tensor_list_from_tensor, bw_trans, is_param4});
237 auto is_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimWhile));
238 MS_CHECK_TRUE_RET(is_while != nullptr, {});
239 auto is_param5 = std::make_shared<CondVar>(IsParameterNode);
240 MS_CHECK_TRUE_RET(is_param5 != nullptr, {});
241 auto is_param6 = std::make_shared<CondVar>(IsParameterNode);
242 MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
243 auto bw_while = VectorRef({is_while, bw_vars_[0], bw_vars_[1], is_param5, bw_stride, is_param6, bw_reserve,
244 bw_init_state_, bw_min, bw_from_tensor, input_length_});
245 bw_while.insert(bw_while.end(), bw_vars_.begin() + kOffsetTwo, bw_vars_.end());
246 auto is_var2 = std::make_shared<Var>();
247 MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
248 bw_while.emplace_back(is_var2);
249 auto is_tuple_getitem = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem));
250 MS_CHECK_TRUE_RET(is_tuple_getitem != nullptr, {});
251 auto is_var3 = std::make_shared<Var>();
252 MS_CHECK_TRUE_RET(is_var3 != nullptr, {});
253 auto bw_get_item = VectorRef({is_tuple_getitem, bw_while, is_var3});
254 auto is_tensor_list_stack = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListStack));
255 MS_CHECK_TRUE_RET(is_tensor_list_stack != nullptr, {});
256 auto is_param7 = std::make_shared<CondVar>(IsParameterNode);
257 MS_CHECK_TRUE_RET(is_param7 != nullptr, {});
258 auto bw_stack = VectorRef({is_tensor_list_stack, bw_get_item, is_param7});
259 auto is_var4 = std::make_shared<Var>();
260 MS_CHECK_TRUE_RET(is_var4 != nullptr, {});
261 auto bw_out_trans = VectorRef({is_transpose, bw_stack, is_var4});
262 auto bw_reverse1 = VectorRef({is_reverse_sequence, bw_out_trans, input_length_});
263 return bw_reverse1;
264 }
265
DefinePattern() const266 const BaseRef TfBidirectionGruFusion::DefinePattern() const {
267 if (!Init()) {
268 MS_LOG(ERROR) << "initial member failed.";
269 return {};
270 }
271
272 // forward
273 auto fw_out_trans = DefineFowardPattern();
274 MS_CHECK_TRUE_RET(!fw_out_trans.empty(), {});
275
276 // backward
277 auto bw_reverse1 = DefinebackwardPattern();
278 MS_CHECK_TRUE_RET(!bw_reverse1.empty(), {});
279
280 auto is_concat = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimConcat));
281 MS_CHECK_TRUE_RET(is_concat != nullptr, {});
282 auto concat = VectorRef({is_concat, fw_out_trans, bw_reverse1});
283 return concat;
284 }
285
GetCondGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const286 AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
287 MS_ASSERT(primitive_vars != nullptr);
288 auto is_less1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
289 MS_CHECK_TRUE_RET(is_less1 != nullptr, nullptr);
290 auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
291 MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
292 auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
293 MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
294 VectorRef less1_ref = VectorRef({is_less1, is_param1, is_param2});
295 auto is_less2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess));
296 MS_CHECK_TRUE_RET(is_less2 != nullptr, nullptr);
297 auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
298 MS_CHECK_TRUE_RET(is_param3 != nullptr, nullptr);
299 auto is_param4 = std::make_shared<CondVar>(IsParameterNode);
300 MS_CHECK_TRUE_RET(is_param4 != nullptr, nullptr);
301 VectorRef less2_ref = VectorRef({is_less2, is_param3, is_param4});
302 auto is_logical_and = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLogicalAnd));
303 MS_CHECK_TRUE_RET(is_logical_and != nullptr, nullptr);
304 VectorRef logicaland_ref = VectorRef({is_logical_and, less1_ref, less2_ref});
305 auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
306 MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
307 VectorRef return_ref = VectorRef({is_return, logicaland_ref});
308 VarPtr is_fg = std::make_shared<Var>("RootG");
309 MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
310 auto pattern = Helper::SexpToNode(return_ref, is_fg, primitive_vars.get(), true);
311 return pattern;
312 }
313
GetBodyGraphPattern(const PrimitiveVarMapPtr & primitive_vars) const314 AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primitive_vars) const {
315 MS_ASSERT(primitive_vars != nullptr);
316 std::vector<CondVarPtr> placeholders;
317 for (int i = 0; i < kReservedParamNodesNum; ++i) {
318 auto is_param_placeholder = std::make_shared<CondVar>(IsParameterNode);
319 MS_CHECK_TRUE_RET(is_param_placeholder != nullptr, nullptr);
320 placeholders.emplace_back(is_param_placeholder);
321 }
322 auto is_var1 = std::make_shared<Var>();
323 MS_CHECK_TRUE_RET(is_var1 != nullptr, nullptr);
324 auto is_param1 = std::make_shared<CondVar>(IsParameterNode);
325 MS_CHECK_TRUE_RET(is_param1 != nullptr, nullptr);
326 VectorRef add = VectorRef({is_var1, placeholders[2], is_param1});
327 auto is_var2 = std::make_shared<Var>();
328 MS_CHECK_TRUE_RET(is_var2 != nullptr, nullptr);
329 auto is_param2 = std::make_shared<CondVar>(IsParameterNode);
330 MS_CHECK_TRUE_RET(is_param2 != nullptr, nullptr);
331 VectorRef add1 = VectorRef({is_var2, placeholders[0], is_param2});
332
333 auto is_getitem = std::make_shared<Var>("GetItem");
334 MS_CHECK_TRUE_RET(is_getitem != nullptr, nullptr);
335 auto is_param3 = std::make_shared<CondVar>(IsParameterNode);
336 MS_CHECK_TRUE_RET(is_param3 != nullptr, nullptr);
337 VectorRef get_item = VectorRef({is_getitem, placeholders[6], placeholders[2], is_param3});
338 auto is_var3 = std::make_shared<Var>();
339 MS_CHECK_TRUE_RET(is_var3 != nullptr, nullptr);
340 VectorRef concat_input_h = VectorRef({is_var3, get_item, placeholders[4]});
341
342 auto is_var_matmul1 = std::make_shared<Var>("Matmul");
343 MS_CHECK_TRUE_RET(is_var_matmul1 != nullptr, nullptr);
344 VectorRef matmul1 = VectorRef({is_var_matmul1, concat_input_h, placeholders[8]});
345 auto is_var_biasadd1 = std::make_shared<Var>("BiasAdd");
346 MS_CHECK_TRUE_RET(is_var_biasadd1 != nullptr, nullptr);
347 VectorRef biasadd1 = VectorRef({is_var_biasadd1, matmul1, placeholders[9]});
348 auto is_var_sigmoid = std::make_shared<Var>("Sigmoid");
349 MS_CHECK_TRUE_RET(is_var_sigmoid != nullptr, nullptr);
350 VectorRef sigmoid1 = VectorRef({is_var_sigmoid, biasadd1});
351
352 auto new_hidden = GenerateBodyGraphHiddenPattern(sigmoid1, get_item, placeholders);
353 MS_CHECK_TRUE_RET(!new_hidden.empty(), nullptr);
354
355 auto is_var_ge = std::make_shared<Var>("GreaterEqual");
356 MS_CHECK_TRUE_RET(is_var_ge != nullptr, nullptr);
357 VectorRef greater_equal = VectorRef({is_var_ge, placeholders[2], placeholders[7]});
358
359 auto is_var_switch1 = std::make_shared<Var>("Switch");
360 MS_CHECK_TRUE_RET(is_var_switch1 != nullptr, {});
361 VectorRef select_output = VectorRef({is_var_switch1, greater_equal, placeholders[12], new_hidden});
362 auto is_var_setitem = std::make_shared<Var>("SetItem");
363 MS_CHECK_TRUE_RET(is_var_setitem != nullptr, {});
364 VectorRef output = VectorRef({is_var_setitem, placeholders[3], placeholders[2], select_output});
365
366 auto is_var_switch2 = std::make_shared<Var>("Switch");
367 MS_CHECK_TRUE_RET(is_var_switch2 != nullptr, {});
368 VectorRef select_hidden = VectorRef({is_var_switch2, greater_equal, placeholders[4], new_hidden});
369
370 auto is_make_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMakeTuple));
371 MS_CHECK_TRUE_RET(is_make_tuple != nullptr, nullptr);
372 std::vector<BaseRef> outputs = {is_make_tuple, add1, placeholders[1], add,
373 output, select_hidden, placeholders[5], placeholders[6],
374 placeholders[7]};
375 outputs.insert(outputs.end(), placeholders.begin() + 8, placeholders.end());
376 VectorRef make_tuple_node = VectorRef(outputs);
377 auto is_return = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReturn));
378 MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
379 VectorRef return_node = VectorRef({is_return, make_tuple_node});
380
381 VarPtr is_fg = std::make_shared<Var>("RootG");
382 MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
383 auto pattern = Helper::SexpToNode(return_node, is_fg, primitive_vars.get(), true);
384 return pattern;
385 }
386
GetDefaultTensorInfo(const AnfNodePtr & parameter_anf)387 tensor::TensorPtr TfBidirectionGruFusion::GetDefaultTensorInfo(const AnfNodePtr ¶meter_anf) {
388 MS_ASSERT(parameter_anf != nullptr);
389 if (!utils::isa<ParameterPtr>(parameter_anf)) {
390 MS_LOG(DEBUG) << "parameter_anf is not ParameterPtr";
391 return nullptr;
392 }
393 auto parameter = utils::cast<ParameterPtr>(parameter_anf);
394 if (!parameter->has_default()) {
395 MS_LOG(DEBUG) << "parameter not have default value";
396 return nullptr;
397 }
398 auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param());
399 return tensor_info;
400 }
401
GetInputAndHiddenSize(const AnfNodePtr & fw_cand_kernel_anf,const AnfNodePtr & bw_cand_kernel_anf,int * input_size,int * hidden_size)402 STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_kernel_anf,
403 const AnfNodePtr &bw_cand_kernel_anf, int *input_size,
404 int *hidden_size) {
405 MS_ASSERT(fw_cand_kernel_anf != nullptr);
406 MS_ASSERT(bw_cand_kernel_anf != nullptr);
407 MS_ASSERT(input_size != nullptr);
408 MS_ASSERT(hidden_size != nullptr);
409 auto fw_cand_kernel_value = GetDefaultTensorInfo(fw_cand_kernel_anf);
410 if (fw_cand_kernel_value == nullptr) {
411 return RET_ERROR;
412 }
413 auto fw_cand_kernel_shape = fw_cand_kernel_value->shape();
414 if (fw_cand_kernel_shape.size() != kInputSizeTwo) {
415 return RET_ERROR;
416 }
417 auto bw_cand_kernel_value = GetDefaultTensorInfo(bw_cand_kernel_anf);
418 if (bw_cand_kernel_value == nullptr) {
419 return RET_ERROR;
420 }
421 auto bw_cand_kernel_shape = bw_cand_kernel_value->shape();
422 if (bw_cand_kernel_shape.size() != kInputSizeTwo) {
423 return RET_ERROR;
424 }
425 if (fw_cand_kernel_shape != bw_cand_kernel_shape) {
426 return RET_ERROR;
427 }
428 if (fw_cand_kernel_shape[1] <= 0 || fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1] <= 0) {
429 MS_LOG(DEBUG) << "gru input size or hidden size illegal";
430 return RET_ERROR;
431 }
432 *hidden_size = fw_cand_kernel_shape[1];
433 *input_size = fw_cand_kernel_shape[0] - fw_cand_kernel_shape[1];
434 return RET_OK;
435 }
436
AddDefaultParameter(const FuncGraphPtr & func_graph,const std::string & name,const std::vector<int> & shape,const TypeId type,void ** tensor_data)437 ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &func_graph, const std::string &name,
438 const std::vector<int> &shape, const TypeId type,
439 void **tensor_data) {
440 MS_ASSERT(func_graph != nullptr);
441 MS_ASSERT(tensor_data != nullptr);
442 auto parameter = func_graph->add_parameter();
443 MS_CHECK_TRUE_RET(parameter != nullptr, nullptr);
444 parameter->set_name(name);
445 std::vector<int64_t> shape_vector(shape.begin(), shape.end());
446 auto abstract = lite::CreateTensorAbstract(shape_vector, type);
447 if (abstract == nullptr) {
448 MS_LOG(ERROR) << "Create tensor abstarct failed";
449 return nullptr;
450 }
451 parameter->set_abstract(abstract);
452
453 auto gate_weight_default = std::make_shared<tensor::Tensor>(type, shape_vector);
454 if (gate_weight_default == nullptr) {
455 MS_LOG(ERROR) << "gate_weight_default is nullptr";
456 return nullptr;
457 }
458
459 *tensor_data = gate_weight_default->data_c();
460 parameter->set_default_param(gate_weight_default);
461 return parameter;
462 }
463
CopyFlattenMatData(const float * mat,const int C,const int r0,const int r1,const int c0,const int c1,float * data,bool t)464 void TfBidirectionGruFusion::CopyFlattenMatData(const float *mat, const int C, const int r0, const int r1, const int c0,
465 const int c1, float *data, bool t) {
466 MS_ASSERT(mat != nullptr);
467 MS_ASSERT(data != nullptr);
468 MS_ASSERT(r0 >= 0 && r0 < r1);
469 MS_ASSERT(c0 >= 0 && c0 < c1 && c1 <= C);
470 const int RT = r1 - r0;
471 const int CT = c1 - c0;
472 for (int i = r0; i < r1; ++i) {
473 for (int j = c0; j < c1; ++j) {
474 if (t) {
475 data[(j - c0) * RT + (i - r0)] = mat[i * C + j];
476 } else {
477 data[(i - r0) * CT + (j - c0)] = mat[i * C + j];
478 }
479 }
480 }
481 }
482
ConvertWeightData(const AnfNodePtr & gate_weight,const AnfNodePtr & cand_weight,const int input_size,const int hidden_size,float * gate_tensor_data,float * recu_tensor_data)483 STATUS TfBidirectionGruFusion::ConvertWeightData(const AnfNodePtr &gate_weight, const AnfNodePtr &cand_weight,
484 const int input_size, const int hidden_size, float *gate_tensor_data,
485 float *recu_tensor_data) {
486 MS_ASSERT(gate_weight != nullptr);
487 MS_ASSERT(cand_weight != nullptr);
488 MS_ASSERT(gate_tensor_data != nullptr);
489 MS_ASSERT(recu_tensor_data != nullptr);
490 const std::vector<int64_t> gate_shape{input_size + hidden_size, hidden_size * kGateNum};
491 const std::vector<int64_t> cand_shape{hidden_size * kGateNum, hidden_size};
492 auto gate_weight_value = GetDefaultTensorInfo(gate_weight);
493 if (gate_weight_value == nullptr) {
494 return RET_ERROR;
495 }
496 auto gate_weight_data = reinterpret_cast<float *>(gate_weight_value->data_c());
497 if (gate_weight_data == nullptr) {
498 return RET_ERROR;
499 }
500 auto gate_weight_shape = gate_weight_value->shape();
501
502 auto cand_weight_value = GetDefaultTensorInfo(cand_weight);
503 if (cand_weight_value == nullptr) {
504 return RET_ERROR;
505 }
506 auto cand_weight_data = reinterpret_cast<float *>(cand_weight_value->data_c());
507 if (cand_weight_data == nullptr) {
508 return RET_ERROR;
509 }
510 auto cand_weight_shape = cand_weight_value->shape();
511
512 if (gate_weight_shape != gate_shape || cand_weight_shape != cand_shape) {
513 return RET_ERROR;
514 }
515
516 // input_update_weight
517 CopyFlattenMatData(gate_weight_data, hidden_size * kGateNum, 0, input_size, hidden_size, hidden_size * kGateNum,
518 gate_tensor_data, true);
519 // input_reset_weight
520 CopyFlattenMatData(gate_weight_data, hidden_size * kGateNum, 0, input_size, 0, hidden_size,
521 gate_tensor_data + input_size * hidden_size, true);
522 // input_hidden_weight
523 CopyFlattenMatData(cand_weight_data, hidden_size, 0, input_size, 0, hidden_size,
524 gate_tensor_data + input_size * hidden_size * kGateNum, true);
525
526 // state_update_weight
527 CopyFlattenMatData(gate_weight_data, hidden_size * kGateNum, input_size, input_size + hidden_size, hidden_size,
528 hidden_size * kGateNum, recu_tensor_data, true);
529 // state_reset_weight
530 CopyFlattenMatData(gate_weight_data, hidden_size * kGateNum, input_size, input_size + hidden_size, 0, hidden_size,
531 recu_tensor_data + hidden_size * hidden_size, true);
532 // state_hidden_weight
533 CopyFlattenMatData(cand_weight_data, hidden_size, input_size, input_size + hidden_size, 0, hidden_size,
534 recu_tensor_data + hidden_size * hidden_size * kGateNum, true);
535 return RET_OK;
536 }
537
ConvertBiasData(const AnfNodePtr & gate_bias,const AnfNodePtr & cand_bias,const int hidden_size,float * tensor_data)538 STATUS TfBidirectionGruFusion::ConvertBiasData(const AnfNodePtr &gate_bias, const AnfNodePtr &cand_bias,
539 const int hidden_size, float *tensor_data) {
540 MS_ASSERT(gate_bias != nullptr && cand_bias != nullptr);
541 MS_ASSERT(tensor_data != nullptr);
542 std::vector<int64_t> gate_shape{hidden_size * kGateNum};
543 std::vector<int64_t> cand_shape{hidden_size};
544 auto gate_bias_value = GetDefaultTensorInfo(gate_bias);
545 if (gate_bias_value == nullptr) {
546 return RET_ERROR;
547 }
548 auto gate_bias_data = reinterpret_cast<float *>(gate_bias_value->data_c());
549 auto gate_bias_shape = gate_bias_value->shape();
550 auto cand_bias_value = GetDefaultTensorInfo(cand_bias);
551 if (cand_bias_value == nullptr) {
552 return RET_ERROR;
553 }
554 auto cand_bias_data = reinterpret_cast<float *>(cand_bias_value->data_c());
555 auto cand_bias_shape = cand_bias_value->shape();
556 if (gate_bias_shape != gate_shape || cand_bias_shape != cand_shape) {
557 return RET_ERROR;
558 }
559
560 // update_gate bias
561 CopyFlattenMatData(gate_bias_data, hidden_size * kGateNum, 0, 1, hidden_size, hidden_size * kGateNum, tensor_data,
562 false);
563 // reset_gate bias
564 CopyFlattenMatData(gate_bias_data, hidden_size * kGateNum, 0, 1, 0, hidden_size, tensor_data + hidden_size, false);
565 // hidden_gate bias
566 CopyFlattenMatData(cand_bias_data, hidden_size, 0, 1, 0, hidden_size, tensor_data + hidden_size * kGateNum, false);
567
568 return RET_OK;
569 }
570
GetStackedHiddenState(const FuncGraphPtr & func_graph,const AnfNodePtr & fw_init_state,const AnfNodePtr & bw_init_state,const std::string & base_name)571 CNodePtr TfBidirectionGruFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state,
572 const AnfNodePtr &bw_init_state, const std::string &base_name) {
573 MS_ASSERT(func_graph != nullptr);
574 MS_ASSERT(fw_init_state != nullptr);
575 MS_ASSERT(bw_init_state != nullptr);
576 auto stack_prim = std::make_shared<ops::Stack>();
577 MS_CHECK_TRUE_RET(stack_prim != nullptr, nullptr);
578 auto stack_prim_c = stack_prim->GetPrim();
579 MS_CHECK_TRUE_RET(stack_prim_c != nullptr, nullptr);
580 stack_prim->set_axis(0);
581 auto value_node = NewValueNode(stack_prim_c);
582 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
583 std::vector<AnfNodePtr> new_node_inputs = {value_node, fw_init_state, bw_init_state};
584 auto new_node = func_graph->NewCNode(new_node_inputs);
585 MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
586 if (fw_init_state->abstract() != nullptr) {
587 new_node->set_abstract(fw_init_state->abstract()->Clone());
588 }
589 new_node->set_fullname_with_scope("stack_hidden_" + base_name);
590 return new_node;
591 }
592
CreateBiDirectionGruNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const EquivPtr & equiv,const std::string & base_name,int var_offset) const593 CNodePtr TfBidirectionGruFusion::CreateBiDirectionGruNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
594 const EquivPtr &equiv, const std::string &base_name,
595 int var_offset) const {
596 MS_ASSERT(func_graph != nullptr);
597 MS_ASSERT(input != nullptr);
598 MS_ASSERT(equiv != nullptr);
599 auto gru_prim = std::make_shared<ops::GRU>();
600 MS_CHECK_TRUE_RET(gru_prim != nullptr, nullptr);
601 auto gru_prim_c = gru_prim->GetPrim();
602 MS_CHECK_TRUE_RET(gru_prim_c != nullptr, nullptr);
603 gru_prim->set_bidirectional(true);
604 auto value_node = NewValueNode(gru_prim_c);
605 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
606
607 auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset]]);
608 MS_ASSERT(fw_gate_kernel != nullptr);
609 auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 1]]);
610 MS_ASSERT(fw_gate_bias != nullptr);
611 auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 2]]);
612 MS_ASSERT(fw_cand_kernel != nullptr);
613 auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[var_offset + 3]]);
614 MS_ASSERT(fw_cand_bias != nullptr);
615
616 auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset]]);
617 MS_ASSERT(bw_gate_kernel != nullptr);
618 auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 1]]);
619 MS_ASSERT(bw_gate_bias != nullptr);
620 auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 2]]);
621 MS_ASSERT(bw_cand_kernel != nullptr);
622 auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[var_offset + 3]]);
623 MS_ASSERT(bw_cand_bias != nullptr);
624
625 auto fw_init_state = utils::cast<AnfNodePtr>((*equiv)[fw_init_state_]);
626 MS_ASSERT(fw_init_state != nullptr);
627 auto bw_init_state = utils::cast<AnfNodePtr>((*equiv)[bw_init_state_]);
628 MS_ASSERT(bw_init_state != nullptr);
629 auto stacked_hidden = GetStackedHiddenState(func_graph, fw_init_state, bw_init_state, base_name);
630 if (stacked_hidden == nullptr) {
631 return nullptr;
632 }
633 auto input_length = utils::cast<AnfNodePtr>((*equiv)[input_length_]);
634 MS_ASSERT(input_length != nullptr);
635
636 int input_size = 0;
637 int hidden_size = 0;
638 auto status = GetInputAndHiddenSize(fw_cand_kernel, bw_cand_kernel, &input_size, &hidden_size);
639 if (status != RET_OK) {
640 return nullptr;
641 }
642 std::vector<int> gate_weight_shape{2, hidden_size * 3, input_size};
643 float *gate_tensor_data = nullptr;
644 auto gate_weight = AddDefaultParameter(func_graph, base_name + "_gate_weight", gate_weight_shape, kNumberTypeFloat32,
645 reinterpret_cast<void **>(&gate_tensor_data));
646 if (gate_weight == nullptr) {
647 return nullptr;
648 }
649 std::vector<int> recu_weight_shape{2, hidden_size * 3, hidden_size};
650 float *recu_tensor_data = nullptr;
651 auto recu_weight = AddDefaultParameter(func_graph, base_name + "_cand_weight", recu_weight_shape, kNumberTypeFloat32,
652 reinterpret_cast<void **>(&recu_tensor_data));
653 if (recu_weight == nullptr || recu_tensor_data == nullptr) {
654 return nullptr;
655 }
656 std::vector<int> bias_shape{2, hidden_size * 6};
657 float *bias_tensor_data = nullptr;
658 auto bias = AddDefaultParameter(func_graph, base_name + "_bias", bias_shape, kNumberTypeFloat32,
659 reinterpret_cast<void **>(&bias_tensor_data));
660 if (bias == nullptr || bias_tensor_data == nullptr) {
661 return nullptr;
662 }
663 for (int i = 0; i < 2 * hidden_size * 6; ++i) {
664 bias_tensor_data[i] = 0.0f;
665 }
666
667 if (ConvertWeightData(fw_gate_kernel, fw_cand_kernel, input_size, hidden_size, gate_tensor_data, recu_tensor_data) !=
668 RET_OK) {
669 return nullptr;
670 }
671 auto gate_data_diff = hidden_size * input_size * 3;
672 auto recu_data_diff = hidden_size * hidden_size * 3;
673 if (ConvertWeightData(bw_gate_kernel, bw_cand_kernel, input_size, hidden_size, gate_tensor_data + gate_data_diff,
674 recu_tensor_data + recu_data_diff) != RET_OK) {
675 return nullptr;
676 }
677
678 if (ConvertBiasData(fw_gate_bias, fw_cand_bias, hidden_size, bias_tensor_data) != RET_OK) {
679 return nullptr;
680 }
681 auto bias_data_diff = hidden_size * 6;
682 if (ConvertBiasData(bw_gate_bias, bw_cand_bias, hidden_size, bias_tensor_data + bias_data_diff) != RET_OK) {
683 return nullptr;
684 }
685 std::vector<AnfNodePtr> new_node_inputs = {value_node, input, gate_weight, recu_weight,
686 bias, stacked_hidden, input_length};
687 auto new_node = func_graph->NewCNode(new_node_inputs);
688 MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
689 auto prim = GetValueNode<PrimitivePtr>(new_node->input(0));
690 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
691 prim->AddAttr(ops::kFormat, MakeValue<int64_t>(Format::NHWC));
692 new_node->set_fullname_with_scope(base_name);
693 return new_node;
694 }
695
GetPostProcessNode(const FuncGraphPtr & func_graph,const CNodePtr & gru_output,const std::string & base_name)696 CNodePtr TfBidirectionGruFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
697 const std::string &base_name) {
698 MS_ASSERT(func_graph != nullptr);
699 MS_ASSERT(gru_output != nullptr);
700 auto split_prim = std::make_shared<ops::Split>();
701 MS_CHECK_TRUE_RET(split_prim != nullptr, nullptr);
702 auto split_prim_c = split_prim->GetPrim();
703 MS_CHECK_TRUE_RET(split_prim_c != nullptr, nullptr);
704 split_prim->set_output_num(2);
705 split_prim->set_axis(1);
706 auto split_value_node = NewValueNode(split_prim_c);
707 MS_CHECK_TRUE_RET(split_value_node != nullptr, nullptr);
708 std::vector<AnfNodePtr> new_node_inputs = {split_value_node, gru_output};
709 auto split_new_node = func_graph->NewCNode(new_node_inputs);
710 MS_CHECK_TRUE_RET(split_new_node != nullptr, nullptr);
711 split_new_node->set_fullname_with_scope("split_" + base_name);
712 if (TfliteLstmCellFusion::SetAbstractTuple(split_new_node, 2) != RET_OK) {
713 return nullptr;
714 }
715
716 auto split_out1 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 0);
717 if (split_out1 == nullptr) {
718 return nullptr;
719 }
720 auto split_out2 = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, split_new_node, 1);
721 if (split_out2 == nullptr) {
722 return nullptr;
723 }
724
725 auto concat_prim = std::make_shared<ops::Concat>();
726 MS_CHECK_TRUE_RET(concat_prim != nullptr, nullptr);
727 auto concat_prim_c = concat_prim->GetPrim();
728 MS_CHECK_TRUE_RET(concat_prim_c != nullptr, nullptr);
729 concat_prim->set_axis(3);
730 auto concat_value_node = NewValueNode(concat_prim_c);
731 MS_CHECK_TRUE_RET(concat_value_node != nullptr, nullptr);
732 std::vector<AnfNodePtr> concat_new_node_inputs = {concat_value_node, split_out1, split_out2};
733 auto concat_new_node = func_graph->NewCNode(concat_new_node_inputs);
734 MS_CHECK_TRUE_RET(concat_new_node != nullptr, nullptr);
735 concat_new_node->set_fullname_with_scope("concat_" + base_name);
736 if (gru_output->abstract() != nullptr) {
737 concat_new_node->set_abstract(gru_output->abstract()->Clone());
738 }
739
740 auto squeeze_prim = std::make_shared<ops::Squeeze>();
741 MS_CHECK_TRUE_RET(squeeze_prim != nullptr, nullptr);
742 auto squeeze_prim_c = squeeze_prim->GetPrim();
743 MS_CHECK_TRUE_RET(squeeze_prim_c != nullptr, nullptr);
744 squeeze_prim->set_axis(std::vector<int64_t>{1});
745 auto squeeze_value_node = NewValueNode(squeeze_prim_c);
746 MS_CHECK_TRUE_RET(squeeze_value_node != nullptr, nullptr);
747 std::vector<AnfNodePtr> squeeze_new_node_inputs = {squeeze_value_node, concat_new_node};
748 auto squeeze_new_node = func_graph->NewCNode(squeeze_new_node_inputs);
749 MS_CHECK_TRUE_RET(squeeze_new_node != nullptr, nullptr);
750 squeeze_new_node->set_fullname_with_scope("squeeze_" + base_name);
751 if (gru_output->abstract() != nullptr) {
752 squeeze_new_node->set_abstract(gru_output->abstract()->Clone());
753 }
754
755 auto transpose_prim = std::make_shared<ops::Transpose>();
756 MS_CHECK_TRUE_RET(transpose_prim != nullptr, nullptr);
757 auto transpose_perm = BuildIntVecParameterNode(func_graph, {1, 0, 2}, "transpose_" + base_name + "_perm");
758 MS_CHECK_TRUE_RET(transpose_perm != nullptr, nullptr);
759 auto transpose_prim_c = transpose_prim->GetPrim();
760 MS_CHECK_TRUE_RET(transpose_prim_c != nullptr, nullptr);
761 auto transpose_new_node = func_graph->NewCNode(transpose_prim_c, {squeeze_new_node, transpose_perm});
762 MS_CHECK_TRUE_RET(transpose_new_node != nullptr, nullptr);
763 transpose_new_node->set_fullname_with_scope("transpose_" + base_name);
764 if (gru_output->abstract() != nullptr) {
765 transpose_new_node->set_abstract(gru_output->abstract()->Clone());
766 }
767
768 return transpose_new_node;
769 }
770
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & concat_node,const EquivPtr & equiv) const771 const AnfNodePtr TfBidirectionGruFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
772 const EquivPtr &equiv) const {
773 if (func_graph == nullptr || concat_node == nullptr) {
774 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
775 return nullptr;
776 }
777
778 auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]);
779 MS_ASSERT(transpose_input != nullptr);
780 if (!utils::isa<CNodePtr>(transpose_input) || !CheckPrimitiveType(transpose_input, prim::kPrimTranspose)) {
781 return nullptr;
782 }
783
784 auto fw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>();
785 MS_CHECK_TRUE_RET(fw_cond_primitive_vars != nullptr, nullptr);
786 auto fw_cond_graph_pattern = GetCondGraphPattern(fw_cond_primitive_vars);
787 MS_CHECK_TRUE_RET(fw_cond_graph_pattern != nullptr, nullptr);
788 auto fw_cond = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[0]]);
789 MS_ASSERT(fw_cond != nullptr);
790 auto fw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(fw_cond_graph_pattern, fw_cond_primitive_vars, fw_cond,
791 kCondCNodesNum, kCondNodesNum);
792 if (fw_cond_equiv == nullptr || fw_cond_equiv->empty()) {
793 return nullptr;
794 }
795
796 auto bw_cond_primitive_vars = std::make_shared<PrimitiveVarMap>();
797 MS_CHECK_TRUE_RET(bw_cond_primitive_vars != nullptr, nullptr);
798 auto bw_cond_graph_pattern = GetCondGraphPattern(bw_cond_primitive_vars);
799 MS_CHECK_TRUE_RET(bw_cond_graph_pattern != nullptr, nullptr);
800 auto bw_cond = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[0]]);
801 MS_ASSERT(bw_cond != nullptr);
802 auto bw_cond_equiv = TfliteLstmCellFusion::CheckSubGraph(bw_cond_graph_pattern, bw_cond_primitive_vars, bw_cond,
803 kCondCNodesNum, kCondNodesNum);
804 if (bw_cond_equiv == nullptr || bw_cond_equiv->empty()) {
805 return nullptr;
806 }
807
808 auto fw_primitive_vars_body = std::make_shared<PrimitiveVarMap>();
809 MS_CHECK_TRUE_RET(fw_primitive_vars_body != nullptr, nullptr);
810 auto fw_body_graph_pattern = GetBodyGraphPattern(fw_primitive_vars_body);
811 MS_CHECK_TRUE_RET(fw_body_graph_pattern != nullptr, nullptr);
812 auto fw_body = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[1]]);
813 MS_ASSERT(fw_body != nullptr);
814 auto fw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(fw_body_graph_pattern, fw_primitive_vars_body, fw_body,
815 kBodyCNodesNum, kBodyNodesNum);
816 if (fw_body_equiv == nullptr || fw_body_equiv->empty()) {
817 return nullptr;
818 }
819
820 auto bw_primitive_vars_body = std::make_shared<PrimitiveVarMap>();
821 MS_CHECK_TRUE_RET(bw_primitive_vars_body != nullptr, nullptr);
822 auto bw_body_graph_pattern = GetBodyGraphPattern(bw_primitive_vars_body);
823 MS_CHECK_TRUE_RET(bw_body_graph_pattern != nullptr, nullptr);
824 auto bw_body = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[1]]);
825 MS_ASSERT(bw_body != nullptr);
826 auto bw_body_equiv = TfliteLstmCellFusion::CheckSubGraph(bw_body_graph_pattern, bw_primitive_vars_body, bw_body,
827 kBodyCNodesNum, kBodyNodesNum);
828 if (bw_body_equiv == nullptr || bw_body_equiv->empty()) {
829 return nullptr;
830 }
831
832 const std::string gru_name = "gru_" + concat_node->fullname_with_scope();
833 auto gru_node = CreateBiDirectionGruNode(func_graph, transpose_input, equiv, gru_name, 2);
834 if (gru_node == nullptr) {
835 return nullptr;
836 }
837 if (TfliteLstmCellFusion::SetAbstractTuple(gru_node, 2) != RET_OK) {
838 return nullptr;
839 }
840
841 auto get_item_node = TfliteLstmCellFusion::CreateOutputGetItem(func_graph, gru_node, 0);
842 if (get_item_node == nullptr) {
843 return nullptr;
844 }
845
846 auto output_node = GetPostProcessNode(func_graph, get_item_node, gru_node->fullname_with_scope());
847 MS_LOG(INFO) << "gru node:" << gru_node->fullname_with_scope() << " fusion success";
848 return output_node;
849 }
850 } // namespace opt
851 } // namespace mindspore
852