1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/grad/ir/ir_pass.h"
18 #include <memory>
19 #include <vector>
20 #include <functional>
21 #include "pipeline/pynative/pynative_utils.h"
22 #include "ops/sequence_ops.h"
23 #include "ops/nn_ops.h"
24 #include "ops/op_utils.h"
25 #include "include/backend/optimizer/helper.h"
26 #include "include/common/utils/hook.h"
27 #include "runtime/pynative/op_function/pyboost_grad_functions.h"
28
29 namespace mindspore {
30 namespace pynative {
31 namespace bprop_pass {
32 namespace {
33 constexpr auto kTupleToMakeTuple = "tuple_to_make_tuple";
34
35 mindspore::HashMap<AnfNodePtr, std::vector<std::pair<size_t, AnfNodePtr>>> node_attr_value_;
36
CreateTensorByConstantValue(const ValueNodePtr & v_node)37 void CreateTensorByConstantValue(const ValueNodePtr &v_node) {
38 MS_EXCEPTION_IF_NULL(v_node);
39 const auto &value = v_node->value();
40 MS_EXCEPTION_IF_NULL(value);
41 auto tensor_ptr = PyNativeAlgo::Common::CreateTensorByConstantValue(value);
42 MS_EXCEPTION_IF_NULL(tensor_ptr);
43 v_node->set_value(tensor_ptr);
44 v_node->set_abstract(tensor_ptr->ToAbstract());
45 }
46
ChangeInputToAttr(const PrimitivePtr & prim,const CNodePtr & cnode,const ValuePtr & input_names,const mindspore::HashSet<size_t> & input_to_attr,bool grad_by_value)47 void ChangeInputToAttr(const PrimitivePtr &prim, const CNodePtr &cnode, const ValuePtr &input_names,
48 const mindspore::HashSet<size_t> &input_to_attr, bool grad_by_value) {
49 MS_EXCEPTION_IF_NULL(prim);
50 MS_EXCEPTION_IF_NULL(cnode);
51 MS_EXCEPTION_IF_NULL(input_names);
52 const auto &input_names_vec = GetValue<std::vector<std::string>>(input_names);
53 AnfNodePtrList new_inputs{NewValueNode(prim)};
54 size_t convert_size = 0;
55 for (size_t i = 0; i < cnode->size() - 1; ++i) {
56 auto input_node = cnode->input(i + 1);
57 MS_EXCEPTION_IF_NULL(input_node);
58 if (input_node->isa<ValueNode>() && input_to_attr.find(i) != input_to_attr.end()) {
59 const auto &value_node = input_node->cast<ValueNodePtr>();
60 MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
61 if (i >= input_names_vec.size()) {
62 MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
63 }
64 const auto &value = value_node->value();
65 if (value->isa<tensor::BaseTensor>()) {
66 auto tensor = value->cast<tensor::BaseTensorPtr>();
67 if (tensor->data().const_data() == nullptr && !tensor->has_user_data(kTensorValueIsEmpty)) {
68 return;
69 }
70 }
71 ++convert_size;
72 if (!grad_by_value) {
73 auto &pair = node_attr_value_[cnode];
74 (void)pair.emplace_back(i, value_node);
75 }
76 prim->set_attr(input_names_vec[i], value);
77 } else {
78 (void)new_inputs.emplace_back(input_node);
79 }
80 }
81 if (convert_size > 0) {
82 cnode->AddAttr(kAttrConvertAttrNode, MakeValue(convert_size));
83 }
84 cnode->set_inputs(new_inputs);
85 }
86
SetReverseParameterReplaceInfo(autograd::IrBprop * ir_bprop,const AnfNodePtr & node)87 void SetReverseParameterReplaceInfo(autograd::IrBprop *ir_bprop, const AnfNodePtr &node) {
88 MS_EXCEPTION_IF_NULL(ir_bprop);
89 MS_EXCEPTION_IF_NULL(node);
90 if (!node->isa<CNode>()) {
91 return;
92 }
93 const auto &cnode = node->cast<CNodePtr>();
94 for (size_t i = 1; i < cnode->size(); ++i) {
95 const auto &input = cnode->input(i);
96 MS_EXCEPTION_IF_NULL(input);
97 if (input->isa<Parameter>()) {
98 ir_bprop->AddReverseUser(input, cnode, i);
99 } else if (input->isa<CNode>()) {
100 SetReverseParameterReplaceInfo(ir_bprop, input);
101 }
102 }
103 }
104
105 template <typename T>
GetScalarAnfNodeValue(const AnfNodePtr & anf_node)106 std::optional<T> GetScalarAnfNodeValue(const AnfNodePtr &anf_node) {
107 if (!anf_node->isa<ValueNode>()) {
108 return std::nullopt;
109 }
110 auto value_node = anf_node->cast<ValueNodePtr>();
111 auto value_opt = mindspore::ops::GetScalarValue<T>(value_node->value());
112 if (!value_opt.has_value()) {
113 return std::nullopt;
114 }
115 return value_opt.value();
116 }
117
CreateBNInferGrad(autograd::IrBprop * ir_bprop,const CNodePtr & batchnorm_cnode,const AnfNodePtr & node,bool grad_by_value)118 CNodePtr CreateBNInferGrad(autograd::IrBprop *ir_bprop, const CNodePtr &batchnorm_cnode, const AnfNodePtr &node,
119 bool grad_by_value) {
120 MS_EXCEPTION_IF_NULL(ir_bprop);
121 MS_EXCEPTION_IF_NULL(batchnorm_cnode);
122 MS_EXCEPTION_IF_NULL(node);
123 constexpr size_t kIdxGrads = 1;
124 constexpr size_t kIdxScale = 3;
125 constexpr size_t kIdxVariance = 5;
126 constexpr size_t kIdxIsTraining = 7;
127 constexpr size_t kIdxEpsilon = 8;
128
129 AnfNodePtrList inputs{NewValueNode(prim::kPrimBNInferGrad)};
130 (void)inputs.emplace_back(batchnorm_cnode->input(kIdxGrads));
131 (void)inputs.emplace_back(batchnorm_cnode->input(kIdxScale));
132 (void)inputs.emplace_back(batchnorm_cnode->input(kIdxVariance));
133 (void)inputs.emplace_back(batchnorm_cnode->input(kIdxEpsilon));
134 auto new_node = ir_bprop->ad_param()->tape_->FuncGraph::NewCNode(inputs);
135 new_node->set_abstract(node->abstract());
136 new_node->set_scope(batchnorm_cnode->scope());
137
138 if (!grad_by_value) {
139 SetReverseParameterReplaceInfo(ir_bprop, batchnorm_cnode->input(kIndex2));
140 SetReverseParameterReplaceInfo(ir_bprop, batchnorm_cnode->input(kIndex4));
141 SetReverseParameterReplaceInfo(ir_bprop, batchnorm_cnode->input(kIndex6));
142 }
143 ir_bprop->AddUser(batchnorm_cnode->input(kIdxGrads), new_node, kIndex1);
144 ir_bprop->AddUser(batchnorm_cnode->input(kIdxScale), new_node, kIndex2);
145 ir_bprop->AddUser(batchnorm_cnode->input(kIdxVariance), new_node, kIndex3);
146
147 auto is_training_opt = GetScalarAnfNodeValue<bool>(batchnorm_cnode->input(kIdxIsTraining));
148 if (is_training_opt.has_value()) {
149 auto is_training = is_training_opt.value();
150 common::AnfAlgo::SetNodeAttr(kAttrIsTraining, MakeValue(is_training), new_node);
151 } else {
152 MS_LOG(ERROR) << "For BNInferGrad pass, failed to get attr is_training.";
153 }
154
155 auto epsilon_opt = GetScalarAnfNodeValue<pyfloat>(batchnorm_cnode->input(kIdxEpsilon));
156 float epsilon{1e-5};
157 if (epsilon_opt.has_value()) {
158 epsilon = epsilon_opt.has_value() ? epsilon_opt.value() : 1e-5;
159 } else {
160 MS_LOG(ERROR) << "For BNInferGrad pass, failed to get attr epsilon, use default epsilon: 1e-5.";
161 }
162 common::AnfAlgo::SetNodeAttr(kAttrEpsilon, MakeValue(epsilon), new_node);
163 return new_node;
164 }
165
166 class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR {
167 public:
Run(const CNodePtr & mul_node,const AnfNodePtr & sparse_softmax_node)168 CNodePtr Run(const CNodePtr &mul_node, const AnfNodePtr &sparse_softmax_node) {
169 GetDepthAndBatchSizeFromSparseSoftmaxNode(sparse_softmax_node);
170
171 AnfNodePtrList softmax_node_outputs;
172 auto expand_dims_node = CreateMulInput(mul_node, sparse_softmax_node, &softmax_node_outputs);
173
174 AnfNodePtrList new_mul_inputs{NewValueNode(prim::kPrimMul), softmax_node_outputs[kIndex1], expand_dims_node};
175 auto new_mul_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(new_mul_inputs);
176 new_mul_node->set_abstract(mul_node->abstract());
177 new_mul_node->set_scope(mul_node->scope());
178 auto is_dynamic = common::AnfAlgo::IsDynamicShape(sparse_softmax_node);
179 ShapeVector shape = is_dynamic ? ShapeVector{-1, depth_} : ShapeVector{batch_size_, depth_};
180 common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, new_mul_node.get());
181
182 auto logits_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, kIndex0);
183 // Reshape 1D result to multi-dim result.
184 auto reshape_node = CreateReshape(new_mul_node, logits_shape);
185 return reshape_node;
186 }
187
188 autograd::IrBprop *ir_bprop_{nullptr};
189
190 private:
CreateReshape(const AnfNodePtr & input_node,const ShapeVector & shape)191 CNodePtr CreateReshape(const AnfNodePtr &input_node, const ShapeVector &shape) {
192 MS_EXCEPTION_IF_NULL(input_node);
193
194 auto reshape_primitive = std::make_shared<Primitive>(kReshapeOpName);
195 std::vector<std::string> input_names = {"x", "shape"};
196 std::vector<std::string> output_names = {"output"};
197 reshape_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
198 reshape_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
199
200 auto shape_node = NewValueNode(shape);
201 CreateTensorByConstantValue(shape_node);
202 AnfNodePtrList reshape_inputs{NewValueNode(reshape_primitive), input_node, shape_node};
203 auto reshape_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(reshape_inputs);
204 auto data_types = common::AnfAlgo::GetOutputInferDataType(input_node, kIndex0);
205 common::AnfAlgo::SetOutputInferTypeAndShape({data_types}, {shape}, reshape_node.get());
206 reshape_node->set_scope(input_node->scope());
207 constexpr auto kShapeFromTensor = "shape_from_tensor";
208 common::AnfAlgo::SetNodeAttr(kShapeFromTensor, MakeValue(true), reshape_node);
209 ir_bprop_->AddUser(input_node, reshape_node, kIndex1);
210 return reshape_node;
211 }
212
GetDepthAndBatchSizeFromSparseSoftmaxNode(const AnfNodePtr & sparse_softmax_node)213 void GetDepthAndBatchSizeFromSparseSoftmaxNode(const AnfNodePtr &sparse_softmax_node) {
214 MS_EXCEPTION_IF_NULL(sparse_softmax_node);
215 auto logits_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, kIndex0);
216 auto labels_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, kIndex1);
217 if (!logits_shape.empty()) {
218 size_t index = logits_shape.size() - 1;
219 depth_ = logits_shape[index];
220 } else {
221 MS_LOG(EXCEPTION) << "Logits's shape of node [" << sparse_softmax_node->DebugString() << "] is empty"
222 << trace::DumpSourceLines(sparse_softmax_node);
223 }
224 batch_size_ = std::accumulate(labels_shape.begin(), labels_shape.end(), 1, std::multiplies<int64_t>());
225 }
226
CreateOneHot(const CNodePtr & sparse_softmax_node)227 CNodePtr CreateOneHot(const CNodePtr &sparse_softmax_node) {
228 MS_EXCEPTION_IF_NULL(sparse_softmax_node);
229
230 auto is_dynamic = common::AnfAlgo::IsDynamicShape(sparse_softmax_node);
231 ShapeVector shape = is_dynamic ? ShapeVector{-1} : ShapeVector{batch_size_};
232
233 // Reshape multi-dim labels to 1D labels.
234 auto reshape_node = CreateReshape(sparse_softmax_node->input(kIndex2), shape);
235
236 auto value_on = std::make_shared<tensor::Tensor>(1.0, kFloat32);
237 auto value_on_node = PyNativeAlgo::Common::CreateValueNodeByValue(value_on);
238 auto value_off = std::make_shared<tensor::Tensor>(0.0, kFloat32);
239 auto value_off_node = PyNativeAlgo::Common::CreateValueNodeByValue(value_off);
240 auto value_axis = MakeValue<int64_t>(-1);
241 auto value_axis_node = PyNativeAlgo::Common::CreateValueNodeByValue(value_axis);
242 auto one_hot_primitive = std::make_shared<Primitive>(kOneHotOpName);
243 std::vector<std::string> input_names = {"indices", "depth", "on_value", "off_value", "axis"};
244 std::vector<std::string> output_names = {"output"};
245 one_hot_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
246 one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
247
248 auto depth_node = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue<int64_t>(depth_));
249 CreateTensorByConstantValue(depth_node);
250 AnfNodePtrList one_hot_inputs{
251 NewValueNode(one_hot_primitive), reshape_node, depth_node, value_on_node, value_off_node, value_axis_node};
252 auto one_hot_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(one_hot_inputs);
253 ShapeVector one_hot_shape = {batch_size_, depth_};
254 common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {one_hot_shape}, one_hot_node.get());
255 one_hot_node->set_scope(sparse_softmax_node->scope());
256 ir_bprop_->AddUser(reshape_node, one_hot_node, kIndex1);
257 return one_hot_node;
258 }
259
CreateSoftmaxCrossEntropyWithLogits(const CNodePtr & sparse_softmax_node,const CNodePtr & one_hot_node)260 CNodePtr CreateSoftmaxCrossEntropyWithLogits(const CNodePtr &sparse_softmax_node, const CNodePtr &one_hot_node) {
261 MS_EXCEPTION_IF_NULL(sparse_softmax_node);
262 MS_EXCEPTION_IF_NULL(one_hot_node);
263
264 auto is_dynamic = common::AnfAlgo::IsDynamicShape(sparse_softmax_node);
265 ShapeVector shape = is_dynamic ? ShapeVector{-1, depth_} : ShapeVector{batch_size_, depth_};
266
267 // Reshape multi-dim logits to 2D logits.
268 auto reshape_node = CreateReshape(sparse_softmax_node->input(kIndex1), shape);
269 AnfNodePtrList inputs{NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)), reshape_node,
270 one_hot_node};
271 auto softmax_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(inputs);
272 ShapeVector loss_shape = {batch_size_};
273 auto data_types = common::AnfAlgo::GetOutputInferDataType(one_hot_node, kIndex0);
274 auto types = {data_types, data_types};
275 auto shapes = {loss_shape, shape};
276 common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, softmax_node.get());
277 softmax_node->set_scope(sparse_softmax_node->scope());
278 return softmax_node;
279 }
280
CreateMultipleOutputsOfAnfNode(const AnfNodePtr & node,size_t output_num,AnfNodePtrList * outputs)281 void CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, AnfNodePtrList *outputs) {
282 MS_EXCEPTION_IF_NULL(node);
283 MS_EXCEPTION_IF_NULL(outputs);
284 MS_EXCEPTION_IF_NULL(node->abstract());
285 const auto &abs_seq = node->abstract()->cast<abstract::AbstractSequencePtr>();
286 MS_EXCEPTION_IF_NULL(abs_seq);
287 if (abs_seq->size() != output_num) {
288 MS_LOG(EXCEPTION) << "Abstract seq size " << abs_seq->size() << " is not equal to " << output_num;
289 }
290 for (size_t i = 0; i < output_num; i++) {
291 auto idx = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue<int64_t>(SizeToLong(i)));
292 auto tuple_getitem =
293 ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
294 tuple_getitem->set_abstract(abs_seq->elements()[i]);
295 (void)outputs->emplace_back(tuple_getitem);
296 }
297 }
298
CreateTile(const CNodePtr & sparse_softmax_node,const CNodePtr & mul_node)299 CNodePtr CreateTile(const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node) {
300 MS_EXCEPTION_IF_NULL(sparse_softmax_node);
301 MS_EXCEPTION_IF_NULL(mul_node);
302 if (batch_size_ == 1) {
303 return nullptr;
304 }
305 auto tile_primitive = std::make_shared<Primitive>(kTileOpName);
306 std::vector<std::string> input_names = {"x", "multiples"};
307 std::vector<std::string> output_names = {"output"};
308 tile_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
309 tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
310
311 AnfNodePtrList tile_inputs;
312 if (batch_size_ < 0) {
313 AnfNodePtrList dynamic_shape_inputs{NewValueNode(std::make_shared<Primitive>("DynamicShape")),
314 sparse_softmax_node->input(kIndex2)};
315 auto shape_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(dynamic_shape_inputs);
316 auto labels_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, kIndex1);
317 ShapeVector tensor_shp({static_cast<int64_t>(labels_shape.size())});
318 auto dynamic_shape_abstract =
319 std::make_shared<abstract::AbstractTensor>(kInt64, std::make_shared<abstract::Shape>(tensor_shp));
320 MS_EXCEPTION_IF_NULL(dynamic_shape_abstract);
321 shape_node->set_abstract(dynamic_shape_abstract);
322 shape_node->set_scope(mul_node->scope());
323 ir_bprop_->AddUser(sparse_softmax_node->input(kIndex2), shape_node, kIndex1);
324 tile_inputs = {NewValueNode(tile_primitive), mul_node->input(kIndex2), shape_node};
325 } else {
326 std::vector<int64_t> multiples_v = {batch_size_};
327 auto multiples_node = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue(multiples_v));
328 tile_inputs = {NewValueNode(tile_primitive), mul_node->input(kIndex2), multiples_node};
329 }
330
331 auto tile_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(tile_inputs);
332 ShapeVector tile_shape = {batch_size_};
333 common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1UL)},
334 {tile_shape}, tile_node.get());
335 tile_node->set_scope(mul_node->scope());
336 ir_bprop_->AddUser(mul_node->input(kIndex2), tile_node, kIndex1);
337 // feature map set
338 std::vector<size_t> feature_map_input_indexs;
339 (void)feature_map_input_indexs.emplace_back(0);
340 constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
341 common::AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), tile_node);
342 return tile_node;
343 }
344
CreateRealDiv(const CNodePtr & sparse_softmax_node,const AnfNodePtr & tile_node)345 CNodePtr CreateRealDiv(const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node) {
346 MS_EXCEPTION_IF_NULL(sparse_softmax_node);
347 MS_EXCEPTION_IF_NULL(tile_node);
348 auto y_value = static_cast<float>(batch_size_);
349 auto y = std::make_shared<tensor::Tensor>(y_value, kFloat32);
350 auto y_node = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue(y));
351
352 auto real_div_primitive = std::make_shared<Primitive>(kRealDivOpName);
353 std::vector<std::string> input_names = {"x", "y"};
354 std::vector<std::string> output_names = {"output"};
355 real_div_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
356 real_div_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
357
358 AnfNodePtrList real_div_inputs{NewValueNode(real_div_primitive), tile_node, y_node};
359 auto real_div_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(real_div_inputs);
360 real_div_node->set_abstract(tile_node->abstract());
361 real_div_node->set_scope(sparse_softmax_node->scope());
362 return real_div_node;
363 }
364
CreateExpandDims(const CNodePtr & real_div_node)365 CNodePtr CreateExpandDims(const CNodePtr &real_div_node) {
366 MS_EXCEPTION_IF_NULL(real_div_node);
367
368 constexpr int64_t axis = -1;
369 auto axis_abstract = std::make_shared<abstract::AbstractScalar>();
370 MS_EXCEPTION_IF_NULL(axis_abstract);
371 axis_abstract->set_type(kInt64);
372 auto axis_node = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue(axis), axis_abstract);
373 MS_EXCEPTION_IF_NULL(axis_node);
374
375 auto expand_dims_primitive = std::make_shared<Primitive>(kExpandDimsOpName);
376 std::vector<std::string> input_names = {"x"};
377 std::vector<std::string> output_names = {"output"};
378 expand_dims_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
379 expand_dims_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
380
381 AnfNodePtrList expand_dims_inputs = {NewValueNode(expand_dims_primitive), real_div_node, axis_node};
382 auto expand_dims_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(expand_dims_inputs);
383 auto y_shape = common::AnfAlgo::GetOutputInferShape(real_div_node, 0UL);
384 (void)y_shape.emplace_back(1);
385 common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(real_div_node, 0UL)},
386 {y_shape}, expand_dims_node.get());
387 common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), expand_dims_node);
388 expand_dims_node->set_scope(real_div_node->scope());
389 return expand_dims_node;
390 }
391
CreateMulInput(const CNodePtr & mul_node,const AnfNodePtr & sparse_softmax_node,AnfNodePtrList * softmax_node_outputs)392 CNodePtr CreateMulInput(const CNodePtr &mul_node, const AnfNodePtr &sparse_softmax_node,
393 AnfNodePtrList *softmax_node_outputs) {
394 MS_EXCEPTION_IF_NULL(mul_node);
395 MS_EXCEPTION_IF_NULL(sparse_softmax_node);
396 auto sparse_softmax_cnode = sparse_softmax_node->cast<CNodePtr>();
397 MS_EXCEPTION_IF_NULL(sparse_softmax_cnode);
398 auto one_hot_node = CreateOneHot(sparse_softmax_cnode);
399 auto softmax_node = CreateSoftmaxCrossEntropyWithLogits(sparse_softmax_cnode, one_hot_node);
400 CreateMultipleOutputsOfAnfNode(softmax_node, opt::kSoftmaxCrossEntropyWithLogitsOutputNum, softmax_node_outputs);
401 auto tile_node = CreateTile(sparse_softmax_cnode, mul_node);
402 CNodePtr real_div_node;
403 if (tile_node == nullptr) {
404 real_div_node = CreateRealDiv(sparse_softmax_cnode, mul_node->input(kIndex2));
405 ir_bprop_->AddUser(mul_node->input(kIndex2), real_div_node, kIndex1);
406 } else {
407 real_div_node = CreateRealDiv(sparse_softmax_cnode, tile_node);
408 }
409 auto expand_dims_node = CreateExpandDims(real_div_node);
410 return expand_dims_node;
411 }
412
413 int64_t batch_size_{0};
414 int64_t depth_{0};
415 };
416
AddCNodeInputs(const CNodePtr & cnode,AnfNodePtrList * cnode_inputs,size_t index,const AnfNodePtr & input_node)417 void AddCNodeInputs(const CNodePtr &cnode, AnfNodePtrList *cnode_inputs, size_t index, const AnfNodePtr &input_node) {
418 MS_EXCEPTION_IF_NULL(cnode);
419 MS_EXCEPTION_IF_NULL(cnode_inputs);
420 MS_EXCEPTION_IF_NULL(input_node);
421 auto new_inputs = cnode->inputs();
422 (void)new_inputs.insert(new_inputs.begin() + SizeToLong(index) + kIndex1, input_node);
423 MS_EXCEPTION_IF_NULL(cnode_inputs);
424 (void)cnode_inputs->insert(cnode_inputs->begin() + SizeToLong(index) + kIndex1, input_node);
425 cnode->set_inputs(new_inputs);
426 }
427
GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(const AnfNodePtr & node,const std::string & op_name,autograd::IrBprop * ir_bprop)428 AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(const AnfNodePtr &node, const std::string &op_name,
429 autograd::IrBprop *ir_bprop) {
430 if (op_name != kSparseSoftmaxCrossEntropyWithLogitsOpName) {
431 return node;
432 }
433 MS_EXCEPTION_IF_NULL(node);
434 auto mul_node = node->cast<CNodePtr>();
435 MS_EXCEPTION_IF_NULL(mul_node);
436 if (mul_node->HasAttr(kIsKNode) || !IsPrimitiveCNode(mul_node, prim::kPrimMul)) {
437 return node;
438 }
439
440 auto sparse_softmax_node = mul_node->input(kIndex1);
441 if (!common::AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
442 return node;
443 }
444 // Use static class for create only once
445 static auto sparse_softmax_cross_entropy_with_logits =
446 std::make_shared<SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>();
447 sparse_softmax_cross_entropy_with_logits->ir_bprop_ = ir_bprop;
448 return sparse_softmax_cross_entropy_with_logits->Run(mul_node, sparse_softmax_node);
449 }
450 } // namespace
451
ConvertMakeTupleInputToDynamicInput(const AnfNodePtr & node,SeenNum seen,bool run_by_single_op)452 void IrPassForward::ConvertMakeTupleInputToDynamicInput(const AnfNodePtr &node, SeenNum seen, bool run_by_single_op) {
453 MS_EXCEPTION_IF_NULL(node);
454 if (!node->isa<CNode>()) {
455 return;
456 }
457 auto cnode = node->cast<CNodePtr>();
458 bool need_traverse = !grad_by_value_ && cnode->HasAttr(kIsKNode);
459 if (need_traverse || cnode->seen_ == seen || IsPrimitiveCNode(cnode, prim::kPrimBpropCut) ||
460 !IsPrimitiveCNode(cnode) || IsPrimitiveCNode(cnode, prim::kPrimMakeDict)) {
461 return;
462 }
463 cnode->seen_ = seen;
464 if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
465 ConvertMakeTupleInputToDynamicInput(cnode->input(kIndex1), seen, run_by_single_op);
466 return;
467 }
468 for (size_t i = 1; i < cnode->size(); ++i) {
469 ConvertMakeTupleInputToDynamicInput(cnode->input(i), seen, run_by_single_op);
470 }
471
472 if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) &&
473 std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) {
474 MS_EXCEPTION_IF_NULL(node->abstract());
475 return node->abstract()->isa<abstract::AbstractSequence>();
476 })) {
477 AnfNodePtrList plant_inputs;
478 std::vector<int64_t> dyn_input_sizes;
479 (void)plant_inputs.emplace_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode));
480 for (size_t i = 1; i < cnode->size(); ++i) {
481 const auto &input_node = cnode->input(i);
482 if (common::AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
483 auto dyn_input_size = opt::SplitTupleInputs(ir_bprop_->ad_param()->tape_, input_node, &plant_inputs);
484 (void)dyn_input_sizes.emplace_back(dyn_input_size);
485 } else {
486 (void)plant_inputs.emplace_back(input_node);
487 (void)dyn_input_sizes.emplace_back(-1);
488 }
489 }
490 // If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
491 if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
492 // Pyboost op no need plant tuple inputs
493 auto prim = GetCNodePrimitive(cnode);
494 MS_EXCEPTION_IF_NULL(prim);
495 MS_LOG(DEBUG) << "Get run by single op " << run_by_single_op;
496 if (run_by_single_op && runtime::PyBoostOpExecute::GetInstance().IsPyBoostOpRegistered(prim->name())) {
497 cnode->AddAttr(kAttrIsPyboostTupleInput, MakeValue(true));
498 return;
499 }
500 cnode->AddAttr(kTupleToMakeTuple, MakeValue(true));
501 common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
502 MS_LOG(DEBUG) << "Change node to dynamic len " << cnode->DebugString();
503 cnode->set_inputs(plant_inputs);
504 for (size_t i = 1; i < plant_inputs.size(); ++i) {
505 ir_bprop_->AddUser(plant_inputs[i], cnode, i);
506 }
507 }
508 }
509 }
510
PassBackwardHook(const ValuePtr & value,const AnfNodePtr & grad_node)511 AnfNodePtr IrPassForward::PassBackwardHook(const ValuePtr &value, const AnfNodePtr &grad_node) {
512 MS_EXCEPTION_IF_NULL(value);
513 MS_EXCEPTION_IF_NULL(grad_node);
514 auto tensor = value->cast<tensor::BaseTensorPtr>();
515 if (tensor == nullptr) {
516 MS_LOG(DEBUG) << "Hook just work on tensor, not support value " << value->ToString();
517 return grad_node;
518 }
519 auto auto_grad_meta = tensor->auto_grad_meta_data();
520 MS_EXCEPTION_IF_NULL(auto_grad_meta);
521 if (auto_grad_meta->backward_hooks().empty()) {
522 MS_LOG(DEBUG) << "Get empty backward hooks for tensor id " << tensor->id();
523 return grad_node;
524 }
525 AnfNodePtr res = grad_node;
526 for (const auto &[id, hook] : auto_grad_meta->backward_hooks()) {
527 if (hook->hook_map_.size() != kSizeOne) {
528 MS_LOG(EXCEPTION) << "Tensor hook just work on one tensor value, not support value sequence";
529 }
530 auto hook_fn = hook->hook_map_.begin()->second;
531 if (hook_fn.ptr() == nullptr) {
532 MS_LOG(DEBUG) << "Hook id " << id << " have been delete by python";
533 continue;
534 }
535 MS_LOG(DEBUG) << "Insert bprop cut fn " << ConvertPyObjToString(hook_fn) << " for tensor " << value->ToString()
536 << " with id " << tensor->id();
537 auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
538 bprop_cut->AddAttr("tensor_hook", MakeValue(true));
539 bprop_cut->AddBackwardHookFn(kIndex0, hook_fn);
540 // Need input out and dout for bprop run, current just make a fake
541 AnfNodePtrList inputs{NewValueNode(bprop_cut), grad_node, NewValueNode(MakeValue("FakeOutput")), res};
542 res = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(inputs);
543 // Need update after execute
544 res->set_abstract(grad_node->abstract());
545
546 // For run graph by single op
547 ir_bprop_->ad_param()->tape_->set_flag(kFlagPyNativeBpropGraphWithBpropCut, true);
548 ir_bprop_->set_bprop_graph_run_by_single_op(true);
549 }
550 auto_grad_meta->ClearBackwardHooks();
551 return res;
552 }
553
ConvertConstInputToAttr(const CNodePtr & cnode,bool is_dynamic_shape)554 CNodePtr IrPassForward::ConvertConstInputToAttr(const CNodePtr &cnode, bool is_dynamic_shape) {
555 MS_EXCEPTION_IF_NULL(cnode);
556 const auto &prim = GetCNodePrimitive(cnode);
557 if (prim == nullptr) {
558 MS_LOG(DEBUG) << "Get cnode not primitive " << cnode->DebugString();
559 return cnode;
560 }
561 // Pyboost op no need convert input to attr
562 if (runtime::PyBoostOpExecute::GetInstance().IsPyBoostOpRegistered(prim->name())) {
563 cnode->AddAttr(kAttrConvertAttrNode, MakeValue(true));
564 return cnode;
565 }
566 auto TraverseCNode = [this, is_dynamic_shape](const CNodePtr &cnode) {
567 for (size_t i = 1; i < cnode->size(); ++i) {
568 // Avoiding infinite loops
569 if (!cnode->HasAttr(kIsKNode) && cnode->input(i)->isa<CNode>()) {
570 cnode->set_input(i, ConvertConstInputToAttr(cnode->input(i)->cast<CNodePtr>(), is_dynamic_shape));
571 }
572 }
573 };
574
575 mindspore::HashSet<size_t> input_to_attr = {};
576 PyNativeAlgo::Common::GetConstInputToAttr(prim, prim->name(), device_target_, is_dynamic_shape, &input_to_attr);
577 if (input_to_attr.empty()) {
578 TraverseCNode(cnode);
579 return cnode;
580 }
581 const auto &input_names = prim->GetAttr(kAttrInputNames);
582 if (input_names == nullptr) {
583 MS_LOG(DEBUG) << "input_names are nullptr";
584 return cnode;
585 }
586
587 ChangeInputToAttr(prim, cnode, input_names, input_to_attr, grad_by_value_);
588
589 // If cast input has a cast
590 TraverseCNode(cnode);
591 return cnode;
592 }
593
BatchNormGradToBNInferGrad(const AnfNodePtr & node,const std::string & op_name)594 AnfNodePtr IrPassForward::BatchNormGradToBNInferGrad(const AnfNodePtr &node, const std::string &op_name) {
595 if (op_name != kBatchNormOpName) {
596 return node;
597 }
598 MS_EXCEPTION_IF_NULL(node);
599 auto cnode = node->cast<CNodePtr>();
600 MS_EXCEPTION_IF_NULL(cnode);
601 if (cnode->HasAttr(kIsKNode) || !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
602 return cnode;
603 }
604 auto batchnorm_grad_node = cnode->input(kRealInputNodeIndexInTupleGetItem);
605 MS_EXCEPTION_IF_NULL(batchnorm_grad_node);
606 if (!IsPrimitiveCNode(batchnorm_grad_node, prim::kPrimBatchNormGrad)) {
607 return cnode;
608 }
609 AnfNodePtr index_node = cnode->input(kInputNodeOutputIndexInTupleGetItem);
610 MS_EXCEPTION_IF_NULL(index_node);
611 auto value_node = index_node->cast<ValueNodePtr>();
612 MS_EXCEPTION_IF_NULL(value_node);
613 auto index = GetValue<int64_t>(value_node->value());
614 if (index != 0) {
615 MS_LOG(DEBUG) << "TupleGetitem must be 0th output of BatchNormGrad";
616 return cnode;
617 }
618 auto batchnorm_grad_cnode = batchnorm_grad_node->cast<CNodePtr>();
619 MS_EXCEPTION_IF_NULL(batchnorm_grad_cnode);
620 constexpr size_t kIdxIsTraining = 7;
621 auto is_training_opt = GetScalarAnfNodeValue<bool>(batchnorm_grad_cnode->input(kIdxIsTraining));
622 if (!is_training_opt.has_value()) {
623 return cnode;
624 }
625 if (is_training_opt.value()) {
626 MS_LOG(DEBUG) << "Attr 'is_training' is true, no need do fusion";
627 return cnode;
628 }
629
630 need_reverse_graph_ = true;
631 auto new_cnode = CreateBNInferGrad(ir_bprop_, batchnorm_grad_cnode, node, grad_by_value_);
632 auto &pair = node_attr_value_[new_cnode];
633 (void)pair.emplace_back(UINT32_MAX, node);
634 return new_cnode;
635 }
636
ReverseConstantToAttrNode(const CNodePtr & cnode,ValuePtrList * inputs_value,AnfNodePtrList * cnode_inputs)637 void IrPassForward::ReverseConstantToAttrNode(const CNodePtr &cnode, ValuePtrList *inputs_value,
638 AnfNodePtrList *cnode_inputs) {
639 MS_EXCEPTION_IF_NULL(cnode);
640 if (!cnode->HasAttr(kAttrConvertAttrNode)) {
641 return;
642 }
643 ReverseCNodeInputs(cnode, cnode_inputs, inputs_value);
644 }
645
ReverseMakeTupleNode(const CNodePtr & cnode,ValuePtrList * inputs_value,AnfNodePtrList * cnode_inputs)646 void IrPassForward::ReverseMakeTupleNode(const CNodePtr &cnode, ValuePtrList *inputs_value,
647 AnfNodePtrList *cnode_inputs) {
648 MS_EXCEPTION_IF_NULL(cnode);
649 MS_EXCEPTION_IF_NULL(inputs_value);
650 MS_EXCEPTION_IF_NULL(cnode_inputs);
651 if (!cnode->HasAttr(kTupleToMakeTuple)) {
652 return;
653 }
654 AnfNodePtrList new_inputs{cnode->input(kIndex0)};
655 const auto &dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrDynInputSizes);
656 for (size_t i = 0; i < dyn_input_sizes.size(); ++i) {
657 if (dyn_input_sizes[i] >= 0) {
658 // Compress input
659 AnfNodePtrList cnode_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
660 AnfNodePtrList knode_inputs{NewValueNode(prim::kPrimMakeTuple)};
661 ValuePtrList value_tuple;
662 abstract::AbstractBasePtrList abs_list;
663 for (int64_t j = 0; j < dyn_input_sizes[i]; ++j) {
664 auto input = cnode->input(i + j + kIndex1);
665 (void)cnode_tuple_inputs.emplace_back(input);
666 (void)knode_inputs.emplace_back(cnode_inputs->at(i + j + kIndex1));
667 (void)value_tuple.emplace_back(inputs_value->at(i + j));
668 (void)abs_list.emplace_back(input->abstract());
669 }
670 // Update knode inputs to make tuple inputs
671 auto cnode_graph = cnode->func_graph();
672 MS_EXCEPTION_IF_NULL(cnode_graph);
673 auto cnode_tuple = cnode_graph->NewCNode(cnode_tuple_inputs);
674 auto abs = std::make_shared<abstract::AbstractTuple>(abs_list);
675 cnode_tuple->set_abstract(abs);
676 (void)new_inputs.emplace_back(cnode_tuple);
677
678 // Update knode inputs
679 auto knode_input = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(knode_inputs);
680 knode_input->set_abstract(abs);
681 size_t begin_index = i + kIndex1;
682 auto it = cnode_inputs->erase(cnode_inputs->begin() + SizeToLong(begin_index),
683 cnode_inputs->begin() + SizeToLong(begin_index) + dyn_input_sizes[i]);
684 (void)cnode_inputs->insert(it, knode_input);
685
686 // Update input value
687 auto item = inputs_value->erase(inputs_value->begin() + SizeToLong(kIndex0),
688 inputs_value->begin() + SizeToLong(kIndex0) + dyn_input_sizes[i]);
689 (void)inputs_value->insert(item, std::make_shared<ValueTuple>(value_tuple));
690 } else {
691 auto last_index = (i == 0 ? 0 : i - 1);
692 auto skip_index = (dyn_input_sizes[last_index] == -1 ? 1 : dyn_input_sizes[last_index]);
693 (void)new_inputs.emplace_back(cnode->input(i + skip_index));
694 }
695 }
696 cnode->set_inputs(new_inputs);
697 (void)cnode->EraseAttr(kTupleToMakeTuple);
698 }
699
ReverseBNInfer(const CNodePtr & cnode)700 void IrPassForward::ReverseBNInfer(const CNodePtr &cnode) {
701 MS_EXCEPTION_IF_NULL(cnode);
702 if (!IsPrimitiveCNode(cnode, prim::kPrimBNInferGrad)) {
703 return;
704 }
705 const auto item = node_attr_value_.find(cnode);
706 if (item == node_attr_value_.end()) {
707 return;
708 }
709 auto func_graph = cnode->func_graph();
710 MS_EXCEPTION_IF_NULL(func_graph);
711 auto manager = func_graph->manager();
712 if (manager == nullptr) {
713 manager = Manage(func_graph, false);
714 }
715 if (item->second.size() != kIndex1) {
716 MS_LOG(EXCEPTION) << "Replace item size " << item->second.size() << " is not equal to " << kIndex1;
717 }
718 if (!manager->Replace(cnode, item->second[kIndex0].second)) {
719 MS_LOG(EXCEPTION) << "Replace failed. cnode " << cnode->DebugString() << " to cnode "
720 << item->second[kIndex0].second->DebugString();
721 }
722 (void)node_attr_value_.erase(item);
723 }
724
ReverseCNodeInputs(const CNodePtr & cnode,AnfNodePtrList * cnode_inputs,ValuePtrList * inputs_value)725 void IrPassForward::ReverseCNodeInputs(const CNodePtr &cnode, AnfNodePtrList *cnode_inputs,
726 ValuePtrList *inputs_value) {
727 MS_EXCEPTION_IF_NULL(cnode);
728 MS_EXCEPTION_IF_NULL(inputs_value);
729 MS_EXCEPTION_IF_NULL(cnode_inputs);
730 const auto item = node_attr_value_.find(cnode);
731 if (item == node_attr_value_.end()) {
732 return;
733 }
734 for (const auto &t : item->second) {
735 if (t.second->isa<ValueNode>()) {
736 auto vnode = t.second->cast<ValueNodePtr>();
737 auto v = vnode->value();
738 (void)PyNativeAlgo::Common::SetValueGradInfo(v, nullptr, InputType::kConstant);
739 AddCNodeInputs(cnode, cnode_inputs, t.first, PyNativeAlgo::Common::CreateValueNodeByValue(v, nullptr));
740 (void)inputs_value->insert(inputs_value->begin() + SizeToLong(t.first), v);
741 } else if (t.second->isa<Parameter>()) {
742 const auto it = ir_bprop_->ad_param()->anfnode_to_variable_adjoint_.find(t.second);
743 if (it == ir_bprop_->ad_param()->anfnode_to_variable_adjoint_.end()) {
744 MS_LOG(EXCEPTION) << "Can not find " << t.second << " in anfnode_to_variable_adjoint";
745 }
746 AddCNodeInputs(cnode, cnode_inputs, t.first, it->second->k_node());
747 (void)inputs_value->insert(inputs_value->begin() + SizeToLong(t.first), it->second->out_value());
748 } else {
749 MS_LOG(EXCEPTION) << "No scenario for " << t.second->DebugString();
750 }
751 }
752 (void)node_attr_value_.erase(item);
753 }
754
ReversePassFuncGraph(const FuncGraphPtr & func_graph)755 void IrPassForward::ReversePassFuncGraph(const FuncGraphPtr &func_graph) {
756 MS_EXCEPTION_IF_NULL(func_graph);
757 const auto &order = TopoSort(func_graph->output());
758 for (const auto &node : order) {
759 if (node == nullptr || !node->isa<CNode>()) {
760 continue;
761 }
762 auto cnode = node->cast<CNodePtr>();
763 MS_EXCEPTION_IF_NULL(cnode);
764 // Bn Ascend only
765 if (device_target_ == kAscendDevice) {
766 ReverseBNInfer(cnode);
767 }
768 }
769 need_reverse_graph_ = false;
770 PyNativeAlgo::Common::DumpGraphIR("reverse_cnode_graph.ir", func_graph);
771 }
772
ReversePassCNode(const CNodePtr & cnode,ValuePtrList * inputs_value,AnfNodePtrList * cnode_inputs)773 void IrPassForward::ReversePassCNode(const CNodePtr &cnode, ValuePtrList *inputs_value, AnfNodePtrList *cnode_inputs) {
774 // Notice, The reverser step is opposite to the positive pass
775 auto tape_graph = ir_bprop_->ad_param()->tape_;
776 MS_EXCEPTION_IF_NULL(tape_graph);
777
778 ReverseMakeTupleNode(cnode, inputs_value, cnode_inputs);
779 ReverseConstantToAttrNode(cnode, inputs_value, cnode_inputs);
780 }
781
PassForDin(const CNodePtr & cnode,const std::string & op_name,bool is_dynamic_shape)782 CNodePtr IrPassForward::PassForDin(const CNodePtr &cnode, const std::string &op_name, bool is_dynamic_shape) {
783 // If you want add a pass here, please take care of high grad
784 MS_EXCEPTION_IF_NULL(ir_bprop_);
785 AnfNodePtr new_din = ConvertConstInputToAttr(cnode, is_dynamic_shape);
786
787 // Ascend only
788 if (device_target_ == kAscendDevice) {
789 new_din = BatchNormGradToBNInferGrad(new_din, op_name);
790 new_din = GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(new_din, op_name, ir_bprop_);
791 }
792 return new_din->cast<CNodePtr>();
793 }
794
795 bool IrPassForward::need_reverse_graph_ = false;
796
ClearCache()797 void ClearCache() { node_attr_value_.clear(); }
798 } // namespace bprop_pass
799 } // namespace pynative
800 } // namespace mindspore
801