1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
17 #include <vector>
18 #include <memory>
19 #include "backend/session/kernel_graph.h"
20 #include "backend/session/anf_runtime_algorithm.h"
21 #include "utils/trace_base.h"
22 #include "utils/tensor_construct_utils.h"
23
24 namespace mindspore {
25 namespace opt {
26 namespace {
27 constexpr size_t kDynamicRNNGradInputNum = 16;
28 constexpr size_t kSplitVOutputNum = 2;
29 constexpr size_t kBasicCellOutputNum = 2;
30 constexpr size_t kBasicLstmCStateGradOutput0DimNum = 3;
31 constexpr int64_t kAttrNValue = 2;
32 constexpr int64_t kAttrDynInputSizesValue = 2;
33 constexpr int64_t kAttrAxis2Value = 2;
34 constexpr int64_t kAttrNumSplitValue = 2;
35 constexpr int64_t kAttrSplitDimValue = 2;
36 constexpr size_t kDimMultiNum = 4;
37
CreateTLoopNode(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,std::vector<std::vector<AnfNodePtr>> * result_nodes)38 void CreateTLoopNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
39 std::vector<std::vector<AnfNodePtr>> *result_nodes) {
40 MS_EXCEPTION_IF_NULL(func_graph);
41 MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
42 MS_EXCEPTION_IF_NULL(result_nodes);
43 std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_nodes;
44 std::vector<AnfNodePtr> matmul_nodes;
45 std::vector<AnfNodePtr> split_nodes;
46 // Get the size of t
47 auto origin_input9_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex11), 0);
48 size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex9), 0)[0];
49 auto input_i_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex12), 0);
50
51 for (size_t i = 0; i < t_size; ++i) {
52 // Create basic_lstm_cell_c_state_grad
53 std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
54 NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
55 auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
56
57 std::vector<size_t> output0_dims{
58 origin_input9_shape[kDim0],
59 kDimMultiNum * (((origin_input9_shape[kDim1] + kCubeSize - 1) / kCubeSize) * kCubeSize)};
60 std::vector<size_t> output1_dims{input_i_shape[kDim1], input_i_shape[kDim2]};
61 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16, kNumberTypeFloat32}, {output0_dims, output1_dims},
62 basic_lstm_cell_c_state_grad.get());
63 AnfAlgo::SetNodeAttr("forget_bias", MakeValue(1.0f), basic_lstm_cell_c_state_grad);
64 AnfAlgo::SetNodeAttr("activation", MakeValue("Tanh"), basic_lstm_cell_c_state_grad);
65
66 // Create matmul
67 auto origin_input1_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex2), 0);
68 std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
69 auto matmul = func_graph->NewCNode(matmul_inputs);
70 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {{IntToSize(1), output0_dims[0], origin_input1_shape[0]}},
71 matmul.get());
72 AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), matmul);
73 AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(true), matmul);
74
75 // Create split
76 std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
77 auto split_v = func_graph->NewCNode(splitv_input);
78 auto origin_output2_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex2);
79 auto origin_output3_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, kIndex3);
80 std::vector<size_t> split_v_output0_shape{IntToSize(1), origin_output2_shape[kDim1], origin_output2_shape[kDim2]};
81 std::vector<size_t> split_v_output1_shape{IntToSize(1), origin_output3_shape[kDim0], origin_output3_shape[kDim1]};
82 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32},
83 {split_v_output0_shape, split_v_output1_shape}, split_v.get());
84
85 AnfAlgo::SetNodeAttr(kAttrSizeSplits,
86 MakeValue(std::vector<int64_t>{
87 SizeToLong((origin_output2_shape[kDim2] + kCubeSize - 1) / kCubeSize * kCubeSize),
88 SizeToLong((origin_output3_shape[kDim1] + kCubeSize - 1) / kCubeSize * kCubeSize)}),
89 split_v);
90 AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(kAttrSplitDimValue)), split_v);
91 AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(static_cast<int64_t>(kAttrNumSplitValue)), split_v);
92
93 basic_lstm_cell_c_state_grad_nodes.emplace_back(basic_lstm_cell_c_state_grad);
94 matmul_nodes.emplace_back(matmul);
95 split_nodes.emplace_back(split_v);
96 }
97 result_nodes->emplace_back(basic_lstm_cell_c_state_grad_nodes);
98 result_nodes->emplace_back(matmul_nodes);
99 result_nodes->emplace_back(split_nodes);
100 }
101
CreateLSTMSPlitV(const FuncGraphPtr & func_graph,const AnfNodePtr & input,const std::vector<std::vector<size_t>> & split_shapes,const std::vector<TypeId> & split_types,const std::vector<int64_t> & size_split,size_t num_split_x)102 AnfNodePtr CreateLSTMSPlitV(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
103 const std::vector<std::vector<size_t>> &split_shapes,
104 const std::vector<TypeId> &split_types, const std::vector<int64_t> &size_split,
105 size_t num_split_x) {
106 std::vector<AnfNodePtr> lstm_split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
107 input};
108 auto lstm_split = func_graph->NewCNode(lstm_split_input);
109 AnfAlgo::SetOutputInferTypeAndShape(split_types, split_shapes, lstm_split.get());
110 AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_split), lstm_split);
111 AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(static_cast<int64_t>(0)), lstm_split);
112 AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(num_split_x)), lstm_split);
113 return lstm_split;
114 }
115
AddLSTMInputGradNode(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,std::vector<AnfNodePtr> * outputs)116 AnfNodePtr AddLSTMInputGradNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
117 std::vector<AnfNodePtr> *outputs) {
118 std::vector<std::vector<AnfNodePtr>> result_nodes;
119 CreateTLoopNode(func_graph, dynamic_rnn_grad_cnode, &result_nodes);
120
121 auto origin_input5_shape = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0);
122 std::vector<size_t> split_c_dims{IntToSize(1), origin_input5_shape[0], origin_input5_shape[1]};
123
124 auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
125 size_t num_split_x = AnfAlgo::GetOutputInferShape(origin_input7, 0)[0];
126 std::vector<std::vector<size_t>> split_shapes;
127 std::vector<TypeId> split_types;
128 std::vector<int64_t> size_split;
129 for (size_t i = 0; i < num_split_x; ++i) {
130 split_shapes.emplace_back(split_c_dims);
131 split_types.emplace_back(kNumberTypeFloat32);
132 size_split.emplace_back(1);
133 }
134 // Create lstm_split_c
135 auto lstm_split_c = CreateLSTMSPlitV(func_graph, origin_input7, split_shapes, split_types, size_split, num_split_x);
136 std::vector<AnfNodePtr> lstm_split_c_outputs;
137 CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_c, num_split_x, &lstm_split_c_outputs);
138
139 // Create lstm_split_dy
140 auto lstm_split_dy = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex9), split_shapes, split_types,
141 size_split, num_split_x);
142 std::vector<AnfNodePtr> lstm_split_dy_outputs;
143 CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_dy, num_split_x, &lstm_split_dy_outputs);
144
145 // Create lstm_split_i
146 auto lstm_split_i = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex12), split_shapes, split_types,
147 size_split, num_split_x);
148 std::vector<AnfNodePtr> lstm_split_i_outputs;
149 CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_i, num_split_x, &lstm_split_i_outputs);
150
151 // Create lstm_split_j
152 auto lstm_split_j = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex13), split_shapes, split_types,
153 size_split, num_split_x);
154 std::vector<AnfNodePtr> lstm_split_j_outputs;
155 CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_j, num_split_x, &lstm_split_j_outputs);
156
157 // Create lstm_split_f
158 auto lstm_split_f = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex14), split_shapes, split_types,
159 size_split, num_split_x);
160 std::vector<AnfNodePtr> lstm_split_f_outputs;
161 CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_f, num_split_x, &lstm_split_f_outputs);
162
163 // Create lstm_split_o
164 auto lstm_split_o = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex15), split_shapes, split_types,
165 size_split, num_split_x);
166 std::vector<AnfNodePtr> lstm_split_o_outputs;
167 CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_o, num_split_x, &lstm_split_o_outputs);
168
169 // Create lstm_split_tanh
170 auto lstm_split_tanh = CreateLSTMSPlitV(func_graph, dynamic_rnn_grad_cnode->input(kIndex16), split_shapes,
171 split_types, size_split, num_split_x);
172 std::vector<AnfNodePtr> lstm_split_tanh_outputs;
173 CreateMultipleOutputsOfAnfNode(func_graph, lstm_split_tanh, num_split_x, &lstm_split_tanh_outputs);
174
175 // Add edges
176 std::vector<AnfNodePtr> pre_basic_lstm_cell_c_state_grad_outputs;
177 std::vector<AnfNodePtr> pre_split_outputs;
178 auto basic_lstm_cell_c_state_grad_nodes = result_nodes[kIndex0];
179 auto matmul_nodes = result_nodes[kIndex1];
180 auto split_nodes = result_nodes[kIndex2];
181 std::vector<AnfNodePtr> lstm_x_concat_input(num_split_x + 1);
182 lstm_x_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
183 std::vector<AnfNodePtr> lstm_gage_concat_input(num_split_x + 1);
184 lstm_gage_concat_input[0] = NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()));
185
186 for (size_t i = 0; i < num_split_x; ++i) {
187 size_t idx = num_split_x - i - 1;
188 // Create basic_lstm_cell_c_state_grad
189 std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_inputs = {
190 NewValueNode(std::make_shared<Primitive>(kBasicLSTMCellCStateGradV2OpName))};
191 if (i == num_split_x - 1) {
192 std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
193 dynamic_rnn_grad_cnode->input(6)};
194 auto reshape = func_graph->NewCNode(reshape_inputs);
195 auto reshape_out_shape = {IntToSize(1),
196 AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[0],
197 AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex6), 0)[1]};
198 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {reshape_out_shape}, reshape.get());
199 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(reshape);
200 } else {
201 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_c_outputs[idx - 1]);
202 }
203 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_dy_outputs[idx]);
204 if (i == 0) {
205 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex10));
206 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex11));
207 } else {
208 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_split_outputs[1]);
209 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]);
210 }
211 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_i_outputs[idx]);
212 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_j_outputs[idx]);
213 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_f_outputs[idx]);
214 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_o_outputs[idx]);
215 (void)basic_lstm_cell_c_state_grad_inputs.emplace_back(lstm_split_tanh_outputs[idx]);
216 auto basic_lstm_cell_c_state_grad = func_graph->NewCNode(basic_lstm_cell_c_state_grad_inputs);
217 MS_EXCEPTION_IF_NULL(basic_lstm_cell_c_state_grad);
218 basic_lstm_cell_c_state_grad->set_abstract(basic_lstm_cell_c_state_grad_nodes[i]->abstract());
219 AnfAlgo::CopyNodeAttrs(basic_lstm_cell_c_state_grad_nodes[i], basic_lstm_cell_c_state_grad);
220 // Create outputs for current basic_lstm_cell_c_state_grad node
221 std::vector<AnfNodePtr> basic_lstm_cell_c_state_grad_outputs;
222 CreateMultipleOutputsOfAnfNode(func_graph, basic_lstm_cell_c_state_grad, kBasicCellOutputNum,
223 &basic_lstm_cell_c_state_grad_outputs);
224 pre_basic_lstm_cell_c_state_grad_outputs = basic_lstm_cell_c_state_grad_outputs;
225
226 // Create MatMul
227 std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name()))};
228 (void)matmul_inputs.emplace_back(basic_lstm_cell_c_state_grad_outputs[0]);
229 (void)matmul_inputs.emplace_back(dynamic_rnn_grad_cnode->input(kIndex2));
230 auto matmul = func_graph->NewCNode(matmul_inputs);
231 MS_EXCEPTION_IF_NULL(matmul);
232 matmul->set_abstract(matmul_nodes[i]->abstract());
233 AnfAlgo::CopyNodeAttrs(matmul_nodes[i], matmul);
234
235 // Create splitv
236 std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
237 matmul};
238 auto split_v = func_graph->NewCNode(splitv_input);
239 MS_EXCEPTION_IF_NULL(split_v);
240 split_v->set_abstract(split_nodes[i]->abstract());
241 AnfAlgo::CopyNodeAttrs(split_nodes[i], split_v);
242
243 // Create outputs for current split node
244 std::vector<AnfNodePtr> split_outputs;
245 CreateMultipleOutputsOfAnfNode(func_graph, split_v, kSplitVOutputNum, &split_outputs);
246 pre_split_outputs = split_outputs;
247
248 lstm_x_concat_input[idx + 1] = split_outputs[0];
249
250 auto basic_lstm_cell_c_state_grad_outputs_0_shape =
251 AnfAlgo::GetOutputInferShape(basic_lstm_cell_c_state_grad_outputs[0], 0);
252 std::vector<size_t> temp_shape;
253 if (basic_lstm_cell_c_state_grad_outputs_0_shape.size() == kBasicLstmCStateGradOutput0DimNum) {
254 temp_shape = basic_lstm_cell_c_state_grad_outputs_0_shape;
255 } else {
256 temp_shape = {1, basic_lstm_cell_c_state_grad_outputs_0_shape[0],
257 basic_lstm_cell_c_state_grad_outputs_0_shape[1]};
258 }
259 std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
260 basic_lstm_cell_c_state_grad_outputs[0]};
261 auto reshape = func_graph->NewCNode(reshape_input);
262 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(basic_lstm_cell_c_state_grad_outputs[0], 0)},
263 {temp_shape}, reshape.get());
264 lstm_gage_concat_input[idx + 1] = reshape;
265 }
266
267 // Create lstm_x_concat
268 auto lstm_x_concat = func_graph->NewCNode(lstm_x_concat_input);
269 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 2)},
270 lstm_x_concat.get());
271 AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_x_concat);
272 AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_x_concat);
273 AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(static_cast<int64_t>(0)), lstm_x_concat);
274
275 // Create lstm_gage_concat
276 auto lstm_gage_concat = func_graph->NewCNode(lstm_gage_concat_input);
277 auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
278 AnfAlgo::SetOutputInferTypeAndShape(
279 {kNumberTypeFloat16},
280 {{origin_input7_shape[kDim0], origin_input7_shape[kDim1], kDimMultiNum * origin_input7_shape[kDim2]}},
281 lstm_gage_concat.get());
282 AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(num_split_x)), lstm_gage_concat);
283 AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(num_split_x)}), lstm_gage_concat);
284 AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), lstm_gage_concat);
285
286 outputs->emplace_back(lstm_x_concat);
287 outputs->emplace_back(pre_split_outputs[1]);
288 outputs->emplace_back(pre_basic_lstm_cell_c_state_grad_outputs[1]);
289 return lstm_gage_concat;
290 }
291
CreateSplitV(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode)292 AnfNodePtr CreateSplitV(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
293 MS_EXCEPTION_IF_NULL(func_graph);
294 MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
295 // Create node
296 auto origin_input6 = dynamic_rnn_grad_cnode->input(kIndex7);
297 std::vector<AnfNodePtr> splitv_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name())),
298 origin_input6};
299 auto split_v = func_graph->NewCNode(splitv_input);
300 // Set infer data type and shape
301 auto dtypes = {AnfAlgo::GetOutputInferDataType(origin_input6, 0), AnfAlgo::GetOutputInferDataType(origin_input6, 0)};
302 auto origin_input6_shape = AnfAlgo::GetOutputInferShape(origin_input6, 0);
303 std::vector<size_t> shape1 = {origin_input6_shape[kDim0] - 1, origin_input6_shape[kDim1], origin_input6_shape[kDim2]};
304 std::vector<size_t> shape2 = {1, origin_input6_shape[kDim1], origin_input6_shape[kDim2]};
305 std::vector<std::vector<size_t>> shapes = {shape1, shape2};
306 AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split_v.get());
307 // Set attr
308 AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(SizeToLong(0)), split_v);
309 AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(SizeToLong(kAttrNumSplitValue)), split_v);
310 AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(std::vector<int64_t>{SizeToLong(origin_input6_shape[0] - 1), 1}),
311 split_v);
312 AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), split_v);
313 return split_v;
314 }
315
CreateHConcat(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,const AnfNodePtr & splitv)316 AnfNodePtr CreateHConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
317 const AnfNodePtr &splitv) {
318 MS_EXCEPTION_IF_NULL(func_graph);
319 MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
320 MS_EXCEPTION_IF_NULL(splitv);
321 // Create node
322 std::vector<AnfNodePtr> splitv_outputs;
323 CreateMultipleOutputsOfAnfNode(func_graph, splitv, kSplitVOutputNum, &splitv_outputs);
324 if (splitv_outputs.size() != kSplitVOutputNum) {
325 MS_LOG(EXCEPTION) << "Create outputs of node " << splitv->DebugString() << " failed"
326 << " trace: " << trace::DumpSourceLines(dynamic_rnn_grad_cnode);
327 }
328 auto origin_input4 = dynamic_rnn_grad_cnode->input(kIndex5);
329 auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
330 // Create reshape to change shape
331 std::vector<size_t> shape_tmp;
332 if (origin_input4_shape.size() == kShape4dDims) {
333 shape_tmp = origin_input4_shape;
334 } else {
335 shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
336 }
337 std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
338 origin_input4};
339 auto reshape = func_graph->NewCNode(reshape_input);
340 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
341 std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
342 reshape, splitv_outputs[0]};
343 auto concat = func_graph->NewCNode(concat_inputs);
344 // Set infer data type and shape
345 auto splitv_output0_shape = AnfAlgo::GetOutputInferShape(splitv, 0);
346 std::vector<size_t> shape = {splitv_output0_shape[0] + 1, origin_input4_shape[0], origin_input4_shape[1]};
347 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape}, concat.get());
348 // Set attr
349 AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kAttrNValue)), concat);
350 AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kAttrDynInputSizesValue}), concat);
351 AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(0)), concat);
352 AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
353 return concat;
354 }
355
CreateConcat(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,const AnfNodePtr & h_concat)356 AnfNodePtr CreateConcat(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
357 const AnfNodePtr &h_concat) {
358 MS_EXCEPTION_IF_NULL(func_graph);
359 MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
360 // Create node
361 auto origin_input0 = dynamic_rnn_grad_cnode->input(1);
362 std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
363 origin_input0, h_concat};
364 auto concat = func_graph->NewCNode(concat_inputs);
365 // Set infer data type and shape
366 auto origin_output0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
367 auto h_concat_output_shape = AnfAlgo::GetOutputInferShape(h_concat, 0);
368 std::vector<size_t> shape = {origin_output0_shape[kDim0], origin_output0_shape[kDim1],
369 origin_output0_shape[kDim2] + h_concat_output_shape[kDim2]};
370 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
371 // Set attr
372 AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kAttrNValue)), concat);
373 AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kAttrDynInputSizesValue}), concat);
374 AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(kAttrAxis2Value)), concat);
375 AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
376 return concat;
377 }
378
CreateConcatNodeT1(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode)379 AnfNodePtr CreateConcatNodeT1(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
380 MS_EXCEPTION_IF_NULL(func_graph);
381 MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
382 // Create node
383 auto origin_input0 = dynamic_rnn_grad_cnode->input(kIndex1);
384 auto origin_input4 = dynamic_rnn_grad_cnode->input(kIndex5);
385 auto origin_input4_shape = AnfAlgo::GetOutputInferShape(origin_input4, 0);
386 // Create reshape to change shape
387 std::vector<size_t> shape_tmp;
388 if (origin_input4_shape.size() == kShape3dDims) {
389 shape_tmp = origin_input4_shape;
390 } else {
391 shape_tmp = {1, origin_input4_shape[0], origin_input4_shape[1]};
392 }
393 std::vector<AnfNodePtr> reshape_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
394 origin_input4};
395 auto reshape = func_graph->NewCNode(reshape_input);
396 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input4, 0)}, {shape_tmp}, reshape.get());
397
398 std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name())),
399 origin_input0, reshape};
400 auto concat = func_graph->NewCNode(concat_inputs);
401 // Set infer data type and shape
402 auto origin_input0_shape = AnfAlgo::GetOutputInferShape(origin_input0, 0);
403 std::vector<size_t> shape = {origin_input0_shape[kDim0], origin_input0_shape[kDim1],
404 origin_input0_shape[kDim2] + shape_tmp[kDim2]};
405 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_input0, 0)}, {shape}, concat.get());
406 // Set attr
407 AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(kAttrNValue)), concat);
408 AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{kAttrDynInputSizesValue}), concat);
409 AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(kAttrAxis2Value)), concat);
410 AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), concat);
411 return concat;
412 }
413
CreateBatchMatMul(const FuncGraphPtr & func_graph,const AnfNodePtr & lstm_input_grad,const AnfNodePtr & concat)414 AnfNodePtr CreateBatchMatMul(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
415 const AnfNodePtr &concat) {
416 MS_EXCEPTION_IF_NULL(func_graph);
417 // Create node
418 std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
419 concat, lstm_input_grad};
420 auto batch_matmul = func_graph->NewCNode(matmul_inputs);
421 // Set infer data type and shape
422 auto concat_shape = AnfAlgo::GetOutputInferShape(concat, 0);
423 auto lstm_input_grad_shape = AnfAlgo::GetOutputInferShape(lstm_input_grad, 0);
424 std::vector<size_t> shape = {concat_shape[kDim0], concat_shape[kDim2], lstm_input_grad_shape[kDim2]};
425 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, batch_matmul.get());
426 // Set attr
427 AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
428 AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(true), batch_matmul);
429 AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
430 return batch_matmul;
431 }
432
CreateBatchMatMul2(const FuncGraphPtr & func_graph,const AnfNodePtr & lstm_input_grad,const AnfNodePtr & node)433 AnfNodePtr CreateBatchMatMul2(const FuncGraphPtr &func_graph, const AnfNodePtr &lstm_input_grad,
434 const AnfNodePtr &node) {
435 MS_EXCEPTION_IF_NULL(func_graph);
436 // Create node
437 std::vector<AnfNodePtr> matmul_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimBatchMatMul->name())),
438 node, lstm_input_grad};
439 auto batch_matmul = func_graph->NewCNode(matmul_inputs);
440 // Set infer data type and shape
441 auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex0], IntToSize(1),
442 AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kIndex2]};
443 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, batch_matmul.get());
444 // Set attr
445 AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), batch_matmul);
446 AnfAlgo::SetNodeAttr("transpose_x1", MakeValue(false), batch_matmul);
447 AnfAlgo::SetNodeAttr("transpose_x2", MakeValue(false), batch_matmul);
448 return batch_matmul;
449 }
450
CreateDwReduceSum(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,const AnfNodePtr & batch_matmul)451 AnfNodePtr CreateDwReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
452 const AnfNodePtr &batch_matmul) {
453 MS_EXCEPTION_IF_NULL(func_graph);
454 // Create node
455 std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
456 batch_matmul};
457 auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
458 // Set infer data type and shape
459 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
460 {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reduce_sum.get());
461 // Set attr
462 AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
463 AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
464 AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
465 return reduce_sum;
466 }
467
CreateDwReshape(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode,const AnfNodePtr & batch_matmul)468 AnfNodePtr CreateDwReshape(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode,
469 const AnfNodePtr &batch_matmul) {
470 MS_EXCEPTION_IF_NULL(func_graph);
471 // Create node
472 std::vector<AnfNodePtr> reshape_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReshape->name())),
473 batch_matmul};
474 auto reshape = func_graph->NewCNode(reshape_inputs);
475 // Set infer data type and shape
476 AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(dynamic_rnn_grad_cnode, 0)},
477 {AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode, 0)}, reshape.get());
478 AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reshape);
479 return reshape;
480 }
481
CreateValueNode(const FuncGraphPtr & func_graph,const CNodePtr & dynamic_rnn_grad_cnode)482 AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynamic_rnn_grad_cnode) {
483 auto origin_input7 = dynamic_rnn_grad_cnode->input(kIndex8);
484 auto origin_input7_shape = AnfAlgo::GetOutputInferShape(origin_input7, 0);
485 auto t_size = origin_input7_shape[0];
486 auto n_size = origin_input7_shape[1];
487
488 std::vector<size_t> shape = {t_size, IntToSize(1), n_size};
489 std::vector<int64_t> output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)};
490 std::vector<int64_t> output_tensor = {SizeToLong(t_size) * SizeToLong(n_size)};
491 auto tensor = TensorConstructUtils::CreateOnesTensor(kFloat32, output_tensor);
492 auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);
493 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
494 auto value_node = kernel_graph->NewValueNode(x_abstract, tensor);
495 kernel_graph->AddValueNodeToGraph(value_node);
496 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, value_node.get());
497 return value_node;
498 }
499
CreateDbReduceSum(const FuncGraphPtr & func_graph,const CNodePtr &,const AnfNodePtr & lstm_input_grad,const AnfNodePtr & value_node)500 AnfNodePtr CreateDbReduceSum(const FuncGraphPtr &func_graph, const CNodePtr &, const AnfNodePtr &lstm_input_grad,
501 const AnfNodePtr &value_node) {
502 MS_EXCEPTION_IF_NULL(func_graph);
503 // Create node
504 auto batch_matmul = CreateBatchMatMul2(func_graph, lstm_input_grad, value_node);
505 std::vector<AnfNodePtr> reduce_sum_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())),
506 batch_matmul};
507 auto reduce_sum = func_graph->NewCNode(reduce_sum_inputs);
508 // Set infer data type and shape
509 auto out_shape = {AnfAlgo::GetOutputInferShape(lstm_input_grad, 0)[kDim2]};
510 AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat16}, {out_shape}, reduce_sum.get());
511 // Set attr
512 AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{0}), reduce_sum);
513 AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(false), reduce_sum);
514 AnfAlgo::SetNodeAttr("is_backend_insert", MakeValue(true), reduce_sum);
515 return reduce_sum;
516 }
517 } // namespace
518
DefinePattern() const519 const BaseRef DynamicRnnGradFissionV2::DefinePattern() const {
520 VarPtr Xs = std::make_shared<SeqVar>();
521 return VectorRef({prim::kPrimDynamicRNNGrad, Xs});
522 }
523
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const524 const AnfNodePtr DynamicRnnGradFissionV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
525 const EquivPtr &) const {
526 MS_EXCEPTION_IF_NULL(func_graph);
527 MS_EXCEPTION_IF_NULL(node);
528 auto dynamic_rnn_grad_cnode = node->cast<CNodePtr>();
529 MS_EXCEPTION_IF_NULL(dynamic_rnn_grad_cnode);
530 if (dynamic_rnn_grad_cnode->inputs().size() < kDynamicRNNGradInputNum + 1) {
531 MS_LOG(INFO) << "The node " << dynamic_rnn_grad_cnode->DebugString() << " has less than "
532 << (kDynamicRNNGradInputNum + 1) << " inputs";
533 return nullptr;
534 }
535 if (AnfAlgo::IsDynamicShape(node)) {
536 MS_LOG(INFO) << "DynamicRnnGrad is dynamic shape, can not do fission.";
537 return nullptr;
538 }
539 std::vector<AnfNodePtr> new_outputs;
540 auto lstm_input_grad = AddLSTMInputGradNode(func_graph, dynamic_rnn_grad_cnode, &new_outputs);
541
542 size_t t_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[0];
543 size_t hidden_size = AnfAlgo::GetOutputInferShape(dynamic_rnn_grad_cnode->input(kIndex7), 0)[kDim2];
544 if (hidden_size % kCubeSize != 0) {
545 MS_LOG(EXCEPTION) << "`hidden_size` in this node should be multiple of 16, but got " << hidden_size << ". "
546 << dynamic_rnn_grad_cnode->DebugString();
547 }
548 AnfNodePtr concat = nullptr;
549 if (t_size != 1) {
550 auto splitv = CreateSplitV(func_graph, dynamic_rnn_grad_cnode);
551 auto h_concat = CreateHConcat(func_graph, dynamic_rnn_grad_cnode, splitv);
552 concat = CreateConcat(func_graph, dynamic_rnn_grad_cnode, h_concat);
553 } else {
554 concat = CreateConcatNodeT1(func_graph, dynamic_rnn_grad_cnode);
555 }
556
557 auto batch_matmul = CreateBatchMatMul(func_graph, lstm_input_grad, concat);
558 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
559 if (t_size != 1) {
560 auto dw_reduce_sum = CreateDwReduceSum(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
561 (void)make_tuple_inputs.emplace_back(dw_reduce_sum);
562 } else {
563 auto dw_reshape = CreateDwReshape(func_graph, dynamic_rnn_grad_cnode, batch_matmul);
564 (void)make_tuple_inputs.emplace_back(dw_reshape);
565 }
566
567 auto value_node = CreateValueNode(func_graph, dynamic_rnn_grad_cnode);
568 // create reduce_sum_2
569 auto db_reduce_sum = CreateDbReduceSum(func_graph, dynamic_rnn_grad_cnode, lstm_input_grad, value_node);
570 (void)make_tuple_inputs.emplace_back(db_reduce_sum);
571 make_tuple_inputs.insert(make_tuple_inputs.end(), new_outputs.begin(), new_outputs.end());
572 auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
573 return make_tuple;
574 }
575 } // namespace opt
576 } // namespace mindspore
577