• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/pynative/grad/function/func_pass.h"
18 #include <memory>
19 #include <vector>
20 #include <functional>
21 #include "pipeline/pynative/pynative_utils.h"
22 #include "ops/sequence_ops.h"
23 #include "ops/nn_ops.h"
24 #include "ops/op_utils.h"
25 #include "include/backend/optimizer/helper.h"
26 
27 namespace mindspore {
28 namespace pynative {
29 namespace bprop_pass {
30 namespace {
ChangeInputToAttr(const PrimitivePtr & prim,const NodePtrList & inputs,const ValuePtr & input_names,const mindspore::HashSet<size_t> & input_to_attr)31 NodePtrList ChangeInputToAttr(const PrimitivePtr &prim, const NodePtrList &inputs, const ValuePtr &input_names,
32                               const mindspore::HashSet<size_t> &input_to_attr) {
33   MS_EXCEPTION_IF_NULL(prim);
34   MS_EXCEPTION_IF_NULL(input_names);
35   const auto &input_names_vec = GetValue<std::vector<std::string>>(input_names);
36   NodePtrList new_inputs{};
37   size_t convert_size = 0;
38   size_t input_size = inputs.size();
39   for (size_t i = 0; i < input_size; ++i) {
40     auto value = inputs[i]->Value();
41     if (value->isa<Scalar>() && input_to_attr.find(i) != input_to_attr.end()) {
42       MS_LOG(DEBUG) << "start erase input[" << i << "] of op[" + prim->name() + "]";
43       if (i >= input_names_vec.size()) {
44         MS_LOG(EXCEPTION) << "Index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
45       }
46       if (value->isa<tensor::BaseTensor>()) {
47         auto tensor = value->cast<tensor::BaseTensorPtr>();
48         if (tensor->data().const_data() == nullptr && !tensor->has_user_data(kTensorValueIsEmpty)) {
49           return inputs;
50         }
51       }
52       ++convert_size;
53       prim->set_attr(input_names_vec[i], value);
54     } else {
55       (void)new_inputs.emplace_back(inputs[i]);
56     }
57   }
58   if (convert_size > 0) {
59     (void)prim->AddAttr(kAttrConvertAttrNode, MakeValue(convert_size));
60   }
61   return new_inputs;
62 }
63 
64 class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR {
65  public:
Run(const NodePtrList & inputs,const NodePtr & dout)66   NodePtr Run(const NodePtrList &inputs, const NodePtr &dout) {
67     GetDepthAndBatchSizeFromSparseSoftmaxNode(inputs);
68 
69     NodePtrList softmax_node_outputs;
70     auto expand_dims_node = CreateMulInput(inputs, dout, &softmax_node_outputs);
71 
72     NodePtr new_mul_node =
73       func_builder_->EmitOp(func_builder_->NewPrimitive(kMulOpName), {softmax_node_outputs[kIndex1], expand_dims_node});
74     // Reshape 1D result to multi-dim result.
75     auto reshape_node = CreateReshape(new_mul_node, BaseShapeToShape(inputs[kIndex0]->GetShape()));
76     return reshape_node;
77   }
78 
79   autograd::FuncBuilder *func_builder_{nullptr};
80 
81  private:
CreateReshape(const NodePtr & input_node,const ShapeVector & shape)82   NodePtr CreateReshape(const NodePtr &input_node, const ShapeVector &shape) {
83     MS_EXCEPTION_IF_NULL(input_node);
84     std::vector<std::string> input_names = {"x", "shape"};
85     std::vector<std::string> output_names = {"output"};
86     auto prim = func_builder_->NewPrimitive(
87       kReshapeOpName, {{kAttrInputNames, MakeValue(input_names)}, {kAttrOutputNames, MakeValue(output_names)}});
88     constexpr auto kShapeFromTensor = "shape_from_tensor";
89     prim->set_attr(kShapeFromTensor, MakeValue(true));
90     auto shape_node = func_builder_->NewFuncNode(PyNativeAlgo::Common::CreateTensorByConstantValue(MakeValue(shape)),
91                                                  nullptr, InputType::kConstant);
92     shape_node->set_abstract(shape_node->Value()->ToAbstract());
93     return func_builder_->EmitOp(prim, {input_node, shape_node});
94   }
95 
GetDepthAndBatchSizeFromSparseSoftmaxNode(const NodePtrList & inputs)96   void GetDepthAndBatchSizeFromSparseSoftmaxNode(const NodePtrList &inputs) {
97     auto logits_shape = BaseShapeToShape(inputs[kIndex0]->GetShape());
98     auto labels_shape = BaseShapeToShape(inputs[kIndex1]->GetShape());
99     if (!logits_shape.empty()) {
100       size_t index = logits_shape.size() - 1;
101       depth_ = logits_shape[index];
102     } else {
103       MS_LOG(EXCEPTION) << "Logits's shape of node SparseSoftmaxCrossEntropyWithLogit is empty";
104     }
105     batch_size_ = std::accumulate(labels_shape.begin(), labels_shape.end(), 1, std::multiplies<int64_t>());
106   }
107 
CreateOneHot(const NodePtrList & inputs)108   NodePtr CreateOneHot(const NodePtrList &inputs) {
109     ShapeVector shape = ShapeVector{batch_size_};
110 
111     // Reshape multi-dim labels to 1D labels.
112     auto reshape_node = CreateReshape(inputs[kIndex1], shape);
113 
114     auto value_on = std::make_shared<tensor::Tensor>(1.0, kFloat32);
115     auto value_off = std::make_shared<tensor::Tensor>(0.0, kFloat32);
116     auto value_axis = MakeValue<int64_t>(-1);
117     std::vector<std::string> input_names = {"indices", "depth", "on_value", "off_value", "axis"};
118     std::vector<std::string> output_names = {"output"};
119     auto one_hot_primitive = func_builder_->NewPrimitive(
120       kOneHotOpName, {{kAttrInputNames, MakeValue(input_names)}, {kAttrOutputNames, MakeValue(output_names)}});
121     auto depth_node = func_builder_->NewFuncNode(MakeValue<int64_t>(depth_), nullptr, InputType::kConstant);
122     depth_node->set_abstract(depth_node->Value()->ToAbstract());
123     auto value_on_node = func_builder_->NewFuncNode(value_on, nullptr, InputType::kConstant);
124     value_on_node->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(value_on_node->Value()->ToAbstract()));
125     auto value_off_node = func_builder_->NewFuncNode(value_off, nullptr, InputType::kConstant);
126     value_off_node->set_abstract(value_off_node->Value()->ToAbstract());
127     auto value_axis_node = func_builder_->NewFuncNode(value_axis, nullptr, InputType::kConstant);
128     value_axis_node->set_abstract(
129       PyNativeAlgo::Common::SetAbstractValueToAnyValue(value_axis_node->Value()->ToAbstract()));
130     NodePtrList one_hot_inputs{reshape_node, depth_node, value_on_node, value_off_node, value_axis_node};
131     return func_builder_->EmitOp(one_hot_primitive, one_hot_inputs);
132   }
133 
CreateSoftmaxCrossEntropyWithLogits(const NodePtrList & inputs,const NodePtr & one_hot_node)134   NodePtr CreateSoftmaxCrossEntropyWithLogits(const NodePtrList &inputs, const NodePtr &one_hot_node) {
135     MS_EXCEPTION_IF_NULL(one_hot_node);
136     ShapeVector shape = ShapeVector{batch_size_, depth_};
137     // Reshape multi-dim logits to 2D logits.
138     auto reshape_node = CreateReshape(inputs[kIndex0], shape);
139     auto softmax_prim = func_builder_->NewPrimitive(kSoftmaxCrossEntropyWithLogitsOpName);
140     return func_builder_->EmitOp(softmax_prim, {reshape_node, one_hot_node});
141   }
142 
CreateMultipleOutputsOfAnfNode(const NodePtr & node,size_t output_num,NodePtrList * outputs)143   void CreateMultipleOutputsOfAnfNode(const NodePtr &node, size_t output_num, NodePtrList *outputs) {
144     MS_EXCEPTION_IF_NULL(node);
145     MS_EXCEPTION_IF_NULL(outputs);
146     MS_EXCEPTION_IF_NULL(node->abstract());
147     const auto &abs_seq = node->abstract()->cast<abstract::AbstractSequencePtr>();
148     MS_EXCEPTION_IF_NULL(abs_seq);
149     if (abs_seq->size() != output_num) {
150       MS_LOG(EXCEPTION) << "Abstract seq size " << abs_seq->size() << " is not equal to " << output_num;
151     }
152     for (size_t i = 0; i < output_num; i++) {
153       (void)outputs->emplace_back(func_builder_->TupleGetItem(node, i));
154     }
155   }
156 
CreateTile(const NodePtrList & inputs,const NodePtr & dout)157   NodePtr CreateTile(const NodePtrList &inputs, const NodePtr &dout) {
158     if (batch_size_ == 1) {
159       return nullptr;
160     }
161     std::vector<std::string> input_names = {"x", "multiples"};
162     std::vector<std::string> output_names = {"output"};
163     auto tile_primitive = func_builder_->NewPrimitive(
164       kTileOpName, {{kAttrInputNames, MakeValue(input_names)}, {kAttrOutputNames, MakeValue(output_names)}});
165     NodePtrList tile_inputs;
166     if (batch_size_ < 0) {
167       auto shape_node = func_builder_->EmitOp(func_builder_->NewPrimitive("DynamicShape"), {inputs[kIndex1]});
168       tile_inputs = {dout, shape_node};
169     } else {
170       std::vector<int64_t> multiples_v = {batch_size_};
171       auto multiples_node = func_builder_->NewFuncNode(MakeValue(multiples_v), nullptr, InputType::kConstant);
172       multiples_node->set_abstract(multiples_node->Value()->ToAbstract());
173       tile_inputs = {dout, multiples_node};
174     }
175     auto tile_node = func_builder_->EmitOp(tile_primitive, tile_inputs);
176     // feature map set
177     std::vector<size_t> feature_map_input_indexs;
178     (void)feature_map_input_indexs.emplace_back(0);
179     constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
180     tile_primitive->set_attr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs));
181     return tile_node;
182   }
183 
CreateRealDiv(const NodePtr & tile_node)184   NodePtr CreateRealDiv(const NodePtr &tile_node) {
185     MS_EXCEPTION_IF_NULL(tile_node);
186     auto y_value = static_cast<float>(batch_size_);
187     auto y = std::make_shared<tensor::Tensor>(y_value, kFloat32);
188     auto y_node = func_builder_->NewFuncNode(y, nullptr, InputType::kConstant);
189     y_node->set_abstract(PyNativeAlgo::Common::SetAbstractValueToAnyValue(y_node->Value()->ToAbstract()));
190     std::vector<std::string> input_names = {"x", "y"};
191     std::vector<std::string> output_names = {"output"};
192     auto real_div_primitive = func_builder_->NewPrimitive(
193       kRealDivOpName, {{kAttrInputNames, MakeValue(input_names)}, {kAttrOutputNames, MakeValue(output_names)}});
194     return func_builder_->EmitOp(real_div_primitive, {tile_node, y_node});
195   }
196 
CreateExpandDims(const NodePtr & real_div_node)197   NodePtr CreateExpandDims(const NodePtr &real_div_node) {
198     MS_EXCEPTION_IF_NULL(real_div_node);
199     constexpr int64_t axis = -1;
200     auto axis_v = MakeValue(axis);
201     auto axis_node = func_builder_->NewFuncNode(axis_v, nullptr, InputType::kConstant);
202     axis_node->set_abstract(axis_v->ToAbstract());
203     std::vector<std::string> input_names = {"x"};
204     std::vector<std::string> output_names = {"output"};
205     auto expand_dims_primitive = func_builder_->NewPrimitive(
206       kExpandDimsOpName, {{kAttrInputNames, MakeValue(input_names)}, {kAttrOutputNames, MakeValue(output_names)}});
207     expand_dims_primitive->set_attr(kAttrAxis, axis_v);
208     return func_builder_->EmitOp(expand_dims_primitive, {real_div_node, axis_node});
209   }
210 
CreateMulInput(const NodePtrList & inputs,const NodePtr & dout,NodePtrList * softmax_node_outputs)211   NodePtr CreateMulInput(const NodePtrList &inputs, const NodePtr &dout, NodePtrList *softmax_node_outputs) {
212     MS_EXCEPTION_IF_NULL(softmax_node_outputs);
213     auto one_hot_node = CreateOneHot(inputs);
214     auto softmax_node = CreateSoftmaxCrossEntropyWithLogits(inputs, one_hot_node);
215     CreateMultipleOutputsOfAnfNode(softmax_node, opt::kSoftmaxCrossEntropyWithLogitsOutputNum, softmax_node_outputs);
216     auto tile_node = CreateTile(inputs, dout);
217     NodePtr real_div_node;
218     if (tile_node == nullptr) {
219       real_div_node = CreateRealDiv(dout);
220     } else {
221       real_div_node = CreateRealDiv(tile_node);
222     }
223     auto expand_dims_node = CreateExpandDims(real_div_node);
224     return expand_dims_node;
225   }
226 
227   int64_t batch_size_{0};
228   int64_t depth_{0};
229 };
230 
SplitTupleInputs(autograd::FuncBuilder * func_builder,const NodePtr & input,NodePtrList * plant_inputs)231 size_t SplitTupleInputs(autograd::FuncBuilder *func_builder, const NodePtr &input, NodePtrList *plant_inputs) {
232   MS_EXCEPTION_IF_NULL(func_builder);
233   MS_EXCEPTION_IF_NULL(input);
234   MS_EXCEPTION_IF_NULL(plant_inputs);
235   MS_EXCEPTION_IF_NULL(input->Value());
236   auto input_abs = input->abstract();
237   auto value_seq = input->Value()->cast<ValueSequencePtr>()->value();
238   auto abs_seq = input_abs->cast<abstract::AbstractSequencePtr>();
239   MS_EXCEPTION_IF_NULL(abs_seq);
240   size_t input_size = value_seq.size();
241   for (size_t i = 0; i < input_size; ++i) {
242     const auto &value = value_seq[i];
243     const auto &abs = abs_seq->elements()[i];
244     (void)plant_inputs->emplace_back(func_builder->NewFuncNode(value, abs, input->input_type()));
245   }
246   return input_size;
247 }
248 }  // namespace
249 
ConvertMakeTupleInputToDynamicInput(const PrimitivePtr & prim,const NodePtrList & inputs)250 NodePtrList FuncPassForward::ConvertMakeTupleInputToDynamicInput(const PrimitivePtr &prim, const NodePtrList &inputs) {
251   MS_EXCEPTION_IF_NULL(prim);
252   if (!IsPrimitiveEquals(prim, prim::kPrimMakeTuple) &&
253       std::any_of(inputs.begin(), inputs.end(),
254                   [](const NodePtr &node) { return node->Value()->isa<abstract::AbstractSequence>(); })) {
255     NodePtrList plant_inputs;
256     std::vector<int64_t> dyn_input_sizes;
257     for (const auto &input : inputs) {
258       MS_EXCEPTION_IF_NULL(input->Value());
259       if (input->Value()->isa<ValueSequence>()) {
260         auto dyn_input_size = SplitTupleInputs(func_builder_, input, &plant_inputs);
261         (void)dyn_input_sizes.emplace_back(dyn_input_size);
262       } else {
263         (void)plant_inputs.emplace_back(input);
264         (void)dyn_input_sizes.emplace_back(-1);
265       }
266     }
267     // If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
268     if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
269       prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_input_sizes));
270       MS_LOG(DEBUG) << "Change node to dynamic len " << prim->name();
271     }
272     return plant_inputs;
273   }
274   return inputs;
275 }
276 
ConvertConstInputToAttr(const PrimitivePtr & prim,const NodePtrList & inputs)277 NodePtrList FuncPassForward::ConvertConstInputToAttr(const PrimitivePtr &prim, const NodePtrList &inputs) {
278   MS_EXCEPTION_IF_NULL(prim);
279   mindspore::HashSet<size_t> input_to_attr = {};
280   PyNativeAlgo::Common::GetConstInputToAttr(prim, prim->name(), device_target_, false, &input_to_attr);
281   if (input_to_attr.empty()) {
282     return inputs;
283   }
284   const auto &input_names = prim->GetAttr(kAttrInputNames);
285   if (input_names == nullptr) {
286     MS_LOG(DEBUG) << "input_names are nullptr";
287     return inputs;
288   }
289   return ChangeInputToAttr(prim, inputs, input_names, input_to_attr);
290 }
291 
BatchNormGradToBNInferGrad(const NodePtrList & inputs,bool is_scale_or_bias_grad)292 NodePtr FuncPassForward::BatchNormGradToBNInferGrad(const NodePtrList &inputs, bool is_scale_or_bias_grad) {
293   if (device_target_ != kAscendDevice || is_scale_or_bias_grad) {
294     return func_builder_->Emit(kBatchNormGradOpName, inputs);
295   }
296   constexpr size_t kIdxIsTraining = 6;
297   auto is_training_opt = mindspore::ops::GetScalarValue<bool>(inputs[kIdxIsTraining]->Value());
298   if (!is_training_opt.has_value()) {
299     MS_LOG(DEBUG) << "Can not find Attr 'is_training' in training input";
300     return func_builder_->Emit(kBatchNormGradOpName, inputs);
301   }
302   if (is_training_opt.value()) {
303     MS_LOG(DEBUG) << "Attr 'is_training' is true, no need do fusion";
304     return func_builder_->Emit(kBatchNormGradOpName, inputs);
305   }
306 
307   auto bn_infer_grad_prim = func_builder_->NewPrimitive(kBNInferGradOpName);
308   constexpr size_t kIdxGrads = 0;
309   constexpr size_t kIdxScale = 2;
310   constexpr size_t kIdxVariance = 4;
311   constexpr size_t kIdxEpsilon = 7;
312   NodePtrList new_inputs{inputs[kIdxGrads], inputs[kIdxScale], inputs[kIdxVariance], inputs[kIdxEpsilon]};
313 
314   auto epsilon_opt = mindspore::ops::GetScalarValue<pyfloat>(inputs[kIdxEpsilon]->Value());
315   float epsilon{1e-5};
316   if (epsilon_opt.has_value()) {
317     epsilon = epsilon_opt.has_value() ? epsilon_opt.value() : 1e-5;
318   } else {
319     MS_LOG(ERROR) << "For BNInferGrad pass, failed to get attr epsilon, use default epsilon: 1e-5.";
320   }
321   bn_infer_grad_prim->set_attr(kAttrIsTraining, MakeValue(epsilon));
322   bn_infer_grad_prim->set_attr(kAttrIsTraining, MakeValue(is_training_opt.value()));
323   auto dx = func_builder_->EmitOp(bn_infer_grad_prim, new_inputs);
324   return func_builder_->MakeTuple(
325     {dx, func_builder_->OutZeros(inputs[kIdxScale]), func_builder_->OutZeros(inputs[kIdxScale])});
326 }
327 
GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(const NodePtrList & inputs,const expander::DAttr & attrs,const NodePtr & out,const NodePtr & dout,bool is_graph_mode)328 NodePtr FuncPassForward::GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(const NodePtrList &inputs,
329                                                                             const expander::DAttr &attrs,
330                                                                             const NodePtr &out, const NodePtr &dout,
331                                                                             bool is_graph_mode) {
332   if (device_target_ != kAscendDevice) {
333     auto grad = func_builder_->Emit(kSparseSoftmaxCrossEntropyWithLogitsOpName, inputs, attrs);
334     if (is_graph_mode) {
335       grad = func_builder_->Depend(grad, out);
336     }
337     grad = func_builder_->Emit(kMulOpName, {grad, dout});
338     return grad;
339   }
340 
341   // Use static class for create only once
342   static auto sparse_softmax_cross_entropy_with_logits =
343     std::make_shared<SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR>();
344   sparse_softmax_cross_entropy_with_logits->func_builder_ = func_builder_;
345   return sparse_softmax_cross_entropy_with_logits->Run(inputs, dout);
346 }
347 
PassForOpInput(const PrimitivePtr & prim,const NodePtrList & inputs)348 NodePtrList FuncPassForward::PassForOpInput(const PrimitivePtr &prim, const NodePtrList &inputs) {
349   MS_EXCEPTION_IF_NULL(func_builder_);
350   if (prim == nullptr) {
351     NodePtrList new_inputs = ConvertConstInputToAttr(prim, inputs);
352     return ConvertMakeTupleInputToDynamicInput(prim, new_inputs);
353   }
354   return inputs;
355 }
356 }  // namespace bprop_pass
357 }  // namespace pynative
358 }  // namespace mindspore
359