• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/pynative/grad/function/func_builder.h"
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include <set>
26 #include "runtime/pynative/op_function/pyboost_grad_functions.h"
27 #include "include/backend/optimizer/helper.h"
28 #include "include/backend/optimizer/op_adaptation_info_factory.h"
29 #include "pipeline/pynative/pynative_utils.h"
30 #include "mindspore/core/ops/op_utils.h"
31 #include "frontend/operator/cc_implementations.h"
32 
33 namespace mindspore::pynative::autograd {
34 namespace {
35 template <typename T>
PrintDebugInfo(std::vector<T> items,const std::string & info_header="")36 std::string PrintDebugInfo(std::vector<T> items, const std::string &info_header = "") {
37   static constexpr size_t end_char_size = 2;
38   std::ostringstream buf;
39   buf << info_header;
40   for (size_t i = 0; i < items.size(); ++i) {
41     if (items[i] == nullptr) {
42       MS_LOG(DEBUG) << "The " << i << "'th item is nullptr!";
43       continue;
44     }
45     if (items[i]->template isa<tensor::BaseTensor>()) {
46       auto tensor = items[i]->template cast<tensor::BaseTensorPtr>();
47       auto grad = std::make_shared<tensor::Tensor>(*tensor);
48       grad->data_sync();
49       buf << i << "th: "
50           << "ptr " << items[i].get() << ", " << grad->ToStringRepr() << ", ";
51     } else {
52       buf << i << "th: "
53           << "ptr " << items[i].get() << ", " << items[i]->ToString() << ", ";
54     }
55   }
56   return buf.str().erase(buf.str().size() - end_char_size);
57 }
58 
GetValueDependArgIndices(const PrimitivePtr & primitive,const NodePtrList & inputs)59 std::set<int64_t> GetValueDependArgIndices(const PrimitivePtr &primitive, const NodePtrList &inputs) {
60   auto depend_list = ops::GetInputDependValueList(primitive);
61   auto attr = primitive->GetAttr(kAttrDynInputSizes);
62   if (attr == nullptr) {
63     return depend_list;
64   }
65   // mapping from input prototype index to corresponding start index of real input
66   std::vector<int64_t> dyn_input_sizes = GetValue<std::vector<int64_t>>(attr);
67   if (!dyn_input_sizes.empty()) {
68     auto temp_depend_list = depend_list;
69     depend_list.clear();
70     for (const auto item : temp_depend_list) {
71       int64_t offset = 0;
72       for (int64_t i = 0; i < item; i++) {
73         auto idx = static_cast<size_t>(i);
74         if (dyn_input_sizes[idx] == -1) {
75           offset += 1;
76         } else {
77           offset += dyn_input_sizes[idx];
78         }
79       }
80       depend_list.emplace(offset);
81       MS_LOG(DEBUG) << "Adjust depend list from " << item << " to " << offset << " for op: " << primitive->name();
82     }
83   }
84   return depend_list;
85 }
86 
SetDependValue(const PrimitivePtr & primitive,const NodePtrList & inputs)87 void SetDependValue(const PrimitivePtr &primitive, const NodePtrList &inputs) {
88   auto depend_list = GetValueDependArgIndices(primitive, inputs);
89   if (depend_list.empty()) {
90     return;
91   }
92   int64_t input_size = inputs.size();
93   for (const auto index : depend_list) {
94     if (index >= input_size) {
95       MS_LOG(EXCEPTION) << "For depend list index should be less than inputs size: " << input_size
96                         << ", but got index: " << index;
97     }
98     const auto abstract = inputs[index]->abstract();
99     const auto value = inputs[index]->Value();
100     auto tensor = value->cast<tensor::BaseTensorPtr>();
101     if (tensor != nullptr) {
102       tensor->data_sync();
103     }
104     abstract->set_value(value);
105   }
106 }
107 
BuildShape(const abstract::AbstractBasePtr & abs)108 std::vector<int64_t> BuildShape(const abstract::AbstractBasePtr &abs) {
109   auto base_shape = abs->BuildShape();
110   if (base_shape->isa<abstract::NoShape>()) {
111     return {};
112   }
113   auto shape = base_shape->cast<abstract::ShapePtr>();
114   MS_EXCEPTION_IF_NULL(shape);
115   return shape->shape();
116 }
117 
ParseCond(const NodePtr & cond)118 bool ParseCond(const NodePtr &cond) {
119   auto cond_val = cond->Value();
120   if (cond_val->isa<BoolImm>()) {
121     return GetValue<bool>(cond_val);
122   }
123   if (cond_val->isa<tensor::BaseTensor>()) {
124     auto tensor = cond_val->cast<tensor::BaseTensorPtr>();
125     tensor->data_sync();
126     size_t data_size = tensor->DataSize();
127     auto tensor_type = tensor->Dtype();
128     if (tensor_type->type_id() == kNumberTypeBool) {
129       auto data_c = reinterpret_cast<bool *>(tensor->data_c());
130       MS_EXCEPTION_IF_NULL(data_c);
131       return std::all_of(data_c, data_c + data_size, [](const bool &data) { return static_cast<bool>(data); });
132     }
133   }
134   MS_LOG(EXCEPTION) << "For control flow, the cond should be Tensor[bool] or bool, but got: " << cond_val->ToString();
135 }
136 }  // namespace
137 
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)138 NodePtr FuncBuilder::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) {
139   MS_LOG(DEBUG) << "Emit op " << prim->name();
140   auto real_inputs = pass_forward_->PassForOpInput(prim, inputs);
141   std::vector<ValuePtr> op_inputs;
142   op_inputs.reserve(real_inputs.size());
143   abstract::AbstractBasePtrList input_abs;
144   input_abs.reserve(real_inputs.size());
145   std::vector<InputType> input_mask;
146   input_mask.reserve(real_inputs.size());
147   SetDependValue(prim, inputs);
148   for (const auto &input : real_inputs) {
149     auto abs = input->abstract();
150     auto value = FillZeros(input->Value(), abs);
151     (void)op_inputs.emplace_back(value);
152     (void)input_abs.emplace_back(abs);
153     (void)input_mask.emplace_back(input->input_type());
154   }
155   MS_LOG(DEBUG) << "Get input value size " << op_inputs.size() << ", "
156                 << PyNativeAlgo::Common::PrintDebugInfo(op_inputs);
157   MS_LOG(DEBUG) << "Get input abs size " << input_abs.size() << ", " << PyNativeAlgo::Common::PrintDebugInfo(input_abs);
158   VectorRef outputs;
159   runtime::OpRunnerInfo op_runner_info{prim, device_target_, op_inputs, input_abs, input_mask, nullptr};
160   runtime::PyBoostOpExecute::GetInstance().Execute(&op_runner_info, &outputs);
161   auto real_outputs = common::AnfAlgo::TransformVectorRefToMultiValue(outputs);
162   MS_LOG(DEBUG) << "Get output value size " << real_outputs.size() << ", "
163                 << PyNativeAlgo::Common::PrintDebugInfo(real_outputs);
164   if (op_runner_info.output_value_simple_info != nullptr) {
165     // Get output abstract
166     op_runner_info.output_abs = TransformValueSimpleInfoToAbstract(*op_runner_info.output_value_simple_info);
167   }
168   ValuePtr value_result;
169   MS_EXCEPTION_IF_NULL(op_runner_info.output_abs);
170   if (real_outputs.size() == kSizeOne && !op_runner_info.output_abs->isa<abstract::AbstractSequence>()) {
171     value_result = real_outputs[kIndex0];
172   } else {
173     value_result = std::make_shared<ValueTuple>(std::move(real_outputs));
174   }
175   // Set abstract to tensor cache
176   if (op_runner_info.output_value_simple_info != nullptr) {
177     PyNativeAlgo::AutoGrad::CacheOutputAbstract(value_result, op_runner_info.output_abs);
178   }
179   auto result = NewFuncNode(value_result, op_runner_info.output_abs, InputType::kOpOutput);
180   return result;
181 }
182 
EmitValue(const ValuePtr & value)183 NodePtr FuncBuilder::EmitValue(const ValuePtr &value) {
184   // For constant value, its abstract may not use, we delay set abs, if op use its abstract, we can get abstract
185   // from FuncBuilder::abstract()
186   auto node = NewFuncNode(value, nullptr, InputType::kConstant);
187   return node;
188 }
189 
Stack(const NodePtr & x,const ValuePtr & axis_value)190 NodePtr FuncBuilder::Stack(const NodePtr &x, const ValuePtr &axis_value) {
191   NodePtrList node_inputs = FlattenNode(x);
192   int64_t axis = GetValue<int64_t>(axis_value);
193   return Stack(node_inputs, axis);
194 }
195 
Stack(const NodePtrList & x,int64_t axis)196 NodePtr FuncBuilder::Stack(const NodePtrList &x, int64_t axis) {
197   std::vector<int64_t> dyn_size{static_cast<int64_t>(x.size()), -1};
198   expander::DAttr attrs{std::make_pair(kAttrDynInputSizes, MakeValue(dyn_size)),
199                         std::make_pair("axis", MakeValue(axis))};
200   return Emit(kStackOpName, x, attrs);
201 }
202 
BatchNormGrad(const NodePtrList & inputs,bool is_scale_or_bias_grad)203 NodePtr FuncBuilder::BatchNormGrad(const NodePtrList &inputs, bool is_scale_or_bias_grad) {
204   return pass_forward_->BatchNormGradToBNInferGrad(inputs, is_scale_or_bias_grad);
205 }
206 
SparseSoftmaxCrossEntropyWithLogits(const NodePtrList & inputs,const expander::DAttr & attrs,const NodePtr & out,const NodePtr & dout,bool is_graph_mode)207 NodePtr FuncBuilder::SparseSoftmaxCrossEntropyWithLogits(const NodePtrList &inputs, const expander::DAttr &attrs,
208                                                          const NodePtr &out, const NodePtr &dout, bool is_graph_mode) {
209   return pass_forward_->GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(inputs, attrs, out, dout, is_graph_mode);
210 }
211 
Depend(const NodePtr & value,const NodePtr & expr)212 NodePtr FuncBuilder::Depend(const NodePtr &value, const NodePtr &expr) { return value; }
213 
TupleGetItem(const NodePtr & input,size_t i)214 NodePtr FuncBuilder::TupleGetItem(const NodePtr &input, size_t i) {
215   auto value = input->Value();
216   if (!value->isa<ValueSequence>()) {
217     MS_LOG(EXCEPTION) << "Input value should be sequence"
218                       << "but got " << value->ToString();
219   }
220   auto seq = value->cast<ValueSequencePtr>();
221   if (seq->size() <= i) {
222     MS_LOG(EXCEPTION) << "Input value sequence size should > " << i << " but got " << value->ToString();
223   }
224   abstract::AbstractBasePtr item_abs = nullptr;
225   auto seq_abs = input->abstract()->cast<abstract::AbstractSequencePtr>();
226   if (seq_abs != nullptr && seq_abs->size() == seq->size()) {
227     item_abs = seq_abs->elements()[i];
228   }
229   return NewFuncNode(seq->value()[i], item_abs, input->input_type());
230 }
231 
OutZeros(const NodePtr & node)232 NodePtr FuncBuilder::OutZeros(const NodePtr &node) {
233   if (!node->Value()->isa<ValueSequence>()) {
234     return NewFuncNode(kNone, nullptr, InputType::kConstant);
235   }
236   auto val_seq = node->Value()->cast<ValueSequencePtr>();
237   if (val_seq->size() == kSizeZero) {
238     return NewFuncNode(kNone, nullptr, InputType::kConstant);
239   }
240   const auto &value = val_seq->value()[kIndexZero];
241   if (!value->isa<tensor::Tensor>()) {
242     return NewFuncNode(kNone, nullptr, InputType::kConstant);
243   } else {
244     ValuePtrList values(val_seq->size(), kNone);
245     return NewFuncNode(std::make_shared<ValueTuple>(values), nullptr, InputType::kConstant);
246   }
247 }
248 
Ones(const ValuePtr & value)249 ValuePtr FuncBuilder::Ones(const ValuePtr &value) {
250   auto ones_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(value->ToAbstract());
251   NodePtrList inputs{NewFuncNode(value, ones_abs, InputType::kOpOutput)};
252   return EmitOp(prim::kPrimOnesLike, inputs)->Value();
253 }
254 
Zeros(const ValuePtr & value)255 ValuePtr FuncBuilder::Zeros(const ValuePtr &value) {
256   auto zeros_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(value->ToAbstract());
257   auto input = NewFuncNode(value, zeros_abs, InputType::kOpOutput);
258   return ZerosLike(input)->Value();
259 }
260 
Add(const ValuePtr & input,const ValuePtr & other)261 ValuePtr FuncBuilder::Add(const ValuePtr &input, const ValuePtr &other) {
262   auto input_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(input->ToAbstract());
263   auto other_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(other->ToAbstract());
264   auto input_node = NewFuncNode(input, input_abs, InputType::kOpOutput);
265   auto other_node = NewFuncNode(other, other_abs, InputType::kOpOutput);
266   return Emit(mindspore::kAddOpName, {input_node, other_node})->Value();
267 }
268 
TupleGetItem(const NodePtr & input,const NodePtr & index)269 NodePtr FuncBuilder::TupleGetItem(const NodePtr &input, const NodePtr &index) {
270   auto value = index->Value();
271   size_t i = GetValue<int64_t>(value);
272   return TupleGetItem(input, i);
273 }
274 
MakeTuple(const NodePtrList & inputs)275 NodePtr FuncBuilder::MakeTuple(const NodePtrList &inputs) {
276   ValuePtrList values;
277   AbstractBasePtrList abs;
278   std::transform(inputs.begin(), inputs.end(), std::back_inserter(values),
279                  [](const NodePtr &node) { return node->Value(); });
280   auto value = std::make_shared<ValueTuple>(values);
281   auto tuple_node = NewFuncNode(value, nullptr, InputType::kOpOutput);
282   return tuple_node;
283 }
284 
MakeList(const NodePtrList & inputs)285 NodePtr FuncBuilder::MakeList(const NodePtrList &inputs) { return MakeTuple(inputs); }
286 
Conditional(const NodePtr & cond,const expander::Emitter::BlockFunc & true_case,const expander::Emitter::BlockFunc & false_case)287 NodePtr FuncBuilder::Conditional(const NodePtr &cond, const expander::Emitter::BlockFunc &true_case,
288                                  const expander::Emitter::BlockFunc &false_case) {
289   NodePtrList result;
290   if (ParseCond(cond)) {
291     result = true_case(this);
292   } else {
293     result = false_case(this);
294   }
295   if (result.size() == kSizeOne) {
296     return result[kIndex0];
297   }
298   return MakeTuple(result);
299 }
300 
ScalarEq(const NodePtr & lhs,const NodePtr & rhs,const TypePtr & dst_type)301 NodePtr FuncBuilder::ScalarEq(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) {
302   auto lhs_val = lhs->Value();
303   auto rhs_val = rhs->Value();
304   ValuePtr result;
305   if (lhs_val->isa<BoolImm>() && rhs_val->isa<BoolImm>()) {
306     result = MakeValue(GetValue<bool>(lhs_val) == GetValue<bool>(rhs_val));
307   } else {
308     result = prim::ScalarEq({lhs->Value(), rhs->Value()});
309   }
310   MS_LOG(DEBUG) << "ScalarEq op: lhs " << lhs_val->ToString() << ", rhs " << rhs_val->ToString();
311   return NewFuncNode(result, nullptr, InputType::kOpOutput);
312 }
313 
SetInputs(std::string instance_name,const std::vector<NodePtr> * inputs,mindspore::HashMap<std::string,ValuePtr> * attrs_ptr)314 void FuncBuilder::SetInputs(std::string instance_name, const std::vector<NodePtr> *inputs,
315                             mindspore::HashMap<std::string, ValuePtr> *attrs_ptr) {
316   instance_name_ = std::move(instance_name);
317   inputs_ptr_ = inputs;
318   attrs_ptr_ = attrs_ptr;
319 }
320 
FlattenNode(const NodePtr & input)321 NodePtrList FuncBuilder::FlattenNode(const NodePtr &input) {
322   if (!input->Value()->isa<ValueSequence>()) {
323     return {input};
324   }
325   auto value_seq = input->Value()->cast<ValueSequencePtr>()->value();
326   auto value_abs = input->abstract()->cast<abstract::AbstractSequencePtr>();
327   MS_EXCEPTION_IF_NULL(value_abs);
328   NodePtrList flattenNodes;
329   flattenNodes.reserve(value_seq.size());
330   for (size_t i = 0; i < value_seq.size(); ++i) {
331     auto &value = value_seq[i];
332     (void)flattenNodes.emplace_back(NewFuncNode(value, value_abs->elements()[i], input->input_type()));
333   }
334   return flattenNodes;
335 }
336 
FillZeros(const ValuePtr & value,const abstract::AbstractBasePtr & abs)337 ValuePtr FuncBuilder::FillZeros(const ValuePtr &value, const abstract::AbstractBasePtr &abs) {
338   auto convert_value = value;
339   if (value->isa<None>()) {
340     if (abs->isa<abstract::AbstractTensor>()) {
341       auto tensor_dtype = abs->BuildType()->cast<TensorTypePtr>();
342       MS_EXCEPTION_IF_NULL(tensor_dtype);
343       auto dtype = tensor_dtype->element();
344       auto shape = BuildShape(abs);
345       auto out_tensor = std::make_shared<tensor::Tensor>(dtype->type_id(), shape);
346       auto zero_node = ZerosLike(NewFuncNode(out_tensor, abs, InputType::kOpOutput));
347       convert_value = zero_node->Value();
348     } else {
349       MS_LOG(DEBUG) << "None value abstract got None abstract!";
350     }
351   } else if (value->isa<ValueSequence>()) {
352     auto seq = value->cast<ValueSequencePtr>();
353     auto abs_list = abs->cast<abstract::AbstractSequencePtr>();
354     MS_EXCEPTION_IF_NULL(abs_list);
355     std::vector<ValuePtr> value_list;
356     value_list.reserve(seq->value().size());
357     for (size_t i = 0; i < seq->value().size(); ++i) {
358       const auto &val = seq->value()[i];
359       const auto &temp_abs = abs_list->elements()[i];
360       auto convert = FillZeros(val, temp_abs);
361       (void)value_list.emplace_back(convert);
362     }
363     convert_value = std::make_shared<ValueTuple>(value_list);
364   }
365   return convert_value;
366 }
367 }  // namespace mindspore::pynative::autograd
368