• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/ir/ir_pass.h"
18 #include <memory>
19 #include <vector>
20 #include <functional>
21 #include "pipeline/pynative/pynative_utils.h"
22 #include "ops/sequence_ops.h"
23 #include "ops/nn_ops.h"
24 #include "ops/op_utils.h"
25 #include "include/backend/optimizer/helper.h"
26 #include "include/common/utils/hook.h"
27 #include "runtime/pynative/op_function/pyboost_grad_functions.h"
28 
29 namespace mindspore {
30 namespace pynative {
31 namespace bprop_pass {
32 namespace {
33 constexpr auto kTupleToMakeTuple = "tuple_to_make_tuple";
34 
35 mindspore::HashMap<AnfNodePtr, std::vector<std::pair<size_t, AnfNodePtr>>> node_attr_value_;
36 
CreateTensorByConstantValue(const ValueNodePtr & v_node)37 void CreateTensorByConstantValue(const ValueNodePtr &v_node) {
38   MS_EXCEPTION_IF_NULL(v_node);
39   const auto &value = v_node->value();
40   MS_EXCEPTION_IF_NULL(value);
41   auto tensor_ptr = PyNativeAlgo::Common::CreateTensorByConstantValue(value);
42   MS_EXCEPTION_IF_NULL(tensor_ptr);
43   v_node->set_value(tensor_ptr);
44   v_node->set_abstract(tensor_ptr->ToAbstract());
45 }
46 
ChangeInputToAttr(const PrimitivePtr & prim,const CNodePtr & cnode,const ValuePtr & input_names,const mindspore::HashSet<size_t> & input_to_attr,bool grad_by_value)47 void ChangeInputToAttr(const PrimitivePtr &prim, const CNodePtr &cnode, const ValuePtr &input_names,
48                        const mindspore::HashSet<size_t> &input_to_attr, bool grad_by_value) {
49   MS_EXCEPTION_IF_NULL(prim);
50   MS_EXCEPTION_IF_NULL(cnode);
51   MS_EXCEPTION_IF_NULL(input_names);
52   const auto &input_names_vec = GetValue<std::vector<std::string>>(input_names);
53   AnfNodePtrList new_inputs{NewValueNode(prim)};
54   size_t convert_size = 0;
55   for (size_t i = 0; i < cnode->size() - 1; ++i) {
56     auto input_node = cnode->input(i + 1);
57     MS_EXCEPTION_IF_NULL(input_node);
58     if (input_node->isa<ValueNode>() && input_to_attr.find(i) != input_to_attr.end()) {
59       const auto &value_node = input_node->cast<ValueNodePtr>();
60       MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
61       if (i >= input_names_vec.size()) {
62         MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
63       }
64       const auto &value = value_node->value();
65       if (value->isa<tensor::BaseTensor>()) {
66         auto tensor = value->cast<tensor::BaseTensorPtr>();
67         if (tensor->data().const_data() == nullptr && !tensor->has_user_data(kTensorValueIsEmpty)) {
68           return;
69         }
70       }
71       ++convert_size;
72       if (!grad_by_value) {
73         auto &pair = node_attr_value_[cnode];
74         (void)pair.emplace_back(i, value_node);
75       }
76       prim->set_attr(input_names_vec[i], value);
77     } else {
78       (void)new_inputs.emplace_back(input_node);
79     }
80   }
81   if (convert_size > 0) {
82     cnode->AddAttr(kAttrConvertAttrNode, MakeValue(convert_size));
83   }
84   cnode->set_inputs(new_inputs);
85 }
86 
SetReverseParameterReplaceInfo(autograd::IrBprop * ir_bprop,const AnfNodePtr & node)87 void SetReverseParameterReplaceInfo(autograd::IrBprop *ir_bprop, const AnfNodePtr &node) {
88   MS_EXCEPTION_IF_NULL(ir_bprop);
89   MS_EXCEPTION_IF_NULL(node);
90   if (!node->isa<CNode>()) {
91     return;
92   }
93   const auto &cnode = node->cast<CNodePtr>();
94   for (size_t i = 1; i < cnode->size(); ++i) {
95     const auto &input = cnode->input(i);
96     MS_EXCEPTION_IF_NULL(input);
97     if (input->isa<Parameter>()) {
98       ir_bprop->AddReverseUser(input, cnode, i);
99     } else if (input->isa<CNode>()) {
100       SetReverseParameterReplaceInfo(ir_bprop, input);
101     }
102   }
103 }
104 
105 template <typename T>
GetScalarAnfNodeValue(const AnfNodePtr & anf_node)106 std::optional<T> GetScalarAnfNodeValue(const AnfNodePtr &anf_node) {
107   if (!anf_node->isa<ValueNode>()) {
108     return std::nullopt;
109   }
110   auto value_node = anf_node->cast<ValueNodePtr>();
111   auto value_opt = mindspore::ops::GetScalarValue<T>(value_node->value());
112   if (!value_opt.has_value()) {
113     return std::nullopt;
114   }
115   return value_opt.value();
116 }
117 
CreateBNInferGrad(autograd::IrBprop * ir_bprop,const CNodePtr & batchnorm_cnode,const AnfNodePtr & node,bool grad_by_value)118 CNodePtr CreateBNInferGrad(autograd::IrBprop *ir_bprop, const CNodePtr &batchnorm_cnode, const AnfNodePtr &node,
119                            bool grad_by_value) {
120   MS_EXCEPTION_IF_NULL(ir_bprop);
121   MS_EXCEPTION_IF_NULL(batchnorm_cnode);
122   MS_EXCEPTION_IF_NULL(node);
123   constexpr size_t kIdxGrads = 1;
124   constexpr size_t kIdxScale = 3;
125   constexpr size_t kIdxVariance = 5;
126   constexpr size_t kIdxIsTraining = 7;
127   constexpr size_t kIdxEpsilon = 8;
128 
129   AnfNodePtrList inputs{NewValueNode(prim::kPrimBNInferGrad)};
130   (void)inputs.emplace_back(batchnorm_cnode->input(kIdxGrads));
131   (void)inputs.emplace_back(batchnorm_cnode->input(kIdxScale));
132   (void)inputs.emplace_back(batchnorm_cnode->input(kIdxVariance));
133   (void)inputs.emplace_back(batchnorm_cnode->input(kIdxEpsilon));
134   auto new_node = ir_bprop->ad_param()->tape_->FuncGraph::NewCNode(inputs);
135   new_node->set_abstract(node->abstract());
136   new_node->set_scope(batchnorm_cnode->scope());
137 
138   if (!grad_by_value) {
139     SetReverseParameterReplaceInfo(ir_bprop, batchnorm_cnode->input(kIndex2));
140     SetReverseParameterReplaceInfo(ir_bprop, batchnorm_cnode->input(kIndex4));
141     SetReverseParameterReplaceInfo(ir_bprop, batchnorm_cnode->input(kIndex6));
142   }
143   ir_bprop->AddUser(batchnorm_cnode->input(kIdxGrads), new_node, kIndex1);
144   ir_bprop->AddUser(batchnorm_cnode->input(kIdxScale), new_node, kIndex2);
145   ir_bprop->AddUser(batchnorm_cnode->input(kIdxVariance), new_node, kIndex3);
146 
147   auto is_training_opt = GetScalarAnfNodeValue<bool>(batchnorm_cnode->input(kIdxIsTraining));
148   if (is_training_opt.has_value()) {
149     auto is_training = is_training_opt.value();
150     common::AnfAlgo::SetNodeAttr(kAttrIsTraining, MakeValue(is_training), new_node);
151   } else {
152     MS_LOG(ERROR) << "For BNInferGrad pass, failed to get attr is_training.";
153   }
154 
155   auto epsilon_opt = GetScalarAnfNodeValue<pyfloat>(batchnorm_cnode->input(kIdxEpsilon));
156   float epsilon{1e-5};
157   if (epsilon_opt.has_value()) {
158     epsilon = epsilon_opt.has_value() ? epsilon_opt.value() : 1e-5;
159   } else {
160     MS_LOG(ERROR) << "For BNInferGrad pass, failed to get attr epsilon, use default epsilon: 1e-5.";
161   }
162   common::AnfAlgo::SetNodeAttr(kAttrEpsilon, MakeValue(epsilon), new_node);
163   return new_node;
164 }
165 
166 class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR {
167  public:
Run(const CNodePtr & mul_node,const AnfNodePtr & sparse_softmax_node)168   CNodePtr Run(const CNodePtr &mul_node, const AnfNodePtr &sparse_softmax_node) {
169     GetDepthAndBatchSizeFromSparseSoftmaxNode(sparse_softmax_node);
170 
171     AnfNodePtrList softmax_node_outputs;
172     auto expand_dims_node = CreateMulInput(mul_node, sparse_softmax_node, &softmax_node_outputs);
173 
174     AnfNodePtrList new_mul_inputs{NewValueNode(prim::kPrimMul), softmax_node_outputs[kIndex1], expand_dims_node};
175     auto new_mul_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(new_mul_inputs);
176     new_mul_node->set_abstract(mul_node->abstract());
177     new_mul_node->set_scope(mul_node->scope());
178     auto is_dynamic = common::AnfAlgo::IsDynamicShape(sparse_softmax_node);
179     ShapeVector shape = is_dynamic ? ShapeVector{-1, depth_} : ShapeVector{batch_size_, depth_};
180     common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {shape}, new_mul_node.get());
181 
182     auto logits_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, kIndex0);
183     // Reshape 1D result to multi-dim result.
184     auto reshape_node = CreateReshape(new_mul_node, logits_shape);
185     return reshape_node;
186   }
187 
188   autograd::IrBprop *ir_bprop_{nullptr};
189 
190  private:
CreateReshape(const AnfNodePtr & input_node,const ShapeVector & shape)191   CNodePtr CreateReshape(const AnfNodePtr &input_node, const ShapeVector &shape) {
192     MS_EXCEPTION_IF_NULL(input_node);
193 
194     auto reshape_primitive = std::make_shared<Primitive>(kReshapeOpName);
195     std::vector<std::string> input_names = {"x", "shape"};
196     std::vector<std::string> output_names = {"output"};
197     reshape_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
198     reshape_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
199 
200     auto shape_node = NewValueNode(shape);
201     CreateTensorByConstantValue(shape_node);
202     AnfNodePtrList reshape_inputs{NewValueNode(reshape_primitive), input_node, shape_node};
203     auto reshape_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(reshape_inputs);
204     auto data_types = common::AnfAlgo::GetOutputInferDataType(input_node, kIndex0);
205     common::AnfAlgo::SetOutputInferTypeAndShape({data_types}, {shape}, reshape_node.get());
206     reshape_node->set_scope(input_node->scope());
207     constexpr auto kShapeFromTensor = "shape_from_tensor";
208     common::AnfAlgo::SetNodeAttr(kShapeFromTensor, MakeValue(true), reshape_node);
209     ir_bprop_->AddUser(input_node, reshape_node, kIndex1);
210     return reshape_node;
211   }
212 
GetDepthAndBatchSizeFromSparseSoftmaxNode(const AnfNodePtr & sparse_softmax_node)213   void GetDepthAndBatchSizeFromSparseSoftmaxNode(const AnfNodePtr &sparse_softmax_node) {
214     MS_EXCEPTION_IF_NULL(sparse_softmax_node);
215     auto logits_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, kIndex0);
216     auto labels_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, kIndex1);
217     if (!logits_shape.empty()) {
218       size_t index = logits_shape.size() - 1;
219       depth_ = logits_shape[index];
220     } else {
221       MS_LOG(EXCEPTION) << "Logits's shape of node [" << sparse_softmax_node->DebugString() << "] is empty"
222                         << trace::DumpSourceLines(sparse_softmax_node);
223     }
224     batch_size_ = std::accumulate(labels_shape.begin(), labels_shape.end(), 1, std::multiplies<int64_t>());
225   }
226 
CreateOneHot(const CNodePtr & sparse_softmax_node)227   CNodePtr CreateOneHot(const CNodePtr &sparse_softmax_node) {
228     MS_EXCEPTION_IF_NULL(sparse_softmax_node);
229 
230     auto is_dynamic = common::AnfAlgo::IsDynamicShape(sparse_softmax_node);
231     ShapeVector shape = is_dynamic ? ShapeVector{-1} : ShapeVector{batch_size_};
232 
233     // Reshape multi-dim labels to 1D labels.
234     auto reshape_node = CreateReshape(sparse_softmax_node->input(kIndex2), shape);
235 
236     auto value_on = std::make_shared<tensor::Tensor>(1.0, kFloat32);
237     auto value_on_node = PyNativeAlgo::Common::CreateValueNodeByValue(value_on);
238     auto value_off = std::make_shared<tensor::Tensor>(0.0, kFloat32);
239     auto value_off_node = PyNativeAlgo::Common::CreateValueNodeByValue(value_off);
240     auto value_axis = MakeValue<int64_t>(-1);
241     auto value_axis_node = PyNativeAlgo::Common::CreateValueNodeByValue(value_axis);
242     auto one_hot_primitive = std::make_shared<Primitive>(kOneHotOpName);
243     std::vector<std::string> input_names = {"indices", "depth", "on_value", "off_value", "axis"};
244     std::vector<std::string> output_names = {"output"};
245     one_hot_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
246     one_hot_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
247 
248     auto depth_node = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue<int64_t>(depth_));
249     CreateTensorByConstantValue(depth_node);
250     AnfNodePtrList one_hot_inputs{
251       NewValueNode(one_hot_primitive), reshape_node, depth_node, value_on_node, value_off_node, value_axis_node};
252     auto one_hot_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(one_hot_inputs);
253     ShapeVector one_hot_shape = {batch_size_, depth_};
254     common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {one_hot_shape}, one_hot_node.get());
255     one_hot_node->set_scope(sparse_softmax_node->scope());
256     ir_bprop_->AddUser(reshape_node, one_hot_node, kIndex1);
257     return one_hot_node;
258   }
259 
CreateSoftmaxCrossEntropyWithLogits(const CNodePtr & sparse_softmax_node,const CNodePtr & one_hot_node)260   CNodePtr CreateSoftmaxCrossEntropyWithLogits(const CNodePtr &sparse_softmax_node, const CNodePtr &one_hot_node) {
261     MS_EXCEPTION_IF_NULL(sparse_softmax_node);
262     MS_EXCEPTION_IF_NULL(one_hot_node);
263 
264     auto is_dynamic = common::AnfAlgo::IsDynamicShape(sparse_softmax_node);
265     ShapeVector shape = is_dynamic ? ShapeVector{-1, depth_} : ShapeVector{batch_size_, depth_};
266 
267     // Reshape multi-dim logits to 2D logits.
268     auto reshape_node = CreateReshape(sparse_softmax_node->input(kIndex1), shape);
269     AnfNodePtrList inputs{NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)), reshape_node,
270                           one_hot_node};
271     auto softmax_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(inputs);
272     ShapeVector loss_shape = {batch_size_};
273     auto data_types = common::AnfAlgo::GetOutputInferDataType(one_hot_node, kIndex0);
274     auto types = {data_types, data_types};
275     auto shapes = {loss_shape, shape};
276     common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, softmax_node.get());
277     softmax_node->set_scope(sparse_softmax_node->scope());
278     return softmax_node;
279   }
280 
CreateMultipleOutputsOfAnfNode(const AnfNodePtr & node,size_t output_num,AnfNodePtrList * outputs)281   void CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t output_num, AnfNodePtrList *outputs) {
282     MS_EXCEPTION_IF_NULL(node);
283     MS_EXCEPTION_IF_NULL(outputs);
284     MS_EXCEPTION_IF_NULL(node->abstract());
285     const auto &abs_seq = node->abstract()->cast<abstract::AbstractSequencePtr>();
286     MS_EXCEPTION_IF_NULL(abs_seq);
287     if (abs_seq->size() != output_num) {
288       MS_LOG(EXCEPTION) << "Abstract seq size " << abs_seq->size() << " is not equal to " << output_num;
289     }
290     for (size_t i = 0; i < output_num; i++) {
291       auto idx = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue<int64_t>(SizeToLong(i)));
292       auto tuple_getitem =
293         ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
294       tuple_getitem->set_abstract(abs_seq->elements()[i]);
295       (void)outputs->emplace_back(tuple_getitem);
296     }
297   }
298 
CreateTile(const CNodePtr & sparse_softmax_node,const CNodePtr & mul_node)299   CNodePtr CreateTile(const CNodePtr &sparse_softmax_node, const CNodePtr &mul_node) {
300     MS_EXCEPTION_IF_NULL(sparse_softmax_node);
301     MS_EXCEPTION_IF_NULL(mul_node);
302     if (batch_size_ == 1) {
303       return nullptr;
304     }
305     auto tile_primitive = std::make_shared<Primitive>(kTileOpName);
306     std::vector<std::string> input_names = {"x", "multiples"};
307     std::vector<std::string> output_names = {"output"};
308     tile_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
309     tile_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
310 
311     AnfNodePtrList tile_inputs;
312     if (batch_size_ < 0) {
313       AnfNodePtrList dynamic_shape_inputs{NewValueNode(std::make_shared<Primitive>("DynamicShape")),
314                                           sparse_softmax_node->input(kIndex2)};
315       auto shape_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(dynamic_shape_inputs);
316       auto labels_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, kIndex1);
317       ShapeVector tensor_shp({static_cast<int64_t>(labels_shape.size())});
318       auto dynamic_shape_abstract =
319         std::make_shared<abstract::AbstractTensor>(kInt64, std::make_shared<abstract::Shape>(tensor_shp));
320       MS_EXCEPTION_IF_NULL(dynamic_shape_abstract);
321       shape_node->set_abstract(dynamic_shape_abstract);
322       shape_node->set_scope(mul_node->scope());
323       ir_bprop_->AddUser(sparse_softmax_node->input(kIndex2), shape_node, kIndex1);
324       tile_inputs = {NewValueNode(tile_primitive), mul_node->input(kIndex2), shape_node};
325     } else {
326       std::vector<int64_t> multiples_v = {batch_size_};
327       auto multiples_node = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue(multiples_v));
328       tile_inputs = {NewValueNode(tile_primitive), mul_node->input(kIndex2), multiples_node};
329     }
330 
331     auto tile_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(tile_inputs);
332     ShapeVector tile_shape = {batch_size_};
333     common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(mul_node, 1UL)},
334                                                 {tile_shape}, tile_node.get());
335     tile_node->set_scope(mul_node->scope());
336     ir_bprop_->AddUser(mul_node->input(kIndex2), tile_node, kIndex1);
337     // feature map set
338     std::vector<size_t> feature_map_input_indexs;
339     (void)feature_map_input_indexs.emplace_back(0);
340     constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
341     common::AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), tile_node);
342     return tile_node;
343   }
344 
CreateRealDiv(const CNodePtr & sparse_softmax_node,const AnfNodePtr & tile_node)345   CNodePtr CreateRealDiv(const CNodePtr &sparse_softmax_node, const AnfNodePtr &tile_node) {
346     MS_EXCEPTION_IF_NULL(sparse_softmax_node);
347     MS_EXCEPTION_IF_NULL(tile_node);
348     auto y_value = static_cast<float>(batch_size_);
349     auto y = std::make_shared<tensor::Tensor>(y_value, kFloat32);
350     auto y_node = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue(y));
351 
352     auto real_div_primitive = std::make_shared<Primitive>(kRealDivOpName);
353     std::vector<std::string> input_names = {"x", "y"};
354     std::vector<std::string> output_names = {"output"};
355     real_div_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
356     real_div_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
357 
358     AnfNodePtrList real_div_inputs{NewValueNode(real_div_primitive), tile_node, y_node};
359     auto real_div_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(real_div_inputs);
360     real_div_node->set_abstract(tile_node->abstract());
361     real_div_node->set_scope(sparse_softmax_node->scope());
362     return real_div_node;
363   }
364 
CreateExpandDims(const CNodePtr & real_div_node)365   CNodePtr CreateExpandDims(const CNodePtr &real_div_node) {
366     MS_EXCEPTION_IF_NULL(real_div_node);
367 
368     constexpr int64_t axis = -1;
369     auto axis_abstract = std::make_shared<abstract::AbstractScalar>();
370     MS_EXCEPTION_IF_NULL(axis_abstract);
371     axis_abstract->set_type(kInt64);
372     auto axis_node = PyNativeAlgo::Common::CreateValueNodeByValue(MakeValue(axis), axis_abstract);
373     MS_EXCEPTION_IF_NULL(axis_node);
374 
375     auto expand_dims_primitive = std::make_shared<Primitive>(kExpandDimsOpName);
376     std::vector<std::string> input_names = {"x"};
377     std::vector<std::string> output_names = {"output"};
378     expand_dims_primitive->set_attr(kAttrInputNames, MakeValue(input_names));
379     expand_dims_primitive->set_attr(kAttrOutputNames, MakeValue(output_names));
380 
381     AnfNodePtrList expand_dims_inputs = {NewValueNode(expand_dims_primitive), real_div_node, axis_node};
382     auto expand_dims_node = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(expand_dims_inputs);
383     auto y_shape = common::AnfAlgo::GetOutputInferShape(real_div_node, 0UL);
384     (void)y_shape.emplace_back(1);
385     common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(real_div_node, 0UL)},
386                                                 {y_shape}, expand_dims_node.get());
387     common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), expand_dims_node);
388     expand_dims_node->set_scope(real_div_node->scope());
389     return expand_dims_node;
390   }
391 
CreateMulInput(const CNodePtr & mul_node,const AnfNodePtr & sparse_softmax_node,AnfNodePtrList * softmax_node_outputs)392   CNodePtr CreateMulInput(const CNodePtr &mul_node, const AnfNodePtr &sparse_softmax_node,
393                           AnfNodePtrList *softmax_node_outputs) {
394     MS_EXCEPTION_IF_NULL(mul_node);
395     MS_EXCEPTION_IF_NULL(sparse_softmax_node);
396     auto sparse_softmax_cnode = sparse_softmax_node->cast<CNodePtr>();
397     MS_EXCEPTION_IF_NULL(sparse_softmax_cnode);
398     auto one_hot_node = CreateOneHot(sparse_softmax_cnode);
399     auto softmax_node = CreateSoftmaxCrossEntropyWithLogits(sparse_softmax_cnode, one_hot_node);
400     CreateMultipleOutputsOfAnfNode(softmax_node, opt::kSoftmaxCrossEntropyWithLogitsOutputNum, softmax_node_outputs);
401     auto tile_node = CreateTile(sparse_softmax_cnode, mul_node);
402     CNodePtr real_div_node;
403     if (tile_node == nullptr) {
404       real_div_node = CreateRealDiv(sparse_softmax_cnode, mul_node->input(kIndex2));
405       ir_bprop_->AddUser(mul_node->input(kIndex2), real_div_node, kIndex1);
406     } else {
407       real_div_node = CreateRealDiv(sparse_softmax_cnode, tile_node);
408     }
409     auto expand_dims_node = CreateExpandDims(real_div_node);
410     return expand_dims_node;
411   }
412 
413   int64_t batch_size_{0};
414   int64_t depth_{0};
415 };
416 
AddCNodeInputs(const CNodePtr & cnode,AnfNodePtrList * cnode_inputs,size_t index,const AnfNodePtr & input_node)417 void AddCNodeInputs(const CNodePtr &cnode, AnfNodePtrList *cnode_inputs, size_t index, const AnfNodePtr &input_node) {
418   MS_EXCEPTION_IF_NULL(cnode);
419   MS_EXCEPTION_IF_NULL(cnode_inputs);
420   MS_EXCEPTION_IF_NULL(input_node);
421   auto new_inputs = cnode->inputs();
422   (void)new_inputs.insert(new_inputs.begin() + SizeToLong(index) + kIndex1, input_node);
423   MS_EXCEPTION_IF_NULL(cnode_inputs);
424   (void)cnode_inputs->insert(cnode_inputs->begin() + SizeToLong(index) + kIndex1, input_node);
425   cnode->set_inputs(new_inputs);
426 }
427 
GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(const AnfNodePtr & node,const std::string & op_name,autograd::IrBprop * ir_bprop)428 AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(const AnfNodePtr &node, const std::string &op_name,
429                                                               autograd::IrBprop *ir_bprop) {
430   if (op_name != kSparseSoftmaxCrossEntropyWithLogitsOpName) {
431     return node;
432   }
433   MS_EXCEPTION_IF_NULL(node);
434   auto mul_node = node->cast<CNodePtr>();
435   MS_EXCEPTION_IF_NULL(mul_node);
436   if (mul_node->HasAttr(kIsKNode) || !IsPrimitiveCNode(mul_node, prim::kPrimMul)) {
437     return node;
438   }
439 
440   auto sparse_softmax_node = mul_node->input(kIndex1);
441   if (!common::AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
442     return node;
443   }
444   // Use static class for create only once
445   static auto sparse_softmax_cross_entropy_with_logits =
446     std::make_shared<SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>();
447   sparse_softmax_cross_entropy_with_logits->ir_bprop_ = ir_bprop;
448   return sparse_softmax_cross_entropy_with_logits->Run(mul_node, sparse_softmax_node);
449 }
450 }  // namespace
451 
ConvertMakeTupleInputToDynamicInput(const AnfNodePtr & node,SeenNum seen,bool run_by_single_op)452 void IrPassForward::ConvertMakeTupleInputToDynamicInput(const AnfNodePtr &node, SeenNum seen, bool run_by_single_op) {
453   MS_EXCEPTION_IF_NULL(node);
454   if (!node->isa<CNode>()) {
455     return;
456   }
457   auto cnode = node->cast<CNodePtr>();
458   bool need_traverse = !grad_by_value_ && cnode->HasAttr(kIsKNode);
459   if (need_traverse || cnode->seen_ == seen || IsPrimitiveCNode(cnode, prim::kPrimBpropCut) ||
460       !IsPrimitiveCNode(cnode) || IsPrimitiveCNode(cnode, prim::kPrimMakeDict)) {
461     return;
462   }
463   cnode->seen_ = seen;
464   if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
465     ConvertMakeTupleInputToDynamicInput(cnode->input(kIndex1), seen, run_by_single_op);
466     return;
467   }
468   for (size_t i = 1; i < cnode->size(); ++i) {
469     ConvertMakeTupleInputToDynamicInput(cnode->input(i), seen, run_by_single_op);
470   }
471 
472   if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) &&
473       std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) {
474         MS_EXCEPTION_IF_NULL(node->abstract());
475         return node->abstract()->isa<abstract::AbstractSequence>();
476       })) {
477     AnfNodePtrList plant_inputs;
478     std::vector<int64_t> dyn_input_sizes;
479     (void)plant_inputs.emplace_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode));
480     for (size_t i = 1; i < cnode->size(); ++i) {
481       const auto &input_node = cnode->input(i);
482       if (common::AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
483         auto dyn_input_size = opt::SplitTupleInputs(ir_bprop_->ad_param()->tape_, input_node, &plant_inputs);
484         (void)dyn_input_sizes.emplace_back(dyn_input_size);
485       } else {
486         (void)plant_inputs.emplace_back(input_node);
487         (void)dyn_input_sizes.emplace_back(-1);
488       }
489     }
490     // If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
491     if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
492       // Pyboost op no need plant tuple inputs
493       auto prim = GetCNodePrimitive(cnode);
494       MS_EXCEPTION_IF_NULL(prim);
495       MS_LOG(DEBUG) << "Get run by single op " << run_by_single_op;
496       if (run_by_single_op && runtime::PyBoostOpExecute::GetInstance().IsPyBoostOpRegistered(prim->name())) {
497         cnode->AddAttr(kAttrIsPyboostTupleInput, MakeValue(true));
498         return;
499       }
500       cnode->AddAttr(kTupleToMakeTuple, MakeValue(true));
501       common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
502       MS_LOG(DEBUG) << "Change node to dynamic len " << cnode->DebugString();
503       cnode->set_inputs(plant_inputs);
504       for (size_t i = 1; i < plant_inputs.size(); ++i) {
505         ir_bprop_->AddUser(plant_inputs[i], cnode, i);
506       }
507     }
508   }
509 }
510 
PassBackwardHook(const ValuePtr & value,const AnfNodePtr & grad_node)511 AnfNodePtr IrPassForward::PassBackwardHook(const ValuePtr &value, const AnfNodePtr &grad_node) {
512   MS_EXCEPTION_IF_NULL(value);
513   MS_EXCEPTION_IF_NULL(grad_node);
514   auto tensor = value->cast<tensor::BaseTensorPtr>();
515   if (tensor == nullptr) {
516     MS_LOG(DEBUG) << "Hook just work on tensor, not support value " << value->ToString();
517     return grad_node;
518   }
519   auto auto_grad_meta = tensor->auto_grad_meta_data();
520   MS_EXCEPTION_IF_NULL(auto_grad_meta);
521   if (auto_grad_meta->backward_hooks().empty()) {
522     MS_LOG(DEBUG) << "Get empty backward hooks for tensor id " << tensor->id();
523     return grad_node;
524   }
525   AnfNodePtr res = grad_node;
526   for (const auto &[id, hook] : auto_grad_meta->backward_hooks()) {
527     if (hook->hook_map_.size() != kSizeOne) {
528       MS_LOG(EXCEPTION) << "Tensor hook just work on one tensor value, not support value sequence";
529     }
530     auto hook_fn = hook->hook_map_.begin()->second;
531     if (hook_fn.ptr() == nullptr) {
532       MS_LOG(DEBUG) << "Hook id " << id << " have been delete by python";
533       continue;
534     }
535     MS_LOG(DEBUG) << "Insert bprop cut fn " << ConvertPyObjToString(hook_fn) << " for tensor " << value->ToString()
536                   << " with id " << tensor->id();
537     auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
538     bprop_cut->AddAttr("tensor_hook", MakeValue(true));
539     bprop_cut->AddBackwardHookFn(kIndex0, hook_fn);
540     // Need input out and dout for bprop run, current just make a fake
541     AnfNodePtrList inputs{NewValueNode(bprop_cut), grad_node, NewValueNode(MakeValue("FakeOutput")), res};
542     res = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(inputs);
543     // Need update after execute
544     res->set_abstract(grad_node->abstract());
545 
546     // For run graph by single op
547     ir_bprop_->ad_param()->tape_->set_flag(kFlagPyNativeBpropGraphWithBpropCut, true);
548     ir_bprop_->set_bprop_graph_run_by_single_op(true);
549   }
550   auto_grad_meta->ClearBackwardHooks();
551   return res;
552 }
553 
ConvertConstInputToAttr(const CNodePtr & cnode,bool is_dynamic_shape)554 CNodePtr IrPassForward::ConvertConstInputToAttr(const CNodePtr &cnode, bool is_dynamic_shape) {
555   MS_EXCEPTION_IF_NULL(cnode);
556   const auto &prim = GetCNodePrimitive(cnode);
557   if (prim == nullptr) {
558     MS_LOG(DEBUG) << "Get cnode not primitive " << cnode->DebugString();
559     return cnode;
560   }
561   // Pyboost op no need convert input to attr
562   if (runtime::PyBoostOpExecute::GetInstance().IsPyBoostOpRegistered(prim->name())) {
563     cnode->AddAttr(kAttrConvertAttrNode, MakeValue(true));
564     return cnode;
565   }
566   auto TraverseCNode = [this, is_dynamic_shape](const CNodePtr &cnode) {
567     for (size_t i = 1; i < cnode->size(); ++i) {
568       // Avoiding infinite loops
569       if (!cnode->HasAttr(kIsKNode) && cnode->input(i)->isa<CNode>()) {
570         cnode->set_input(i, ConvertConstInputToAttr(cnode->input(i)->cast<CNodePtr>(), is_dynamic_shape));
571       }
572     }
573   };
574 
575   mindspore::HashSet<size_t> input_to_attr = {};
576   PyNativeAlgo::Common::GetConstInputToAttr(prim, prim->name(), device_target_, is_dynamic_shape, &input_to_attr);
577   if (input_to_attr.empty()) {
578     TraverseCNode(cnode);
579     return cnode;
580   }
581   const auto &input_names = prim->GetAttr(kAttrInputNames);
582   if (input_names == nullptr) {
583     MS_LOG(DEBUG) << "input_names are nullptr";
584     return cnode;
585   }
586 
587   ChangeInputToAttr(prim, cnode, input_names, input_to_attr, grad_by_value_);
588 
589   // If cast input has a cast
590   TraverseCNode(cnode);
591   return cnode;
592 }
593 
BatchNormGradToBNInferGrad(const AnfNodePtr & node,const std::string & op_name)594 AnfNodePtr IrPassForward::BatchNormGradToBNInferGrad(const AnfNodePtr &node, const std::string &op_name) {
595   if (op_name != kBatchNormOpName) {
596     return node;
597   }
598   MS_EXCEPTION_IF_NULL(node);
599   auto cnode = node->cast<CNodePtr>();
600   MS_EXCEPTION_IF_NULL(cnode);
601   if (cnode->HasAttr(kIsKNode) || !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
602     return cnode;
603   }
604   auto batchnorm_grad_node = cnode->input(kRealInputNodeIndexInTupleGetItem);
605   MS_EXCEPTION_IF_NULL(batchnorm_grad_node);
606   if (!IsPrimitiveCNode(batchnorm_grad_node, prim::kPrimBatchNormGrad)) {
607     return cnode;
608   }
609   AnfNodePtr index_node = cnode->input(kInputNodeOutputIndexInTupleGetItem);
610   MS_EXCEPTION_IF_NULL(index_node);
611   auto value_node = index_node->cast<ValueNodePtr>();
612   MS_EXCEPTION_IF_NULL(value_node);
613   auto index = GetValue<int64_t>(value_node->value());
614   if (index != 0) {
615     MS_LOG(DEBUG) << "TupleGetitem must be 0th output of BatchNormGrad";
616     return cnode;
617   }
618   auto batchnorm_grad_cnode = batchnorm_grad_node->cast<CNodePtr>();
619   MS_EXCEPTION_IF_NULL(batchnorm_grad_cnode);
620   constexpr size_t kIdxIsTraining = 7;
621   auto is_training_opt = GetScalarAnfNodeValue<bool>(batchnorm_grad_cnode->input(kIdxIsTraining));
622   if (!is_training_opt.has_value()) {
623     return cnode;
624   }
625   if (is_training_opt.value()) {
626     MS_LOG(DEBUG) << "Attr 'is_training' is true, no need do fusion";
627     return cnode;
628   }
629 
630   need_reverse_graph_ = true;
631   auto new_cnode = CreateBNInferGrad(ir_bprop_, batchnorm_grad_cnode, node, grad_by_value_);
632   auto &pair = node_attr_value_[new_cnode];
633   (void)pair.emplace_back(UINT32_MAX, node);
634   return new_cnode;
635 }
636 
ReverseConstantToAttrNode(const CNodePtr & cnode,ValuePtrList * inputs_value,AnfNodePtrList * cnode_inputs)637 void IrPassForward::ReverseConstantToAttrNode(const CNodePtr &cnode, ValuePtrList *inputs_value,
638                                               AnfNodePtrList *cnode_inputs) {
639   MS_EXCEPTION_IF_NULL(cnode);
640   if (!cnode->HasAttr(kAttrConvertAttrNode)) {
641     return;
642   }
643   ReverseCNodeInputs(cnode, cnode_inputs, inputs_value);
644 }
645 
ReverseMakeTupleNode(const CNodePtr & cnode,ValuePtrList * inputs_value,AnfNodePtrList * cnode_inputs)646 void IrPassForward::ReverseMakeTupleNode(const CNodePtr &cnode, ValuePtrList *inputs_value,
647                                          AnfNodePtrList *cnode_inputs) {
648   MS_EXCEPTION_IF_NULL(cnode);
649   MS_EXCEPTION_IF_NULL(inputs_value);
650   MS_EXCEPTION_IF_NULL(cnode_inputs);
651   if (!cnode->HasAttr(kTupleToMakeTuple)) {
652     return;
653   }
654   AnfNodePtrList new_inputs{cnode->input(kIndex0)};
655   const auto &dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrDynInputSizes);
656   for (size_t i = 0; i < dyn_input_sizes.size(); ++i) {
657     if (dyn_input_sizes[i] >= 0) {
658       // Compress input
659       AnfNodePtrList cnode_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
660       AnfNodePtrList knode_inputs{NewValueNode(prim::kPrimMakeTuple)};
661       ValuePtrList value_tuple;
662       abstract::AbstractBasePtrList abs_list;
663       for (int64_t j = 0; j < dyn_input_sizes[i]; ++j) {
664         auto input = cnode->input(i + j + kIndex1);
665         (void)cnode_tuple_inputs.emplace_back(input);
666         (void)knode_inputs.emplace_back(cnode_inputs->at(i + j + kIndex1));
667         (void)value_tuple.emplace_back(inputs_value->at(i + j));
668         (void)abs_list.emplace_back(input->abstract());
669       }
670       // Update knode inputs to make tuple inputs
671       auto cnode_graph = cnode->func_graph();
672       MS_EXCEPTION_IF_NULL(cnode_graph);
673       auto cnode_tuple = cnode_graph->NewCNode(cnode_tuple_inputs);
674       auto abs = std::make_shared<abstract::AbstractTuple>(abs_list);
675       cnode_tuple->set_abstract(abs);
676       (void)new_inputs.emplace_back(cnode_tuple);
677 
678       // Update knode inputs
679       auto knode_input = ir_bprop_->ad_param()->tape_->FuncGraph::NewCNode(knode_inputs);
680       knode_input->set_abstract(abs);
681       size_t begin_index = i + kIndex1;
682       auto it = cnode_inputs->erase(cnode_inputs->begin() + SizeToLong(begin_index),
683                                     cnode_inputs->begin() + SizeToLong(begin_index) + dyn_input_sizes[i]);
684       (void)cnode_inputs->insert(it, knode_input);
685 
686       // Update input value
687       auto item = inputs_value->erase(inputs_value->begin() + SizeToLong(kIndex0),
688                                       inputs_value->begin() + SizeToLong(kIndex0) + dyn_input_sizes[i]);
689       (void)inputs_value->insert(item, std::make_shared<ValueTuple>(value_tuple));
690     } else {
691       auto last_index = (i == 0 ? 0 : i - 1);
692       auto skip_index = (dyn_input_sizes[last_index] == -1 ? 1 : dyn_input_sizes[last_index]);
693       (void)new_inputs.emplace_back(cnode->input(i + skip_index));
694     }
695   }
696   cnode->set_inputs(new_inputs);
697   (void)cnode->EraseAttr(kTupleToMakeTuple);
698 }
699 
ReverseBNInfer(const CNodePtr & cnode)700 void IrPassForward::ReverseBNInfer(const CNodePtr &cnode) {
701   MS_EXCEPTION_IF_NULL(cnode);
702   if (!IsPrimitiveCNode(cnode, prim::kPrimBNInferGrad)) {
703     return;
704   }
705   const auto item = node_attr_value_.find(cnode);
706   if (item == node_attr_value_.end()) {
707     return;
708   }
709   auto func_graph = cnode->func_graph();
710   MS_EXCEPTION_IF_NULL(func_graph);
711   auto manager = func_graph->manager();
712   if (manager == nullptr) {
713     manager = Manage(func_graph, false);
714   }
715   if (item->second.size() != kIndex1) {
716     MS_LOG(EXCEPTION) << "Replace item size " << item->second.size() << " is not equal to " << kIndex1;
717   }
718   if (!manager->Replace(cnode, item->second[kIndex0].second)) {
719     MS_LOG(EXCEPTION) << "Replace failed. cnode " << cnode->DebugString() << " to cnode "
720                       << item->second[kIndex0].second->DebugString();
721   }
722   (void)node_attr_value_.erase(item);
723 }
724 
ReverseCNodeInputs(const CNodePtr & cnode,AnfNodePtrList * cnode_inputs,ValuePtrList * inputs_value)725 void IrPassForward::ReverseCNodeInputs(const CNodePtr &cnode, AnfNodePtrList *cnode_inputs,
726                                        ValuePtrList *inputs_value) {
727   MS_EXCEPTION_IF_NULL(cnode);
728   MS_EXCEPTION_IF_NULL(inputs_value);
729   MS_EXCEPTION_IF_NULL(cnode_inputs);
730   const auto item = node_attr_value_.find(cnode);
731   if (item == node_attr_value_.end()) {
732     return;
733   }
734   for (const auto &t : item->second) {
735     if (t.second->isa<ValueNode>()) {
736       auto vnode = t.second->cast<ValueNodePtr>();
737       auto v = vnode->value();
738       (void)PyNativeAlgo::Common::SetValueGradInfo(v, nullptr, InputType::kConstant);
739       AddCNodeInputs(cnode, cnode_inputs, t.first, PyNativeAlgo::Common::CreateValueNodeByValue(v, nullptr));
740       (void)inputs_value->insert(inputs_value->begin() + SizeToLong(t.first), v);
741     } else if (t.second->isa<Parameter>()) {
742       const auto it = ir_bprop_->ad_param()->anfnode_to_variable_adjoint_.find(t.second);
743       if (it == ir_bprop_->ad_param()->anfnode_to_variable_adjoint_.end()) {
744         MS_LOG(EXCEPTION) << "Can not find " << t.second << " in anfnode_to_variable_adjoint";
745       }
746       AddCNodeInputs(cnode, cnode_inputs, t.first, it->second->k_node());
747       (void)inputs_value->insert(inputs_value->begin() + SizeToLong(t.first), it->second->out_value());
748     } else {
749       MS_LOG(EXCEPTION) << "No scenario for " << t.second->DebugString();
750     }
751   }
752   (void)node_attr_value_.erase(item);
753 }
754 
ReversePassFuncGraph(const FuncGraphPtr & func_graph)755 void IrPassForward::ReversePassFuncGraph(const FuncGraphPtr &func_graph) {
756   MS_EXCEPTION_IF_NULL(func_graph);
757   const auto &order = TopoSort(func_graph->output());
758   for (const auto &node : order) {
759     if (node == nullptr || !node->isa<CNode>()) {
760       continue;
761     }
762     auto cnode = node->cast<CNodePtr>();
763     MS_EXCEPTION_IF_NULL(cnode);
764     // Bn Ascend only
765     if (device_target_ == kAscendDevice) {
766       ReverseBNInfer(cnode);
767     }
768   }
769   need_reverse_graph_ = false;
770   PyNativeAlgo::Common::DumpGraphIR("reverse_cnode_graph.ir", func_graph);
771 }
772 
ReversePassCNode(const CNodePtr & cnode,ValuePtrList * inputs_value,AnfNodePtrList * cnode_inputs)773 void IrPassForward::ReversePassCNode(const CNodePtr &cnode, ValuePtrList *inputs_value, AnfNodePtrList *cnode_inputs) {
774   // Notice, The reverser step is opposite to the positive pass
775   auto tape_graph = ir_bprop_->ad_param()->tape_;
776   MS_EXCEPTION_IF_NULL(tape_graph);
777 
778   ReverseMakeTupleNode(cnode, inputs_value, cnode_inputs);
779   ReverseConstantToAttrNode(cnode, inputs_value, cnode_inputs);
780 }
781 
PassForDin(const CNodePtr & cnode,const std::string & op_name,bool is_dynamic_shape)782 CNodePtr IrPassForward::PassForDin(const CNodePtr &cnode, const std::string &op_name, bool is_dynamic_shape) {
783   // If you want add a pass here, please take care of high grad
784   MS_EXCEPTION_IF_NULL(ir_bprop_);
785   AnfNodePtr new_din = ConvertConstInputToAttr(cnode, is_dynamic_shape);
786 
787   // Ascend only
788   if (device_target_ == kAscendDevice) {
789     new_din = BatchNormGradToBNInferGrad(new_din, op_name);
790     new_din = GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(new_din, op_name, ir_bprop_);
791   }
792   return new_din->cast<CNodePtr>();
793 }
794 
795 bool IrPassForward::need_reverse_graph_ = false;
796 
ClearCache()797 void ClearCache() { node_attr_value_.clear(); }
798 }  // namespace bprop_pass
799 }  // namespace pynative
800 }  // namespace mindspore
801