• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "frontend/operator/composite/composite.h"
20 #include <algorithm>
21 #include <tuple>
22 #include <regex>
23 #include "ops/structure_ops.h"
24 #include "ops/sequence_ops.h"
25 #include "ops/framework_ops.h"
26 #include "ir/anf.h"
27 #include "ir/func_graph.h"
28 #include "abstract/abstract_value.h"
29 #include "abstract/abstract_function.h"
30 #include "abstract/dshape.h"
31 #include "abstract/param_validator.h"
32 #include "frontend/operator/cc_implementations.h"
33 #include "frontend/optimizer/opt.h"
34 #include "utils/symbolic.h"
35 #include "include/common/fallback.h"
36 #include "include/common/pybind_api/api_register.h"
37 #include "ir/signature.h"
38 #include "pipeline/jit/ps/fallback.h"
39 #include "pipeline/jit/ps/debug/trace.h"
40 #include "utils/interpret_node_recorder.h"
41 #include "utils/ms_context.h"
42 #include "include/common/utils/utils.h"
43 #include "pipeline/jit/ps/parse/resolve.h"
44 
45 namespace mindspore {
46 // namespace to support composite operators definition
47 namespace prim {
48 constexpr auto kStepDefault = 1;
49 
50 using mindspore::abstract::AbstractBase;
51 using mindspore::abstract::AbstractBasePtr;
52 using mindspore::abstract::AbstractClass;
53 using mindspore::abstract::AbstractDictionary;
54 using mindspore::abstract::AbstractDictionaryPtr;
55 using mindspore::abstract::AbstractElementPair;
56 using mindspore::abstract::AbstractEllipsis;
57 using mindspore::abstract::AbstractEllipsisPtr;
58 using mindspore::abstract::AbstractFunction;
59 using mindspore::abstract::AbstractFunctionPtr;
60 using mindspore::abstract::AbstractList;
61 using mindspore::abstract::AbstractListPtr;
62 using mindspore::abstract::AbstractNone;
63 using mindspore::abstract::AbstractScalar;
64 using mindspore::abstract::AbstractSequence;
65 using mindspore::abstract::AbstractSequencePtr;
66 using mindspore::abstract::AbstractSlice;
67 using mindspore::abstract::AbstractTensor;
68 using mindspore::abstract::AbstractTuple;
69 using mindspore::abstract::AbstractTuplePtr;
70 using mindspore::abstract::AbstractUndetermined;
71 using mindspore::abstract::EnvSetSparseResultMgr;
72 using mindspore::abstract::FuncGraphAbstractClosure;
73 using mindspore::abstract::PartialAbstractClosure;
74 
Init()75 void HyperMap::Init() {
76   if (fn_leaf_) {
77     name_ = "hyper_map[" + fn_leaf_->name() + "]";
78   }
79   signatures_ =
80     // def hypermap(func:read, *args:ref):
81     std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
82                             {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
83 }
84 
HyperMap(bool reverse,const std::shared_ptr<MultitypeFuncGraph> & fn_leaf)85 HyperMap::HyperMap(bool reverse, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
86     : MetaFuncGraph("hyper_map"),
87       fn_leaf_(fn_leaf),
88       reverse_(reverse),
89       nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeDictionary}) {
90   Init();
91 }
92 
HyperMap(const HyperMap & h)93 HyperMap::HyperMap(const HyperMap &h)
94     : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), reverse_(h.reverse_), nonleaf_(h.nonleaf_) {
95   Init();
96 }
97 
SetObjectForFnLeaf(const py::object & leaf_object)98 void HyperMap::SetObjectForFnLeaf(const py::object &leaf_object) {
99   if (fn_leaf_ != nullptr) {
100     fn_leaf_->set_meta_obj(leaf_object);
101   }
102 }
103 
FullMake(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const104 AnfNodePtr HyperMap::FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
105                               const ArgsPairList &arg_map) const {
106   MS_EXCEPTION_IF_NULL(func_graph);
107   std::vector<AnfNodePtr> inputs;
108   if (fn_arg != nullptr) {
109     inputs.push_back(fn_arg);
110   } else {
111     inputs.push_back(NewValueNode(fn_leaf_));
112   }
113 
114   (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
115                        [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
116   return func_graph->NewCNodeInOrder(inputs);
117 }
118 
GetHyperMapInputIndex(size_t num) const119 std::pair<std::string, std::string> HyperMap::GetHyperMapInputIndex(size_t num) const {
120   std::string error_index;
121   std::string next_index;
122   const size_t first_index = 1;
123   const size_t second_index = 2;
124   if (num == first_index) {
125     // The first element in HyperMap is func_graph
126     error_index = "first";
127     next_index = "second";
128   } else if (num == second_index) {
129     error_index = "second";
130     next_index = "third";
131   } else {
132     error_index = std::to_string(num) + "th";
133     next_index = std::to_string(num + 1) + "th";
134   }
135   return std::pair<std::string, std::string>(error_index, next_index);
136 }
137 
FullMake(const std::shared_ptr<List> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const138 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
139                               const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const {
140   MS_EXCEPTION_IF_NULL(func_graph);
141   MS_EXCEPTION_IF_NULL(type);
142 
143   size_t size = type->elements().size();
144   size_t num = 0;
145   std::ostringstream oss;
146   bool is_not_same = false;
147   for (auto &item : arg_map) {
148     num++;
149     auto lhs = std::static_pointer_cast<List>(item.second);
150     auto [error_index, next_index] = GetHyperMapInputIndex(num);
151     if (lhs == nullptr) {
152       MS_LOG(EXCEPTION) << "The " << error_index << " element in HyperMap has wrong type, expected a List, but got "
153                         << item.second->ToString() << ".";
154     }
155     if (lhs->elements().size() != size) {
156       oss << "\nThe length of the " << error_index << " element in HyperMap is " << size << ", but the length of the "
157           << next_index << " element in HyperMap is " << lhs->elements().size() << ".\n";
158       is_not_same = true;
159       break;
160     }
161   }
162   if (is_not_same) {
163     MS_LOG(EXCEPTION) << "The lists in HyperMap should have the same length. " << oss.str();
164   }
165 
166   // Cannot use shared_from_base() also known as this, as it will make a reference cycle on
167   // hypermap and graph generated, it will cause memory leak.
168   auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
169   constexpr size_t kPrimHoldLen = 1;
170   std::vector<AnfNodePtr> inputs;
171   inputs.reserve(size + kPrimHoldLen);
172   inputs.push_back(NewValueNode(prim::kPrimMakeList));
173 
174   for (size_t i = 0; i < size; i++) {
175     MS_LOG(DEBUG) << "FullMakeList for the " << i << "th element of the target, reverse_: " << reverse_;
176     std::vector<AnfNodePtr> inputs2;
177     inputs2.push_back(fn_rec);
178     if (fn_arg != nullptr) {
179       inputs2.push_back(fn_arg);
180     }
181     size_t pos = (reverse_ ? (size - 1 - i) : i);
182     (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
183                          [&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
184                            return func_graph->NewCNodeInOrder(
185                              {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
186                          });
187 
188     auto call_node = func_graph->NewCNodeInOrder(inputs2);
189     if (reverse_) {
190       (void)inputs.insert(inputs.cbegin() + 1, call_node);
191     } else {
192       inputs.emplace_back(call_node);
193     }
194   }
195   return func_graph->NewCNodeInOrder(inputs);
196 }
197 
FullMake(const std::shared_ptr<Tuple> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const198 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
199                               const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const {
200   MS_EXCEPTION_IF_NULL(func_graph);
201   MS_EXCEPTION_IF_NULL(type);
202 
203   size_t size = type->elements().size();
204   size_t num = 0;
205   std::ostringstream oss;
206   bool is_not_same = false;
207   for (auto &item : arg_map) {
208     num++;
209     auto lhs = std::static_pointer_cast<Tuple>(item.second);
210     auto [error_index, next_index] = GetHyperMapInputIndex(num);
211     if (lhs == nullptr) {
212       MS_LOG(EXCEPTION) << "The " << error_index << " element in HyperMap has wrong type, expected a Tuple, but got "
213                         << item.second->ToString() << ".";
214     }
215     if (lhs->elements().size() != size) {
216       oss << "\nThe length of the " << error_index << " element in HyperMap is " << size << ", but the length of the "
217           << next_index << " element in HyperMap is " << lhs->elements().size() << ".\n";
218       is_not_same = true;
219       break;
220     }
221   }
222   if (is_not_same) {
223     MS_LOG(EXCEPTION) << "The length of tuples in HyperMap must be the same. " << oss.str();
224   }
225 
226   // Cannot use shared_from_base() also known as this, as it will make a reference cycle on
227   // hypermap and graph generated, it will cause memory leak.
228   auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
229   constexpr size_t kPrimHoldLen = 1;
230   std::vector<AnfNodePtr> inputs;
231   inputs.reserve(size + kPrimHoldLen);
232   inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
233 
234   for (size_t i = 0; i < size; i++) {
235     MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th element of the target, reverse_: " << reverse_;
236     std::vector<AnfNodePtr> inputs2;
237     inputs2.push_back(fn_rec);
238     if (fn_arg != nullptr) {
239       inputs2.push_back(fn_arg);
240     }
241     size_t pos = (reverse_ ? (size - 1 - i) : i);
242     (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
243                          [&func_graph, &pos](std::pair<AnfNodePtr, Any> item) {
244                            return func_graph->NewCNodeInOrder(
245                              {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
246                          });
247 
248     auto call_node = func_graph->NewCNodeInOrder(inputs2);
249     if (reverse_) {
250       inputs.insert(inputs.begin() + 1, call_node);
251     } else {
252       inputs.emplace_back(call_node);
253     }
254   }
255 
256   if (inputs.size() > 1) {
257     return func_graph->NewCNodeInOrder(inputs);
258   }
259   // Empty tuple.
260   auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
261   auto empty_tuple = NewValueNode(empty_tuple_value);
262   return empty_tuple;
263 }
264 
FullMake(const std::shared_ptr<Dictionary> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const265 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Dictionary> &type, const FuncGraphPtr &func_graph,
266                               const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const {
267   MS_EXCEPTION_IF_NULL(func_graph);
268   MS_EXCEPTION_IF_NULL(type);
269 
270   size_t size = type->key_values().size();
271   size_t num = 0;
272   std::ostringstream oss;
273   bool is_not_same = false;
274   for (auto &item : arg_map) {
275     num++;
276     auto lhs = std::static_pointer_cast<Dictionary>(item.second);
277     auto [error_index, next_index] = GetHyperMapInputIndex(num);
278     if (lhs == nullptr) {
279       MS_LOG(EXCEPTION) << "The " << error_index
280                         << " element in HyperMap has wrong type, expected a Dictionary, but got "
281                         << item.second->ToString() << ".";
282     }
283     if (lhs->key_values().size() != size) {
284       oss << "\nThe length of the " << error_index << " element in HyperMap is " << size << ", but the length of the "
285           << next_index << " element in HyperMap is " << lhs->key_values().size() << ".\n";
286       is_not_same = true;
287       break;
288     }
289   }
290   if (is_not_same) {
291     MS_LOG(EXCEPTION) << "The length of dict in HyperMap must be the same. " << oss.str();
292   }
293 
294   // cannot use shared_from_base() also known as this, as it will make a reference cycle on
295   // hypermap and graph generated, it will cause memory leak.
296   auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
297   std::vector<AnfNodePtr> key_inputs{NewValueNode(prim::kPrimMakeTuple)};
298   std::vector<AnfNodePtr> value_inputs{NewValueNode(prim::kPrimMakeTuple)};
299 
300   for (size_t i = 0; i < size; i++) {
301     MS_LOG(DEBUG) << "FullMakeDict for the " << i << "th element of the target.";
302     auto key = type->key_values()[i].first;
303     (void)key_inputs.emplace_back(NewValueNode(key));
304     std::vector<AnfNodePtr> inputs;
305     (void)inputs.emplace_back(fn_rec);
306     if (fn_arg != nullptr) {
307       (void)inputs.emplace_back(fn_arg);
308     }
309     (void)std::transform(
310       arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
311       [&func_graph, &key](const std::pair<AnfNodePtr, TypePtr> &item) {
312         return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimDictGetItem), item.first, NewValueNode(key)});
313       });
314     auto call_node = func_graph->NewCNodeInOrder(inputs);
315     (void)value_inputs.emplace_back(call_node);
316   }
317   std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeDict), func_graph->NewCNodeInOrder(key_inputs),
318                                  func_graph->NewCNodeInOrder(value_inputs)};
319   return func_graph->NewCNodeInOrder(inputs);
320 }
321 
Make(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map) const322 AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const {
323   bool is_leaf = false;
324   TypeId id = kObjectTypeEnd;
325   std::pair<AnfNodePtr, TypePtr> pair;
326   for (auto &item : arg_map) {
327     pair = item;
328     id = item.second->type_id();
329     // The graph building reaches the leaf situation when there exists type that can not be divided any more.
330     if (nonleaf_.count(id) == 0) {
331       is_leaf = true;
332       break;
333     }
334   }
335 
336   if (!is_leaf) {
337     // In a nonleaf situation, all arguments must have the same generic.
338     bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
339       if (item.first != pair.first) {
340         return item.second->type_id() != pair.second->type_id();
341       }
342       return false;
343     });
344     if (is_not_same) {
345       std::ostringstream oss;
346       oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
347           << trace::GetDebugInfoStr(func_graph->debug_info()) << "\n";
348       int64_t idx = 0;
349       std::string str_index = "first";
350       const int64_t diff_index = 2;
351       for (auto &item : arg_map) {
352         // The first element in HyperMap is func_graph
353         if (idx == 0) {
354           str_index = "second";
355         } else if (idx == 1) {
356           str_index = "third";
357         } else {
358           str_index = std::to_string(idx + diff_index) + "th";
359         }
360         ++idx;
361         oss << "The type of the " << str_index << " argument in HyperMap is " << item.second->ToString() << ".\n";
362       }
363       MS_LOG(EXCEPTION) << "In a nonleaf situation, the types of arguments in HyperMap must be consistent, "
364                         << "but the types of arguments are inconsistent.\n"
365                         << oss.str();
366     }
367   }
368 
369   switch (id) {
370     case kObjectTypeList: {
371       auto type = std::static_pointer_cast<List>(pair.second);
372       return FullMake(type, func_graph, fn_arg, arg_map);
373     }
374     case kObjectTypeTuple: {
375       auto type = std::static_pointer_cast<Tuple>(pair.second);
376       return FullMake(type, func_graph, fn_arg, arg_map);
377     }
378     case kObjectTypeDictionary: {
379       auto type = std::static_pointer_cast<Dictionary>(pair.second);
380       return FullMake(type, func_graph, fn_arg, arg_map);
381     }
382     default:
383       return FullMake(func_graph, fn_arg, arg_map);
384   }
385 }
386 
GenerateFromTypes(const TypePtrList & args_abs_list)387 FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_abs_list) {
388   FuncGraphPtr res_fg = std::make_shared<FuncGraph>();
389   res_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
390   res_fg->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
391   res_fg->debug_info()->set_name("hyper_map");
392 
393   AnfNodePtr fn_param = nullptr;
394   std::size_t i = 0;
395   ArgsPairList argmap;
396   if (fn_leaf_ == nullptr) {
397     fn_param = res_fg->add_parameter();
398     i = 1;
399   }
400 
401   std::size_t size = args_abs_list.size();
402   for (; i < size; ++i) {
403     argmap.push_back(std::make_pair(res_fg->add_parameter(), args_abs_list[i]));
404   }
405 
406   res_fg->set_output(Make(res_fg, fn_param, argmap));
407   return res_fg;
408 }
409 
NormalizeArgs(const AbstractBasePtrList & args_abs_list) const410 abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_abs_list) const {
411   if (fn_leaf_ == nullptr) {
412     if (args_abs_list.empty()) {
413       MS_LOG(EXCEPTION) << "The size of arguments in list should not be empty. But the size of arguments is 0.";
414     }
415     MS_EXCEPTION_IF_NULL(args_abs_list[0]);
416     // Assert that hypermap's function param does not contain free variables
417     if (args_abs_list[0]->isa<FuncGraphAbstractClosure>()) {
418       auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_abs_list[0]);
419       auto func_graph = graph_func->func_graph();
420       if (func_graph->parent() != nullptr) {
421         MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
422       }
423     }
424   }
425 
426   AbstractBasePtrList broadened;
427   (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(broadened),
428                        [](const AbstractBasePtr &arg) -> AbstractBasePtr {
429                          MS_EXCEPTION_IF_NULL(arg);
430                          return arg->Broaden();
431                        });
432   return broadened;
433 }
434 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)435 FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
436   int64_t tuple_size = SizeToLong(args_abs_list.size());
437 
438   std::ostringstream ss;
439   // ▶make_tuple_
440   ss << "\u25B8make_tuple_" << tuple_size;
441   FuncGraphPtr fg = std::make_shared<FuncGraph>();
442   fg->debug_info()->set_name(ss.str());
443 
444   std::vector<AnfNodePtr> params;
445   params.push_back(NewValueNode(prim::kPrimMakeTuple));
446   for (int64_t i = 0; i < tuple_size; ++i) {
447     params.push_back(fg->add_parameter());
448   }
449 
450   // Make fprop first result, make_tuple's forward result.
451   AnfNodePtr out = fg->NewCNodeInOrder(params);
452 
453   // Make fprop second result, make_tuple's backward function.
454   FuncGraphPtr bprop = std::make_shared<FuncGraph>();
455 
456   ss.str(std::string());
457   ss.clear();
458   // ◀make_tuple_
459   ss << "\u25C2make_tuple_" << tuple_size;
460   bprop->debug_info()->set_name(ss.str());
461   AnfNodePtr dout = bprop->add_parameter();
462 
463   std::vector<AnfNodePtr> grads;
464   grads.push_back(NewValueNode(prim::kPrimMakeTuple));
465   grads.push_back(NewEnviron(bprop));
466   for (int64_t i = 0; i < tuple_size; ++i) {
467     grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
468   }
469 
470   bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
471   bprop->set_output(bprop->NewCNodeInOrder(grads));
472 
473   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
474   fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
475   (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
476   return fg;
477 }
478 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)479 FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
480   int64_t list_size = SizeToLong(args_abs_list.size());
481 
482   std::ostringstream ss;
483   // ▶make_list_
484   ss << "\u25B8make_list_" << list_size;
485   FuncGraphPtr fg = std::make_shared<FuncGraph>();
486   fg->debug_info()->set_name(ss.str());
487 
488   std::vector<AnfNodePtr> params;
489   params.push_back(NewValueNode(prim::kPrimMakeList));
490   for (int64_t i = 0; i < list_size; ++i) {
491     params.push_back(fg->add_parameter());
492   }
493 
494   // Make fprop first result, make_list's forward result.
495   AnfNodePtr out = fg->NewCNodeInOrder(params);
496 
497   // Make fprop second result, make_list's backward function.
498   FuncGraphPtr bprop = std::make_shared<FuncGraph>();
499 
500   ss.str(std::string());
501   ss.clear();
502   // ◀make_list_
503   ss << "\u25C2make_list_" << list_size;
504   bprop->debug_info()->set_name(ss.str());
505   AnfNodePtr dout = bprop->add_parameter();
506 
507   std::vector<AnfNodePtr> grads;
508   grads.push_back(NewValueNode(prim::kPrimMakeTuple));
509   grads.push_back(NewEnviron(bprop));
510   for (int64_t i = 0; i < list_size; ++i) {
511     grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
512   }
513 
514   bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
515   bprop->set_output(bprop->NewCNodeInOrder(grads));
516 
517   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
518   fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
519   (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
520   return fg;
521 }
522 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)523 FuncGraphPtr MakeDictGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
524   constexpr size_t input_size = 2;
525   CheckArgsSize("MakeDict", args_abs_list, input_size);
526   std::ostringstream ss;
527   // ▶make_dict_
528   ss << "\u25B8make_dict_" << input_size;
529   FuncGraphPtr fg = std::make_shared<FuncGraph>();
530   fg->debug_info()->set_name(ss.str());
531 
532   std::vector<AnfNodePtr> params{NewValueNode(prim::kPrimMakeDict)};
533   for (size_t i = 0; i < input_size; ++i) {
534     (void)params.emplace_back(fg->add_parameter());
535   }
536 
537   // Make fprop first result, make_dict's forward result.
538   AnfNodePtr out = fg->NewCNodeInOrder(params);
539 
540   // Make fprop second result, make_dict's backward function.
541   FuncGraphPtr bprop = std::make_shared<FuncGraph>();
542 
543   ss.str(std::string());
544   ss.clear();
545   // ◀make_dict_
546   ss << "\u25C2make_dict_" << input_size;
547   bprop->debug_info()->set_name(ss.str());
548   AnfNodePtr dout = bprop->add_parameter();
549 
550   std::vector<AnfNodePtr> grads{NewValueNode(prim::kPrimMakeTuple)};
551   (void)grads.emplace_back(NewEnviron(bprop));
552 
553   auto abs0_tuple = dyn_cast_ptr<AbstractTuple>(args_abs_list[0]);
554   if (abs0_tuple == nullptr) {
555     MS_LOG(INTERNAL_EXCEPTION) << "The first input of make_dict should be a tuple, but got abstract: "
556                                << args_abs_list[0]->ToString();
557   }
558   // Add gradients of keys tuple and values tuple.
559   std::vector<AnfNodePtr> keys_grads_inputs{NewValueNode(kPrimMakeTuple)};
560   std::vector<AnfNodePtr> values_grads_inputs{NewValueNode(kPrimMakeTuple)};
561   for (size_t i = 0; i < abs0_tuple->size(); ++i) {
562     auto key_item =
563       bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[1], NewValueNode(SizeToLong(i))});
564     (void)keys_grads_inputs.emplace_back(key_item);
565     (void)values_grads_inputs.emplace_back(
566       bprop->NewCNodeInOrder({NewValueNode(prim::kPrimDictGetItem), dout, key_item}));
567   }
568   (void)grads.emplace_back(bprop->NewCNodeInOrder(keys_grads_inputs));
569   (void)grads.emplace_back(bprop->NewCNodeInOrder(values_grads_inputs));
570 
571   bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
572   bprop->set_output(bprop->NewCNodeInOrder(grads));
573 
574   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
575   fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
576   (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeDict));
577   return fg;
578 }
579 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)580 FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
581   int64_t args_size = SizeToLong(args_abs_list.size());
582   constexpr auto py_execute_grad_input_count = 3;
583   if (args_size < py_execute_grad_input_count) {
584     MS_LOG(INTERNAL_EXCEPTION) << "The inputs size of PyExecuteGradient should not less than "
585                                << py_execute_grad_input_count;
586   }
587 
588   std::ostringstream ss;
589   // ▶PyExecute
590   ss << "\u25B8PyExecute_" << args_size;
591   FuncGraphPtr fg = std::make_shared<FuncGraph>();
592   fg->debug_info()->set_name(ss.str());
593 
594   std::vector<AnfNodePtr> params;
595   (void)params.emplace_back(NewValueNode(prim::kPrimPyExecute));
596   for (int64_t i = 0; i < args_size; ++i) {
597     (void)params.emplace_back(fg->add_parameter());
598   }
599 
600   // Make fprop first result, PyExecute's forward result.
601   AnfNodePtr out = fg->NewCNodeInOrder(params);
602   InterpretNodeRecorder::GetInstance().PushPyExecuteNode(out);
603 
604   // Make fprop second result, PyExecute's backward function.
605   FuncGraphPtr bprop = std::make_shared<FuncGraph>();
606 
607   ss.str(std::string());
608   ss.clear();
609   // ◀PyExecute
610   ss << "\u25C2PyExecute_" << args_size;
611   bprop->debug_info()->set_name(ss.str());
612   (void)bprop->add_parameter();
613 
614   std::vector<AnfNodePtr> grads;
615   (void)grads.emplace_back(NewValueNode(prim::kPrimMakeTuple));
616   (void)grads.emplace_back(NewEnviron(bprop));
617   // Propagate for script string.
618   (void)grads.emplace_back(params[1]);
619   // Propagate for local dict keys.
620   const auto &local_key_args = dyn_cast<abstract::AbstractTuple>(args_abs_list[1]);
621   MS_EXCEPTION_IF_NULL(local_key_args);
622   std::vector<AnfNodePtr> keys;
623   (void)keys.emplace_back(NewValueNode(prim::kPrimMakeTuple));
624   for (size_t i = 0; i < local_key_args->size(); ++i) {
625     constexpr auto keys_num = 2;
626     const auto &key_item =
627       bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[keys_num], NewValueNode(SizeToLong(i))});
628     const auto &element = local_key_args->elements()[i];
629     const auto &str_element = dyn_cast<abstract::AbstractScalar>(element);
630     if (str_element != nullptr && str_element->BuildType()->isa<String>()) {
631       (void)keys.emplace_back(key_item);
632     } else {
633       (void)keys.emplace_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), key_item}));
634     }
635   }
636   (void)grads.emplace_back(bprop->NewCNodeInOrder(keys));
637   // Propagate for local dict values.
638   constexpr auto values_arg_num = 2;
639   const auto &local_value_args = dyn_cast<abstract::AbstractTuple>(args_abs_list[values_arg_num]);
640   MS_EXCEPTION_IF_NULL(local_value_args);
641   std::vector<AnfNodePtr> values;
642   (void)values.emplace_back(NewValueNode(prim::kPrimMakeTuple));
643   for (size_t i = 0; i < local_value_args->size(); ++i) {
644     constexpr auto values_num = 3;
645     const auto &value_item =
646       bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[values_num], NewValueNode(SizeToLong(i))});
647     const auto &element = local_value_args->elements()[i];
648     const auto &str_element = dyn_cast<abstract::AbstractScalar>(element);
649     if (str_element != nullptr && str_element->BuildType()->isa<String>()) {
650       (void)values.emplace_back(value_item);
651     } else {
652       (void)values.emplace_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), value_item}));
653     }
654   }
655   (void)grads.emplace_back(bprop->NewCNodeInOrder(values));
656 
657   // Add gradients for extra monad.
658   for (size_t i = py_execute_grad_input_count; i < args_abs_list.size(); ++i) {
659     if (args_abs_list[i]->isa<abstract::AbstractUMonad>()) {
660       (void)grads.emplace_back(NewValueNode(kUMonad));
661     } else if (args_abs_list[i]->isa<abstract::AbstractIOMonad>()) {
662       (void)grads.emplace_back(NewValueNode(kIOMonad));
663     } else {
664       (void)grads.emplace_back(NewValueNode(kValueAny));
665     }
666   }
667 
668   bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
669   bprop->set_output(bprop->NewCNodeInOrder(grads));
670 
671   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
672   fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
673   (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimPyExecute));
674   return fg;
675 }
676 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)677 FuncGraphPtr MutableGradient::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
678   constexpr size_t min_input_size = 1;
679   constexpr size_t max_input_size = 2;
680   auto input_size = args_abs_list.size();
681   if (input_size != min_input_size && input_size != max_input_size) {
682     MS_LOG(EXCEPTION) << "The number of input to mutable must be " << min_input_size << " or " << max_input_size
683                       << ", but got: " << input_size;
684   }
685   std::ostringstream ss;
686   // ▶mutable_
687   ss << "\u25B8mutable_" << input_size;
688   FuncGraphPtr fg = std::make_shared<FuncGraph>();
689   fg->debug_info()->set_name(ss.str());
690 
691   std::vector<AnfNodePtr> params;
692   params.push_back(NewValueNode(prim::kPrimMutable));
693   for (size_t i = 0; i < input_size; ++i) {
694     params.push_back(fg->add_parameter());
695   }
696 
697   // Make fprop first result, mutable's forward result.
698   AnfNodePtr out = fg->NewCNodeInOrder(params);
699 
700   // Make fprop second result, mutable's backward function.
701   FuncGraphPtr bprop = std::make_shared<FuncGraph>();
702 
703   ss.str(std::string());
704   ss.clear();
705   // ◀mutable_
706   ss << "\u25C2mutable_" << input_size;
707   bprop->debug_info()->set_name(ss.str());
708   AnfNodePtr dout = bprop->add_parameter();
709 
710   std::vector<AnfNodePtr> grads;
711   grads.push_back(NewValueNode(prim::kPrimMakeTuple));
712   grads.push_back(NewEnviron(bprop));
713   grads.push_back(dout);
714   if (input_size == max_input_size) {
715     grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), params[2]}));
716   }
717 
718   bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
719   bprop->set_output(bprop->NewCNodeInOrder(grads));
720 
721   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
722   fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
723   (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMutable));
724   return fg;
725 }
726 
727 namespace {
IsTupleAllTensor(const AbstractTuplePtr & tuple_arg)728 bool IsTupleAllTensor(const AbstractTuplePtr &tuple_arg) {
729   MS_EXCEPTION_IF_NULL(tuple_arg);
730   for (size_t i = 0; i < tuple_arg->size(); ++i) {
731     if (!(*tuple_arg)[i]->isa<AbstractUndetermined>() &&
732         !((*tuple_arg)[i]->isa<AbstractTuple>() && IsTupleAllTensor((*tuple_arg)[i]->cast<AbstractTuplePtr>()))) {
733       return false;
734     }
735   }
736   return true;
737 }
738 
EnableGradFirstForTuple(const AbstractTuplePtr & tuple_arg,bool enable_tuple_grad)739 bool EnableGradFirstForTuple(const AbstractTuplePtr &tuple_arg, bool enable_tuple_grad) {
740   return tuple_arg->size() > 1 && (*tuple_arg)[1]->isa<AbstractTuple>() && enable_tuple_grad &&
741          IsTupleAllTensor((*tuple_arg)[1]->cast<AbstractTuplePtr>());
742 }
743 
EnableGradForScalar(const AbstractBasePtr & abs)744 bool EnableGradForScalar(const AbstractBasePtr &abs) {
745   return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
746          abs->BuildType()->isa<Number>();
747 }
748 
CanGradArgument(const AbstractTuplePtr & tuple_arg,size_t pos)749 bool CanGradArgument(const AbstractTuplePtr &tuple_arg, size_t pos) {
750   MS_EXCEPTION_IF_NULL(tuple_arg);
751   return tuple_arg->size() > pos && (*tuple_arg)[pos] != nullptr &&
752          ((*tuple_arg)[pos]->BuildValue()->ContainsValueAny() || EnableGradForScalar((*tuple_arg)[pos]));
753 }
754 
GenerateFuncGraphByPosition(const FuncGraphPtr & fg,const AbstractTuplePtr & tuple_arg,const AbstractTuplePtr & pos,bool return_ids=false)755 void GenerateFuncGraphByPosition(const FuncGraphPtr &fg, const AbstractTuplePtr &tuple_arg, const AbstractTuplePtr &pos,
756                                  bool return_ids = false) {
757   if (pos == nullptr) {
758     MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!";
759   }
760   if (pos->empty()) {
761     MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position.";
762   }
763   AnfNodePtr tuple_parameter = fg->add_parameter();
764   (void)fg->add_parameter();  // The 'grad_position' parameter.
765   // Collect all parameters by 'grad_position'.
766   std::vector<AnfNodePtr> pos_elements = {NewValueNode(prim::kPrimMakeTuple)};
767   CNodePtr current_element = nullptr;
768   for (size_t i = 0; i < pos->size(); ++i) {
769     auto val = pos->elements()[i]->BuildValue();
770     MS_EXCEPTION_IF_NULL(val);
771     auto int_val = LongToSize(dyn_cast<Int64Imm>(val)->value());
772     ++int_val;  // Ignore the env position.
773     if (int_val >= tuple_arg->size()) {
774       MS_EXCEPTION(IndexError) << "Position index " << (int_val - 1) << " is exceed input size.";
775     }
776     if (!CanGradArgument(tuple_arg, int_val)) {
777       continue;
778     }
779     current_element =
780       fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(int_val))});
781     if (return_ids) {
782       current_element =
783         fg->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), NewValueNode(SizeToLong(int_val) - 1), current_element});
784     }
785     pos_elements.push_back(current_element);
786   }
787 
788   // The returned result may vary for grad result element number.
789   // A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
790   //
791   // Notice that even if the user set 'grad_position' as multiple choices,
792   // the 'CanGradArgument' may change it to only one choice or none choice.
793   constexpr size_t args_least_size = 2;
794   if (pos_elements.size() == args_least_size) {
795     fg->set_output(current_element);
796   } else if (pos_elements.size() > args_least_size) {
797     fg->set_output(fg->NewCNodeInOrder(pos_elements));
798   } else {  // The 'pos' is empty AbstractTuple.
799     auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
800     auto empty_tuple = NewValueNode(empty_tuple_value);
801     fg->set_output(empty_tuple);
802   }
803 }
804 }  // namespace
805 
GenerateTailFuncGraph(const AbstractSequencePtr & sequence_arg) const806 FuncGraphPtr Tail::GenerateTailFuncGraph(const AbstractSequencePtr &sequence_arg) const {
807   MS_EXCEPTION_IF_NULL(sequence_arg);
808   FuncGraphPtr fg = std::make_shared<FuncGraph>();
809   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
810   fg->debug_info()->set_name("tail");
811 
812   AnfNodePtr tuple_parameter = fg->add_parameter();
813   std::vector<AnfNodePtr> elements;
814   PrimitivePtr op = nullptr;
815   if (sequence_arg->isa<AbstractTuple>()) {
816     (void)elements.emplace_back(NewValueNode(prim::kPrimMakeTuple));
817     op = prim::kPrimTupleGetItem;
818   } else {
819     (void)elements.emplace_back(NewValueNode(prim::kPrimMakeList));
820     op = prim::kPrimListGetItem;
821   }
822 
823   // Remove the first element to make a new sequence.
824   for (size_t i = 1; i < sequence_arg->size(); ++i) {
825     elements.push_back(fg->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
826   }
827   if (elements.size() > 1) {
828     fg->set_output(fg->NewCNodeInOrder(elements));
829     return fg;
830   }
831 
832   // No element left, return empty tuple.
833   if (sequence_arg->isa<AbstractTuple>()) {
834     auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
835     auto empty_tuple = NewValueNode(empty_tuple_value);
836     fg->set_output(empty_tuple);
837   }
838   // No element left, return empty list.
839   auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
840   auto empty_tuple = NewValueNode(empty_tuple_value);
841   fg->set_output(empty_tuple);
842   return fg;
843 }
844 
GenerateGradFuncGraph(const AbstractTuplePtr & tuple_arg,const AbstractTuplePtr & position) const845 FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, const AbstractTuplePtr &position) const {
846   MS_EXCEPTION_IF_NULL(tuple_arg);
847   FuncGraphPtr fg = std::make_shared<FuncGraph>();
848   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
849   fg->debug_info()->set_name("grad_tail");
850 
851   if (tail_type_ == kGradFirst) {
852     AnfNodePtr tuple_parameter = fg->add_parameter();
853     if (CanGradArgument(tuple_arg, 1) || EnableGradFirstForTuple(tuple_arg, enable_tuple_grad_first_)) {
854       fg->set_output(
855         fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(1))}));
856     } else {
857       fg->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList())));
858     }
859     return fg;
860   }
861 
862   if (tail_type_ == kGradByPosition) {
863     GenerateFuncGraphByPosition(fg, tuple_arg, position, return_ids_);
864     return fg;
865   }
866 
867   if (tail_type_ == kGradAll) {
868     AnfNodePtr tuple_parameter = fg->add_parameter();
869     std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
870     for (size_t i = 1; i < tuple_arg->size(); ++i) {
871       MS_EXCEPTION_IF_NULL((*tuple_arg)[i]);
872       if (CanGradArgument(tuple_arg, i)) {
873         elements.push_back(
874           fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))}));
875       }
876     }
877 
878     // We should deal with 'get_all=True' as other options later:
879     // "The returned result may vary for grad result element number.
880     // A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
881     //
882     // Notice that even if the user set 'get_all=True' and pass multiple inputs,
883     // the 'CanGradArgument' may change it to only one gradient output or no gradient."
884     constexpr size_t args_least_size = 2;
885     if (elements.size() >= args_least_size) {
886       fg->set_output(fg->NewCNodeInOrder(elements));
887       return fg;
888     }
889     // Empty tuple.
890     auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
891     auto empty_tuple = NewValueNode(empty_tuple_value);
892     fg->set_output(empty_tuple);
893     return fg;
894   }
895   MS_LOG(INTERNAL_EXCEPTION) << "'tail_type_' is not for GradOperation, but " << tail_type_;
896 }
897 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)898 FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
899   // To handle normal tail.
900   if (args_abs_list.size() < 1) {
901     MS_LOG(EXCEPTION) << "'Tail' requires at least 1 argument, but got " << args_abs_list.size();
902   }
903   if (tail_type_ >= kNotGrad) {
904     AbstractSequencePtr sequence_arg = dyn_cast<AbstractSequence>(args_abs_list[0]);
905     if (sequence_arg == nullptr) {
906       MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple or list, but got " << args_abs_list[0]->ToString();
907     }
908     return GenerateTailFuncGraph(sequence_arg);
909   }
910 
911   // To handle for GradOperation tail.
912   constexpr size_t args_max_size = 2;
913   if (args_abs_list.size() > args_max_size) {
914     MS_LOG(EXCEPTION) << "'Tail' requires at most 2 arguments for GradOperation, but got " << args_abs_list.size();
915   }
916   AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_abs_list[0]);
917   if (tuple_arg == nullptr) {
918     MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple, but got " << args_abs_list[0]->ToString();
919   }
920   if (args_abs_list.size() == args_max_size) {
921     AbstractTuplePtr pos = dyn_cast<AbstractTuple>(args_abs_list[1]);
922     if (pos == nullptr) {
923       MS_LOG(EXCEPTION) << "'Tail' arg1 'position' must be tuple, but got " << args_abs_list[1]->ToString();
924     }
925     return GenerateGradFuncGraph(tuple_arg, pos);
926   }
927   return GenerateGradFuncGraph(tuple_arg);
928 }
929 namespace {
CreateGradOutputs(const FuncGraphPtr & k_child,const AnfNodePtr & gradient,const AnfNodePtr & f_app,bool has_aux,bool get_value)930 AnfNodePtr CreateGradOutputs(const FuncGraphPtr &k_child, const AnfNodePtr &gradient, const AnfNodePtr &f_app,
931                              bool has_aux, bool get_value) {
932   if (get_value) {
933     return k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), f_app, gradient});
934   }
935   if (!has_aux) {
936     return gradient;
937   }
938   PrimitivePtr get_tuple_item_op = prim::kPrimTupleGetItem;
939   PrimitivePtr make_tuple_op = prim::kPrimMakeTuple;
940   std::vector<AnfNodePtr> elements = {NewValueNode(make_tuple_op)};
941   (void)elements.emplace_back(
942     k_child->NewCNodeInOrder({NewValueNode(get_tuple_item_op), f_app, NewValueNode(static_cast<int64_t>(1))}));
943   auto aux_output = k_child->NewCNodeInOrder(elements);
944   auto unpack_node =
945     k_child->NewCNodeInOrder({NewValueNode(get_tuple_item_op), aux_output, NewValueNode(static_cast<int64_t>(0))});
946   return k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), gradient, unpack_node});
947 }
948 }  // namespace
949 
950 // When set aux True, for out1, out2, out3 = fn(inputs), only first out1 contributes to differentiation of fn.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)951 FuncGraphPtr GradAux::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
952   AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_abs_list[0]);
953   if (tuple_arg == nullptr) {
954     MS_LOG(EXCEPTION) << "When has_aux is True, origin fn requires more than one outputs.\n"
955                       << "'GradAux' arg0 must be tuple, but got " << args_abs_list[0]->ToString();
956   }
957   FuncGraphPtr fg = std::make_shared<FuncGraph>();
958   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
959   AnfNodePtr tuple_parameter = fg->add_parameter();
960   // get_value flag
961   (void)fg->add_parameter();
962 
963   AbstractScalarPtr get_value_ptr = dyn_cast<AbstractScalar>(args_abs_list[1]);
964   bool get_value_flag = GetValue<bool>(get_value_ptr->BuildValue());
965   std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
966   elements.push_back(
967     fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(0))}));
968   if (get_value_flag) {
969     for (size_t i = 1; i < tuple_arg->size(); i++) {
970       auto aux_node =
971         fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))});
972       auto stop_gradient_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimStopGradient), aux_node});
973       elements.push_back(stop_gradient_node);
974     }
975   } else {
976     std::vector<AnfNodePtr> aux_elements = {NewValueNode(prim::kPrimMakeTuple)};
977     for (size_t i = 1; i < tuple_arg->size(); i++) {
978       auto aux_node =
979         fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))});
980       auto stop_gradient_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimStopGradient), aux_node});
981       aux_elements.push_back(stop_gradient_node);
982     }
983     elements.push_back(fg->NewCNodeInOrder(aux_elements));
984   }
985 
986   constexpr size_t args_least_size = 2;
987   if (elements.size() < args_least_size) {
988     MS_LOG(EXCEPTION) << "When has_aux is True, origin fn requires more than one outputs, but got " << elements.size()
989                       << " outputs.\n"
990                       << trace::GetDebugInfoStr(fg->debug_info());
991   }
992   fg->set_output(fg->NewCNodeInOrder(elements));
993   return fg;
994 }
995 
GradOperation(const std::string & name,bool get_all,bool get_by_list,bool sens_param,bool get_by_position,bool has_aux,bool get_value,bool return_ids,bool merge_forward)996 GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param,
997                              bool get_by_position, bool has_aux, bool get_value, bool return_ids, bool merge_forward)
998     : MetaFuncGraph(name),
999       get_all_(get_all),
1000       get_by_list_(get_by_list),
1001       sens_param_(sens_param),
1002       get_by_position_(get_by_position),
1003       has_aux_(has_aux),
1004       get_value_(get_value),
1005       return_ids_(return_ids),
1006       merge_forward_(merge_forward) {
1007   if (get_by_position) {
1008     signatures_ =
1009       // def grad(func:read, weight_list:ref, position_list:ref):
1010       std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
1011                               {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault},
1012                               {"position_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
1013   } else if (get_by_list) {
1014     signatures_ =
1015       // def grad(func:read, weight_list:ref):
1016       std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
1017                               {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
1018   }
1019 }
1020 
GetGrad(const AnfNodePtr & j,const AnfNodePtr & weights,const AnfNodePtr & position,const std::vector<AnfNodePtr> & forward_graph_params,bool enable_tuple_grad,bool is_weights_none) const1021 FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
1022                                     const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
1023                                     bool is_weights_none) const {
1024   FuncGraphPtr k_child = std::make_shared<FuncGraph>();
1025   k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1026   k_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true);
1027 
1028   AnfNodePtr position_node = nullptr;
1029   if (position != nullptr) {
1030     position_node = position;
1031   }
1032 
1033   std::vector<AnfNodePtr> inputs;
1034   inputs.push_back(j);
1035   for (size_t i = 0; i < forward_graph_params.size(); ++i) {
1036     inputs.push_back(k_child->add_parameter());
1037   }
1038   auto k_app = k_child->NewCNodeInOrder(inputs);
1039 
1040   auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem);
1041   auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
1042   auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
1043 
1044   GradByParameter(k_child, f_app, bprop, weights, position_node, enable_tuple_grad, is_weights_none);
1045   return k_child;
1046 }
1047 
SetNodeByParameter(const CNodePtr & grad,const FuncGraphPtr & fg) const1048 CNodePtr GradOperation::SetNodeByParameter(const CNodePtr &grad, const FuncGraphPtr &fg) const {
1049   CNodePtr fv_bprop;
1050   if (!weight_value_->isa<AbstractTuple>()) {
1051     auto weight_ref = dyn_cast<abstract::AbstractRefTensor>(weight_value_);
1052     if (weight_ref != nullptr) {
1053       auto weight_key = weight_ref->ref_key_value()->cast<RefKeyPtr>();
1054       auto param_name = weight_key->value();
1055       fv_bprop = fg->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), NewValueNode(param_name), grad});
1056     } else {
1057       MS_LOG(INTERNAL_EXCEPTION) << "Abstract of parameter should be AbstractRefTensor, but got "
1058                                  << weight_value_->ToString();
1059     }
1060   } else {
1061     std::vector<AnfNodePtr> params;
1062     AbstractTuplePtr weight_tuple = weight_value_->cast<AbstractTuplePtr>();
1063     const AbstractBasePtrList &elements = weight_tuple->elements();
1064     params.push_back(NewValueNode(prim::kPrimMakeTuple));
1065     for (size_t i = 0; i < weight_tuple->size(); i++) {
1066       auto weight_ref = dyn_cast<abstract::AbstractRefTensor>(elements[i]);
1067       if (weight_ref != nullptr) {
1068         auto weight_key = weight_ref->ref_key_value()->cast<RefKeyPtr>();
1069         auto param_name = weight_key->value();
1070         auto grad_value =
1071           fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), grad, NewValueNode(static_cast<int64_t>(i))});
1072         fv_bprop = fg->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), NewValueNode(param_name), grad_value});
1073         params.push_back(fv_bprop);
1074       } else {
1075         MS_LOG(INTERNAL_EXCEPTION) << "Abstract of parameter should be AbstractRefTensor, but got "
1076                                    << weight_value_->ToString();
1077       }
1078     }
1079     fv_bprop = fg->NewCNodeInOrder(params);
1080   }
1081   return fv_bprop;
1082 }
1083 
1084 // Do grad by the parameter of GradOperation.
GradByParameter(const FuncGraphPtr & k_child,const AnfNodePtr & f_app,const AnfNodePtr & bprop,const AnfNodePtr & weights,const AnfNodePtr & position,bool enable_tuple_grad,bool is_weights_none) const1085 void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
1086                                     const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad,
1087                                     bool is_weights_none) const {
1088   MS_EXCEPTION_IF_NULL(k_child);
1089 
1090   AnfNodePtr bprop_arg = nullptr;
1091   if (sens_param_) {
1092     bprop_arg = k_child->add_parameter();
1093   } else {
1094     auto ones_like = prim::GetPythonOps("ones_like");
1095     bprop_arg = k_child->NewCNodeInOrder({NewValueNode(ones_like), f_app});
1096   }
1097   AnfNodePtr b_app = k_child->NewCNodeInOrder({bprop, bprop_arg});
1098   // Add sense parameter flag for bound_node_.
1099   if (b_app->isa<CNode>() && sens_param_) {
1100     b_app->cast<CNodePtr>()->AddAttr("sens_param_", MakeValue(true));
1101   }
1102 
1103   CNodePtr fv_bprop = nullptr;
1104   if (get_by_list_) {
1105     if (is_weights_none) {
1106       fv_bprop = k_child->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple)});
1107     } else {
1108       // Python code: grads = hyper_map(F.partial(env_get, env), weights)
1109       AnfNodePtr env =
1110         k_child->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))});
1111       AnfNodePtr partial_env_get =
1112         k_child->NewCNodeInOrder({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
1113       MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
1114       fv_bprop = k_child->NewCNodeInOrder({NewValueNode(hyper_map), partial_env_get, weights});
1115       if (return_ids_) {
1116         fv_bprop = SetNodeByParameter(fv_bprop, k_child);
1117       }
1118     }
1119   }
1120 
1121   CNodePtr inputs_bprop = nullptr;
1122   if (get_by_position_) {
1123     TailPtr tail_grad_by_position = std::make_shared<Tail>("tail_grad_by_position", kGradByPosition, return_ids_);
1124     inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_by_position), b_app, position});
1125   } else if (get_all_) {
1126     TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll);
1127     inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app});
1128   }
1129 
1130   // Gradients wrt inputs and parameters
1131   if (fv_bprop != nullptr && inputs_bprop != nullptr) {
1132     auto make_tuple = k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop});
1133     k_child->set_output(CreateGradOutputs(k_child, make_tuple, f_app, has_aux_, get_value_));
1134     return;
1135   }
1136 
1137   // Gradients wrt parameters
1138   if (fv_bprop != nullptr) {
1139     k_child->set_output(CreateGradOutputs(k_child, fv_bprop, f_app, has_aux_, get_value_));
1140     return;
1141   }
1142 
1143   // Gradients wrt inputs
1144   if (inputs_bprop != nullptr) {
1145     k_child->set_output(CreateGradOutputs(k_child, inputs_bprop, f_app, has_aux_, get_value_));
1146     return;
1147   }
1148   // Gradients wrt first input.
1149   // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
1150   // so obtain first input grad by setting tail_type of Tail to kGradFirst.
1151   TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
1152   tail_grad_first->set_enable_tuple_grad_first(enable_tuple_grad);
1153   auto tail_grad_first_cnode = k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app});
1154   k_child->set_output(CreateGradOutputs(k_child, tail_grad_first_cnode, f_app, has_aux_, get_value_));
1155 }
1156 
1157 namespace {
1158 // Check if primal func graph has the primitive returned sparse result in its bprop().
CheckPrimBpropReturnSparse(const FuncGraphPtr & primal_graph)1159 void CheckPrimBpropReturnSparse(const FuncGraphPtr &primal_graph) {
1160   bool has_sparse_bprop_prim = false;
1161   (void)TopoSort(primal_graph->return_node(), SuccDeeperSimple,
1162                  [&has_sparse_bprop_prim](const AnfNodePtr &node) -> IncludeType {
1163                    MS_EXCEPTION_IF_NULL(node);
1164                    if (has_sparse_bprop_prim) {
1165                      return EXCLUDE;
1166                    }
1167                    PrimitivePtr prim = nullptr;
1168                    if (node->isa<CNode>()) {
1169                      prim = GetCNodePrimitiveWithoutDoSignature(node);
1170                    } else {
1171                      prim = GetPrimitiveWithoutDoSignature(node);
1172                    }
1173                    if (prim != nullptr) {
1174                      bool sparse_bprop = GetPrimitiveFlag(prim, GRAPH_FLAG_BPROP_RETURN_SPARSE);
1175                      if (sparse_bprop) {
1176                        MS_LOG(DEBUG) << "prim: " << prim->ToString() << " has attr 'bprop_return_sparse'";
1177                        has_sparse_bprop_prim = true;
1178                        return EXCLUDE;
1179                      }
1180                    }
1181                    return FOLLOW;
1182                  });
1183   if (has_sparse_bprop_prim) {
1184     primal_graph->set_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP, true);
1185     EnvSetSparseResultMgr::GetInstance().Set(true);
1186   }
1187 }
1188 }  // namespace
1189 
1190 // Generate the graph.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1191 FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1192   if (args_abs_list.empty()) {
1193     MS_LOG(EXCEPTION)
1194       << "'GradOperation' requires a forward network or function as an input, while the input is empty.";
1195   }
1196 
1197   constexpr size_t fn_index = 0;
1198   auto fn_abs = args_abs_list[fn_index];
1199   constexpr size_t len_with_weight = 2;
1200   constexpr size_t weights_index = 1;
1201   if (return_ids_ && args_abs_list.size() >= len_with_weight) {
1202     weight_value_ = args_abs_list[weights_index];
1203   }
1204   MS_EXCEPTION_IF_NULL(fn_abs);
1205   if (fn_abs->isa<AbstractClass>()) {
1206     auto class_abs = dyn_cast<AbstractClass>(fn_abs);
1207     auto class_val = class_abs->BuildValue();
1208     MS_EXCEPTION_IF_NULL(class_val);
1209     auto class_obj = class_val->cast<parse::MsClassObjectPtr>();
1210     MS_EXCEPTION_IF_NULL(class_obj);
1211     auto obj_name = std::regex_replace(class_obj->name(), std::regex("MsClassObject:"), "");
1212     MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell' type "
1213                       << "object, but got object with jit_class type" << obj_name << ".";
1214   }
1215   AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(fn_abs);
1216   if (fn == nullptr) {
1217     MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell', but got "
1218                       << args_abs_list[0]->ToString();
1219   }
1220 
1221   auto real_fn = fn->cast_ptr<FuncGraphAbstractClosure>();
1222   if (real_fn == nullptr) {
1223     MS_LOG(EXCEPTION) << "For 'GradOperation', the first argument must be a 'Function' or 'Cell', but got "
1224                       << fn->ToString();
1225   }
1226   FuncGraphPtr forward_graph = real_fn->func_graph();
1227   MS_EXCEPTION_IF_NULL(forward_graph);
1228 
1229   if (has_aux_) {
1230     GradAuxPtr aux_fn = std::make_shared<GradAux>("aux_fn");
1231     auto output_cnode = forward_graph->output();
1232     auto aux_fn_cnode = forward_graph->NewCNodeInOrder({NewValueNode(aux_fn), output_cnode, NewValueNode(get_value_)});
1233     forward_graph->set_output(aux_fn_cnode);
1234   }
1235 
1236   forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1237 
1238   // Check if primal func graph has the primitive returned sparse result in its bprop().
1239   CheckPrimBpropReturnSparse(forward_graph);
1240 
1241   FuncGraphPtr grad_fg = nullptr;
1242   {
1243     TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
1244     grad_fg = std::make_shared<FuncGraph>();
1245   }
1246   auto nparam = forward_graph->parameters().size();
1247 
1248   std::ostringstream ss;
1249   ss << "grad{" << nparam << "}";
1250   grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1251   grad_fg->debug_info()->set_name(ss.str());
1252   ParameterPtr param_graph = grad_fg->add_parameter();
1253 
1254   bool is_weights_empty_or_none = false;
1255   AnfNodePtr weights = nullptr;
1256   AnfNodePtr position = nullptr;
1257   if (args_abs_list.size() > weights_index) {
1258     auto weights_abs = args_abs_list[weights_index];
1259     MS_EXCEPTION_IF_NULL(weights_abs);
1260     if (weights_abs->isa<AbstractSequence>()) {
1261       if (weights_abs->cast<AbstractSequencePtr>()->empty()) {
1262         is_weights_empty_or_none = true;
1263       }
1264     }
1265   }
1266   if (get_by_position_) {
1267     weights = grad_fg->add_parameter();
1268     position = grad_fg->add_parameter();
1269   } else if (get_by_list_) {
1270     weights = grad_fg->add_parameter();
1271     // Check if weights is None.
1272     if (!is_weights_empty_or_none && args_abs_list.size() > weights_index) {
1273       auto weights_abs = args_abs_list[weights_index];
1274       MS_EXCEPTION_IF_NULL(weights_abs);
1275       if (weights_abs->isa<AbstractNone>()) {
1276         is_weights_empty_or_none = true;
1277       }
1278     }
1279   }
1280 
1281   std::vector<AnfNodePtr> inputs;
1282   inputs.push_back(NewValueNode(prim::kPrimJ));
1283   inputs.push_back(param_graph);
1284   auto j = grad_fg->NewCNodeInOrder(inputs);
1285   if (merge_forward_) {
1286     j->set_user_data<bool>("merge_forward", std::make_shared<bool>(true));
1287   }
1288   // df is checked in GetGrad
1289   FuncGraphPtr k_child = nullptr;
1290   {
1291     TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
1292     k_child = GetGrad(j, weights, position, forward_graph->parameters(),
1293                       forward_graph->has_flag("enable_tuple_grad_first"), is_weights_empty_or_none);
1294     k_child->set_flag(FUNC_GRAPH_FLAG_ARGS_NO_EXPAND, true);
1295   }
1296   grad_fg->set_output(NewValueNode(k_child));
1297 
1298   return grad_fg;
1299 }
1300 
1301 // Generate the vmap_graph.
VmapOperation(const std::string & name)1302 VmapOperation::VmapOperation(const std::string &name) : MetaFuncGraph(name) {
1303   auto default_zero = std::make_shared<Int64Imm>(static_cast<int64_t>(0));
1304   signatures_ =
1305     // def vmap(func:read, in_axes:ref, out_axes:ref):
1306     std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
1307                             {"in_axes", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault, default_zero,
1308                              SignatureEnumDType::kDTypeEmptyDefaultValue},
1309                             {"out_axes", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault, default_zero,
1310                              SignatureEnumDType::kDTypeEmptyDefaultValue}});
1311 }
1312 
GetVmap(const AnfNodePtr & vmap,int param_number) const1313 FuncGraphPtr VmapOperation::GetVmap(const AnfNodePtr &vmap, int param_number) const {
1314   FuncGraphPtr vmap_child = std::make_shared<FuncGraph>();
1315   vmap_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1316   vmap_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true);
1317 
1318   std::vector<AnfNodePtr> inputs;
1319   inputs.push_back(vmap);
1320   for (int i = 0; i < param_number; ++i) {
1321     inputs.push_back(vmap_child->add_parameter());
1322   }
1323   auto vmap_app = vmap_child->NewCNodeInOrder(inputs);
1324   vmap_child->set_output(vmap_app);
1325 
1326   return vmap_child;
1327 }
1328 
1329 namespace {
IsAxesAllNone(const ValuePtr & axes)1330 bool IsAxesAllNone(const ValuePtr &axes) {
1331   MS_EXCEPTION_IF_NULL(axes);
1332   ValueSequencePtr axes_seq = dyn_cast<ValueSequence>(axes);
1333   auto axes_seq_value = axes_seq->value();
1334   if (std::all_of(axes_seq_value.begin(), axes_seq_value.end(), [](const ValuePtr &axes_value_ptr) {
1335         if (axes_value_ptr->isa<ValueSequence>()) {
1336           return IsAxesAllNone(axes_value_ptr);
1337         }
1338         if (!axes_value_ptr->isa<None>()) {
1339           return false;
1340         }
1341         return true;
1342       })) {
1343     return true;
1344   }
1345   return false;
1346 }
1347 
CheckAxes(const AbstractBasePtr & axes_abs,bool is_in_axes=false,int nparam=0,size_t cell_size=0)1348 ValuePtr CheckAxes(const AbstractBasePtr &axes_abs, bool is_in_axes = false, int nparam = 0, size_t cell_size = 0) {
1349   ValuePtr axes_value = nullptr;
1350   auto axes_name = is_in_axes ? "in_axes" : "out_axes";
1351 
1352   auto axes_abs_sequence = dyn_cast<AbstractSequence>(axes_abs);
1353   if (axes_abs_sequence != nullptr) {
1354     axes_value = axes_abs->cast<AbstractSequencePtr>()->ElementsBuildValue<ValueTuple>();
1355     MS_EXCEPTION_IF_NULL(axes_value);
1356     if (is_in_axes) {
1357       ValueSequencePtr in_axes_seq = dyn_cast<ValueSequence>(axes_value);
1358       int in_axes_size = SizeToInt(in_axes_seq->size());
1359       if (nparam != in_axes_size) {
1360         MS_LOG(EXCEPTION) << "When vmap`s '" << axes_name
1361                           << "' is a tuple or list, and its size must be equal to the number of arguments of 'fn': "
1362                           << nparam << ", but got size: " << in_axes_size << ".";
1363       }
1364     }
1365     bool elem_all_none = IsAxesAllNone(axes_value);
1366     if (elem_all_none && cell_size == 0) {
1367       MS_LOG(EXCEPTION) << "The '" << axes_name
1368                         << "' of 'vmap' cannot be all None while 'fn' is not a 'CellList', but got "
1369                         << axes_value->ToString() << ".";
1370     }
1371   } else {
1372     axes_value = axes_abs->BuildValue();
1373     MS_EXCEPTION_IF_NULL(axes_value);
1374     if (axes_value->isa<None>() && cell_size == 0) {
1375       MS_LOG(EXCEPTION) << "The '" << axes_name
1376                         << "' of 'vmap' cannot be a single None while 'fn' is not a 'CellList'.";
1377     } else if (!axes_value->isa<None>() && !axes_value->isa<Int64Imm>()) {
1378       MS_LOG(EXCEPTION) << "The axis in vmap`s '" << axes_name << "' can only be of type Int or None, but got "
1379                         << axes_abs->ToString() << ".";
1380     }
1381   }
1382   return axes_value;
1383 }
1384 
CheckVmapFunc(const AbstractBasePtr & fn_arg,int * nparam,size_t * cell_size)1385 DebugInfoPtr CheckVmapFunc(const AbstractBasePtr &fn_arg, int *nparam, size_t *cell_size) {
1386   DebugInfoPtr origin_graph_info = nullptr;
1387   // In the model ensembling parallel training scenario, fn is a CellList.
1388   AbstractTuplePtr cell_list = dyn_cast<AbstractTuple>(fn_arg);
1389   if (cell_list != nullptr) {
1390     *cell_size = cell_list->size();
1391     if (*cell_size <= 1) {
1392       MS_LOG(EXCEPTION) << "In the model ensembling parallel training scenario ('VmapOperation' arg0 is a 'CellList'),"
1393                         << " the size of 'CellList' must be greater than 1, but got " << *cell_size << ".";
1394     }
1395     const AbstractBasePtrList &cell_list_fns = cell_list->elements();
1396     for (auto fn_abs : cell_list_fns) {
1397       MS_EXCEPTION_IF_NULL(fn_abs);
1398       AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(fn_abs);
1399       if (fn == nullptr) {
1400         MS_LOG(EXCEPTION) << "'VmapOperation' arg0 is a 'CellList', whose elements must be 'Cell', but got "
1401                           << fn_abs->ToString() << ".";
1402       }
1403       auto partial_fn = dyn_cast<PartialAbstractClosure>(fn_abs);
1404       if (partial_fn != nullptr) {
1405         fn = partial_fn->fn();
1406       }
1407       auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
1408       if (real_fn == nullptr) {
1409         MS_LOG(EXCEPTION) << "'VmapOperation' arg0 is a 'CellList', whose element " << fn->ToString()
1410                           << " cast to 'FuncGraphAbstractClosure' failed.";
1411       }
1412 
1413       FuncGraphPtr orig_graph = real_fn->func_graph();
1414       MS_EXCEPTION_IF_NULL(orig_graph);
1415       orig_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1416       int fn_nparam =
1417         SizeToInt(orig_graph->parameters().size() - (partial_fn != nullptr ? partial_fn->args().size() : 0));
1418       if (*nparam == -1) {
1419         origin_graph_info = orig_graph->debug_info();
1420         *nparam = fn_nparam;
1421       } else if (*nparam != fn_nparam) {
1422         MS_LOG(EXCEPTION) << "'VmapOperation' arg0 is a CellList, whose elements's inputs should be consistent.";
1423       }
1424     }
1425   } else {
1426     AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(fn_arg);
1427     if (fn == nullptr) {
1428       MS_LOG(EXCEPTION) << "'VmapOperation' arg0 must be a 'Function' or 'Cell', but got " << fn_arg->ToString() << ".";
1429     }
1430     auto partial_fn = dyn_cast<PartialAbstractClosure>(fn);
1431     if (partial_fn != nullptr) {
1432       fn = partial_fn->fn();
1433     }
1434     auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
1435     if (real_fn == nullptr) {
1436       MS_LOG(EXCEPTION) << "'VmapOperation' arg0 " << fn->ToString() << " cast to 'FuncGraphAbstractClosure' failed.";
1437     }
1438 
1439     FuncGraphPtr orig_graph = real_fn->func_graph();
1440     MS_EXCEPTION_IF_NULL(orig_graph);
1441     orig_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1442     *nparam = SizeToInt(orig_graph->parameters().size() - (partial_fn != nullptr ? partial_fn->args().size() : 0));
1443     origin_graph_info = orig_graph->debug_info();
1444   }
1445   return origin_graph_info;
1446 }
1447 }  // namespace
1448 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1449 FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1450   if (args_abs_list.empty()) {
1451     MS_LOG(EXCEPTION) << "'VmapOperation' requires a network or function as an input, while the input is empty.";
1452   }
1453 
1454   constexpr auto vmap_operation_input_num = 3;
1455   const std::string op_name = "vmap";
1456   CheckArgsSize(op_name, args_abs_list, vmap_operation_input_num);
1457 
1458   auto fn_arg = args_abs_list[0];
1459   auto in_axes_arg = args_abs_list[1];
1460   auto out_axes_arg = args_abs_list[2];
1461 
1462   int nparam = -1;
1463   size_t cell_size = 0;
1464   DebugInfoPtr origin_graph_info = CheckVmapFunc(fn_arg, &nparam, &cell_size);
1465 
1466   FuncGraphPtr vmap_fg = nullptr;
1467   {
1468     TraceGuard guard(std::make_shared<TraceVmapOperation>(origin_graph_info));
1469     vmap_fg = std::make_shared<FuncGraph>();
1470   }
1471 
1472   std::ostringstream ss;
1473   ss << "vmap{" << nparam << "}";
1474   vmap_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1475   vmap_fg->debug_info()->set_name(ss.str());
1476 
1477   // Add parameter for `fn`, `in_axes` and `out_axes` respectively.
1478   ParameterPtr param_graph = vmap_fg->add_parameter();
1479   (void)vmap_fg->add_parameter();
1480   (void)vmap_fg->add_parameter();
1481 
1482   // Validity verification of in_axes and out_axes
1483   ValuePtr in_axes = CheckAxes(in_axes_arg, true, nparam, cell_size);
1484   ValuePtr out_axes = CheckAxes(out_axes_arg);
1485 
1486   PrimitivePtr kprim_vmap = std::make_shared<Primitive>(kVmapOpName, kSideEffectPropagate);
1487   kprim_vmap->set_attr("in_axes", in_axes);
1488   kprim_vmap->set_attr("out_axes", out_axes);
1489   kprim_vmap->set_attr("cell_size", MakeValue(cell_size));
1490 
1491   std::vector<AnfNodePtr> inputs;
1492   inputs.push_back(NewValueNode(kprim_vmap));
1493   inputs.push_back(param_graph);
1494   auto vmap = vmap_fg->NewCNodeInOrder(inputs);
1495 
1496   FuncGraphPtr vmap_child = nullptr;
1497   {
1498     TraceGuard guard(std::make_shared<TraceVmapOperation>(origin_graph_info));
1499     vmap_child = GetVmap(vmap, nparam);
1500   }
1501 
1502   vmap_fg->set_output(NewValueNode(vmap_child));
1503   return vmap_fg;
1504 }
1505 
TaylorOperation(const std::string & name)1506 TaylorOperation::TaylorOperation(const std::string &name) : MetaFuncGraph(name) {
1507   // def Taylor(func:read):
1508   signatures_ = std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}});
1509 }
1510 
GetTaylorGrad(const AnfNodePtr & k,const std::vector<AnfNodePtr> & forward_graph_params) const1511 FuncGraphPtr TaylorOperation::GetTaylorGrad(const AnfNodePtr &k,
1512                                             const std::vector<AnfNodePtr> &forward_graph_params) const {
1513   FuncGraphPtr k_child = std::make_shared<FuncGraph>();
1514   k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1515 
1516   std::vector<AnfNodePtr> inputs;
1517   inputs.push_back(k);
1518   MS_LOG(INFO) << "TaylorOperation forward input size " << forward_graph_params.size();
1519   for (size_t i = 0; i < forward_graph_params.size(); ++i) {
1520     inputs.push_back(k_child->add_parameter());
1521   }
1522   // Taylor(fn)(input params)
1523   auto k_app = k_child->NewCNodeInOrder(inputs);
1524 
1525   k_child->set_output(k_app);
1526   return k_child;
1527 }
1528 
1529 // Generate the graph to calculate higher order derivatives.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1530 FuncGraphPtr TaylorOperation::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1531   if (args_abs_list.empty()) {
1532     MS_LOG(EXCEPTION)
1533       << "'TaylorOperation' requires a forward network or function as an input, while the input is empty.";
1534   }
1535 
1536   MS_EXCEPTION_IF_NULL(args_abs_list[0]);
1537   AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_abs_list[0]);
1538   if (fn == nullptr) {
1539     MS_LOG(EXCEPTION) << "'TaylorOperation' arg0 must be a 'Function' or 'Cell', but got "
1540                       << args_abs_list[0]->ToString();
1541   }
1542 
1543   auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
1544   MS_EXCEPTION_IF_NULL(real_fn);
1545 
1546   FuncGraphPtr forward_graph = real_fn->func_graph();
1547   MS_EXCEPTION_IF_NULL(forward_graph);
1548   forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1549   FuncGraphPtr grad_fg = nullptr;
1550   MS_LOG(INFO) << "'TaylorOperation' forward_graph" << forward_graph->debug_info();
1551   grad_fg = std::make_shared<FuncGraph>();
1552   auto nparam = forward_graph->parameters().size();
1553 
1554   std::ostringstream ss;
1555   ss << "taylorgrad{" << nparam << "}";
1556   grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1557   grad_fg->debug_info()->set_name(ss.str());
1558   ParameterPtr param_graph = grad_fg->add_parameter();
1559 
1560   std::vector<AnfNodePtr> inputs;
1561   inputs.push_back(NewValueNode(prim::kPrimTaylor));
1562   inputs.push_back(param_graph);
1563   // Taylor(fn)
1564   auto mark_taylor = grad_fg->NewCNodeInOrder(inputs);
1565   FuncGraphPtr k_child = nullptr;
1566   {
1567     TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
1568     k_child = GetTaylorGrad(mark_taylor, forward_graph->parameters());
1569   }
1570   grad_fg->set_output(NewValueNode(k_child));
1571   // return Taylor(fn)(inputs)
1572   return grad_fg;
1573 }
1574 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1575 FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1576   // args: tuple1, tuple2
1577   abstract::CheckArgsSize("TupleAdd", args_abs_list, 2);
1578   AbstractBasePtr abs_a = args_abs_list[0];
1579   AbstractBasePtr abs_b = args_abs_list[1];
1580 
1581   AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
1582   AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
1583   if (a_tuple == nullptr || b_tuple == nullptr) {
1584     TypePtrList types;
1585     (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(types),
1586                          [](const AbstractBasePtr &arg) -> TypePtr {
1587                            MS_EXCEPTION_IF_NULL(arg);
1588                            return arg->BuildType();
1589                          });
1590     auto stub = GenerateStubFunc(types);
1591     if (stub != nullptr) {
1592       MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
1593                     << ", function: " << stub->ToString();
1594       return stub;
1595     }
1596     MS_LOG(EXCEPTION) << "The type of argument in TupleAdd operator should be tuple, but the first argument is "
1597                       << args_abs_list[0]->ToString() << ", the second argument is " << args_abs_list[1]->ToString();
1598   }
1599 
1600   FuncGraphPtr ret = std::make_shared<FuncGraph>();
1601   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1602   AnfNodePtr p_tup_a = ret->add_parameter();
1603   AnfNodePtr p_tup_b = ret->add_parameter();
1604 
1605   std::vector<AnfNodePtr> elems;
1606   elems.push_back(NewValueNode(prim::kPrimMakeTuple));
1607 
1608   int64_t tuple_size = SizeToLong(a_tuple->size());
1609   for (int64_t i = 0; i < tuple_size; ++i) {
1610     elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
1611   }
1612 
1613   tuple_size = SizeToLong(b_tuple->size());
1614   for (int64_t i = 0; i < tuple_size; ++i) {
1615     elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
1616   }
1617 
1618   ret->set_output(ret->NewCNodeInOrder(elems));
1619   return ret;
1620 }
1621 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1622 FuncGraphPtr ListAdd::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1623   // args: list1, list2
1624   abstract::CheckArgsSize("ListAdd", args_abs_list, 2);
1625   AbstractBasePtr abs_a = args_abs_list[0];
1626   AbstractBasePtr abs_b = args_abs_list[1];
1627 
1628   AbstractListPtr a_list = dyn_cast<AbstractList>(abs_a);
1629   AbstractListPtr b_list = dyn_cast<AbstractList>(abs_b);
1630   if (a_list == nullptr || b_list == nullptr) {
1631     TypePtrList types;
1632     (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(types),
1633                          [](const AbstractBasePtr &arg) -> TypePtr {
1634                            MS_EXCEPTION_IF_NULL(arg);
1635                            return arg->BuildType();
1636                          });
1637     auto stub = GenerateStubFunc(types);
1638     if (stub != nullptr) {
1639       MS_LOG(DEBUG) << "GenerateStubFunc for ListAdd "
1640                     << ", function: " << stub->ToString();
1641       return stub;
1642     }
1643     MS_LOG(EXCEPTION) << "The type of argument in ListAdd operator should be list, but the first argument is "
1644                       << args_abs_list[0]->ToString() << ", the second argument is " << args_abs_list[1]->ToString();
1645   }
1646 
1647   FuncGraphPtr ret = std::make_shared<FuncGraph>();
1648   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1649   AnfNodePtr p_list_a = ret->add_parameter();
1650   AnfNodePtr p_list_b = ret->add_parameter();
1651 
1652   std::vector<AnfNodePtr> elems;
1653   elems.push_back(NewValueNode(prim::kPrimMakeList));
1654 
1655   int64_t tuple_size = SizeToLong(a_list->size());
1656   for (int64_t i = 0; i < tuple_size; ++i) {
1657     elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), p_list_a, NewValueNode(i)}));
1658   }
1659 
1660   tuple_size = SizeToLong(b_list->size());
1661   for (int64_t i = 0; i < tuple_size; ++i) {
1662     elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), p_list_b, NewValueNode(i)}));
1663   }
1664 
1665   ret->set_output(ret->NewCNodeInOrder(elems));
1666   return ret;
1667 }
1668 
GetArgScalarValue(const abstract::AbstractScalarPtr & scalar,const std::string &)1669 int64_t GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
1670   MS_EXCEPTION_IF_NULL(scalar);
1671   return GetValue<int64_t>(scalar->BuildValue());
1672 }
1673 
GetPositiveIndex(int64_t index,int64_t length)1674 int64_t GetPositiveIndex(int64_t index, int64_t length) {
1675   if (index < 0) {
1676     index += length;
1677   }
1678   return index;
1679 }
1680 
CheckSliceMember(const AbstractBasePtr & member,int64_t default_value,const std::string & member_name)1681 int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, const std::string &member_name) {
1682   MS_EXCEPTION_IF_NULL(member);
1683 
1684   if (member->isa<AbstractScalar>()) {
1685     return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
1686   }
1687 
1688   if (member->isa<AbstractNone>()) {
1689     return default_value;
1690   }
1691 
1692   if (member->isa<AbstractTensor>()) {
1693     MS_EXCEPTION(TypeError)
1694       << "The argument of SliceMember operator must be a Scalar or None or constant Tensor, but got a variable Tensor";
1695   }
1696   MS_EXCEPTION(TypeError)
1697     << "The argument of SliceMember operator must be a Scalar or None or constant Tensor, but got "
1698     << member->BuildType()->ToString();
1699 }
1700 
GenerateTupleSliceParameter(const AbstractSequencePtr & sequence,const AbstractSlicePtr & slice)1701 std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const AbstractSequencePtr &sequence,
1702                                                                   const AbstractSlicePtr &slice) {
1703   MS_EXCEPTION_IF_NULL(sequence);
1704   MS_EXCEPTION_IF_NULL(slice);
1705   int64_t start_index;
1706   int64_t stop_index;
1707   int64_t step_value;
1708 
1709   const std::string start_name("Slice start index");
1710   const std::string stop_name("Slice stop index");
1711   const std::string step_name("Slice step value");
1712 
1713   int64_t tuple_size = SizeToLong(sequence->size());
1714   int64_t start_default = 0;
1715   int64_t stop_default = tuple_size;
1716   int64_t step_default = kStepDefault;
1717 
1718   step_value = CheckSliceMember(slice->step(), step_default, step_name);
1719   if (step_value == 0) {
1720     MS_EXCEPTION(ValueError) << "Slice step cannot be zero.";
1721   }
1722 
1723   if (step_value < 0) {
1724     start_default = tuple_size - 1;
1725     stop_default = ((-tuple_size) - 1);
1726   }
1727 
1728   start_index = CheckSliceMember(slice->start(), start_default, start_name);
1729   stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
1730 
1731   if (start_index < -tuple_size) {
1732     start_index = 0;
1733   }
1734 
1735   if (stop_index > tuple_size) {
1736     stop_index = tuple_size;
1737   }
1738 
1739   if (start_index > tuple_size) {
1740     start_index = tuple_size;
1741   }
1742 
1743   if (stop_index < ((-tuple_size) - 1)) {
1744     stop_index = 0;
1745   }
1746 
1747   start_index = GetPositiveIndex(start_index, tuple_size);
1748 
1749   stop_index = GetPositiveIndex(stop_index, tuple_size);
1750 
1751   return std::make_tuple(start_index, stop_index, step_value);
1752 }
1753 
CheckArgs(const AbstractBasePtrList & args_abs_list)1754 void SequenceSliceGetItem::CheckArgs(const AbstractBasePtrList &args_abs_list) {
1755   constexpr size_t arg_size = 2;
1756   abstract::CheckArgsSize(this->name(), args_abs_list, arg_size);
1757   sequence_ = abstract::CheckArg<AbstractSequence>(this->name(), args_abs_list, 0);
1758   slice_ = abstract::CheckArg<AbstractSlice>(this->name(), args_abs_list, 1);
1759 }
1760 
BuildFuncGraph(int64_t start_index,int64_t stop_index,int64_t step_value)1761 FuncGraphPtr SequenceSliceGetItem::BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) {
1762   FuncGraphPtr ret = std::make_shared<FuncGraph>();
1763   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1764   AnfNodePtr p_seq = ret->add_parameter();
1765   (void)ret->add_parameter();
1766 
1767   std::vector<AnfNodePtr> elems;
1768   elems.push_back(NewValueNode(prim_));
1769   if (step_value > 0) {
1770     for (int64_t index = start_index; index < stop_index; index = index + step_value) {
1771       elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_seq, NewValueNode(index)}));
1772     }
1773   } else {
1774     for (int64_t index = start_index; index > stop_index; index = index + step_value) {
1775       elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_seq, NewValueNode(index)}));
1776     }
1777   }
1778 
1779   ret->set_output(ret->NewCNodeInOrder(elems));
1780   return ret;
1781 }
1782 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1783 FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1784   // select indexed item
1785   // args: tuple of items, index
1786   const std::string op_name = std::string("TupleGetItemTensor");
1787   const size_t inputs_size = 2;
1788   abstract::CheckArgsSize(op_name, args_abs_list, inputs_size);
1789   auto ret_graph = std::make_shared<FuncGraph>();
1790   ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1791   auto tuple = ret_graph->add_parameter();
1792   auto index = ret_graph->add_parameter();
1793 
1794   constexpr size_t tuple_index = 0;
1795   auto abs = args_abs_list[tuple_index];
1796   MS_EXCEPTION_IF_NULL(abs);
1797   auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
1798   MS_EXCEPTION_IF_NULL(tuple_abs);
1799   if (!tuple_abs->dynamic_len()) {
1800     const auto &elements = tuple_abs->elements();
1801     if (std::all_of(elements.begin(), elements.end(), [](const AbstractBasePtr &e) {
1802           MS_EXCEPTION_IF_NULL(e);
1803           return e->isa<abstract::FuncGraphAbstractClosure>() || e->isa<abstract::PartialAbstractClosure>() ||
1804                  e->isa<abstract::PrimitiveAbstractClosure>();
1805         })) {
1806       ret_graph->set_output(ret_graph->NewCNodeInOrder({NewValueNode(prim::kPrimSwitchLayer), index, tuple}));
1807       return ret_graph;
1808     }
1809   }
1810 
1811   const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1812   if (!allow_fallback_runtime) {
1813     MS_EXCEPTION(TypeError) << "When JIT_SYNTAX_LEVEL is STRICT, using Tensor index to get value from tuple requires "
1814                             << "that all elements in tuple should be function but got tuple abstract: "
1815                             << tuple_abs->ToString();
1816   }
1817   // Script
1818   constexpr auto internal_tuple_input = "__internal_tuple_input__";
1819   constexpr auto internal_index_input = "__internal_index_input__";
1820   std::stringstream script_buffer;
1821   script_buffer << internal_tuple_input << "[" << internal_index_input << "]";
1822   const std::string &script = script_buffer.str();
1823   const auto script_str = std::make_shared<StringImm>(script);
1824   // Key
1825   std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1826   (void)key_value_names_list.emplace_back(NewValueNode(internal_tuple_input));
1827   (void)key_value_names_list.emplace_back(NewValueNode(internal_index_input));
1828   const auto key_value_name_tuple = ret_graph->NewCNode(key_value_names_list);
1829   // Value
1830   std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1831   (void)key_value_list.emplace_back(tuple);
1832   (void)key_value_list.emplace_back(index);
1833   const auto key_value_tuple = ret_graph->NewCNode(key_value_list);
1834   auto res =
1835     fallback::CreatePyExecuteCNode(ret_graph, NewValueNode(script_str), key_value_name_tuple, key_value_tuple, nullptr);
1836   ret_graph->set_output(res);
1837   return ret_graph;
1838 }
1839 
1840 namespace {
GetShard(const AnfNodePtr & shard,const std::vector<AnfNodePtr> & origin_graph_params)1841 FuncGraphPtr GetShard(const AnfNodePtr &shard, const std::vector<AnfNodePtr> &origin_graph_params) {
1842   FuncGraphPtr shard_child = std::make_shared<FuncGraph>();
1843   shard_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1844 
1845   std::vector<AnfNodePtr> inputs;
1846   inputs.reserve(origin_graph_params.size() + 1);
1847   (void)inputs.emplace_back(shard);
1848   for (size_t i = 0; i < origin_graph_params.size(); ++i) {
1849     (void)inputs.emplace_back(shard_child->add_parameter());
1850   }
1851   auto shard_app = shard_child->NewCNodeInOrder(std::move(inputs));
1852 
1853   shard_child->set_output(shard_app);
1854   return shard_child;
1855 }
1856 }  // namespace
1857 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)1858 FuncGraphPtr Shard::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
1859   if (args_abs_list.size() != kShardInputSize) {
1860     MS_LOG(EXCEPTION) << "'Shard' requires " << kShardInputSize
1861                       << " inputs. Includes a Cell or function, in_axes, out_axes, parameter_plan, device and level.";
1862   }
1863 
1864   MS_EXCEPTION_IF_NULL(args_abs_list[0]);
1865   AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_abs_list[0]);
1866   if (fn == nullptr) {
1867     MS_LOG(EXCEPTION) << "'Shard' arg0 must be a 'Function' or 'Cell', but got " << args_abs_list[0]->ToString() << ".";
1868   }
1869 
1870   auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
1871   MS_EXCEPTION_IF_NULL(real_fn);
1872   FuncGraphPtr origin_graph = real_fn->func_graph();
1873   MS_EXCEPTION_IF_NULL(origin_graph);
1874   auto execution_mode = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE);
1875   if (execution_mode == kPynativeMode) {
1876     origin_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
1877   }
1878   FuncGraphPtr shard_fg = nullptr;
1879   {
1880     TraceGuard g(std::make_shared<TraceShard>(origin_graph->debug_info()));
1881     shard_fg = std::make_shared<FuncGraph>();
1882   }
1883   // Create the debug info
1884   auto parameter_size = origin_graph->parameters().size();
1885   std::ostringstream ss;
1886   ss << "shard{" << parameter_size << "}";
1887   shard_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1888   shard_fg->debug_info()->set_name(ss.str());
1889   // Make the Shard node.
1890   std::vector<AnfNodePtr> inputs;
1891   inputs.reserve(args_abs_list.size() + 1);
1892   (void)inputs.emplace_back(NewValueNode(prim::kPrimShard));
1893   for (size_t i = 0; i < args_abs_list.size(); ++i) {
1894     (void)inputs.emplace_back(shard_fg->add_parameter());
1895   }
1896   auto shard = shard_fg->NewCNodeInOrder(std::move(inputs));
1897 
1898   FuncGraphPtr shard_child = nullptr;
1899   {
1900     TraceGuard guard(std::make_shared<TraceShard>(shard_fg->debug_info()));
1901     shard_child = GetShard(shard, origin_graph->parameters());
1902   }
1903   shard_fg->set_output(NewValueNode(shard_child));
1904   return shard_fg;
1905 }
1906 
CheckArgs(const AbstractBasePtrList & args_abs_list)1907 void ListSliceSetItem::CheckArgs(const AbstractBasePtrList &args_abs_list) {
1908   constexpr size_t kSliceSetItemArgsSizeargs_size = 3;
1909   constexpr size_t kSliceSetItemListIndex = 0;
1910   constexpr size_t kSliceSetItemSliceIndex = 1;
1911   constexpr size_t kSliceSetItemValueIndex = 2;
1912   abstract::CheckArgsSize("list_slice_set_item", args_abs_list, kSliceSetItemArgsSizeargs_size);
1913   this->sequence_ = abstract::CheckArg<AbstractList>("list_slice_set_item", args_abs_list, kSliceSetItemListIndex);
1914   this->slice_ = abstract::CheckArg<AbstractSlice>("list_slice_set_item", args_abs_list, kSliceSetItemSliceIndex);
1915   this->value_list_ = abstract::CheckArg<AbstractList>("list_slice_set_item", args_abs_list, kSliceSetItemValueIndex);
1916 }
1917 
BuildFuncGraph(int64_t start_index,int64_t stop_index,int64_t step_value)1918 FuncGraphPtr ListSliceSetItem::BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) {
1919   // Init graph with the input list_node slice assign_node
1920   CheckAssignRange(start_index, stop_index, step_value);
1921   auto graph = std::make_shared<FuncGraph>();
1922   graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1923   auto list_node = graph->add_parameter();
1924   (void)graph->add_parameter();
1925   auto assign_parameter = graph->add_parameter();
1926   auto assign_node = GetAssignNode(graph, assign_parameter, step_value);
1927   std::vector<AnfNodePtr> elems = {NewValueNode(prim::kPrimMakeList)};
1928   int64_t list_index = 0;
1929   // check the index is in the slice range
1930   auto check_in_range = [start_index, stop_index, step_value](int64_t index) -> bool {
1931     if (step_value > 0) {
1932       return (index >= start_index && index < stop_index);
1933     }
1934     return (index <= start_index && index > stop_index);
1935   };
1936   int64_t list_size = SizeToLong(sequence_->size());
1937   int64_t assign_index = 0;
1938   int64_t value_size = SizeToLong(value_list_->size());
1939   while (list_index < list_size || assign_index < value_size) {
1940     if (!check_in_range(list_index)) {
1941       // list start <= stop && step = 1 insert the assign node to target node
1942       while (assign_index < value_size && list_index == start_index) {
1943         (void)elems.emplace_back(
1944           graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
1945       }
1946       if (list_index < list_size) {
1947         (void)elems.emplace_back(
1948           graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), list_node, NewValueNode(list_index++)}));
1949       }
1950     } else {
1951       if (((list_index - start_index) % step_value) == 0) {
1952         ++list_index;
1953         if (assign_index >= value_size) {
1954           continue;
1955         }
1956         (void)elems.emplace_back(
1957           graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
1958       } else {
1959         (void)elems.emplace_back(
1960           graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), list_node, NewValueNode(list_index++)}));
1961       }
1962       // the assign node's len is larger than the range
1963       while (!check_in_range(list_index) && assign_index < value_size) {
1964         (void)elems.emplace_back(
1965           graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
1966       }
1967     }
1968   }
1969 
1970   graph->set_output(graph->NewCNodeInOrder(elems));
1971   return graph;
1972 }
1973 
CheckAssignRange(int64_t start_index,int64_t stop_index,int64_t step_value)1974 void ListSliceSetItem::CheckAssignRange(int64_t start_index, int64_t stop_index, int64_t step_value) {
1975   if (step_value != kStepDefault) {
1976     auto range = stop_index - start_index;
1977     int include_start = (range % step_value) == 0 ? 0 : 1;
1978     auto assign_size = (range / step_value) + include_start;
1979     assign_size = assign_size > 0 ? assign_size : 0;
1980     if (assign_size != SizeToLong(value_list_->size())) {
1981       MS_EXCEPTION(ValueError) << "attempt to assign sequence of size " << value_list_->size()
1982                                << " to extended slice of size " << assign_size;
1983     }
1984   }
1985 }
1986 
GetAssignNode(const FuncGraphPtr & func_graph,const AnfNodePtr & assign_node,int64_t step_value)1987 AnfNodePtr ListSliceSetItem::GetAssignNode(const FuncGraphPtr &func_graph, const AnfNodePtr &assign_node,
1988                                            int64_t step_value) {
1989   if (step_value > 0) {
1990     return assign_node;
1991   }
1992   std::vector<AnfNodePtr> elems = {NewValueNode(prim::kPrimMakeList)};
1993   for (int64_t i = SizeToInt(value_list_->size()) - 1; i >= 0; --i) {
1994     (void)elems.emplace_back(
1995       func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), assign_node, NewValueNode(i)}));
1996   }
1997   return func_graph->NewCNodeInOrder(elems);
1998 }
1999 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2000 FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2001   this->CheckArgs(args_abs_list);
2002   auto [start, stop, step] = GenerateTupleSliceParameter(sequence_, slice_);
2003   return this->BuildFuncGraph(start, stop, step);
2004 }
2005 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2006 FuncGraphPtr ZerosLike::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2007   constexpr auto input_size = 1;
2008   abstract::CheckArgsSize("ZerosLike", args_abs_list, input_size);
2009 
2010   auto x = args_abs_list[0];
2011   MS_EXCEPTION_IF_NULL(x);
2012   auto type = x->BuildType();
2013   MS_EXCEPTION_IF_NULL(type);
2014   if (type->type_id() == kTuple->type_id() || type->type_id() == kList->type_id()) {
2015     auto abs_seq = x->cast<AbstractSequencePtr>();
2016     MS_EXCEPTION_IF_NULL(abs_seq);
2017     if (abs_seq->dynamic_len()) {
2018       FuncGraphPtr res_graph = std::make_shared<FuncGraph>();
2019       res_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2020       res_graph->debug_info()->set_name("zeros_like");
2021       auto x_parameter = res_graph->add_parameter();
2022       res_graph->set_output(res_graph->NewCNodeInOrder({NewValueNode(prim::kPrimSequenceZerosLike), x_parameter}));
2023       return res_graph;
2024     }
2025   }
2026 
2027   HyperMap hyper_map(false, fn_leaf_);
2028   TypePtrList types;
2029   (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(types),
2030                        [](const AbstractBasePtr &arg) -> TypePtr {
2031                          MS_EXCEPTION_IF_NULL(arg);
2032                          return arg->BuildType();
2033                        });
2034   return hyper_map.GenerateFromTypes(types);
2035 }
2036 
2037 // IterConvert is used when the input is need to convert to Iterable object.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2038 FuncGraphPtr IterConverter::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2039   constexpr auto input_size = 1;
2040   abstract::CheckArgsSize("IterConverter", args_abs_list, input_size);
2041   auto fg = std::make_shared<FuncGraph>();
2042   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2043   auto input_abs = args_abs_list[0];
2044   MS_EXCEPTION_IF_NULL(input_abs);
2045   if (input_abs->isa<abstract::AbstractAny>() || input_abs->BuildValue()->isa<parse::InterpretedObject>()) {
2046     const std::vector<std::string> funcs_str{"tuple"};
2047     auto ret_node = fallback::GeneratePyInterpretWithAbstract(fg, funcs_str, input_size);
2048     fg->set_output(ret_node);
2049     return fg;
2050   }
2051 
2052   auto input_type = input_abs->BuildType();
2053   MS_EXCEPTION_IF_NULL(input_type);
2054   auto type_id = input_type->type_id();
2055   std::vector<int64_t> iterable_valid_types{
2056     TypeId::kObjectTypeString,     TypeId::kObjectTypeTuple,    TypeId::kObjectTypeList,  TypeId::kObjectTypeDictionary,
2057     TypeId::kObjectTypeTensorType, TypeId::kObjectTypeFunction, TypeId::kMetaTypeExternal};
2058   bool iterable = std::any_of(iterable_valid_types.begin(), iterable_valid_types.end(),
2059                               [type_id](int64_t valid_type) { return valid_type == type_id; });
2060   if (!iterable) {
2061     MS_EXCEPTION(TypeError) << "'" << TypeIdToString(type_id, true) << "' object is not iterable";
2062   }
2063 
2064   auto input = fg->add_parameter();
2065   if (input_abs->isa<AbstractDictionary>()) {
2066     auto ret_node = fg->NewCNode({NewValueNode(prim::kPrimDictGetKeys), input});
2067     fg->set_output(ret_node);
2068     return fg;
2069   }
2070   fg->set_output(input);
2071   return fg;
2072 }
2073 
2074 // HasNext is used to check whether the input has next element input.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2075 FuncGraphPtr HasNext::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2076   constexpr auto input_size = 1;
2077   abstract::CheckArgsSize("HasNext", args_abs_list, input_size);
2078   auto fg = std::make_shared<FuncGraph>();
2079   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2080   auto input_abs = args_abs_list[0];
2081   MS_EXCEPTION_IF_NULL(input_abs);
2082   auto input = fg->add_parameter();
2083   if (input_abs->isa<abstract::AbstractAny>() || input_abs->BuildValue()->isa<parse::InterpretedObject>()) {
2084     AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2085     AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2086     std::stringstream script_buffer;
2087     script_buffer << "__import__('mindspore').common._utils._jit_fallback_has_next_func(";
2088     const std::string data_str = "__data__";
2089     script_buffer << data_str << ")";
2090     (void)local_key_inputs.emplace_back(NewValueNode(data_str));
2091     (void)local_value_inputs.emplace_back(input);
2092     const auto &script = script_buffer.str();
2093     auto local_key_node = fg->NewCNode(local_key_inputs);
2094     auto local_value_node = fg->NewCNode(local_value_inputs);
2095     auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2096     auto ret = fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node);
2097     fg->set_output(ret);
2098     return fg;
2099   }
2100   const std::string module = "mindspore._extends.parse.standard_method";
2101   const std::string func_name = "ms_hasnext";
2102   py::function fn = python_adapter::GetPyFn(module, func_name);
2103   auto prim_func = parse::ParsePythonCode(fn);
2104   auto ret = fg->NewCNode({NewValueNode(prim_func), input});
2105   fg->set_output(ret);
2106   return fg;
2107 }
2108 
2109 // HasNext is used to check whether the input has next element input.
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2110 FuncGraphPtr Next::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2111   constexpr auto input_size = 1;
2112   abstract::CheckArgsSize("Next", args_abs_list, input_size);
2113   auto fg = std::make_shared<FuncGraph>();
2114   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2115   auto input_abs = args_abs_list[0];
2116   MS_EXCEPTION_IF_NULL(input_abs);
2117   auto input = fg->add_parameter();
2118   if (input_abs->isa<abstract::AbstractAny>() || input_abs->BuildValue()->isa<parse::InterpretedObject>()) {
2119     AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2120     AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2121     std::stringstream script_buffer;
2122     script_buffer << "__import__('mindspore').common._utils._jit_fallback_next_func(";
2123     const std::string data_str = "__data__";
2124     script_buffer << data_str << ")";
2125     (void)local_key_inputs.emplace_back(NewValueNode(data_str));
2126     (void)local_value_inputs.emplace_back(input);
2127     const auto &script = script_buffer.str();
2128     auto local_key_node = fg->NewCNode(local_key_inputs);
2129     auto local_value_node = fg->NewCNode(local_value_inputs);
2130     auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2131     auto ret = fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node);
2132     fg->set_output(ret);
2133     return fg;
2134   }
2135   const std::string module = "mindspore._extends.parse.standard_method";
2136   const std::string func_name = input_abs->isa<abstract::AbstractDictionary>() ? "dict_next" : "ms_next";
2137   py::function fn = python_adapter::GetPyFn(module, func_name);
2138   auto prim_func = parse::ParsePythonCode(fn);
2139   auto ret = fg->NewCNode({NewValueNode(prim_func), input});
2140   fg->set_output(ret);
2141   return fg;
2142 }
2143 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2144 FuncGraphPtr TupleFunc::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2145   if (args_abs_list.size() > 1) {
2146     MS_LOG(EXCEPTION) << "For 'TupleFunc', the number of input should be 0 or 1, but got " << args_abs_list.size();
2147   }
2148   auto fg = std::make_shared<FuncGraph>();
2149   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2150   if (args_abs_list.size() == 0) {
2151     auto ret = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple)});
2152     fg->set_output(ret);
2153     return fg;
2154   }
2155 
2156   auto input_abs = args_abs_list[0];
2157   MS_EXCEPTION_IF_NULL(input_abs);
2158   auto input = fg->add_parameter();
2159   if (fallback::ContainsSequenceAnyType(input_abs)) {
2160     AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2161     AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2162     std::stringstream script_buffer;
2163     script_buffer << "tuple(";
2164     const std::string data_str = "__data__";
2165     script_buffer << data_str << ")";
2166     (void)local_key_inputs.emplace_back(NewValueNode(data_str));
2167     (void)local_value_inputs.emplace_back(input);
2168     const auto &script = script_buffer.str();
2169     auto local_key_node = fg->NewCNode(local_key_inputs);
2170     auto local_value_node = fg->NewCNode(local_value_inputs);
2171     auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2172     auto ret = fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node);
2173     fg->set_output(ret);
2174     return fg;
2175   } else if (input_abs->isa<abstract::AbstractTuple>()) {
2176     fg->set_output(input);
2177     return fg;
2178   } else if (input_abs->isa<abstract::AbstractList>()) {
2179     // list to tuple
2180     if (fallback::SequenceAllElementsIsScalar(input_abs)) {
2181       auto prim = std::make_shared<Primitive>("ListToTuple");
2182       auto list_to_tuple = fg->NewCNode({NewValueNode(prim), input});
2183       fg->set_output(list_to_tuple);
2184       return fg;
2185     }
2186   }
2187   const std::string module = "mindspore._extends.parse.standard_method";
2188   const std::string func_name = "tuple_func";
2189   py::function fn = python_adapter::GetPyFn(module, func_name);
2190   auto prim_func = parse::ParsePythonCode(fn);
2191   auto ret = fg->NewCNode({NewValueNode(prim_func), input});
2192   fg->set_output(ret);
2193   return fg;
2194 }
2195 
GenerateFuncGraph(const AbstractBasePtrList & args_abs_list)2196 FuncGraphPtr ListFunc::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
2197   if (args_abs_list.size() > 1) {
2198     MS_LOG(EXCEPTION) << "For 'ListFunc', the number of input should be 0 or 1, but got " << args_abs_list.size();
2199   }
2200   auto fg = std::make_shared<FuncGraph>();
2201   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
2202   if (args_abs_list.size() == 0) {
2203     auto ret = fg->NewCNode({NewValueNode(prim::kPrimMakeList)});
2204     fg->set_output(ret);
2205     return fg;
2206   }
2207 
2208   auto input_abs = args_abs_list[0];
2209   MS_EXCEPTION_IF_NULL(input_abs);
2210   auto input = fg->add_parameter();
2211   if (fallback::ContainsSequenceAnyType(input_abs)) {
2212     AnfNodePtrList local_key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2213     AnfNodePtrList local_value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
2214     std::stringstream script_buffer;
2215     script_buffer << "list(";
2216     const std::string data_str = "__data__";
2217     script_buffer << data_str << ")";
2218     (void)local_key_inputs.emplace_back(NewValueNode(data_str));
2219     (void)local_value_inputs.emplace_back(input);
2220     const auto &script = script_buffer.str();
2221     auto local_key_node = fg->NewCNode(local_key_inputs);
2222     auto local_value_node = fg->NewCNode(local_value_inputs);
2223     auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_node, local_value_node});
2224     auto ret = fallback::CreatePyInterpretCNode(fg, script, py::dict(), local_dict_node);
2225     fg->set_output(ret);
2226     return fg;
2227   } else if (input_abs->isa<abstract::AbstractList>()) {
2228     fg->set_output(input);
2229     return fg;
2230   } else if (input_abs->isa<abstract::AbstractTuple>()) {
2231     // tuple to list
2232     if (fallback::SequenceAllElementsIsScalar(input_abs)) {
2233       auto prim = std::make_shared<Primitive>("TupleToList");
2234       auto tuple_to_list = fg->NewCNode({NewValueNode(prim), input});
2235       fg->set_output(tuple_to_list);
2236       return fg;
2237     }
2238   }
2239   const std::string module = "mindspore._extends.parse.standard_method";
2240   const std::string func_name = "list_func";
2241   py::function fn = python_adapter::GetPyFn(module, func_name);
2242   auto prim_func = parse::ParsePythonCode(fn);
2243   auto ret = fg->NewCNode({NewValueNode(prim_func), input});
2244   fg->set_output(ret);
2245   return fg;
2246 }
2247 }  // namespace prim
2248 }  // namespace mindspore
2249