• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
17 #include <vector>
18 #include <memory>
19 #include "backend/session/kernel_graph.h"
20 #include "backend/session/anf_runtime_algorithm.h"
21 #include "utils/trace_base.h"
22 #include "utils/tensor_construct_utils.h"
23 
24 namespace mindspore {
25 namespace opt {
26 namespace {
27 constexpr size_t kDynamicRNNGradInputNum = 16;
28 constexpr size_t kSplitVOutputNum = 2;
29 constexpr size_t kBasicCellOutputNum = 2;
30 constexpr size_t kBasicLstmCStateGradOutput0DimNum = 3;
31 constexpr int64_t kAttrNValue = 2;
32 constexpr int64_t kAttrDynInputSizesValue = 2;
33 constexpr int64_t kAttrAxis2Value = 2;
34 constexpr int64_t kAttrNumSplitValue = 2;
35 constexpr int64_t kAttrSplitDimValue = 2;
36 constexpr size_t kDimMultiNum = 4;
37 
CreateTLoopNode(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,std::vector<std::vector<AnfNodePtr>> * result_nodes)38 void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
39                      std::vector<std::vector<AnfNodePtr>> *result_nodes) {
40   MS_EXCEPTION_IF_NULL(func_graph);
41   MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
42   MS_EXCEPTION_IF_NULL(result_nodes);
43   std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_nodes;
44   std::vector<AnfNodePtr> matmul_nodes;
45   std::vector<AnfNodePtr> split_nodes;
46   // Get the size of t
47   auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex11), 0);
48   size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex9), 0)[0];
49   auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex12), 0);
50 
51   for (size_t i = 0; i < t_size; ++i) {
52     // Create basic_lstm_cell_c_state_grad
53     std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
54       NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
55     auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
56 
57     std::vector<size_t> output0_dims{
58       origin_input9_shape[kDim0],
59       kDimMultiNum * (((origin_input9_shape[kDim1] + kCubeSize - 1) / kCubeSize) * kCubeSize)};
60     std::vector<size_t> output1_dims{input_i_shape[kDim1], input_i_shape[kDim2]};
61     AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims},
62                                         basic_lstm_cell_c_state_grad.get());
63     AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad);
64     AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad);
65 
66     // Create matmul
67     auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0);
68     std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
69     auto matmul = func_graph->NewCNode(matmul_inputs);
70     AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
71                                         matmul.get());
72     AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul);
73     AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul);
74 
75     // Create split
76     std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
77     auto split_v = func_graph->NewCNode(splitv_input);
78     auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2);
79     auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3);
80     std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
81     std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
82     AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
83                                         {split_v_output0_shape, split_v_output1_shape}, split_v.get());
84 
85     AnfAlgo::SetNodeAttr(kAttrSizeSplits,
86                          MakeValue(std::vector<int64_t>{
87                            SizeToLong((origin_output2_shape[kDim2] + kCubeSize - 1) / kCubeSize * kCubeSize),
88                            SizeToLong((origin_output3_shape[kDim1] + kCubeSize - 1) / kCubeSize * kCubeSize)}),
89                          split_v);
90     AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(kAttrSplitDimValue)), split_v);
91     AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(kAttrNumSplitValue)), split_v);
92 
93     basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
94     matmul_nodes.emplace_back(matmul);
95     split_nodes.emplace_back(split_v);
96   }
97   result_nodes->emplace_back(basic_lstm_cell_c_state_grad_nodes);
98   result_nodes->emplace_back(matmul_nodes);
99   result_nodes->emplace_back(split_nodes);
100 }
101 
CreateLSTMSPlitV(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const std::vector<std::vector<size_t>> & split_shapes,const std::vector<TypeId> & split_types,const std::vector<int64_t> & size_split,size_t num_split_x)102 AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
103                             const std::vector<std::vector<size_t>> &split_shapes,
104                             const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
105                             size_t num_split_x) {
106   std::vector<AnfNodePtr> lstm_split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
107                                               input};
108   auto lstm_split = func_graph->NewCNode(lstm_split_input);
109   AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get());
110   AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split);
111   AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), lstm_split);
112   AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(num_split_x)), lstm_split);
113   return lstm_split;
114 }
115 
AddLSTMInputGradNode(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,std::vector<AnfNodePtr> * outputs)116 AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
117                                 std::vector<AnfNodePtr> *outputs) {
118   std::vector<std::vector<AnfNodePtr>> result_nodes;
119   CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
120 
121   auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
122   std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
123 
124   auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
125   size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
126   std::vector<std::vector<size_t>> split_shapes;
127   std::vector<TypeId> split_types;
128   std::vector<int64_t> size_split;
129   for (size_t i = 0; i < num_split_x; ++i) {
130     split_shapes.emplace_back(split_c_dims);
131     split_types.emplace_back(kNumberTypeFloat32);
132     size_split.emplace_back(1);
133   }
134   // Create lstm_split_c
135   auto lstm_split_c = CreateLSTMSPlitV(func_graph, origin_input7, split_shapes, split_types, size_split, num_split_x);
136   std::vector<AnfNodePtr> lstm_split_c_outputs;
137   CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs);
138 
139   // Create lstm_split_dy
140   auto lstm_split_dy = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex9), split_shapes, split_types,
141                                         size_split, num_split_x);
142   std::vector<AnfNodePtr> lstm_split_dy_outputs;
143   CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs);
144 
145   // Create lstm_split_i
146   auto lstm_split_i = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex12), split_shapes, split_types,
147                                        size_split, num_split_x);
148   std::vector<AnfNodePtr> lstm_split_i_outputs;
149   CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs);
150 
151   // Create lstm_split_j
152   auto lstm_split_j = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex13), split_shapes, split_types,
153                                        size_split, num_split_x);
154   std::vector<AnfNodePtr> lstm_split_j_outputs;
155   CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs);
156 
157   // Create lstm_split_f
158   auto lstm_split_f = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex14), split_shapes, split_types,
159                                        size_split, num_split_x);
160   std::vector<AnfNodePtr> lstm_split_f_outputs;
161   CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs);
162 
163   // Create lstm_split_o
164   auto lstm_split_o = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex15), split_shapes, split_types,
165                                        size_split, num_split_x);
166   std::vector<AnfNodePtr> lstm_split_o_outputs;
167   CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs);
168 
169   // Create lstm_split_tanh
170   auto lstm_split_tanh = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex16), split_shapes,
171                                           split_types, size_split, num_split_x);
172   std::vector<AnfNodePtr> lstm_split_tanh_outputs;
173   CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs);
174 
175   // Add edges
176   std::vector<AnfNodePtr> pre_basic_lstm_cell_c_state_grad_outputs;
177   std::vector<AnfNodePtr> pre_split_outputs;
178   auto basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
179   auto matmul_nodes = result_nodes[kIndex1];
180   auto split_nodes = result_nodes[kIndex2];
181   std::vector<AnfNodePtr> lstm_x_concat_input(num_split_x + 1);
182   lstm_x_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
183   std::vector<AnfNodePtr> lstm_gage_concat_input(num_split_x + 1);
184   lstm_gage_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
185 
186   for (size_t i = 0; i < num_split_x; ++i) {
187     size_t idx = num_split_x - i - 1;
188     // Create basic_lstm_cell_c_state_grad
189     std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
190       NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
191     if (i == num_split_x - 1) {
192       std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
193                                                 dynamic_rnn_grad_cnode->input(6)};
194       auto reshape = func_graph->NewCNode(reshape_inputs);
195       auto reshape_out_shape = {IntToSize(1),
196                                 AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[0],
197                                 AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[1]};
198       AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {reshape_out_shape}, reshape.get());
199       (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(reshape);
200     } else {
201       (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_c_outputs[idx - 1]);
202     }
203     (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_dy_outputs[idx]);
204     if (i == 0) {
205       (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex10));
206       (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex11));
207     } else {
208       (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_split_outputs[1]);
209       (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]);
210     }
211     (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_i_outputs[idx]);
212     (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_j_outputs[idx]);
213     (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_f_outputs[idx]);
214     (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]);
215     (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]);
216     auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
217     MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad);
218     basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract());
219     AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
220     // Create outputs for current basic_lstm_cell_c_state_grad node
221     std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_outputs;
222     CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, kBasicCellOutputNum,
223                                    &basic_lstm_cell_c_state_grad_outputs);
224     pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
225 
226     // Create MatMul
227     std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
228     (void)matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
229     (void)matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2));
230     auto matmul = func_graph->NewCNode(matmul_inputs);
231     MS_EXCEPTION_IF_NULL(matmul);
232     matmul->set_abstract(matmul_nodes[i]->abstract());
233     AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul);
234 
235     // Create splitv
236     std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
237                                             matmul};
238     auto split_v = func_graph->NewCNode(splitv_input);
239     MS_EXCEPTION_IF_NULL(split_v);
240     split_v->set_abstract(split_nodes[i]->abstract());
241     AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v);
242 
243     // Create outputs for current split node
244     std::vector<AnfNodePtr> split_outputs;
245     CreateMultipleOutputsOfAnfNode(func_graph, split_v, kSplitVOutputNum, &split_outputs);
246     pre_split_outputs = split_outputs;
247 
248     lstm_x_concat_input[idx + 1] = split_outputs[0];
249 
250     auto basic_lstm_cell_c_state_grad_outputs_0_shape =
251       AnfAlgo::GetOutputInferShape(basic_lstm_cell_c_state_grad_outputs[0], 0);
252     std::vector<size_t> temp_shape;
253     if (basic_lstm_cell_c_state_grad_outputs_0_shape.size() == kBasicLstmCStateGradOutput0DimNum) {
254       temp_shape = basic_lstm_cell_c_state_grad_outputs_0_shape;
255     } else {
256       temp_shape = {1, basic_lstm_cell_c_state_grad_outputs_0_shape[0],
257                     basic_lstm_cell_c_state_grad_outputs_0_shape[1]};
258     }
259     std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
260                                              basic_lstm_cell_c_state_grad_outputs[0]};
261     auto reshape = func_graph->NewCNode(reshape_input);
262     AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(basic_lstm_cell_c_state_grad_outputs[0], 0)},
263                                         {temp_shape}, reshape.get());
264     lstm_gage_concat_input[idx + 1] = reshape;
265   }
266 
267   // Create lstm_x_concat
268   auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input);
269   AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)},
270                                       lstm_x_concat.get());
271   AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_x_concat);
272   AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_x_concat);
273   AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), lstm_x_concat);
274 
275   // Create lstm_gage_concat
276   auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
277   auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
278   AnfAlgo::SetOutputInferTypeAndShape(
279     {kNumberTypeFloat16},
280     {{origin_input7_shape[kDim0], origin_input7_shape[kDim1], kDimMultiNum * origin_input7_shape[kDim2]}},
281     lstm_gage_concat.get());
282   AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
283   AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
284   AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), lstm_gage_concat);
285 
286   outputs->emplace_back(lstm_x_concat);
287   outputs->emplace_back(pre_split_outputs[1]);
288   outputs->emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]);
289   return lstm_gage_concat;
290 }
291 
CreateSplitV(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode)292 AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
293   MS_EXCEPTION_IF_NULL(func_graph);
294   MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
295   // Create node
296   auto origin_input6 = dynamic_rnn_grad_cnode->input(kIndex7);
297   std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
298                                           origin_input6};
299   auto split_v = func_graph->NewCNode(splitv_input);
300   // Set infer data type and shape
301   auto dtypes = {AnfAlgo::GetOutputInferDataType(origin_input6, 0), AnfAlgo::GetOutputInferDataType(origin_input6, 0)};
302   auto origin_input6_shape = AnfAlgo::GetOutputInferShape(origin_input6, 0);
303   std::vector<size_t> shape1 = {origin_input6_shape[kDim0] - 1, origin_input6_shape[kDim1], origin_input6_shape[kDim2]};
304   std::vector<size_t> shape2 = {1, origin_input6_shape[kDim1], origin_input6_shape[kDim2]};
305   std::vector<std::vector<size_t>> shapes = {shape1, shape2};
306   AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
307   // Set attr
308   AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(SizeToLong(0)), split_v);
309   AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(kAttrNumSplitValue)), split_v);
310   AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int64_t>{SizeToLong(origin_input6_shape[0] - 1), 1}),
311                        split_v);
312   AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v);
313   return split_v;
314 }
315 
CreateHConcat(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,const AnfNodePtr & splitv)316 AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
317                          const AnfNodePtr &splitv) {
318   MS_EXCEPTION_IF_NULL(func_graph);
319   MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
320   MS_EXCEPTION_IF_NULL(splitv);
321   // Create node
322   std::vector<AnfNodePtr> splitv_outputs;
323   CreateMultipleOutputsOfAnfNode(func_graph, splitv, kSplitVOutputNum, &splitv_outputs);
324   if (splitv_outputs.size() != kSplitVOutputNum) {
325     MS_LOG(EXCEPTION) << "Create outputs of node " << splitv->DebugString() << " failed"
326                       << " trace: " << trace::DumpSourceLines(dynamic_rnn_grad_cnode);
327   }
328   auto origin_input4 = dynamic_rnn_grad_cnode->input(kIndex5);
329   auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
330   // Create reshape to change shape
331   std::vector<size_t> shape_tmp;
332   if (origin_input4_shape.size() == kShape4dDims) {
333     shape_tmp = origin_input4_shape;
334   } else {
335     shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
336   }
337   std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
338                                            origin_input4};
339   auto reshape = func_graph->NewCNode(reshape_input);
340   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
341   std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
342                                            reshape, splitv_outputs[0]};
343   auto concat = func_graph->NewCNode(concat_inputs);
344   // Set infer data type and shape
345   auto splitv_output0_shape = AnfAlgo::GetOutputInferShape(splitv, 0);
346   std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]};
347   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape}, concat.get());
348   // Set attr
349   AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kAttrNValue)), concat);
350   AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kAttrDynInputSizesValue}), concat);
351   AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat);
352   AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
353   return concat;
354 }
355 
CreateConcat(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,const AnfNodePtr & h_concat)356 AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
357                         const AnfNodePtr &h_concat) {
358   MS_EXCEPTION_IF_NULL(func_graph);
359   MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
360   // Create node
361   auto origin_input0 = dynamic_rnn_grad_cnode->input(1);
362   std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
363                                            origin_input0, h_concat};
364   auto concat = func_graph->NewCNode(concat_inputs);
365   // Set infer data type and shape
366   auto origin_output0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
367   auto h_concat_output_shape = AnfAlgo::GetOutputInferShape(h_concat, 0);
368   std::vector<size_t> shape = {origin_output0_shape[kDim0], origin_output0_shape[kDim1],
369                                origin_output0_shape[kDim2] + h_concat_output_shape[kDim2]};
370   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
371   // Set attr
372   AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kAttrNValue)), concat);
373   AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kAttrDynInputSizesValue}), concat);
374   AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(kAttrAxis2Value)), concat);
375   AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
376   return concat;
377 }
378 
CreateConcatNodeT1(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode)379 AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
380   MS_EXCEPTION_IF_NULL(func_graph);
381   MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
382   // Create node
383   auto origin_input0 = dynamic_rnn_grad_cnode->input(kIndex1);
384   auto origin_input4 = dynamic_rnn_grad_cnode->input(kIndex5);
385   auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
386   // Create reshape to change shape
387   std::vector<size_t> shape_tmp;
388   if (origin_input4_shape.size() == kShape3dDims) {
389     shape_tmp = origin_input4_shape;
390   } else {
391     shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
392   }
393   std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
394                                            origin_input4};
395   auto reshape = func_graph->NewCNode(reshape_input);
396   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
397 
398   std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
399                                            origin_input0, reshape};
400   auto concat = func_graph->NewCNode(concat_inputs);
401   // Set infer data type and shape
402   auto origin_input0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
403   std::vector<size_t> shape = {origin_input0_shape[kDim0], origin_input0_shape[kDim1],
404                                origin_input0_shape[kDim2] + shape_tmp[kDim2]};
405   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
406   // Set attr
407   AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kAttrNValue)), concat);
408   AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kAttrDynInputSizesValue}), concat);
409   AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(kAttrAxis2Value)), concat);
410   AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
411   return concat;
412 }
413 
CreateBatchMatMul(const FuncGraphPtr & func_graph,const AnfNodePtr & lstm_input_grad,const AnfNodePtr & concat)414 AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
415                              const AnfNodePtr &concat) {
416   MS_EXCEPTION_IF_NULL(func_graph);
417   // Create node
418   std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
419                                            concat, lstm_input_grad};
420   auto batch_matmul = func_graph->NewCNode(matmul_inputs);
421   // Set infer data type and shape
422   auto concat_shape = AnfAlgo::GetOutputInferShape(concat, 0);
423   auto lstm_input_grad_shape = AnfAlgo::GetOutputInferShape(lstm_input_grad, 0);
424   std::vector<size_t> shape = {concat_shape[kDim0], concat_shape[kDim2], lstm_input_grad_shape[kDim2]};
425   AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, batch_matmul.get());
426   // Set attr
427   AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
428   AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
429   AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
430   return batch_matmul;
431 }
432 
CreateBatchMatMul2(const FuncGraphPtr & func_graph,const AnfNodePtr & lstm_input_grad,const AnfNodePtr & node)433 AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
434                               const AnfNodePtr &node) {
435   MS_EXCEPTION_IF_NULL(func_graph);
436   // Create node
437   std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
438                                            node, lstm_input_grad};
439   auto batch_matmul = func_graph->NewCNode(matmul_inputs);
440   // Set infer data type and shape
441   auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex0], IntToSize(1),
442                     AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex2]};
443   AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, batch_matmul.get());
444   // Set attr
445   AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
446   AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), batch_matmul);
447   AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
448   return batch_matmul;
449 }
450 
CreateDwReduceSum(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,const AnfNodePtr & batch_matmul)451 AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
452                              const AnfNodePtr &batch_matmul) {
453   MS_EXCEPTION_IF_NULL(func_graph);
454   // Create node
455   std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
456                                                batch_matmul};
457   auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
458   // Set infer data type and shape
459   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
460                                       {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get());
461   // Set attr
462   AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
463   AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
464   AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
465   return reduce_sum;
466 }
467 
CreateDwReshape(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,const AnfNodePtr & batch_matmul)468 AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
469                            const AnfNodePtr &batch_matmul) {
470   MS_EXCEPTION_IF_NULL(func_graph);
471   // Create node
472   std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
473                                             batch_matmul};
474   auto reshape = func_graph->NewCNode(reshape_inputs);
475   // Set infer data type and shape
476   AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
477                                       {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get());
478   AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
479   return reshape;
480 }
481 
CreateValueNode(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode)482 AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
483   auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
484   auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
485   auto t_size = origin_input7_shape[0];
486   auto n_size = origin_input7_shape[1];
487 
488   std::vector<size_t> shape = {t_size, IntToSize(1), n_size};
489   std::vector<int64_t> output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)};
490   std::vector<int64_t> output_tensor = {SizeToLong(t_size) * SizeToLong(n_size)};
491   auto tensor = TensorConstructUtils::CreateOnesTensor(kFloat32, output_tensor);
492   auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);
493   auto kernel_graph = func_graph->cast<KernelGraphPtr>();
494   auto value_node = kernel_graph->NewValueNode(x_abstract, tensor);
495   kernel_graph->AddValueNodeToGraph(value_node);
496   AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, value_node.get());
497   return value_node;
498 }
499 
CreateDbReduceSum(const FuncGraphPtr & func_graph,const CNodePtr &,const AnfNodePtr & lstm_input_grad,const AnfNodePtr & value_node)500 AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, const AnfNodePtr &lstm_input_grad,
501                              const AnfNodePtr &value_node) {
502   MS_EXCEPTION_IF_NULL(func_graph);
503   // Create node
504   auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node);
505   std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
506                                                batch_matmul};
507   auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
508   // Set infer data type and shape
509   auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kDim2]};
510   AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
511   // Set attr
512   AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
513   AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
514   AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
515   return reduce_sum;
516 }
517 }  // namespace
518 
DefinePattern() const519 const BaseRef DynamicRnnGradFissionV2::DefinePattern() const {
520   VarPtr Xs = std::make_shared<SeqVar>();
521   return VectorRef({prim::kPrimDynamicRNNGrad, Xs});
522 }
523 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const524 const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
525                                                   const EquivPtr &) const {
526   MS_EXCEPTION_IF_NULL(func_graph);
527   MS_EXCEPTION_IF_NULL(node);
528   auto dynamic_rnn_grad_cnode = node->cast<CNodePtr>();
529   MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
530   if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) {
531     MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than "
532                  << (kDynamicRNNGradInputNum + 1) << " inputs";
533     return nullptr;
534   }
535   if (AnfAlgo::IsDynamicShape(node)) {
536     MS_LOG(INFO) << "DynamicRnnGrad is dynamic shape, can not do fission.";
537     return nullptr;
538   }
539   std::vector<AnfNodePtr> new_outputs;
540   auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);
541 
542   size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[0];
543   size_t hidden_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[kDim2];
544   if (hidden_size % kCubeSize != 0) {
545     MS_LOG(EXCEPTION) << "`hidden_size` in this node should be multiple of 16, but got " << hidden_size << ". "
546                       << dynamic_rnn_grad_cnode->DebugString();
547   }
548   AnfNodePtr concat = nullptr;
549   if (t_size != 1) {
550     auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode);
551     auto h_concat = CreateHConcat(func_graph, dynamic_rnn_grad_cnode, splitv);
552     concat = CreateConcat(func_graph, dynamic_rnn_grad_cnode, h_concat);
553   } else {
554     concat = CreateConcatNodeT1(func_graph, dynamic_rnn_grad_cnode);
555   }
556 
557   auto batch_matmul = CreateBatchMatMul(func_graph, lstm_input_grad, concat);
558   std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
559   if (t_size != 1) {
560     auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
561     (void)make_tuple_inputs.emplace_back(dw_reduce_sum);
562   } else {
563     auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
564     (void)make_tuple_inputs.emplace_back(dw_reshape);
565   }
566 
567   auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode);
568   // create reduce_sum_2
569   auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node);
570   (void)make_tuple_inputs.emplace_back(db_reduce_sum);
571   make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end());
572   auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
573   return make_tuple;
574 }
575 }  // namespace opt
576 }  // namespace mindspore
577