1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/pynative/grad/function/func_builder.h"
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include <set>
26 #include "runtime/pynative/op_function/pyboost_grad_functions.h"
27 #include "include/backend/optimizer/helper.h"
28 #include "include/backend/optimizer/op_adaptation_info_factory.h"
29 #include "pipeline/pynative/pynative_utils.h"
30 #include "mindspore/core/ops/op_utils.h"
31 #include "frontend/operator/cc_implementations.h"
32
33 namespace mindspore::pynative::autograd {
34 namespace {
35 template <typename T>
PrintDebugInfo(std::vector<T> items,const std::string & info_header="")36 std::string PrintDebugInfo(std::vector<T> items, const std::string &info_header = "") {
37 static constexpr size_t end_char_size = 2;
38 std::ostringstream buf;
39 buf << info_header;
40 for (size_t i = 0; i < items.size(); ++i) {
41 if (items[i] == nullptr) {
42 MS_LOG(DEBUG) << "The " << i << "'th item is nullptr!";
43 continue;
44 }
45 if (items[i]->template isa<tensor::BaseTensor>()) {
46 auto tensor = items[i]->template cast<tensor::BaseTensorPtr>();
47 auto grad = std::make_shared<tensor::Tensor>(*tensor);
48 grad->data_sync();
49 buf << i << "th: "
50 << "ptr " << items[i].get() << ", " << grad->ToStringRepr() << ", ";
51 } else {
52 buf << i << "th: "
53 << "ptr " << items[i].get() << ", " << items[i]->ToString() << ", ";
54 }
55 }
56 return buf.str().erase(buf.str().size() - end_char_size);
57 }
58
GetValueDependArgIndices(const PrimitivePtr & primitive,const NodePtrList & inputs)59 std::set<int64_t> GetValueDependArgIndices(const PrimitivePtr &primitive, const NodePtrList &inputs) {
60 auto depend_list = ops::GetInputDependValueList(primitive);
61 auto attr = primitive->GetAttr(kAttrDynInputSizes);
62 if (attr == nullptr) {
63 return depend_list;
64 }
65 // mapping from input prototype index to corresponding start index of real input
66 std::vector<int64_t> dyn_input_sizes = GetValue<std::vector<int64_t>>(attr);
67 if (!dyn_input_sizes.empty()) {
68 auto temp_depend_list = depend_list;
69 depend_list.clear();
70 for (const auto item : temp_depend_list) {
71 int64_t offset = 0;
72 for (int64_t i = 0; i < item; i++) {
73 auto idx = static_cast<size_t>(i);
74 if (dyn_input_sizes[idx] == -1) {
75 offset += 1;
76 } else {
77 offset += dyn_input_sizes[idx];
78 }
79 }
80 depend_list.emplace(offset);
81 MS_LOG(DEBUG) << "Adjust depend list from " << item << " to " << offset << " for op: " << primitive->name();
82 }
83 }
84 return depend_list;
85 }
86
SetDependValue(const PrimitivePtr & primitive,const NodePtrList & inputs)87 void SetDependValue(const PrimitivePtr &primitive, const NodePtrList &inputs) {
88 auto depend_list = GetValueDependArgIndices(primitive, inputs);
89 if (depend_list.empty()) {
90 return;
91 }
92 int64_t input_size = inputs.size();
93 for (const auto index : depend_list) {
94 if (index >= input_size) {
95 MS_LOG(EXCEPTION) << "For depend list index should be less than inputs size: " << input_size
96 << ", but got index: " << index;
97 }
98 const auto abstract = inputs[index]->abstract();
99 const auto value = inputs[index]->Value();
100 auto tensor = value->cast<tensor::BaseTensorPtr>();
101 if (tensor != nullptr) {
102 tensor->data_sync();
103 }
104 abstract->set_value(value);
105 }
106 }
107
BuildShape(const abstract::AbstractBasePtr & abs)108 std::vector<int64_t> BuildShape(const abstract::AbstractBasePtr &abs) {
109 auto base_shape = abs->BuildShape();
110 if (base_shape->isa<abstract::NoShape>()) {
111 return {};
112 }
113 auto shape = base_shape->cast<abstract::ShapePtr>();
114 MS_EXCEPTION_IF_NULL(shape);
115 return shape->shape();
116 }
117
ParseCond(const NodePtr & cond)118 bool ParseCond(const NodePtr &cond) {
119 auto cond_val = cond->Value();
120 if (cond_val->isa<BoolImm>()) {
121 return GetValue<bool>(cond_val);
122 }
123 if (cond_val->isa<tensor::BaseTensor>()) {
124 auto tensor = cond_val->cast<tensor::BaseTensorPtr>();
125 tensor->data_sync();
126 size_t data_size = tensor->DataSize();
127 auto tensor_type = tensor->Dtype();
128 if (tensor_type->type_id() == kNumberTypeBool) {
129 auto data_c = reinterpret_cast<bool *>(tensor->data_c());
130 MS_EXCEPTION_IF_NULL(data_c);
131 return std::all_of(data_c, data_c + data_size, [](const bool &data) { return static_cast<bool>(data); });
132 }
133 }
134 MS_LOG(EXCEPTION) << "For control flow, the cond should be Tensor[bool] or bool, but got: " << cond_val->ToString();
135 }
136 } // namespace
137
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)138 NodePtr FuncBuilder::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) {
139 MS_LOG(DEBUG) << "Emit op " << prim->name();
140 auto real_inputs = pass_forward_->PassForOpInput(prim, inputs);
141 std::vector<ValuePtr> op_inputs;
142 op_inputs.reserve(real_inputs.size());
143 abstract::AbstractBasePtrList input_abs;
144 input_abs.reserve(real_inputs.size());
145 std::vector<InputType> input_mask;
146 input_mask.reserve(real_inputs.size());
147 SetDependValue(prim, inputs);
148 for (const auto &input : real_inputs) {
149 auto abs = input->abstract();
150 auto value = FillZeros(input->Value(), abs);
151 (void)op_inputs.emplace_back(value);
152 (void)input_abs.emplace_back(abs);
153 (void)input_mask.emplace_back(input->input_type());
154 }
155 MS_LOG(DEBUG) << "Get input value size " << op_inputs.size() << ", "
156 << PyNativeAlgo::Common::PrintDebugInfo(op_inputs);
157 MS_LOG(DEBUG) << "Get input abs size " << input_abs.size() << ", " << PyNativeAlgo::Common::PrintDebugInfo(input_abs);
158 VectorRef outputs;
159 runtime::OpRunnerInfo op_runner_info{prim, device_target_, op_inputs, input_abs, input_mask, nullptr};
160 runtime::PyBoostOpExecute::GetInstance().Execute(&op_runner_info, &outputs);
161 auto real_outputs = common::AnfAlgo::TransformVectorRefToMultiValue(outputs);
162 MS_LOG(DEBUG) << "Get output value size " << real_outputs.size() << ", "
163 << PyNativeAlgo::Common::PrintDebugInfo(real_outputs);
164 if (op_runner_info.output_value_simple_info != nullptr) {
165 // Get output abstract
166 op_runner_info.output_abs = TransformValueSimpleInfoToAbstract(*op_runner_info.output_value_simple_info);
167 }
168 ValuePtr value_result;
169 MS_EXCEPTION_IF_NULL(op_runner_info.output_abs);
170 if (real_outputs.size() == kSizeOne && !op_runner_info.output_abs->isa<abstract::AbstractSequence>()) {
171 value_result = real_outputs[kIndex0];
172 } else {
173 value_result = std::make_shared<ValueTuple>(std::move(real_outputs));
174 }
175 // Set abstract to tensor cache
176 if (op_runner_info.output_value_simple_info != nullptr) {
177 PyNativeAlgo::AutoGrad::CacheOutputAbstract(value_result, op_runner_info.output_abs);
178 }
179 auto result = NewFuncNode(value_result, op_runner_info.output_abs, InputType::kOpOutput);
180 return result;
181 }
182
EmitValue(const ValuePtr & value)183 NodePtr FuncBuilder::EmitValue(const ValuePtr &value) {
184 // For constant value, its abstract may not use, we delay set abs, if op use its abstract, we can get abstract
185 // from FuncBuilder::abstract()
186 auto node = NewFuncNode(value, nullptr, InputType::kConstant);
187 return node;
188 }
189
Stack(const NodePtr & x,const ValuePtr & axis_value)190 NodePtr FuncBuilder::Stack(const NodePtr &x, const ValuePtr &axis_value) {
191 NodePtrList node_inputs = FlattenNode(x);
192 int64_t axis = GetValue<int64_t>(axis_value);
193 return Stack(node_inputs, axis);
194 }
195
Stack(const NodePtrList & x,int64_t axis)196 NodePtr FuncBuilder::Stack(const NodePtrList &x, int64_t axis) {
197 std::vector<int64_t> dyn_size{static_cast<int64_t>(x.size()), -1};
198 expander::DAttr attrs{std::make_pair(kAttrDynInputSizes, MakeValue(dyn_size)),
199 std::make_pair("axis", MakeValue(axis))};
200 return Emit(kStackOpName, x, attrs);
201 }
202
BatchNormGrad(const NodePtrList & inputs,bool is_scale_or_bias_grad)203 NodePtr FuncBuilder::BatchNormGrad(const NodePtrList &inputs, bool is_scale_or_bias_grad) {
204 return pass_forward_->BatchNormGradToBNInferGrad(inputs, is_scale_or_bias_grad);
205 }
206
SparseSoftmaxCrossEntropyWithLogits(const NodePtrList & inputs,const expander::DAttr & attrs,const NodePtr & out,const NodePtr & dout,bool is_graph_mode)207 NodePtr FuncBuilder::SparseSoftmaxCrossEntropyWithLogits(const NodePtrList &inputs, const expander::DAttr &attrs,
208 const NodePtr &out, const NodePtr &dout, bool is_graph_mode) {
209 return pass_forward_->GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(inputs, attrs, out, dout, is_graph_mode);
210 }
211
Depend(const NodePtr & value,const NodePtr & expr)212 NodePtr FuncBuilder::Depend(const NodePtr &value, const NodePtr &expr) { return value; }
213
TupleGetItem(const NodePtr & input,size_t i)214 NodePtr FuncBuilder::TupleGetItem(const NodePtr &input, size_t i) {
215 auto value = input->Value();
216 if (!value->isa<ValueSequence>()) {
217 MS_LOG(EXCEPTION) << "Input value should be sequence"
218 << "but got " << value->ToString();
219 }
220 auto seq = value->cast<ValueSequencePtr>();
221 if (seq->size() <= i) {
222 MS_LOG(EXCEPTION) << "Input value sequence size should > " << i << " but got " << value->ToString();
223 }
224 abstract::AbstractBasePtr item_abs = nullptr;
225 auto seq_abs = input->abstract()->cast<abstract::AbstractSequencePtr>();
226 if (seq_abs != nullptr && seq_abs->size() == seq->size()) {
227 item_abs = seq_abs->elements()[i];
228 }
229 return NewFuncNode(seq->value()[i], item_abs, input->input_type());
230 }
231
OutZeros(const NodePtr & node)232 NodePtr FuncBuilder::OutZeros(const NodePtr &node) {
233 if (!node->Value()->isa<ValueSequence>()) {
234 return NewFuncNode(kNone, nullptr, InputType::kConstant);
235 }
236 auto val_seq = node->Value()->cast<ValueSequencePtr>();
237 if (val_seq->size() == kSizeZero) {
238 return NewFuncNode(kNone, nullptr, InputType::kConstant);
239 }
240 const auto &value = val_seq->value()[kIndexZero];
241 if (!value->isa<tensor::Tensor>()) {
242 return NewFuncNode(kNone, nullptr, InputType::kConstant);
243 } else {
244 ValuePtrList values(val_seq->size(), kNone);
245 return NewFuncNode(std::make_shared<ValueTuple>(values), nullptr, InputType::kConstant);
246 }
247 }
248
Ones(const ValuePtr & value)249 ValuePtr FuncBuilder::Ones(const ValuePtr &value) {
250 auto ones_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(value->ToAbstract());
251 NodePtrList inputs{NewFuncNode(value, ones_abs, InputType::kOpOutput)};
252 return EmitOp(prim::kPrimOnesLike, inputs)->Value();
253 }
254
Zeros(const ValuePtr & value)255 ValuePtr FuncBuilder::Zeros(const ValuePtr &value) {
256 auto zeros_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(value->ToAbstract());
257 auto input = NewFuncNode(value, zeros_abs, InputType::kOpOutput);
258 return ZerosLike(input)->Value();
259 }
260
Add(const ValuePtr & input,const ValuePtr & other)261 ValuePtr FuncBuilder::Add(const ValuePtr &input, const ValuePtr &other) {
262 auto input_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(input->ToAbstract());
263 auto other_abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(other->ToAbstract());
264 auto input_node = NewFuncNode(input, input_abs, InputType::kOpOutput);
265 auto other_node = NewFuncNode(other, other_abs, InputType::kOpOutput);
266 return Emit(mindspore::kAddOpName, {input_node, other_node})->Value();
267 }
268
TupleGetItem(const NodePtr & input,const NodePtr & index)269 NodePtr FuncBuilder::TupleGetItem(const NodePtr &input, const NodePtr &index) {
270 auto value = index->Value();
271 size_t i = GetValue<int64_t>(value);
272 return TupleGetItem(input, i);
273 }
274
MakeTuple(const NodePtrList & inputs)275 NodePtr FuncBuilder::MakeTuple(const NodePtrList &inputs) {
276 ValuePtrList values;
277 AbstractBasePtrList abs;
278 std::transform(inputs.begin(), inputs.end(), std::back_inserter(values),
279 [](const NodePtr &node) { return node->Value(); });
280 auto value = std::make_shared<ValueTuple>(values);
281 auto tuple_node = NewFuncNode(value, nullptr, InputType::kOpOutput);
282 return tuple_node;
283 }
284
MakeList(const NodePtrList & inputs)285 NodePtr FuncBuilder::MakeList(const NodePtrList &inputs) { return MakeTuple(inputs); }
286
Conditional(const NodePtr & cond,const expander::Emitter::BlockFunc & true_case,const expander::Emitter::BlockFunc & false_case)287 NodePtr FuncBuilder::Conditional(const NodePtr &cond, const expander::Emitter::BlockFunc &true_case,
288 const expander::Emitter::BlockFunc &false_case) {
289 NodePtrList result;
290 if (ParseCond(cond)) {
291 result = true_case(this);
292 } else {
293 result = false_case(this);
294 }
295 if (result.size() == kSizeOne) {
296 return result[kIndex0];
297 }
298 return MakeTuple(result);
299 }
300
ScalarEq(const NodePtr & lhs,const NodePtr & rhs,const TypePtr & dst_type)301 NodePtr FuncBuilder::ScalarEq(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) {
302 auto lhs_val = lhs->Value();
303 auto rhs_val = rhs->Value();
304 ValuePtr result;
305 if (lhs_val->isa<BoolImm>() && rhs_val->isa<BoolImm>()) {
306 result = MakeValue(GetValue<bool>(lhs_val) == GetValue<bool>(rhs_val));
307 } else {
308 result = prim::ScalarEq({lhs->Value(), rhs->Value()});
309 }
310 MS_LOG(DEBUG) << "ScalarEq op: lhs " << lhs_val->ToString() << ", rhs " << rhs_val->ToString();
311 return NewFuncNode(result, nullptr, InputType::kOpOutput);
312 }
313
SetInputs(std::string instance_name,const std::vector<NodePtr> * inputs,mindspore::HashMap<std::string,ValuePtr> * attrs_ptr)314 void FuncBuilder::SetInputs(std::string instance_name, const std::vector<NodePtr> *inputs,
315 mindspore::HashMap<std::string, ValuePtr> *attrs_ptr) {
316 instance_name_ = std::move(instance_name);
317 inputs_ptr_ = inputs;
318 attrs_ptr_ = attrs_ptr;
319 }
320
FlattenNode(const NodePtr & input)321 NodePtrList FuncBuilder::FlattenNode(const NodePtr &input) {
322 if (!input->Value()->isa<ValueSequence>()) {
323 return {input};
324 }
325 auto value_seq = input->Value()->cast<ValueSequencePtr>()->value();
326 auto value_abs = input->abstract()->cast<abstract::AbstractSequencePtr>();
327 MS_EXCEPTION_IF_NULL(value_abs);
328 NodePtrList flattenNodes;
329 flattenNodes.reserve(value_seq.size());
330 for (size_t i = 0; i < value_seq.size(); ++i) {
331 auto &value = value_seq[i];
332 (void)flattenNodes.emplace_back(NewFuncNode(value, value_abs->elements()[i], input->input_type()));
333 }
334 return flattenNodes;
335 }
336
FillZeros(const ValuePtr & value,const abstract::AbstractBasePtr & abs)337 ValuePtr FuncBuilder::FillZeros(const ValuePtr &value, const abstract::AbstractBasePtr &abs) {
338 auto convert_value = value;
339 if (value->isa<None>()) {
340 if (abs->isa<abstract::AbstractTensor>()) {
341 auto tensor_dtype = abs->BuildType()->cast<TensorTypePtr>();
342 MS_EXCEPTION_IF_NULL(tensor_dtype);
343 auto dtype = tensor_dtype->element();
344 auto shape = BuildShape(abs);
345 auto out_tensor = std::make_shared<tensor::Tensor>(dtype->type_id(), shape);
346 auto zero_node = ZerosLike(NewFuncNode(out_tensor, abs, InputType::kOpOutput));
347 convert_value = zero_node->Value();
348 } else {
349 MS_LOG(DEBUG) << "None value abstract got None abstract!";
350 }
351 } else if (value->isa<ValueSequence>()) {
352 auto seq = value->cast<ValueSequencePtr>();
353 auto abs_list = abs->cast<abstract::AbstractSequencePtr>();
354 MS_EXCEPTION_IF_NULL(abs_list);
355 std::vector<ValuePtr> value_list;
356 value_list.reserve(seq->value().size());
357 for (size_t i = 0; i < seq->value().size(); ++i) {
358 const auto &val = seq->value()[i];
359 const auto &temp_abs = abs_list->elements()[i];
360 auto convert = FillZeros(val, temp_abs);
361 (void)value_list.emplace_back(convert);
362 }
363 convert_value = std::make_shared<ValueTuple>(value_list);
364 }
365 return convert_value;
366 }
367 } // namespace mindspore::pynative::autograd
368