1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "frontend/operator/composite/composite.h"
20 #include <algorithm>
21 #include <tuple>
22 #include <regex>
23 #include "ops/structure_ops.h"
24 #include "ops/sequence_ops.h"
25 #include "ops/framework_ops.h"
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "abstract/abstract_value.h"
29 #include "abstract/abstract_function.h"
30 #include "abstract/dshape.h"
31 #include "abstract/param_validator.h"
32 #include "frontend/operator/cc_implementations.h"
33 #include "frontend/optimizer/opt.h"
34 #include "utils/symbolic.h"
35 #include "include/common/fallback.h"
36 #include "include/common/pybind_api/api_register.h"
37 #include "ir/signature.h"
38 #include "pipeline/jit/ps/fallback.h"
39 #include "pipeline/jit/ps/debug/trace.h"
40 #include "utils/interpret_node_recorder.h"
41 #include "utils/ms_context.h"
42 #include "include/common/utils/utils.h"
43 #include "pipeline/jit/ps/parse/resolve.h"
44
45 namespace mindspore {
46 // namespace to support composite operators definition
47 namespace prim {
48 constexpr auto kStepDefault = 1;
49
50 using mindspore::abstract::AbstractBase;
51 using mindspore::abstract::AbstractBasePtr;
52 using mindspore::abstract::AbstractClass;
53 using mindspore::abstract::AbstractDictionary;
54 using mindspore::abstract::AbstractDictionaryPtr;
55 using mindspore::abstract::AbstractElementPair;
56 using mindspore::abstract::AbstractEllipsis;
57 using mindspore::abstract::AbstractEllipsisPtr;
58 using mindspore::abstract::AbstractFunction;
59 using mindspore::abstract::AbstractFunctionPtr;
60 using mindspore::abstract::AbstractList;
61 using mindspore::abstract::AbstractListPtr;
62 using mindspore::abstract::AbstractNone;
63 using mindspore::abstract::AbstractScalar;
64 using mindspore::abstract::AbstractSequence;
65 using mindspore::abstract::AbstractSequencePtr;
66 using mindspore::abstract::AbstractSlice;
67 using mindspore::abstract::AbstractTensor;
68 using mindspore::abstract::AbstractTuple;
69 using mindspore::abstract::AbstractTuplePtr;
70 using mindspore::abstract::AbstractUndetermined;
71 using mindspore::abstract::EnvSetSparseResultMgr;
72 using mindspore::abstract::FuncGraphAbstractClosure;
73 using mindspore::abstract::PartialAbstractClosure;
74
Init()75 void HyperMap::Init() {
76 if (fn_leaf_) {
77 name_ = "hyper_map[" + fn_leaf_->name() + "]";
78 }
79 signatures_ =
80 // def hypermap(func:read, *args:ref):
81 std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
82 {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
83 }
84
HyperMap(bool reverse,const std::shared_ptr<MultitypeFuncGraph> & fn_leaf)85 HyperMap::HyperMap(bool reverse, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
86 : MetaFuncGraph("hyper_map"),
87 fn_leaf_(fn_leaf),
88 reverse_(reverse),
89 nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeDictionary}) {
90 Init();
91 }
92
HyperMap(const HyperMap & h)93 HyperMap::HyperMap(const HyperMap &h)
94 : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), reverse_(h.reverse_), nonleaf_(h.nonleaf_) {
95 Init();
96 }
97
SetObjectForFnLeaf(const py::object & leaf_object)98 void HyperMap::SetObjectForFnLeaf(const py::object &leaf_object) {
99 if (fn_leaf_ != nullptr) {
100 fn_leaf_->set_meta_obj(leaf_object);
101 }
102 }
103
FullMake(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const104 AnfNodePtr HyperMap::FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
105 const ArgsPairList &arg_map) const {
106 MS_EXCEPTION_IF_NULL(func_graph);
107 std::vector<AnfNodePtr> inputs;
108 if (fn_arg != nullptr) {
109 inputs.push_back(fn_arg);
110 } else {
111 inputs.push_back(NewValueNode(fn_leaf_));
112 }
113
114 (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
115 [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
116 return func_graph->NewCNodeInOrder(inputs);
117 }
118
GetHyperMapInputIndex(size_t num) const119 std::pair<std::string, std::string> HyperMap::GetHyperMapInputIndex(size_t num) const {
120 std::string error_index;
121 std::string next_index;
122 const size_t first_index = 1;
123 const size_t second_index = 2;
124 if (num == first_index) {
125 // The first element in HyperMap is func_graph
126 error_index = "first";
127 next_index = "second";
128 } else if (num == second_index) {
129 error_index = "second";
130 next_index = "third";
131 } else {
132 error_index = std::to_string(num) + "th";
133 next_index = std::to_string(num + 1) + "th";
134 }
135 return std::pair<std::string, std::string>(error_index, next_index);
136 }
137
FullMake(const std::shared_ptr<List> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const138 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
139 const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const {
140 MS_EXCEPTION_IF_NULL(func_graph);
141 MS_EXCEPTION_IF_NULL(type);
142
143 size_t size = type->elements().size();
144 size_t num = 0;
145 std::ostringstream oss;
146 bool is_not_same = false;
147 for (auto &item : arg_map) {
148 num++;
149 auto lhs = std::static_pointer_cast<List>(item.second);
150 auto [error_index, next_index] = GetHyperMapInputIndex(num);
151 if (lhs == nullptr) {
152 MS_LOG(EXCEPTION) << "The " << error_index << " element in HyperMap has wrong type, expected a List, but got "
153 << item.second->ToString() << ".";
154 }
155 if (lhs->elements().size() != size) {
156 oss << "\nThe length of the " << error_index << " element in HyperMap is " << size << ", but the length of the "
157 << next_index << " element in HyperMap is " << lhs->elements().size() << ".\n";
158 is_not_same = true;
159 break;
160 }
161 }
162 if (is_not_same) {
163 MS_LOG(EXCEPTION) << "The lists in HyperMap should have the same length. " << oss.str();
164 }
165
166 // Cannot use shared_from_base() also known as this, as it will make a reference cycle on
167 // hypermap and graph generated, it will cause memory leak.
168 auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
169 constexpr size_t kPrimHoldLen = 1;
170 std::vector<AnfNodePtr> inputs;
171 inputs.reserve(size + kPrimHoldLen);
172 inputs.push_back(NewValueNode(prim::kPrimMakeList));
173
174 for (size_t i = 0; i < size; i++) {
175 MS_LOG(DEBUG) << "FullMakeList for the " << i << "th element of the target, reverse_: " << reverse_;
176 std::vector<AnfNodePtr> inputs2;
177 inputs2.push_back(fn_rec);
178 if (fn_arg != nullptr) {
179 inputs2.push_back(fn_arg);
180 }
181 size_t pos = (reverse_ ? (size - 1 - i) : i);
182 (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
183 [&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
184 return func_graph->NewCNodeInOrder(
185 {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
186 });
187
188 auto call_node = func_graph->NewCNodeInOrder(inputs2);
189 if (reverse_) {
190 (void)inputs.insert(inputs.cbegin() + 1, call_node);
191 } else {
192 inputs.emplace_back(call_node);
193 }
194 }
195 return func_graph->NewCNodeInOrder(inputs);
196 }
197
FullMake(const std::shared_ptr<Tuple> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const198 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
199 const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const {
200 MS_EXCEPTION_IF_NULL(func_graph);
201 MS_EXCEPTION_IF_NULL(type);
202
203 size_t size = type->elements().size();
204 size_t num = 0;
205 std::ostringstream oss;
206 bool is_not_same = false;
207 for (auto &item : arg_map) {
208 num++;
209 auto lhs = std::static_pointer_cast<Tuple>(item.second);
210 auto [error_index, next_index] = GetHyperMapInputIndex(num);
211 if (lhs == nullptr) {
212 MS_LOG(EXCEPTION) << "The " << error_index << " element in HyperMap has wrong type, expected a Tuple, but got "
213 << item.second->ToString() << ".";
214 }
215 if (lhs->elements().size() != size) {
216 oss << "\nThe length of the " << error_index << " element in HyperMap is " << size << ", but the length of the "
217 << next_index << " element in HyperMap is " << lhs->elements().size() << ".\n";
218 is_not_same = true;
219 break;
220 }
221 }
222 if (is_not_same) {
223 MS_LOG(EXCEPTION) << "The length of tuples in HyperMap must be the same. " << oss.str();
224 }
225
226 // Cannot use shared_from_base() also known as this, as it will make a reference cycle on
227 // hypermap and graph generated, it will cause memory leak.
228 auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
229 constexpr size_t kPrimHoldLen = 1;
230 std::vector<AnfNodePtr> inputs;
231 inputs.reserve(size + kPrimHoldLen);
232 inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
233
234 for (size_t i = 0; i < size; i++) {
235 MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th element of the target, reverse_: " << reverse_;
236 std::vector<AnfNodePtr> inputs2;
237 inputs2.push_back(fn_rec);
238 if (fn_arg != nullptr) {
239 inputs2.push_back(fn_arg);
240 }
241 size_t pos = (reverse_ ? (size - 1 - i) : i);
242 (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
243 [&func_graph, &pos](std::pair<AnfNodePtr, Any> item) {
244 return func_graph->NewCNodeInOrder(
245 {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
246 });
247
248 auto call_node = func_graph->NewCNodeInOrder(inputs2);
249 if (reverse_) {
250 inputs.insert(inputs.begin() + 1, call_node);
251 } else {
252 inputs.emplace_back(call_node);
253 }
254 }
255
256 if (inputs.size() > 1) {
257 return func_graph->NewCNodeInOrder(inputs);
258 }
259 // Empty tuple.
260 auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
261 auto empty_tuple = NewValueNode(empty_tuple_value);
262 return empty_tuple;
263 }
264
FullMake(const std::shared_ptr<Dictionary> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const265 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Dictionary> &type, const FuncGraphPtr &func_graph,
266 const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const {
267 MS_EXCEPTION_IF_NULL(func_graph);
268 MS_EXCEPTION_IF_NULL(type);
269
270 size_t size = type->key_values().size();
271 size_t num = 0;
272 std::ostringstream oss;
273 bool is_not_same = false;
274 for (auto &item : arg_map) {
275 num++;
276 auto lhs = std::static_pointer_cast<Dictionary>(item.second);
277 auto [error_index, next_index] = GetHyperMapInputIndex(num);
278 if (lhs == nullptr) {
279 MS_LOG(EXCEPTION) << "The " << error_index
280 << " element in HyperMap has wrong type, expected a Dictionary, but got "
281 << item.second->ToString() << ".";
282 }
283 if (lhs->key_values().size() != size) {
284 oss << "\nThe length of the " << error_index << " element in HyperMap is " << size << ", but the length of the "
285 << next_index << " element in HyperMap is " << lhs->key_values().size() << ".\n";
286 is_not_same = true;
287 break;
288 }
289 }
290 if (is_not_same) {
291 MS_LOG(EXCEPTION) << "The length of dict in HyperMap must be the same. " << oss.str();
292 }
293
294 // cannot use shared_from_base() also known as this, as it will make a reference cycle on
295 // hypermap and graph generated, it will cause memory leak.
296 auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
297 std::vector<AnfNodePtr> key_inputs{NewValueNode(prim::kPrimMakeTuple)};
298 std::vector<AnfNodePtr> value_inputs{NewValueNode(prim::kPrimMakeTuple)};
299
300 for (size_t i = 0; i < size; i++) {
301 MS_LOG(DEBUG) << "FullMakeDict for the " << i << "th element of the target.";
302 auto key = type->key_values()[i].first;
303 (void)key_inputs.emplace_back(NewValueNode(key));
304 std::vector<AnfNodePtr> inputs;
305 (void)inputs.emplace_back(fn_rec);
306 if (fn_arg != nullptr) {
307 (void)inputs.emplace_back(fn_arg);
308 }
309 (void)std::transform(
310 arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
311 [&func_graph, &key](const std::pair<AnfNodePtr, TypePtr> &item) {
312 return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimDictGetItem), item.first, NewValueNode(key)});
313 });
314 auto call_node = func_graph->NewCNodeInOrder(inputs);
315 (void)value_inputs.emplace_back(call_node);
316 }
317 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeDict), func_graph->NewCNodeInOrder(key_inputs),
318 func_graph->NewCNodeInOrder(value_inputs)};
319 return func_graph->NewCNodeInOrder(inputs);
320 }
321
Make(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const322 AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const {
323 bool is_leaf = false;
324 TypeId id = kObjectTypeEnd;
325 std::pair<AnfNodePtr, TypePtr> pair;
326 for (auto &item : arg_map) {
327 pair = item;
328 id = item.second->type_id();
329 // The graph building reaches the leaf situation when there exists type that can not be divided any more.
330 if (nonleaf_.count(id) == 0) {
331 is_leaf = true;
332 break;
333 }
334 }
335
336 if (!is_leaf) {
337 // In a nonleaf situation, all arguments must have the same generic.
338 bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
339 if (item.first != pair.first) {
340 return item.second->type_id() != pair.second->type_id();
341 }
342 return false;
343 });
344 if (is_not_same) {
345 std::ostringstream oss;
346 oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
347 << trace::GetDebugInfoStr(func_graph->debug_info()) << "\n";
348 int64_t idx = 0;
349 std::string str_index = "first";
350 const int64_t diff_index = 2;
351 for (auto &item : arg_map) {
352 // The first element in HyperMap is func_graph
353 if (idx == 0) {
354 str_index = "second";
355 } else if (idx == 1) {
356 str_index = "third";
357 } else {
358 str_index = std::to_string(idx + diff_index) + "th";
359 }
360 ++idx;
361 oss << "The type of the " << str_index << " argument in HyperMap is " << item.second->ToString() << ".\n";
362 }
363 MS_LOG(EXCEPTION) << "In a nonleaf situation, the types of arguments in HyperMap must be consistent, "
364 << "but the types of arguments are inconsistent.\n"
365 << oss.str();
366 }
367 }
368
369 switch (id) {
370 case kObjectTypeList: {
371 auto type = std::static_pointer_cast<List>(pair.second);
372 return FullMake(type, func_graph, fn_arg, arg_map);
373 }
374 case kObjectTypeTuple: {
375 auto type = std::static_pointer_cast<Tuple>(pair.second);
376 return FullMake(type, func_graph, fn_arg, arg_map);
377 }
378 case kObjectTypeDictionary: {
379 auto type = std::static_pointer_cast<Dictionary>(pair.second);
380 return FullMake(type, func_graph, fn_arg, arg_map);
381 }
382 default:
383 return FullMake(func_graph, fn_arg, arg_map);
384 }
385 }
386
GenerateFromTypes(const TypePtrList & args_abs_list)387 FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_abs_list) {
388 FuncGraphPtr res_fg = std::make_shared<FuncGraph>();
389 res_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
390 res_fg->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
391 res_fg->debug_info()->set_name("hyper_map");
392
393 AnfNodePtr fn_param = nullptr;
394 std::size_t i = 0;
395 ArgsPairList argmap;
396 if (fn_leaf_ == nullptr) {
397 fn_param = res_fg->add_parameter();
398 i = 1;
399 }
400
401 std::size_t size = args_abs_list.size();
402 for (; i < size; ++i) {
403 argmap.push_back(std::make_pair(res_fg->add_parameter(), args_abs_list[i]));
404 }
405
406 res_fg->set_output(Make(res_fg, fn_param, argmap));
407 return res_fg;
408 }
409
NormalizeArgs(const AbstractBasePtrList & args_abs_list) const410 abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_abs_list) const {
411 if (fn_leaf_ == nullptr) {
412 if (args_abs_list.empty()) {
413 MS_LOG(EXCEPTION) << "The size of arguments in list should not be empty. But the size of arguments is 0.";
414 }
415 MS_EXCEPTION_IF_NULL(args_abs_list[0]);
416 // Assert that hypermap's function param does not contain free variables
417 if (args_abs_list[0]->isa<FuncGraphAbstractClosure>()) {
418 auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_abs_list[0]);
419 auto func_graph = graph_func->func_graph();
420 if (func_graph->parent() != nullptr) {
421 MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
422 }
423 }
424 }
425
426 AbstractBasePtrList broadened;
427 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(broadened),
428 [](const AbstractBasePtr &arg) -> AbstractBasePtr {
429 MS_EXCEPTION_IF_NULL(arg);
430 return arg->Broaden();
431 });
432 return broadened;
433 }
434
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)435 FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
436 int64_t tuple_size = SizeToLong(args_abs_list.size());
437
438 std::ostringstream ss;
439 // ▶make_tuple_
440 ss << "\u25B8make_tuple_" << tuple_size;
441 FuncGraphPtr fg = std::make_shared<FuncGraph>();
442 fg->debug_info()->set_name(ss.str());
443
444 std::vector<AnfNodePtr> params;
445 params.push_back(NewValueNode(prim::kPrimMakeTuple));
446 for (int64_t i = 0; i < tuple_size; ++i) {
447 params.push_back(fg->add_parameter());
448 }
449
450 // Make fprop first result, make_tuple's forward result.
451 AnfNodePtr out = fg->NewCNodeInOrder(params);
452
453 // Make fprop second result, make_tuple's backward function.
454 FuncGraphPtr bprop = std::make_shared<FuncGraph>();
455
456 ss.str(std::string());
457 ss.clear();
458 // ◀make_tuple_
459 ss << "\u25C2make_tuple_" << tuple_size;
460 bprop->debug_info()->set_name(ss.str());
461 AnfNodePtr dout = bprop->add_parameter();
462
463 std::vector<AnfNodePtr> grads;
464 grads.push_back(NewValueNode(prim::kPrimMakeTuple));
465 grads.push_back(NewEnviron(bprop));
466 for (int64_t i = 0; i < tuple_size; ++i) {
467 grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
468 }
469
470 bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
471 bprop->set_output(bprop->NewCNodeInOrder(grads));
472
473 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
474 fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
475 (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
476 return fg;
477 }
478
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)479 FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
480 int64_t list_size = SizeToLong(args_abs_list.size());
481
482 std::ostringstream ss;
483 // ▶make_list_
484 ss << "\u25B8make_list_" << list_size;
485 FuncGraphPtr fg = std::make_shared<FuncGraph>();
486 fg->debug_info()->set_name(ss.str());
487
488 std::vector<AnfNodePtr> params;
489 params.push_back(NewValueNode(prim::kPrimMakeList));
490 for (int64_t i = 0; i < list_size; ++i) {
491 params.push_back(fg->add_parameter());
492 }
493
494 // Make fprop first result, make_list's forward result.
495 AnfNodePtr out = fg->NewCNodeInOrder(params);
496
497 // Make fprop second result, make_list's backward function.
498 FuncGraphPtr bprop = std::make_shared<FuncGraph>();
499
500 ss.str(std::string());
501 ss.clear();
502 // ◀make_list_
503 ss << "\u25C2make_list_" << list_size;
504 bprop->debug_info()->set_name(ss.str());
505 AnfNodePtr dout = bprop->add_parameter();
506
507 std::vector<AnfNodePtr> grads;
508 grads.push_back(NewValueNode(prim::kPrimMakeTuple));
509 grads.push_back(NewEnviron(bprop));
510 for (int64_t i = 0; i < list_size; ++i) {
511 grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
512 }
513
514 bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
515 bprop->set_output(bprop->NewCNodeInOrder(grads));
516
517 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
518 fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
519 (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
520 return fg;
521 }
522
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)523 FuncGraphPtr MakeDictGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
524 constexpr size_t input_size = 2;
525 CheckArgsSize("MakeDict", args_abs_list, input_size);
526 std::ostringstream ss;
527 // ▶make_dict_
528 ss << "\u25B8make_dict_" << input_size;
529 FuncGraphPtr fg = std::make_shared<FuncGraph>();
530 fg->debug_info()->set_name(ss.str());
531
532 std::vector<AnfNodePtr> params{NewValueNode(prim::kPrimMakeDict)};
533 for (size_t i = 0; i < input_size; ++i) {
534 (void)params.emplace_back(fg->add_parameter());
535 }
536
537 // Make fprop first result, make_dict's forward result.
538 AnfNodePtr out = fg->NewCNodeInOrder(params);
539
540 // Make fprop second result, make_dict's backward function.
541 FuncGraphPtr bprop = std::make_shared<FuncGraph>();
542
543 ss.str(std::string());
544 ss.clear();
545 // ◀make_dict_
546 ss << "\u25C2make_dict_" << input_size;
547 bprop->debug_info()->set_name(ss.str());
548 AnfNodePtr dout = bprop->add_parameter();
549
550 std::vector<AnfNodePtr> grads{NewValueNode(prim::kPrimMakeTuple)};
551 (void)grads.emplace_back(NewEnviron(bprop));
552
553 auto abs0_tuple = dyn_cast_ptr<AbstractTuple>(args_abs_list[0]);
554 if (abs0_tuple == nullptr) {
555 MS_LOG(INTERNAL_EXCEPTION) << "The first input of make_dict should be a tuple, but got abstract: "
556 << args_abs_list[0]->ToString();
557 }
558 // Add gradients of keys tuple and values tuple.
559 std::vector<AnfNodePtr> keys_grads_inputs{NewValueNode(kPrimMakeTuple)};
560 std::vector<AnfNodePtr> values_grads_inputs{NewValueNode(kPrimMakeTuple)};
561 for (size_t i = 0; i < abs0_tuple->size(); ++i) {
562 auto key_item =
563 bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[1], NewValueNode(SizeToLong(i))});
564 (void)keys_grads_inputs.emplace_back(key_item);
565 (void)values_grads_inputs.emplace_back(
566 bprop->NewCNodeInOrder({NewValueNode(prim::kPrimDictGetItem), dout, key_item}));
567 }
568 (void)grads.emplace_back(bprop->NewCNodeInOrder(keys_grads_inputs));
569 (void)grads.emplace_back(bprop->NewCNodeInOrder(values_grads_inputs));
570
571 bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
572 bprop->set_output(bprop->NewCNodeInOrder(grads));
573
574 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
575 fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
576 (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeDict));
577 return fg;
578 }
579
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)580 FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
581 int64_t args_size = SizeToLong(args_abs_list.size());
582 constexpr auto py_execute_grad_input_count = 3;
583 if (args_size < py_execute_grad_input_count) {
584 MS_LOG(INTERNAL_EXCEPTION) << "The inputs size of PyExecuteGradient should not less than "
585 << py_execute_grad_input_count;
586 }
587
588 std::ostringstream ss;
589 // ▶PyExecute
590 ss << "\u25B8PyExecute_" << args_size;
591 FuncGraphPtr fg = std::make_shared<FuncGraph>();
592 fg->debug_info()->set_name(ss.str());
593
594 std::vector<AnfNodePtr> params;
595 (void)params.emplace_back(NewValueNode(prim::kPrimPyExecute));
596 for (int64_t i = 0; i < args_size; ++i) {
597 (void)params.emplace_back(fg->add_parameter());
598 }
599
600 // Make fprop first result, PyExecute's forward result.
601 AnfNodePtr out = fg->NewCNodeInOrder(params);
602 InterpretNodeRecorder::GetInstance().PushPyExecuteNode(out);
603
604 // Make fprop second result, PyExecute's backward function.
605 FuncGraphPtr bprop = std::make_shared<FuncGraph>();
606
607 ss.str(std::string());
608 ss.clear();
609 // ◀PyExecute
610 ss << "\u25C2PyExecute_" << args_size;
611 bprop->debug_info()->set_name(ss.str());
612 (void)bprop->add_parameter();
613
614 std::vector<AnfNodePtr> grads;
615 (void)grads.emplace_back(NewValueNode(prim::kPrimMakeTuple));
616 (void)grads.emplace_back(NewEnviron(bprop));
617 // Propagate for script string.
618 (void)grads.emplace_back(params[1]);
619 // Propagate for local dict keys.
620 const auto &local_key_args = dyn_cast<abstract::AbstractTuple>(args_abs_list[1]);
621 MS_EXCEPTION_IF_NULL(local_key_args);
622 std::vector<AnfNodePtr> keys;
623 (void)keys.emplace_back(NewValueNode(prim::kPrimMakeTuple));
624 for (size_t i = 0; i < local_key_args->size(); ++i) {
625 constexpr auto keys_num = 2;
626 const auto &key_item =
627 bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[keys_num], NewValueNode(SizeToLong(i))});
628 const auto &element = local_key_args->elements()[i];
629 const auto &str_element = dyn_cast<abstract::AbstractScalar>(element);
630 if (str_element != nullptr && str_element->BuildType()->isa<String>()) {
631 (void)keys.emplace_back(key_item);
632 } else {
633 (void)keys.emplace_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), key_item}));
634 }
635 }
636 (void)grads.emplace_back(bprop->NewCNodeInOrder(keys));
637 // Propagate for local dict values.
638 constexpr auto values_arg_num = 2;
639 const auto &local_value_args = dyn_cast<abstract::AbstractTuple>(args_abs_list[values_arg_num]);
640 MS_EXCEPTION_IF_NULL(local_value_args);
641 std::vector<AnfNodePtr> values;
642 (void)values.emplace_back(NewValueNode(prim::kPrimMakeTuple));
643 for (size_t i = 0; i < local_value_args->size(); ++i) {
644 constexpr auto values_num = 3;
645 const auto &value_item =
646 bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[values_num], NewValueNode(SizeToLong(i))});
647 const auto &element = local_value_args->elements()[i];
648 const auto &str_element = dyn_cast<abstract::AbstractScalar>(element);
649 if (str_element != nullptr && str_element->BuildType()->isa<String>()) {
650 (void)values.emplace_back(value_item);
651 } else {
652 (void)values.emplace_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), value_item}));
653 }
654 }
655 (void)grads.emplace_back(bprop->NewCNodeInOrder(values));
656
657 // Add gradients for extra monad.
658 for (size_t i = py_execute_grad_input_count; i < args_abs_list.size(); ++i) {
659 if (args_abs_list[i]->isa<abstract::AbstractUMonad>()) {
660 (void)grads.emplace_back(NewValueNode(kUMonad));
661 } else if (args_abs_list[i]->isa<abstract::AbstractIOMonad>()) {
662 (void)grads.emplace_back(NewValueNode(kIOMonad));
663 } else {
664 (void)grads.emplace_back(NewValueNode(kValueAny));
665 }
666 }
667
668 bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
669 bprop->set_output(bprop->NewCNodeInOrder(grads));
670
671 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
672 fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
673 (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimPyExecute));
674 return fg;
675 }
676
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)677 FuncGraphPtr MutableGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
678 constexpr size_t min_input_size = 1;
679 constexpr size_t max_input_size = 2;
680 auto input_size = args_abs_list.size();
681 if (input_size != min_input_size && input_size != max_input_size) {
682 MS_LOG(EXCEPTION) << "The number of input to mutable must be " << min_input_size << " or " << max_input_size
683 << ", but got: " << input_size;
684 }
685 std::ostringstream ss;
686 // ▶mutable_
687 ss << "\u25B8mutable_" << input_size;
688 FuncGraphPtr fg = std::make_shared<FuncGraph>();
689 fg->debug_info()->set_name(ss.str());
690
691 std::vector<AnfNodePtr> params;
692 params.push_back(NewValueNode(prim::kPrimMutable));
693 for (size_t i = 0; i < input_size; ++i) {
694 params.push_back(fg->add_parameter());
695 }
696
697 // Make fprop first result, mutable's forward result.
698 AnfNodePtr out = fg->NewCNodeInOrder(params);
699
700 // Make fprop second result, mutable's backward function.
701 FuncGraphPtr bprop = std::make_shared<FuncGraph>();
702
703 ss.str(std::string());
704 ss.clear();
705 // ◀mutable_
706 ss << "\u25C2mutable_" << input_size;
707 bprop->debug_info()->set_name(ss.str());
708 AnfNodePtr dout = bprop->add_parameter();
709
710 std::vector<AnfNodePtr> grads;
711 grads.push_back(NewValueNode(prim::kPrimMakeTuple));
712 grads.push_back(NewEnviron(bprop));
713 grads.push_back(dout);
714 if (input_size == max_input_size) {
715 grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), params[2]}));
716 }
717
718 bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
719 bprop->set_output(bprop->NewCNodeInOrder(grads));
720
721 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
722 fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
723 (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMutable));
724 return fg;
725 }
726
727 namespace {
IsTupleAllTensor(const AbstractTuplePtr & tuple_arg)728 bool IsTupleAllTensor(const AbstractTuplePtr &tuple_arg) {
729 MS_EXCEPTION_IF_NULL(tuple_arg);
730 for (size_t i = 0; i < tuple_arg->size(); ++i) {
731 if (!(*tuple_arg)[i]->isa<AbstractUndetermined>() &&
732 !((*tuple_arg)[i]->isa<AbstractTuple>() && IsTupleAllTensor((*tuple_arg)[i]->cast<AbstractTuplePtr>()))) {
733 return false;
734 }
735 }
736 return true;
737 }
738
EnableGradFirstForTuple(const AbstractTuplePtr & tuple_arg,bool enable_tuple_grad)739 bool EnableGradFirstForTuple(const AbstractTuplePtr &tuple_arg, bool enable_tuple_grad) {
740 return tuple_arg->size() > 1 && (*tuple_arg)[1]->isa<AbstractTuple>() && enable_tuple_grad &&
741 IsTupleAllTensor((*tuple_arg)[1]->cast<AbstractTuplePtr>());
742 }
743
EnableGradForScalar(const AbstractBasePtr & abs)744 bool EnableGradForScalar(const AbstractBasePtr &abs) {
745 return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
746 abs->BuildType()->isa<Number>();
747 }
748
CanGradArgument(const AbstractTuplePtr & tuple_arg,size_t pos)749 bool CanGradArgument(const AbstractTuplePtr &tuple_arg, size_t pos) {
750 MS_EXCEPTION_IF_NULL(tuple_arg);
751 return tuple_arg->size() > pos && (*tuple_arg)[pos] != nullptr &&
752 ((*tuple_arg)[pos]->BuildValue()->ContainsValueAny() || EnableGradForScalar((*tuple_arg)[pos]));
753 }
754
GenerateFuncGraphByPosition(const FuncGraphPtr & fg,const AbstractTuplePtr & tuple_arg,const AbstractTuplePtr & pos,bool return_ids=false)755 void GenerateFuncGraphByPosition(const FuncGraphPtr &fg, const AbstractTuplePtr &tuple_arg, const AbstractTuplePtr &pos,
756 bool return_ids = false) {
757 if (pos == nullptr) {
758 MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!";
759 }
760 if (pos->empty()) {
761 MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position.";
762 }
763 AnfNodePtr tuple_parameter = fg->add_parameter();
764 (void)fg->add_parameter(); // The 'grad_position' parameter.
765 // Collect all parameters by 'grad_position'.
766 std::vector<AnfNodePtr> pos_elements = {NewValueNode(prim::kPrimMakeTuple)};
767 CNodePtr current_element = nullptr;
768 for (size_t i = 0; i < pos->size(); ++i) {
769 auto val = pos->elements()[i]->BuildValue();
770 MS_EXCEPTION_IF_NULL(val);
771 auto int_val = LongToSize(dyn_cast<Int64Imm>(val)->value());
772 ++int_val; // Ignore the env position.
773 if (int_val >= tuple_arg->size()) {
774 MS_EXCEPTION(IndexError) << "Position index " << (int_val - 1) << " is exceed input size.";
775 }
776 if (!CanGradArgument(tuple_arg, int_val)) {
777 continue;
778 }
779 current_element =
780 fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(int_val))});
781 if (return_ids) {
782 current_element =
783 fg->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), NewValueNode(SizeToLong(int_val) - 1), current_element});
784 }
785 pos_elements.push_back(current_element);
786 }
787
788 // The returned result may vary for grad result element number.
789 // A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
790 //
791 // Notice that even if the user set 'grad_position' as multiple choices,
792 // the 'CanGradArgument' may change it to only one choice or none choice.
793 constexpr size_t args_least_size = 2;
794 if (pos_elements.size() == args_least_size) {
795 fg->set_output(current_element);
796 } else if (pos_elements.size() > args_least_size) {
797 fg->set_output(fg->NewCNodeInOrder(pos_elements));
798 } else { // The 'pos' is empty AbstractTuple.
799 auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
800 auto empty_tuple = NewValueNode(empty_tuple_value);
801 fg->set_output(empty_tuple);
802 }
803 }
804 } // namespace
805
GenerateTailFuncGraph(const AbstractSequencePtr & sequence_arg) const806 FuncGraphPtr Tail::GenerateTailFuncGraph(const AbstractSequencePtr &sequence_arg) const {
807 MS_EXCEPTION_IF_NULL(sequence_arg);
808 FuncGraphPtr fg = std::make_shared<FuncGraph>();
809 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
810 fg->debug_info()->set_name("tail");
811
812 AnfNodePtr tuple_parameter = fg->add_parameter();
813 std::vector<AnfNodePtr> elements;
814 PrimitivePtr op = nullptr;
815 if (sequence_arg->isa<AbstractTuple>()) {
816 (void)elements.emplace_back(NewValueNode(prim::kPrimMakeTuple));
817 op = prim::kPrimTupleGetItem;
818 } else {
819 (void)elements.emplace_back(NewValueNode(prim::kPrimMakeList));
820 op = prim::kPrimListGetItem;
821 }
822
823 // Remove the first element to make a new sequence.
824 for (size_t i = 1; i < sequence_arg->size(); ++i) {
825 elements.push_back(fg->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
826 }
827 if (elements.size() > 1) {
828 fg->set_output(fg->NewCNodeInOrder(elements));
829 return fg;
830 }
831
832 // No element left, return empty tuple.
833 if (sequence_arg->isa<AbstractTuple>()) {
834 auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
835 auto empty_tuple = NewValueNode(empty_tuple_value);
836 fg->set_output(empty_tuple);
837 }
838 // No element left, return empty list.
839 auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
840 auto empty_tuple = NewValueNode(empty_tuple_value);
841 fg->set_output(empty_tuple);
842 return fg;
843 }
844
GenerateGradFuncGraph(const AbstractTuplePtr & tuple_arg,const AbstractTuplePtr & position) const845 FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, const AbstractTuplePtr &position) const {
846 MS_EXCEPTION_IF_NULL(tuple_arg);
847 FuncGraphPtr fg = std::make_shared<FuncGraph>();
848 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
849 fg->debug_info()->set_name("grad_tail");
850
851 if (tail_type_ == kGradFirst) {
852 AnfNodePtr tuple_parameter = fg->add_parameter();
853 if (CanGradArgument(tuple_arg, 1) || EnableGradFirstForTuple(tuple_arg, enable_tuple_grad_first_)) {
854 fg->set_output(
855 fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(1))}));
856 } else {
857 fg->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList())));
858 }
859 return fg;
860 }
861
862 if (tail_type_ == kGradByPosition) {
863 GenerateFuncGraphByPosition(fg, tuple_arg, position, return_ids_);
864 return fg;
865 }
866
867 if (tail_type_ == kGradAll) {
868 AnfNodePtr tuple_parameter = fg->add_parameter();
869 std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
870 for (size_t i = 1; i < tuple_arg->size(); ++i) {
871 MS_EXCEPTION_IF_NULL((*tuple_arg)[i]);
872 if (CanGradArgument(tuple_arg, i)) {
873 elements.push_back(
874 fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))}));
875 }
876 }
877
878 // We should deal with 'get_all=True' as other options later:
879 // "The returned result may vary for grad result element number.
880 // A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
881 //
882 // Notice that even if the user set 'get_all=True' and pass multiple inputs,
883 // the 'CanGradArgument' may change it to only one gradient output or no gradient."
884 constexpr size_t args_least_size = 2;
885 if (elements.size() >= args_least_size) {
886 fg->set_output(fg->NewCNodeInOrder(elements));
887 return fg;
888 }
889 // Empty tuple.
890 auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
891 auto empty_tuple = NewValueNode(empty_tuple_value);
892 fg->set_output(empty_tuple);
893 return fg;
894 }
895 MS_LOG(INTERNAL_EXCEPTION) << "'tail_type_' is not for GradOperation, but " << tail_type_;
896 }
897
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)898 FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
899 // To handle normal tail.
900 if (args_abs_list.size() < 1) {
901 MS_LOG(EXCEPTION) << "'Tail' requires at least 1 argument, but got " << args_abs_list.size();
902 }
903 if (tail_type_ >= kNotGrad) {
904 AbstractSequencePtr sequence_arg = dyn_cast<AbstractSequence>(args_abs_list[0]);
905 if (sequence_arg == nullptr) {
906 MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple or list, but got " << args_abs_list[0]->ToString();
907 }
908 return GenerateTailFuncGraph(sequence_arg);
909 }
910
911 // To handle for GradOperation tail.
912 constexpr size_t args_max_size = 2;
913 if (args_abs_list.size() > args_max_size) {
914 MS_LOG(EXCEPTION) << "'Tail' requires at most 2 arguments for GradOperation, but got " << args_abs_list.size();
915 }
916 AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_abs_list[0]);
917 if (tuple_arg == nullptr) {
918 MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple, but got " << args_abs_list[0]->ToString();
919 }
920 if (args_abs_list.size() == args_max_size) {
921 AbstractTuplePtr pos = dyn_cast<AbstractTuple>(args_abs_list[1]);
922 if (pos == nullptr) {
923 MS_LOG(EXCEPTION) << "'Tail' arg1 'position' must be tuple, but got " << args_abs_list[1]->ToString();
924 }
925 return GenerateGradFuncGraph(tuple_arg, pos);
926 }
927 return GenerateGradFuncGraph(tuple_arg);
928 }
929 namespace {
CreateGradOutputs(const FuncGraphPtr & k_child,const AnfNodePtr & gradient,const AnfNodePtr & f_app,bool has_aux,bool get_value)930 AnfNodePtr CreateGradOutputs(const FuncGraphPtr &k_child, const AnfNodePtr &gradient, const AnfNodePtr &f_app,
931 bool has_aux, bool get_value) {
932 if (get_value) {
933 return k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), f_app, gradient});
934 }
935 if (!has_aux) {
936 return gradient;
937 }
938 PrimitivePtr get_tuple_item_op = prim::kPrimTupleGetItem;
939 PrimitivePtr make_tuple_op = prim::kPrimMakeTuple;
940 std::vector<AnfNodePtr> elements = {NewValueNode(make_tuple_op)};
941 (void)elements.emplace_back(
942 k_child->NewCNodeInOrder({NewValueNode(get_tuple_item_op), f_app, NewValueNode(static_cast<int64_t>(1))}));
943 auto aux_output = k_child->NewCNodeInOrder(elements);
944 auto unpack_node =
945 k_child->NewCNodeInOrder({NewValueNode(get_tuple_item_op), aux_output, NewValueNode(static_cast<int64_t>(0))});
946 return k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), gradient, unpack_node});
947 }
948 } // namespace
949
950 // When set aux True, for out1, out2, out3 = fn(inputs), only first out1 contributes to differentiation of fn.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)951 FuncGraphPtr GradAux::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
952 AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_abs_list[0]);
953 if (tuple_arg == nullptr) {
954 MS_LOG(EXCEPTION) << "When has_aux is True, origin fn requires more than one outputs.\n"
955 << "'GradAux' arg0 must be tuple, but got " << args_abs_list[0]->ToString();
956 }
957 FuncGraphPtr fg = std::make_shared<FuncGraph>();
958 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
959 AnfNodePtr tuple_parameter = fg->add_parameter();
960 // get_value flag
961 (void)fg->add_parameter();
962
963 AbstractScalarPtr get_value_ptr = dyn_cast<AbstractScalar>(args_abs_list[1]);
964 bool get_value_flag = GetValue<bool>(get_value_ptr->BuildValue());
965 std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
966 elements.push_back(
967 fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(0))}));
968 if (get_value_flag) {
969 for (size_t i = 1; i < tuple_arg->size(); i++) {
970 auto aux_node =
971 fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))});
972 auto stop_gradient_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimStopGradient), aux_node});
973 elements.push_back(stop_gradient_node);
974 }
975 } else {
976 std::vector<AnfNodePtr> aux_elements = {NewValueNode(prim::kPrimMakeTuple)};
977 for (size_t i = 1; i < tuple_arg->size(); i++) {
978 auto aux_node =
979 fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))});
980 auto stop_gradient_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimStopGradient), aux_node});
981 aux_elements.push_back(stop_gradient_node);
982 }
983 elements.push_back(fg->NewCNodeInOrder(aux_elements));
984 }
985
986 constexpr size_t args_least_size = 2;
987 if (elements.size() < args_least_size) {
988 MS_LOG(EXCEPTION) << "When has_aux is True, origin fn requires more than one outputs, but got " << elements.size()
989 << " outputs.\n"
990 << trace::GetDebugInfoStr(fg->debug_info());
991 }
992 fg->set_output(fg->NewCNodeInOrder(elements));
993 return fg;
994 }
995
GradOperation(const std::string & name,bool get_all,bool get_by_list,bool sens_param,bool get_by_position,bool has_aux,bool get_value,bool return_ids,bool merge_forward)996 GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param,
997 bool get_by_position, bool has_aux, bool get_value, bool return_ids, bool merge_forward)
998 : MetaFuncGraph(name),
999 get_all_(get_all),
1000 get_by_list_(get_by_list),
1001 sens_param_(sens_param),
1002 get_by_position_(get_by_position),
1003 has_aux_(has_aux),
1004 get_value_(get_value),
1005 return_ids_(return_ids),
1006 merge_forward_(merge_forward) {
1007 if (get_by_position) {
1008 signatures_ =
1009 // def grad(func:read, weight_list:ref, position_list:ref):
1010 std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
1011 {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault},
1012 {"position_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
1013 } else if (get_by_list) {
1014 signatures_ =
1015 // def grad(func:read, weight_list:ref):
1016 std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
1017 {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
1018 }
1019 }
1020
GetGrad(const AnfNodePtr & j,const AnfNodePtr & weights,const AnfNodePtr & position,const std::vector<AnfNodePtr> & forward_graph_params,bool enable_tuple_grad,bool is_weights_none) const1021 FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
1022 const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
1023 bool is_weights_none) const {
1024 FuncGraphPtr k_child = std::make_shared<FuncGraph>();
1025 k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1026 k_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true);
1027
1028 AnfNodePtr position_node = nullptr;
1029 if (position != nullptr) {
1030 position_node = position;
1031 }
1032
1033 std::vector<AnfNodePtr> inputs;
1034 inputs.push_back(j);
1035 for (size_t i = 0; i < forward_graph_params.size(); ++i) {
1036 inputs.push_back(k_child->add_parameter());
1037 }
1038 auto k_app = k_child->NewCNodeInOrder(inputs);
1039
1040 auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem);
1041 auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
1042 auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
1043
1044 GradByParameter(k_child, f_app, bprop, weights, position_node, enable_tuple_grad, is_weights_none);
1045 return k_child;
1046 }
1047
SetNodeByParameter(const CNodePtr & grad,const FuncGraphPtr & fg) const1048 CNodePtr GradOperation::SetNodeByParameter(const CNodePtr &grad, const FuncGraphPtr &fg) const {
1049 CNodePtr fv_bprop;
1050 if (!weight_value_->isa<AbstractTuple>()) {
1051 auto weight_ref = dyn_cast<abstract::AbstractRefTensor>(weight_value_);
1052 if (weight_ref != nullptr) {
1053 auto weight_key = weight_ref->ref_key_value()->cast<RefKeyPtr>();
1054 auto param_name = weight_key->value();
1055 fv_bprop = fg->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), NewValueNode(param_name), grad});
1056 } else {
1057 MS_LOG(INTERNAL_EXCEPTION) << "Abstract of parameter should be AbstractRefTensor, but got "
1058 << weight_value_->ToString();
1059 }
1060 } else {
1061 std::vector<AnfNodePtr> params;
1062 AbstractTuplePtr weight_tuple = weight_value_->cast<AbstractTuplePtr>();
1063 const AbstractBasePtrList &elements = weight_tuple->elements();
1064 params.push_back(NewValueNode(prim::kPrimMakeTuple));
1065 for (size_t i = 0; i < weight_tuple->size(); i++) {
1066 auto weight_ref = dyn_cast<abstract::AbstractRefTensor>(elements[i]);
1067 if (weight_ref != nullptr) {
1068 auto weight_key = weight_ref->ref_key_value()->cast<RefKeyPtr>();
1069 auto param_name = weight_key->value();
1070 auto grad_value =
1071 fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), grad, NewValueNode(static_cast<int64_t>(i))});
1072 fv_bprop = fg->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), NewValueNode(param_name), grad_value});
1073 params.push_back(fv_bprop);
1074 } else {
1075 MS_LOG(INTERNAL_EXCEPTION) << "Abstract of parameter should be AbstractRefTensor, but got "
1076 << weight_value_->ToString();
1077 }
1078 }
1079 fv_bprop = fg->NewCNodeInOrder(params);
1080 }
1081 return fv_bprop;
1082 }
1083
1084 // Do grad by the parameter of GradOperation.
GradByParameter(const FuncGraphPtr & k_child,const AnfNodePtr & f_app,const AnfNodePtr & bprop,const AnfNodePtr & weights,const AnfNodePtr & position,bool enable_tuple_grad,bool is_weights_none) const1085 void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
1086 const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad,
1087 bool is_weights_none) const {
1088 MS_EXCEPTION_IF_NULL(k_child);
1089
1090 AnfNodePtr bprop_arg = nullptr;
1091 if (sens_param_) {
1092 bprop_arg = k_child->add_parameter();
1093 } else {
1094 auto ones_like = prim::GetPythonOps("ones_like");
1095 bprop_arg = k_child->NewCNodeInOrder({NewValueNode(ones_like), f_app});
1096 }
1097 AnfNodePtr b_app = k_child->NewCNodeInOrder({bprop, bprop_arg});
1098 // Add sense parameter flag for bound_node_.
1099 if (b_app->isa<CNode>() && sens_param_) {
1100 b_app->cast<CNodePtr>()->AddAttr("sens_param_", MakeValue(true));
1101 }
1102
1103 CNodePtr fv_bprop = nullptr;
1104 if (get_by_list_) {
1105 if (is_weights_none) {
1106 fv_bprop = k_child->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple)});
1107 } else {
1108 // Python code: grads = hyper_map(F.partial(env_get, env), weights)
1109 AnfNodePtr env =
1110 k_child->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))});
1111 AnfNodePtr partial_env_get =
1112 k_child->NewCNodeInOrder({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
1113 MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
1114 fv_bprop = k_child->NewCNodeInOrder({NewValueNode(hyper_map), partial_env_get, weights});
1115 if (return_ids_) {
1116 fv_bprop = SetNodeByParameter(fv_bprop, k_child);
1117 }
1118 }
1119 }
1120
1121 CNodePtr inputs_bprop = nullptr;
1122 if (get_by_position_) {
1123 TailPtr tail_grad_by_position = std::make_shared<Tail>("tail_grad_by_position", kGradByPosition, return_ids_);
1124 inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_by_position), b_app, position});
1125 } else if (get_all_) {
1126 TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll);
1127 inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app});
1128 }
1129
1130 // Gradients wrt inputs and parameters
1131 if (fv_bprop != nullptr && inputs_bprop != nullptr) {
1132 auto make_tuple = k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop});
1133 k_child->set_output(CreateGradOutputs(k_child, make_tuple, f_app, has_aux_, get_value_));
1134 return;
1135 }
1136
1137 // Gradients wrt parameters
1138 if (fv_bprop != nullptr) {
1139 k_child->set_output(CreateGradOutputs(k_child, fv_bprop, f_app, has_aux_, get_value_));
1140 return;
1141 }
1142
1143 // Gradients wrt inputs
1144 if (inputs_bprop != nullptr) {
1145 k_child->set_output(CreateGradOutputs(k_child, inputs_bprop, f_app, has_aux_, get_value_));
1146 return;
1147 }
1148 // Gradients wrt first input.
1149 // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
1150 // so obtain first input grad by setting tail_type of Tail to kGradFirst.
1151 TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
1152 tail_grad_first->set_enable_tuple_grad_first(enable_tuple_grad);
1153 auto tail_grad_first_cnode = k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app});
1154 k_child->set_output(CreateGradOutputs(k_child, tail_grad_first_cnode, f_app, has_aux_, get_value_));
1155 }
1156
1157 namespace {
1158 // Check if primal func graph has the primitive returned sparse result in its bprop().
CheckPrimBpropReturnSparse(const FuncGraphPtr & primal_graph)1159 void CheckPrimBpropReturnSparse(const FuncGraphPtr &primal_graph) {
1160 bool has_sparse_bprop_prim = false;
1161 (void)TopoSort(primal_graph->return_node(), SuccDeeperSimple,
1162 [&has_sparse_bprop_prim](const AnfNodePtr &node) -> IncludeType {
1163 MS_EXCEPTION_IF_NULL(node);
1164 if (has_sparse_bprop_prim) {
1165 return EXCLUDE;
1166 }
1167 PrimitivePtr prim = nullptr;
1168 if (node->isa<CNode>()) {
1169 prim = GetCNodePrimitiveWithoutDoSignature(node);
1170 } else {
1171 prim = GetPrimitiveWithoutDoSignature(node);
1172 }
1173 if (prim != nullptr) {
1174 bool sparse_bprop = GetPrimitiveFlag(prim, GRAPH_FLAG_BPROP_RETURN_SPARSE);
1175 if (sparse_bprop) {
1176 MS_LOG(DEBUG) << "prim: " << prim->ToString() << " has attr 'bprop_return_sparse'";
1177 has_sparse_bprop_prim = true;
1178 return EXCLUDE;
1179 }
1180 }
1181 return FOLLOW;
1182 });
1183 if (has_sparse_bprop_prim) {
1184 primal_graph->set_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP, true);
1185 EnvSetSparseResultMgr::GetInstance().Set(true);
1186 }
1187 }
1188 } // namespace
1189
1190 // Generate the graph.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1191 FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1192 if (args_abs_list.empty()) {
1193 MS_LOG(EXCEPTION)
1194 << "'GradOperation' requires a forward network or function as an input, while the input is empty.";
1195 }
1196
1197 constexpr size_t fn_index = 0;
1198 auto fn_abs = args_abs_list[fn_index];
1199 constexpr size_t len_with_weight = 2;
1200 constexpr size_t weights_index = 1;
1201 if (return_ids_ && args_abs_list.size() >= len_with_weight) {
1202 weight_value_ = args_abs_list[weights_index];
1203 }
1204 MS_EXCEPTION_IF_NULL(fn_abs);
1205 if (fn_abs->isa<AbstractClass>()) {
1206 auto class_abs = dyn_cast<AbstractClass>(fn_abs);
1207 auto class_val = class_abs->BuildValue();
1208 MS_EXCEPTION_IF_NULL(class_val);
1209 auto class_obj = class_val->cast<parse::MsClassObjectPtr>();
1210 MS_EXCEPTION_IF_NULL(class_obj);
1211 auto obj_name = std::regex_replace(class_obj->name(), std::regex("MsClassObject:"), "");
1212 MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell' type "
1213 << "object, but got object with jit_class type" << obj_name << ".";
1214 }
1215 AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(fn_abs);
1216 if (fn == nullptr) {
1217 MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell', but got "
1218 << args_abs_list[0]->ToString();
1219 }
1220
1221 auto real_fn = fn->cast_ptr<FuncGraphAbstractClosure>();
1222 if (real_fn == nullptr) {
1223 MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell', but got "
1224 << fn->ToString();
1225 }
1226 FuncGraphPtr forward_graph = real_fn->func_graph();
1227 MS_EXCEPTION_IF_NULL(forward_graph);
1228
1229 if (has_aux_) {
1230 GradAuxPtr aux_fn = std::make_shared<GradAux>("aux_fn");
1231 auto output_cnode = forward_graph->output();
1232 auto aux_fn_cnode = forward_graph->NewCNodeInOrder({NewValueNode(aux_fn), output_cnode, NewValueNode(get_value_)});
1233 forward_graph->set_output(aux_fn_cnode);
1234 }
1235
1236 forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1237
1238 // Check if primal func graph has the primitive returned sparse result in its bprop().
1239 CheckPrimBpropReturnSparse(forward_graph);
1240
1241 FuncGraphPtr grad_fg = nullptr;
1242 {
1243 TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
1244 grad_fg = std::make_shared<FuncGraph>();
1245 }
1246 auto nparam = forward_graph->parameters().size();
1247
1248 std::ostringstream ss;
1249 ss << "grad{" << nparam << "}";
1250 grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1251 grad_fg->debug_info()->set_name(ss.str());
1252 ParameterPtr param_graph = grad_fg->add_parameter();
1253
1254 bool is_weights_empty_or_none = false;
1255 AnfNodePtr weights = nullptr;
1256 AnfNodePtr position = nullptr;
1257 if (args_abs_list.size() > weights_index) {
1258 auto weights_abs = args_abs_list[weights_index];
1259 MS_EXCEPTION_IF_NULL(weights_abs);
1260 if (weights_abs->isa<AbstractSequence>()) {
1261 if (weights_abs->cast<AbstractSequencePtr>()->empty()) {
1262 is_weights_empty_or_none = true;
1263 }
1264 }
1265 }
1266 if (get_by_position_) {
1267 weights = grad_fg->add_parameter();
1268 position = grad_fg->add_parameter();
1269 } else if (get_by_list_) {
1270 weights = grad_fg->add_parameter();
1271 // Check if weights is None.
1272 if (!is_weights_empty_or_none && args_abs_list.size() > weights_index) {
1273 auto weights_abs = args_abs_list[weights_index];
1274 MS_EXCEPTION_IF_NULL(weights_abs);
1275 if (weights_abs->isa<AbstractNone>()) {
1276 is_weights_empty_or_none = true;
1277 }
1278 }
1279 }
1280
1281 std::vector<AnfNodePtr> inputs;
1282 inputs.push_back(NewValueNode(prim::kPrimJ));
1283 inputs.push_back(param_graph);
1284 auto j = grad_fg->NewCNodeInOrder(inputs);
1285 if (merge_forward_) {
1286 j->set_user_data<bool>("merge_forward", std::make_shared<bool>(true));
1287 }
1288 // df is checked in GetGrad
1289 FuncGraphPtr k_child = nullptr;
1290 {
1291 TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
1292 k_child = GetGrad(j, weights, position, forward_graph->parameters(),
1293 forward_graph->has_flag("enable_tuple_grad_first"), is_weights_empty_or_none);
1294 k_child->set_flag(FUNC_GRAPH_FLAG_ARGS_NO_EXPAND, true);
1295 }
1296 grad_fg->set_output(NewValueNode(k_child));
1297
1298 return grad_fg;
1299 }
1300
1301 // Generate the vmap_graph.
VmapOperation(const std::string & name)1302 VmapOperation::VmapOperation(const std::string &name) : MetaFuncGraph(name) {
1303 auto default_zero = std::make_shared<Int64Imm>(static_cast<int64_t>(0));
1304 signatures_ =
1305 // def vmap(func:read, in_axes:ref, out_axes:ref):
1306 std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
1307 {"in_axes", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault, default_zero,
1308 SignatureEnumDType::kDTypeEmptyDefaultValue},
1309 {"out_axes", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault, default_zero,
1310 SignatureEnumDType::kDTypeEmptyDefaultValue}});
1311 }
1312
GetVmap(const AnfNodePtr & vmap,int param_number) const1313 FuncGraphPtr VmapOperation::GetVmap(const AnfNodePtr &vmap, int param_number) const {
1314 FuncGraphPtr vmap_child = std::make_shared<FuncGraph>();
1315 vmap_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1316 vmap_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true);
1317
1318 std::vector<AnfNodePtr> inputs;
1319 inputs.push_back(vmap);
1320 for (int i = 0; i < param_number; ++i) {
1321 inputs.push_back(vmap_child->add_parameter());
1322 }
1323 auto vmap_app = vmap_child->NewCNodeInOrder(inputs);
1324 vmap_child->set_output(vmap_app);
1325
1326 return vmap_child;
1327 }
1328
1329 namespace {
IsAxesAllNone(const ValuePtr & axes)1330 bool IsAxesAllNone(const ValuePtr &axes) {
1331 MS_EXCEPTION_IF_NULL(axes);
1332 ValueSequencePtr axes_seq = dyn_cast<ValueSequence>(axes);
1333 auto axes_seq_value = axes_seq->value();
1334 if (std::all_of(axes_seq_value.begin(), axes_seq_value.end(), [](const ValuePtr &axes_value_ptr) {
1335 if (axes_value_ptr->isa<ValueSequence>()) {
1336 return IsAxesAllNone(axes_value_ptr);
1337 }
1338 if (!axes_value_ptr->isa<None>()) {
1339 return false;
1340 }
1341 return true;
1342 })) {
1343 return true;
1344 }
1345 return false;
1346 }
1347
CheckAxes(const AbstractBasePtr & axes_abs,bool is_in_axes=false,int nparam=0,size_t cell_size=0)1348 ValuePtr CheckAxes(const AbstractBasePtr &axes_abs, bool is_in_axes = false, int nparam = 0, size_t cell_size = 0) {
1349 ValuePtr axes_value = nullptr;
1350 auto axes_name = is_in_axes ? "in_axes" : "out_axes";
1351
1352 auto axes_abs_sequence = dyn_cast<AbstractSequence>(axes_abs);
1353 if (axes_abs_sequence != nullptr) {
1354 axes_value = axes_abs->cast<AbstractSequencePtr>()->ElementsBuildValue<ValueTuple>();
1355 MS_EXCEPTION_IF_NULL(axes_value);
1356 if (is_in_axes) {
1357 ValueSequencePtr in_axes_seq = dyn_cast<ValueSequence>(axes_value);
1358 int in_axes_size = SizeToInt(in_axes_seq->size());
1359 if (nparam != in_axes_size) {
1360 MS_LOG(EXCEPTION) << "When vmap`s '" << axes_name
1361 << "' is a tuple or list, and its size must be equal to the number of arguments of 'fn': "
1362 << nparam << ", but got size: " << in_axes_size << ".";
1363 }
1364 }
1365 bool elem_all_none = IsAxesAllNone(axes_value);
1366 if (elem_all_none && cell_size == 0) {
1367 MS_LOG(EXCEPTION) << "The '" << axes_name
1368 << "' of 'vmap' cannot be all None while 'fn' is not a 'CellList', but got "
1369 << axes_value->ToString() << ".";
1370 }
1371 } else {
1372 axes_value = axes_abs->BuildValue();
1373 MS_EXCEPTION_IF_NULL(axes_value);
1374 if (axes_value->isa<None>() && cell_size == 0) {
1375 MS_LOG(EXCEPTION) << "The '" << axes_name
1376 << "' of 'vmap' cannot be a single None while 'fn' is not a 'CellList'.";
1377 } else if (!axes_value->isa<None>() && !axes_value->isa<Int64Imm>()) {
1378 MS_LOG(EXCEPTION) << "The axis in vmap`s '" << axes_name << "' can only be of type Int or None, but got "
1379 << axes_abs->ToString() << ".";
1380 }
1381 }
1382 return axes_value;
1383 }
1384
CheckVmapFunc(const AbstractBasePtr & fn_arg,int * nparam,size_t * cell_size)1385 DebugInfoPtr CheckVmapFunc(const AbstractBasePtr &fn_arg, int *nparam, size_t *cell_size) {
1386 DebugInfoPtr origin_graph_info = nullptr;
1387 // In the model ensembling parallel training scenario, fn is a CellList.
1388 AbstractTuplePtr cell_list = dyn_cast<AbstractTuple>(fn_arg);
1389 if (cell_list != nullptr) {
1390 *cell_size = cell_list->size();
1391 if (*cell_size <= 1) {
1392 MS_LOG(EXCEPTION) << "In the model ensembling parallel training scenario ('VmapOperation' arg0 is a 'CellList'),"
1393 << " the size of 'CellList' must be greater than 1, but got " << *cell_size << ".";
1394 }
1395 const AbstractBasePtrList &cell_list_fns = cell_list->elements();
1396 for (auto fn_abs : cell_list_fns) {
1397 MS_EXCEPTION_IF_NULL(fn_abs);
1398 AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(fn_abs);
1399 if (fn == nullptr) {
1400 MS_LOG(EXCEPTION) << "'VmapOperation' arg0 is a 'CellList', whose elements must be 'Cell', but got "
1401 << fn_abs->ToString() << ".";
1402 }
1403 auto partial_fn = dyn_cast<PartialAbstractClosure>(fn_abs);
1404 if (partial_fn != nullptr) {
1405 fn = partial_fn->fn();
1406 }
1407 auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
1408 if (real_fn == nullptr) {
1409 MS_LOG(EXCEPTION) << "'VmapOperation' arg0 is a 'CellList', whose element " << fn->ToString()
1410 << " cast to 'FuncGraphAbstractClosure' failed.";
1411 }
1412
1413 FuncGraphPtr orig_graph = real_fn->func_graph();
1414 MS_EXCEPTION_IF_NULL(orig_graph);
1415 orig_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1416 int fn_nparam =
1417 SizeToInt(orig_graph->parameters().size() - (partial_fn != nullptr ? partial_fn->args().size() : 0));
1418 if (*nparam == -1) {
1419 origin_graph_info = orig_graph->debug_info();
1420 *nparam = fn_nparam;
1421 } else if (*nparam != fn_nparam) {
1422 MS_LOG(EXCEPTION) << "'VmapOperation' arg0 is a CellList, whose elements's inputs should be consistent.";
1423 }
1424 }
1425 } else {
1426 AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(fn_arg);
1427 if (fn == nullptr) {
1428 MS_LOG(EXCEPTION) << "'VmapOperation' arg0 must be a 'Function' or 'Cell', but got " << fn_arg->ToString() << ".";
1429 }
1430 auto partial_fn = dyn_cast<PartialAbstractClosure>(fn);
1431 if (partial_fn != nullptr) {
1432 fn = partial_fn->fn();
1433 }
1434 auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
1435 if (real_fn == nullptr) {
1436 MS_LOG(EXCEPTION) << "'VmapOperation' arg0 " << fn->ToString() << " cast to 'FuncGraphAbstractClosure' failed.";
1437 }
1438
1439 FuncGraphPtr orig_graph = real_fn->func_graph();
1440 MS_EXCEPTION_IF_NULL(orig_graph);
1441 orig_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1442 *nparam = SizeToInt(orig_graph->parameters().size() - (partial_fn != nullptr ? partial_fn->args().size() : 0));
1443 origin_graph_info = orig_graph->debug_info();
1444 }
1445 return origin_graph_info;
1446 }
1447 } // namespace
1448
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1449 FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1450 if (args_abs_list.empty()) {
1451 MS_LOG(EXCEPTION) << "'VmapOperation' requires a network or function as an input, while the input is empty.";
1452 }
1453
1454 constexpr auto vmap_operation_input_num = 3;
1455 const std::string op_name = "vmap";
1456 CheckArgsSize(op_name, args_abs_list, vmap_operation_input_num);
1457
1458 auto fn_arg = args_abs_list[0];
1459 auto in_axes_arg = args_abs_list[1];
1460 auto out_axes_arg = args_abs_list[2];
1461
1462 int nparam = -1;
1463 size_t cell_size = 0;
1464 DebugInfoPtr origin_graph_info = CheckVmapFunc(fn_arg, &nparam, &cell_size);
1465
1466 FuncGraphPtr vmap_fg = nullptr;
1467 {
1468 TraceGuard guard(std::make_shared<TraceVmapOperation>(origin_graph_info));
1469 vmap_fg = std::make_shared<FuncGraph>();
1470 }
1471
1472 std::ostringstream ss;
1473 ss << "vmap{" << nparam << "}";
1474 vmap_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1475 vmap_fg->debug_info()->set_name(ss.str());
1476
1477 // Add parameter for `fn`, `in_axes` and `out_axes` respectively.
1478 ParameterPtr param_graph = vmap_fg->add_parameter();
1479 (void)vmap_fg->add_parameter();
1480 (void)vmap_fg->add_parameter();
1481
1482 // Validity verification of in_axes and out_axes
1483 ValuePtr in_axes = CheckAxes(in_axes_arg, true, nparam, cell_size);
1484 ValuePtr out_axes = CheckAxes(out_axes_arg);
1485
1486 PrimitivePtr kprim_vmap = std::make_shared<Primitive>(kVmapOpName, kSideEffectPropagate);
1487 kprim_vmap->set_attr("in_axes", in_axes);
1488 kprim_vmap->set_attr("out_axes", out_axes);
1489 kprim_vmap->set_attr("cell_size", MakeValue(cell_size));
1490
1491 std::vector<AnfNodePtr> inputs;
1492 inputs.push_back(NewValueNode(kprim_vmap));
1493 inputs.push_back(param_graph);
1494 auto vmap = vmap_fg->NewCNodeInOrder(inputs);
1495
1496 FuncGraphPtr vmap_child = nullptr;
1497 {
1498 TraceGuard guard(std::make_shared<TraceVmapOperation>(origin_graph_info));
1499 vmap_child = GetVmap(vmap, nparam);
1500 }
1501
1502 vmap_fg->set_output(NewValueNode(vmap_child));
1503 return vmap_fg;
1504 }
1505
TaylorOperation(const std::string & name)1506 TaylorOperation::TaylorOperation(const std::string &name) : MetaFuncGraph(name) {
1507 // def Taylor(func:read):
1508 signatures_ = std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}});
1509 }
1510
GetTaylorGrad(const AnfNodePtr & k,const std::vector<AnfNodePtr> & forward_graph_params) const1511 FuncGraphPtr TaylorOperation::GetTaylorGrad(const AnfNodePtr &k,
1512 const std::vector<AnfNodePtr> &forward_graph_params) const {
1513 FuncGraphPtr k_child = std::make_shared<FuncGraph>();
1514 k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1515
1516 std::vector<AnfNodePtr> inputs;
1517 inputs.push_back(k);
1518 MS_LOG(INFO) << "TaylorOperation forward input size " << forward_graph_params.size();
1519 for (size_t i = 0; i < forward_graph_params.size(); ++i) {
1520 inputs.push_back(k_child->add_parameter());
1521 }
1522 // Taylor(fn)(input params)
1523 auto k_app = k_child->NewCNodeInOrder(inputs);
1524
1525 k_child->set_output(k_app);
1526 return k_child;
1527 }
1528
1529 // Generate the graph to calculate higher order derivatives.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1530 FuncGraphPtr TaylorOperation::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1531 if (args_abs_list.empty()) {
1532 MS_LOG(EXCEPTION)
1533 << "'TaylorOperation' requires a forward network or function as an input, while the input is empty.";
1534 }
1535
1536 MS_EXCEPTION_IF_NULL(args_abs_list[0]);
1537 AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_abs_list[0]);
1538 if (fn == nullptr) {
1539 MS_LOG(EXCEPTION) << "'TaylorOperation' arg0 must be a 'Function' or 'Cell', but got "
1540 << args_abs_list[0]->ToString();
1541 }
1542
1543 auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
1544 MS_EXCEPTION_IF_NULL(real_fn);
1545
1546 FuncGraphPtr forward_graph = real_fn->func_graph();
1547 MS_EXCEPTION_IF_NULL(forward_graph);
1548 forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1549 FuncGraphPtr grad_fg = nullptr;
1550 MS_LOG(INFO) << "'TaylorOperation' forward_graph" << forward_graph->debug_info();
1551 grad_fg = std::make_shared<FuncGraph>();
1552 auto nparam = forward_graph->parameters().size();
1553
1554 std::ostringstream ss;
1555 ss << "taylorgrad{" << nparam << "}";
1556 grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1557 grad_fg->debug_info()->set_name(ss.str());
1558 ParameterPtr param_graph = grad_fg->add_parameter();
1559
1560 std::vector<AnfNodePtr> inputs;
1561 inputs.push_back(NewValueNode(prim::kPrimTaylor));
1562 inputs.push_back(param_graph);
1563 // Taylor(fn)
1564 auto mark_taylor = grad_fg->NewCNodeInOrder(inputs);
1565 FuncGraphPtr k_child = nullptr;
1566 {
1567 TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
1568 k_child = GetTaylorGrad(mark_taylor, forward_graph->parameters());
1569 }
1570 grad_fg->set_output(NewValueNode(k_child));
1571 // return Taylor(fn)(inputs)
1572 return grad_fg;
1573 }
1574
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1575 FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1576 // args: tuple1, tuple2
1577 abstract::CheckArgsSize("TupleAdd", args_abs_list, 2);
1578 AbstractBasePtr abs_a = args_abs_list[0];
1579 AbstractBasePtr abs_b = args_abs_list[1];
1580
1581 AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
1582 AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
1583 if (a_tuple == nullptr || b_tuple == nullptr) {
1584 TypePtrList types;
1585 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(types),
1586 [](const AbstractBasePtr &arg) -> TypePtr {
1587 MS_EXCEPTION_IF_NULL(arg);
1588 return arg->BuildType();
1589 });
1590 auto stub = GenerateStubFunc(types);
1591 if (stub != nullptr) {
1592 MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
1593 << ", function: " << stub->ToString();
1594 return stub;
1595 }
1596 MS_LOG(EXCEPTION) << "The type of argument in TupleAdd operator should be tuple, but the first argument is "
1597 << args_abs_list[0]->ToString() << ", the second argument is " << args_abs_list[1]->ToString();
1598 }
1599
1600 FuncGraphPtr ret = std::make_shared<FuncGraph>();
1601 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1602 AnfNodePtr p_tup_a = ret->add_parameter();
1603 AnfNodePtr p_tup_b = ret->add_parameter();
1604
1605 std::vector<AnfNodePtr> elems;
1606 elems.push_back(NewValueNode(prim::kPrimMakeTuple));
1607
1608 int64_t tuple_size = SizeToLong(a_tuple->size());
1609 for (int64_t i = 0; i < tuple_size; ++i) {
1610 elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
1611 }
1612
1613 tuple_size = SizeToLong(b_tuple->size());
1614 for (int64_t i = 0; i < tuple_size; ++i) {
1615 elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
1616 }
1617
1618 ret->set_output(ret->NewCNodeInOrder(elems));
1619 return ret;
1620 }
1621
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1622 FuncGraphPtr ListAdd::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1623 // args: list1, list2
1624 abstract::CheckArgsSize("ListAdd", args_abs_list, 2);
1625 AbstractBasePtr abs_a = args_abs_list[0];
1626 AbstractBasePtr abs_b = args_abs_list[1];
1627
1628 AbstractListPtr a_list = dyn_cast<AbstractList>(abs_a);
1629 AbstractListPtr b_list = dyn_cast<AbstractList>(abs_b);
1630 if (a_list == nullptr || b_list == nullptr) {
1631 TypePtrList types;
1632 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(types),
1633 [](const AbstractBasePtr &arg) -> TypePtr {
1634 MS_EXCEPTION_IF_NULL(arg);
1635 return arg->BuildType();
1636 });
1637 auto stub = GenerateStubFunc(types);
1638 if (stub != nullptr) {
1639 MS_LOG(DEBUG) << "GenerateStubFunc for ListAdd "
1640 << ", function: " << stub->ToString();
1641 return stub;
1642 }
1643 MS_LOG(EXCEPTION) << "The type of argument in ListAdd operator should be list, but the first argument is "
1644 << args_abs_list[0]->ToString() << ", the second argument is " << args_abs_list[1]->ToString();
1645 }
1646
1647 FuncGraphPtr ret = std::make_shared<FuncGraph>();
1648 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1649 AnfNodePtr p_list_a = ret->add_parameter();
1650 AnfNodePtr p_list_b = ret->add_parameter();
1651
1652 std::vector<AnfNodePtr> elems;
1653 elems.push_back(NewValueNode(prim::kPrimMakeList));
1654
1655 int64_t tuple_size = SizeToLong(a_list->size());
1656 for (int64_t i = 0; i < tuple_size; ++i) {
1657 elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), p_list_a, NewValueNode(i)}));
1658 }
1659
1660 tuple_size = SizeToLong(b_list->size());
1661 for (int64_t i = 0; i < tuple_size; ++i) {
1662 elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), p_list_b, NewValueNode(i)}));
1663 }
1664
1665 ret->set_output(ret->NewCNodeInOrder(elems));
1666 return ret;
1667 }
1668
GetArgScalarValue(const abstract::AbstractScalarPtr & scalar,const std::string &)1669 int64_t GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
1670 MS_EXCEPTION_IF_NULL(scalar);
1671 return GetValue<int64_t>(scalar->BuildValue());
1672 }
1673
GetPositiveIndex(int64_t index,int64_t length)1674 int64_t GetPositiveIndex(int64_t index, int64_t length) {
1675 if (index < 0) {
1676 index += length;
1677 }
1678 return index;
1679 }
1680
CheckSliceMember(const AbstractBasePtr & member,int64_t default_value,const std::string & member_name)1681 int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, const std::string &member_name) {
1682 MS_EXCEPTION_IF_NULL(member);
1683
1684 if (member->isa<AbstractScalar>()) {
1685 return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
1686 }
1687
1688 if (member->isa<AbstractNone>()) {
1689 return default_value;
1690 }
1691
1692 if (member->isa<AbstractTensor>()) {
1693 MS_EXCEPTION(TypeError)
1694 << "The argument of SliceMember operator must be a Scalar or None or constant Tensor, but got a variable Tensor";
1695 }
1696 MS_EXCEPTION(TypeError)
1697 << "The argument of SliceMember operator must be a Scalar or None or constant Tensor, but got "
1698 << member->BuildType()->ToString();
1699 }
1700
GenerateTupleSliceParameter(const AbstractSequencePtr & sequence,const AbstractSlicePtr & slice)1701 std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const AbstractSequencePtr &sequence,
1702 const AbstractSlicePtr &slice) {
1703 MS_EXCEPTION_IF_NULL(sequence);
1704 MS_EXCEPTION_IF_NULL(slice);
1705 int64_t start_index;
1706 int64_t stop_index;
1707 int64_t step_value;
1708
1709 const std::string start_name("Slice start index");
1710 const std::string stop_name("Slice stop index");
1711 const std::string step_name("Slice step value");
1712
1713 int64_t tuple_size = SizeToLong(sequence->size());
1714 int64_t start_default = 0;
1715 int64_t stop_default = tuple_size;
1716 int64_t step_default = kStepDefault;
1717
1718 step_value = CheckSliceMember(slice->step(), step_default, step_name);
1719 if (step_value == 0) {
1720 MS_EXCEPTION(ValueError) << "Slice step cannot be zero.";
1721 }
1722
1723 if (step_value < 0) {
1724 start_default = tuple_size - 1;
1725 stop_default = ((-tuple_size) - 1);
1726 }
1727
1728 start_index = CheckSliceMember(slice->start(), start_default, start_name);
1729 stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
1730
1731 if (start_index < -tuple_size) {
1732 start_index = 0;
1733 }
1734
1735 if (stop_index > tuple_size) {
1736 stop_index = tuple_size;
1737 }
1738
1739 if (start_index > tuple_size) {
1740 start_index = tuple_size;
1741 }
1742
1743 if (stop_index < ((-tuple_size) - 1)) {
1744 stop_index = 0;
1745 }
1746
1747 start_index = GetPositiveIndex(start_index, tuple_size);
1748
1749 stop_index = GetPositiveIndex(stop_index, tuple_size);
1750
1751 return std::make_tuple(start_index, stop_index, step_value);
1752 }
1753
CheckArgs(const AbstractBasePtrList & args_abs_list)1754 void SequenceSliceGetItem::CheckArgs(const AbstractBasePtrList &args_abs_list) {
1755 constexpr size_t arg_size = 2;
1756 abstract::CheckArgsSize(this->name(), args_abs_list, arg_size);
1757 sequence_ = abstract::CheckArg<AbstractSequence>(this->name(), args_abs_list, 0);
1758 slice_ = abstract::CheckArg<AbstractSlice>(this->name(), args_abs_list, 1);
1759 }
1760
BuildFuncGraph(int64_t start_index,int64_t stop_index,int64_t step_value)1761 FuncGraphPtr SequenceSliceGetItem::BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) {
1762 FuncGraphPtr ret = std::make_shared<FuncGraph>();
1763 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1764 AnfNodePtr p_seq = ret->add_parameter();
1765 (void)ret->add_parameter();
1766
1767 std::vector<AnfNodePtr> elems;
1768 elems.push_back(NewValueNode(prim_));
1769 if (step_value > 0) {
1770 for (int64_t index = start_index; index < stop_index; index = index + step_value) {
1771 elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_seq, NewValueNode(index)}));
1772 }
1773 } else {
1774 for (int64_t index = start_index; index > stop_index; index = index + step_value) {
1775 elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_seq, NewValueNode(index)}));
1776 }
1777 }
1778
1779 ret->set_output(ret->NewCNodeInOrder(elems));
1780 return ret;
1781 }
1782
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1783 FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1784 // select indexed item
1785 // args: tuple of items, index
1786 const std::string op_name = std::string("TupleGetItemTensor");
1787 const size_t inputs_size = 2;
1788 abstract::CheckArgsSize(op_name, args_abs_list, inputs_size);
1789 auto ret_graph = std::make_shared<FuncGraph>();
1790 ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1791 auto tuple = ret_graph->add_parameter();
1792 auto index = ret_graph->add_parameter();
1793
1794 constexpr size_t tuple_index = 0;
1795 auto abs = args_abs_list[tuple_index];
1796 MS_EXCEPTION_IF_NULL(abs);
1797 auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
1798 MS_EXCEPTION_IF_NULL(tuple_abs);
1799 if (!tuple_abs->dynamic_len()) {
1800 const auto &elements = tuple_abs->elements();
1801 if (std::all_of(elements.begin(), elements.end(), [](const AbstractBasePtr &e) {
1802 MS_EXCEPTION_IF_NULL(e);
1803 return e->isa<abstract::FuncGraphAbstractClosure>() || e->isa<abstract::PartialAbstractClosure>() ||
1804 e->isa<abstract::PrimitiveAbstractClosure>();
1805 })) {
1806 ret_graph->set_output(ret_graph->NewCNodeInOrder({NewValueNode(prim::kPrimSwitchLayer), index, tuple}));
1807 return ret_graph;
1808 }
1809 }
1810
1811 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1812 if (!allow_fallback_runtime) {
1813 MS_EXCEPTION(TypeError) << "When JIT_SYNTAX_LEVEL is STRICT, using Tensor index to get value from tuple requires "
1814 << "that all elements in tuple should be function but got tuple abstract: "
1815 << tuple_abs->ToString();
1816 }
1817 // Script
1818 constexpr auto internal_tuple_input = "__internal_tuple_input__";
1819 constexpr auto internal_index_input = "__internal_index_input__";
1820 std::stringstream script_buffer;
1821 script_buffer << internal_tuple_input << "[" << internal_index_input << "]";
1822 const std::string &script = script_buffer.str();
1823 const auto script_str = std::make_shared<StringImm>(script);
1824 // Key
1825 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1826 (void)key_value_names_list.emplace_back(NewValueNode(internal_tuple_input));
1827 (void)key_value_names_list.emplace_back(NewValueNode(internal_index_input));
1828 const auto key_value_name_tuple = ret_graph->NewCNode(key_value_names_list);
1829 // Value
1830 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1831 (void)key_value_list.emplace_back(tuple);
1832 (void)key_value_list.emplace_back(index);
1833 const auto key_value_tuple = ret_graph->NewCNode(key_value_list);
1834 auto res =
1835 fallback::CreatePyExecuteCNode(ret_graph, NewValueNode(script_str), key_value_name_tuple, key_value_tuple, nullptr);
1836 ret_graph->set_output(res);
1837 return ret_graph;
1838 }
1839
1840 namespace {
GetShard(const AnfNodePtr & shard,const std::vector<AnfNodePtr> & origin_graph_params)1841 FuncGraphPtr GetShard(const AnfNodePtr &shard, const std::vector<AnfNodePtr> &origin_graph_params) {
1842 FuncGraphPtr shard_child = std::make_shared<FuncGraph>();
1843 shard_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1844
1845 std::vector<AnfNodePtr> inputs;
1846 inputs.reserve(origin_graph_params.size() + 1);
1847 (void)inputs.emplace_back(shard);
1848 for (size_t i = 0; i < origin_graph_params.size(); ++i) {
1849 (void)inputs.emplace_back(shard_child->add_parameter());
1850 }
1851 auto shard_app = shard_child->NewCNodeInOrder(std::move(inputs));
1852
1853 shard_child->set_output(shard_app);
1854 return shard_child;
1855 }
1856 } // namespace
1857
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1858 FuncGraphPtr Shard::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1859 if (args_abs_list.size() != kShardInputSize) {
1860 MS_LOG(EXCEPTION) << "'Shard' requires " << kShardInputSize
1861 << " inputs. Includes a Cell or function, in_axes, out_axes, parameter_plan, device and level.";
1862 }
1863
1864 MS_EXCEPTION_IF_NULL(args_abs_list[0]);
1865 AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_abs_list[0]);
1866 if (fn == nullptr) {
1867 MS_LOG(EXCEPTION) << "'Shard' arg0 must be a 'Function' or 'Cell', but got " << args_abs_list[0]->ToString() << ".";
1868 }
1869
1870 auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
1871 MS_EXCEPTION_IF_NULL(real_fn);
1872 FuncGraphPtr origin_graph = real_fn->func_graph();
1873 MS_EXCEPTION_IF_NULL(origin_graph);
1874 auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
1875 if (execution_mode == kPynativeMode) {
1876 origin_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1877 }
1878 FuncGraphPtr shard_fg = nullptr;
1879 {
1880 TraceGuard g(std::make_shared<TraceShard>(origin_graph->debug_info()));
1881 shard_fg = std::make_shared<FuncGraph>();
1882 }
1883 // Create the debug info
1884 auto parameter_size = origin_graph->parameters().size();
1885 std::ostringstream ss;
1886 ss << "shard{" << parameter_size << "}";
1887 shard_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1888 shard_fg->debug_info()->set_name(ss.str());
1889 // Make the Shard node.
1890 std::vector<AnfNodePtr> inputs;
1891 inputs.reserve(args_abs_list.size() + 1);
1892 (void)inputs.emplace_back(NewValueNode(prim::kPrimShard));
1893 for (size_t i = 0; i < args_abs_list.size(); ++i) {
1894 (void)inputs.emplace_back(shard_fg->add_parameter());
1895 }
1896 auto shard = shard_fg->NewCNodeInOrder(std::move(inputs));
1897
1898 FuncGraphPtr shard_child = nullptr;
1899 {
1900 TraceGuard guard(std::make_shared<TraceShard>(shard_fg->debug_info()));
1901 shard_child = GetShard(shard, origin_graph->parameters());
1902 }
1903 shard_fg->set_output(NewValueNode(shard_child));
1904 return shard_fg;
1905 }
1906
CheckArgs(const AbstractBasePtrList & args_abs_list)1907 void ListSliceSetItem::CheckArgs(const AbstractBasePtrList &args_abs_list) {
1908 constexpr size_t kSliceSetItemArgsSizeargs_size = 3;
1909 constexpr size_t kSliceSetItemListIndex = 0;
1910 constexpr size_t kSliceSetItemSliceIndex = 1;
1911 constexpr size_t kSliceSetItemValueIndex = 2;
1912 abstract::CheckArgsSize("list_slice_set_item", args_abs_list, kSliceSetItemArgsSizeargs_size);
1913 this->sequence_ = abstract::CheckArg<AbstractList>("list_slice_set_item", args_abs_list, kSliceSetItemListIndex);
1914 this->slice_ = abstract::CheckArg<AbstractSlice>("list_slice_set_item", args_abs_list, kSliceSetItemSliceIndex);
1915 this->value_list_ = abstract::CheckArg<AbstractList>("list_slice_set_item", args_abs_list, kSliceSetItemValueIndex);
1916 }
1917
BuildFuncGraph(int64_t start_index,int64_t stop_index,int64_t step_value)1918 FuncGraphPtr ListSliceSetItem::BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) {
1919 // Init graph with the input list_node slice assign_node
1920 CheckAssignRange(start_index, stop_index, step_value);
1921 auto graph = std::make_shared<FuncGraph>();
1922 graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1923 auto list_node = graph->add_parameter();
1924 (void)graph->add_parameter();
1925 auto assign_parameter = graph->add_parameter();
1926 auto assign_node = GetAssignNode(graph, assign_parameter, step_value);
1927 std::vector<AnfNodePtr> elems = {NewValueNode(prim::kPrimMakeList)};
1928 int64_t list_index = 0;
1929 // check the index is in the slice range
1930 auto check_in_range = [start_index, stop_index, step_value](int64_t index) -> bool {
1931 if (step_value > 0) {
1932 return (index >= start_index && index < stop_index);
1933 }
1934 return (index <= start_index && index > stop_index);
1935 };
1936 int64_t list_size = SizeToLong(sequence_->size());
1937 int64_t assign_index = 0;
1938 int64_t value_size = SizeToLong(value_list_->size());
1939 while (list_index < list_size || assign_index < value_size) {
1940 if (!check_in_range(list_index)) {
1941 // list start <= stop && step = 1 insert the assign node to target node
1942 while (assign_index < value_size && list_index == start_index) {
1943 (void)elems.emplace_back(
1944 graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
1945 }
1946 if (list_index < list_size) {
1947 (void)elems.emplace_back(
1948 graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), list_node, NewValueNode(list_index++)}));
1949 }
1950 } else {
1951 if (((list_index - start_index) % step_value) == 0) {
1952 ++list_index;
1953 if (assign_index >= value_size) {
1954 continue;
1955 }
1956 (void)elems.emplace_back(
1957 graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
1958 } else {
1959 (void)elems.emplace_back(
1960 graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), list_node, NewValueNode(list_index++)}));
1961 }
1962 // the assign node's len is larger than the range
1963 while (!check_in_range(list_index) && assign_index < value_size) {
1964 (void)elems.emplace_back(
1965 graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
1966 }
1967 }
1968 }
1969
1970 graph->set_output(graph->NewCNodeInOrder(elems));
1971 return graph;
1972 }
1973
CheckAssignRange(int64_t start_index,int64_t stop_index,int64_t step_value)1974 void ListSliceSetItem::CheckAssignRange(int64_t start_index, int64_t stop_index, int64_t step_value) {
1975 if (step_value != kStepDefault) {
1976 auto range = stop_index - start_index;
1977 int include_start = (range % step_value) == 0 ? 0 : 1;
1978 auto assign_size = (range / step_value) + include_start;
1979 assign_size = assign_size > 0 ? assign_size : 0;
1980 if (assign_size != SizeToLong(value_list_->size())) {
1981 MS_EXCEPTION(ValueError) << "attempt to assign sequence of size " << value_list_->size()
1982 << " to extended slice of size " << assign_size;
1983 }
1984 }
1985 }
1986
GetAssignNode(const FuncGraphPtr & func_graph,const AnfNodePtr & assign_node,int64_t step_value)1987 AnfNodePtr ListSliceSetItem::GetAssignNode(const FuncGraphPtr &func_graph, const AnfNodePtr &assign_node,
1988 int64_t step_value) {
1989 if (step_value > 0) {
1990 return assign_node;
1991 }
1992 std::vector<AnfNodePtr> elems = {NewValueNode(prim::kPrimMakeList)};
1993 for (int64_t i = SizeToInt(value_list_->size()) - 1; i >= 0; --i) {
1994 (void)elems.emplace_back(
1995 func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), assign_node, NewValueNode(i)}));
1996 }
1997 return func_graph->NewCNodeInOrder(elems);
1998 }
1999
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2000 FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2001 this->CheckArgs(args_abs_list);
2002 auto [start, stop, step] = GenerateTupleSliceParameter(sequence_, slice_);
2003 return this->BuildFuncGraph(start, stop, step);
2004 }
2005
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2006 FuncGraphPtr ZerosLike::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2007 constexpr auto input_size = 1;
2008 abstract::CheckArgsSize("ZerosLike", args_abs_list, input_size);
2009
2010 auto x = args_abs_list[0];
2011 MS_EXCEPTION_IF_NULL(x);
2012 auto type = x->BuildType();
2013 MS_EXCEPTION_IF_NULL(type);
2014 if (type->type_id() == kTuple->type_id() || type->type_id() == kList->type_id()) {
2015 auto abs_seq = x->cast<AbstractSequencePtr>();
2016 MS_EXCEPTION_IF_NULL(abs_seq);
2017 if (abs_seq->dynamic_len()) {
2018 FuncGraphPtr res_graph = std::make_shared<FuncGraph>();
2019 res_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2020 res_graph->debug_info()->set_name("zeros_like");
2021 auto x_parameter = res_graph->add_parameter();
2022 res_graph->set_output(res_graph->NewCNodeInOrder({NewValueNode(prim::kPrimSequenceZerosLike), x_parameter}));
2023 return res_graph;
2024 }
2025 }
2026
2027 HyperMap hyper_map(false, fn_leaf_);
2028 TypePtrList types;
2029 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(types),
2030 [](const AbstractBasePtr &arg) -> TypePtr {
2031 MS_EXCEPTION_IF_NULL(arg);
2032 return arg->BuildType();
2033 });
2034 return hyper_map.GenerateFromTypes(types);
2035 }
2036
2037 // IterConvert is used when the input is need to convert to Iterable object.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2038 FuncGraphPtr IterConverter::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2039 constexpr auto input_size = 1;
2040 abstract::CheckArgsSize("IterConverter", args_abs_list, input_size);
2041 auto fg = std::make_shared<FuncGraph>();
2042 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2043 auto input_abs = args_abs_list[0];
2044 MS_EXCEPTION_IF_NULL(input_abs);
2045 if (input_abs->isa<abstract::AbstractAny>() || input_abs->BuildValue()->isa<parse::InterpretedObject>()) {
2046 const std::vector<std::string> funcs_str{"tuple"};
2047 auto ret_node = fallback::GeneratePyInterpretWithAbstract(fg, funcs_str, input_size);
2048 fg->set_output(ret_node);
2049 return fg;
2050 }
2051
2052 auto input_type = input_abs->BuildType();
2053 MS_EXCEPTION_IF_NULL(input_type);
2054 auto type_id = input_type->type_id();
2055 std::vector<int64_t> iterable_valid_types{
2056 TypeId::kObjectTypeString, TypeId::kObjectTypeTuple, TypeId::kObjectTypeList, TypeId::kObjectTypeDictionary,
2057 TypeId::kObjectTypeTensorType, TypeId::kObjectTypeFunction, TypeId::kMetaTypeExternal};
2058 bool iterable = std::any_of(iterable_valid_types.begin(), iterable_valid_types.end(),
2059 [type_id](int64_t valid_type) { return valid_type == type_id; });
2060 if (!iterable) {
2061 MS_EXCEPTION(TypeError) << "'" << TypeIdToString(type_id, true) << "' object is not iterable";
2062 }
2063
2064 auto input = fg->add_parameter();
2065 if (input_abs->isa<AbstractDictionary>()) {
2066 auto ret_node = fg->NewCNode({NewValueNode(prim::kPrimDictGetKeys), input});
2067 fg->set_output(ret_node);
2068 return fg;
2069 }
2070 fg->set_output(input);
2071 return fg;
2072 }
2073
2074 // HasNext is used to check whether the input has next element input.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2075 FuncGraphPtr HasNext::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2076 constexpr auto input_size = 1;
2077 abstract::CheckArgsSize("HasNext", args_abs_list, input_size);
2078 auto fg = std::make_shared<FuncGraph>();
2079 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2080 auto input_abs = args_abs_list[0];
2081 MS_EXCEPTION_IF_NULL(input_abs);
2082 auto input = fg->add_parameter();
2083 if (input_abs->isa<abstract::AbstractAny>() || input_abs->BuildValue()->isa<parse::InterpretedObject>()) {
2084 AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2085 AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2086 std::stringstream script_buffer;
2087 script_buffer << "__import__('mindspore').common._utils._jit_fallback_has_next_func(";
2088 const std::string data_str = "__data__";
2089 script_buffer << data_str << ")";
2090 (void)local_key_inputs.emplace_back(NewValueNode(data_str));
2091 (void)local_value_inputs.emplace_back(input);
2092 const auto &script = script_buffer.str();
2093 auto local_key_node = fg->NewCNode(local_key_inputs);
2094 auto local_value_node = fg->NewCNode(local_value_inputs);
2095 auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2096 auto ret = fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node);
2097 fg->set_output(ret);
2098 return fg;
2099 }
2100 const std::string module = "mindspore._extends.parse.standard_method";
2101 const std::string func_name = "ms_hasnext";
2102 py::function fn = python_adapter::GetPyFn(module, func_name);
2103 auto prim_func = parse::ParsePythonCode(fn);
2104 auto ret = fg->NewCNode({NewValueNode(prim_func), input});
2105 fg->set_output(ret);
2106 return fg;
2107 }
2108
2109 // HasNext is used to check whether the input has next element input.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2110 FuncGraphPtr Next::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2111 constexpr auto input_size = 1;
2112 abstract::CheckArgsSize("Next", args_abs_list, input_size);
2113 auto fg = std::make_shared<FuncGraph>();
2114 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2115 auto input_abs = args_abs_list[0];
2116 MS_EXCEPTION_IF_NULL(input_abs);
2117 auto input = fg->add_parameter();
2118 if (input_abs->isa<abstract::AbstractAny>() || input_abs->BuildValue()->isa<parse::InterpretedObject>()) {
2119 AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2120 AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2121 std::stringstream script_buffer;
2122 script_buffer << "__import__('mindspore').common._utils._jit_fallback_next_func(";
2123 const std::string data_str = "__data__";
2124 script_buffer << data_str << ")";
2125 (void)local_key_inputs.emplace_back(NewValueNode(data_str));
2126 (void)local_value_inputs.emplace_back(input);
2127 const auto &script = script_buffer.str();
2128 auto local_key_node = fg->NewCNode(local_key_inputs);
2129 auto local_value_node = fg->NewCNode(local_value_inputs);
2130 auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2131 auto ret = fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node);
2132 fg->set_output(ret);
2133 return fg;
2134 }
2135 const std::string module = "mindspore._extends.parse.standard_method";
2136 const std::string func_name = input_abs->isa<abstract::AbstractDictionary>() ? "dict_next" : "ms_next";
2137 py::function fn = python_adapter::GetPyFn(module, func_name);
2138 auto prim_func = parse::ParsePythonCode(fn);
2139 auto ret = fg->NewCNode({NewValueNode(prim_func), input});
2140 fg->set_output(ret);
2141 return fg;
2142 }
2143
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2144 FuncGraphPtr TupleFunc::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2145 if (args_abs_list.size() > 1) {
2146 MS_LOG(EXCEPTION) << "For 'TupleFunc', the number of input should be 0 or 1, but got " << args_abs_list.size();
2147 }
2148 auto fg = std::make_shared<FuncGraph>();
2149 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2150 if (args_abs_list.size() == 0) {
2151 auto ret = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple)});
2152 fg->set_output(ret);
2153 return fg;
2154 }
2155
2156 auto input_abs = args_abs_list[0];
2157 MS_EXCEPTION_IF_NULL(input_abs);
2158 auto input = fg->add_parameter();
2159 if (fallback::ContainsSequenceAnyType(input_abs)) {
2160 AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2161 AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2162 std::stringstream script_buffer;
2163 script_buffer << "tuple(";
2164 const std::string data_str = "__data__";
2165 script_buffer << data_str << ")";
2166 (void)local_key_inputs.emplace_back(NewValueNode(data_str));
2167 (void)local_value_inputs.emplace_back(input);
2168 const auto &script = script_buffer.str();
2169 auto local_key_node = fg->NewCNode(local_key_inputs);
2170 auto local_value_node = fg->NewCNode(local_value_inputs);
2171 auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2172 auto ret = fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node);
2173 fg->set_output(ret);
2174 return fg;
2175 } else if (input_abs->isa<abstract::AbstractTuple>()) {
2176 fg->set_output(input);
2177 return fg;
2178 } else if (input_abs->isa<abstract::AbstractList>()) {
2179 // list to tuple
2180 if (fallback::SequenceAllElementsIsScalar(input_abs)) {
2181 auto prim = std::make_shared<Primitive>("ListToTuple");
2182 auto list_to_tuple = fg->NewCNode({NewValueNode(prim), input});
2183 fg->set_output(list_to_tuple);
2184 return fg;
2185 }
2186 }
2187 const std::string module = "mindspore._extends.parse.standard_method";
2188 const std::string func_name = "tuple_func";
2189 py::function fn = python_adapter::GetPyFn(module, func_name);
2190 auto prim_func = parse::ParsePythonCode(fn);
2191 auto ret = fg->NewCNode({NewValueNode(prim_func), input});
2192 fg->set_output(ret);
2193 return fg;
2194 }
2195
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2196 FuncGraphPtr ListFunc::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2197 if (args_abs_list.size() > 1) {
2198 MS_LOG(EXCEPTION) << "For 'ListFunc', the number of input should be 0 or 1, but got " << args_abs_list.size();
2199 }
2200 auto fg = std::make_shared<FuncGraph>();
2201 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2202 if (args_abs_list.size() == 0) {
2203 auto ret = fg->NewCNode({NewValueNode(prim::kPrimMakeList)});
2204 fg->set_output(ret);
2205 return fg;
2206 }
2207
2208 auto input_abs = args_abs_list[0];
2209 MS_EXCEPTION_IF_NULL(input_abs);
2210 auto input = fg->add_parameter();
2211 if (fallback::ContainsSequenceAnyType(input_abs)) {
2212 AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2213 AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2214 std::stringstream script_buffer;
2215 script_buffer << "list(";
2216 const std::string data_str = "__data__";
2217 script_buffer << data_str << ")";
2218 (void)local_key_inputs.emplace_back(NewValueNode(data_str));
2219 (void)local_value_inputs.emplace_back(input);
2220 const auto &script = script_buffer.str();
2221 auto local_key_node = fg->NewCNode(local_key_inputs);
2222 auto local_value_node = fg->NewCNode(local_value_inputs);
2223 auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2224 auto ret = fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node);
2225 fg->set_output(ret);
2226 return fg;
2227 } else if (input_abs->isa<abstract::AbstractList>()) {
2228 fg->set_output(input);
2229 return fg;
2230 } else if (input_abs->isa<abstract::AbstractTuple>()) {
2231 // tuple to list
2232 if (fallback::SequenceAllElementsIsScalar(input_abs)) {
2233 auto prim = std::make_shared<Primitive>("TupleToList");
2234 auto tuple_to_list = fg->NewCNode({NewValueNode(prim), input});
2235 fg->set_output(tuple_to_list);
2236 return fg;
2237 }
2238 }
2239 const std::string module = "mindspore._extends.parse.standard_method";
2240 const std::string func_name = "list_func";
2241 py::function fn = python_adapter::GetPyFn(module, func_name);
2242 auto prim_func = parse::ParsePythonCode(fn);
2243 auto ret = fg->NewCNode({NewValueNode(prim_func), input});
2244 fg->set_output(ret);
2245 return fg;
2246 }
2247 } // namespace prim
2248 } // namespace mindspore
2249