• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
4  *
5  * Copyright 2019-2021 Huawei Technologies Co., Ltd
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  * http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "frontend/operator/composite/composite.h"
21 #include <algorithm>
22 #include <utility>
23 #include <sstream>
24 
25 #include "ir/anf.h"
26 #include "ir/func_graph.h"
27 #include "abstract/abstract_value.h"
28 #include "abstract/abstract_function.h"
29 #include "abstract/dshape.h"
30 #include "abstract/param_validator.h"
31 #include "frontend/operator/cc_implementations.h"
32 #include "frontend/optimizer/opt.h"
33 #include "utils/symbolic.h"
34 #include "pybind_api/api_register.h"
35 #include "ir/signature.h"
36 #include "debug/trace.h"
37 #include "utils/ms_context.h"
38 #include "utils/utils.h"
39 
40 namespace mindspore {
41 // namespace to support composite operators definition
42 namespace prim {
43 using AbstractTensor = mindspore::abstract::AbstractTensor;
44 using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
45 
46 using mindspore::abstract::AbstractAttribute;
47 using mindspore::abstract::AbstractBase;
48 using mindspore::abstract::AbstractClass;
49 using mindspore::abstract::AbstractDictionary;
50 using mindspore::abstract::AbstractDictionaryPtr;
51 using mindspore::abstract::AbstractEllipsis;
52 using mindspore::abstract::AbstractEllipsisPtr;
53 using mindspore::abstract::AbstractFunction;
54 using mindspore::abstract::AbstractFunctionPtr;
55 using mindspore::abstract::AbstractList;
56 using mindspore::abstract::AbstractNone;
57 using mindspore::abstract::AbstractScalar;
58 using mindspore::abstract::AbstractSlice;
59 using mindspore::abstract::AbstractTuple;
60 
61 ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
62                             {"__truediv__", nullptr},    {"__floordiv__", nullptr},   {"__mod__", kPrimScalarMod},
63                             {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq},   {"__lt__", kPrimScalarLt},
64                             {"__gt__", kPrimScalarGt},   {"__ne__", kPrimScalarNe},   {"__le__", kPrimScalarLe},
65                             {"__ge__", kPrimScalarGe}};
66 
67 ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
68 
Init()69 void HyperMap::Init() {
70   if (fn_leaf_) {
71     name_ = "hyper_map[" + fn_leaf_->name() + "]";
72   }
73   signatures_ =
74     // def hypermap(func:read, *args:ref):
75     std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
76                             {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
77 }
78 
HyperMap(bool reverse,const std::shared_ptr<MultitypeFuncGraph> & fn_leaf)79 HyperMap::HyperMap(bool reverse, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
80     : MetaFuncGraph("hyper_map"),
81       fn_leaf_(fn_leaf),
82       reverse_(reverse),
83       broadcast_(false),
84       nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
85   Init();
86 }
87 
HyperMap(const HyperMap & h)88 HyperMap::HyperMap(const HyperMap &h)
89     : MetaFuncGraph("hyper_map"),
90       fn_leaf_(h.fn_leaf_),
91       reverse_(h.reverse_),
92       broadcast_(h.broadcast_),
93       nonleaf_(h.nonleaf_) {
94   Init();
95 }
96 
FullMake(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)97 AnfNodePtr HyperMap::FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
98   MS_EXCEPTION_IF_NULL(func_graph);
99   std::vector<AnfNodePtr> inputs;
100   if (fn_arg != nullptr) {
101     inputs.push_back(fn_arg);
102   } else {
103     inputs.push_back(NewValueNode(fn_leaf_));
104   }
105 
106   (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
107                        [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
108   return func_graph->NewCNodeInOrder(inputs);
109 }
110 
FullMake(const std::shared_ptr<List> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)111 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
112                               const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
113   MS_EXCEPTION_IF_NULL(func_graph);
114   MS_EXCEPTION_IF_NULL(type);
115 
116   size_t size = type->elements().size();
117   size_t num = 0;
118   bool is_not_same =
119     std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
120       num++;
121       auto lhs = std::static_pointer_cast<List>(item.second);
122       if (lhs == nullptr) {
123         MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a List, but got "
124                           << item.second->ToString();
125       }
126       if (lhs->elements().size() != size) {
127         MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
128                       << lhs->elements().size();
129         return true;
130       }
131       return false;
132     });
133   if (is_not_same) {
134     MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
135   }
136 
137   // cannot use shared_from_base() also known as this, as it will make a reference cycle on
138   // hypermap and graph generated, it will cause memory leak.
139   auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
140   constexpr size_t kPrimHoldLen = 1;
141   std::vector<AnfNodePtr> inputs;
142   inputs.reserve(size + kPrimHoldLen);
143   inputs.push_back(NewValueNode(prim::kPrimMakeList));
144 
145   for (size_t i = 0; i < size; i++) {
146     MS_LOG(DEBUG) << "FullMakeList for the " << i << "th element of the target, reverse_: " << reverse_;
147     std::vector<AnfNodePtr> inputs2;
148     inputs2.push_back(fn_rec);
149     if (fn_arg != nullptr) {
150       inputs2.push_back(fn_arg);
151     }
152     size_t pos = (reverse_ ? (size - 1 - i) : i);
153     (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
154                          [&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
155                            return func_graph->NewCNodeInOrder(
156                              {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
157                          });
158 
159     auto call_node = func_graph->NewCNodeInOrder(inputs2);
160     if (reverse_) {
161       inputs.insert(inputs.begin() + 1, call_node);
162     } else {
163       inputs.emplace_back(call_node);
164     }
165   }
166   return func_graph->NewCNodeInOrder(inputs);
167 }
168 
FullMake(const std::shared_ptr<Tuple> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)169 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
170                               const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
171   MS_EXCEPTION_IF_NULL(func_graph);
172   MS_EXCEPTION_IF_NULL(type);
173 
174   size_t size = type->elements().size();
175   size_t num = 0;
176   bool is_not_same =
177     std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
178       num++;
179       auto lhs = std::static_pointer_cast<Tuple>(item.second);
180       if (lhs == nullptr) {
181         MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a Tuple, but got "
182                           << item.second->ToString();
183       }
184       if (lhs->elements().size() != size) {
185         MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
186                       << lhs->elements().size();
187         return true;
188       }
189       return false;
190     });
191   if (is_not_same) {
192     MS_LOG(EXCEPTION) << "Tuple in HyperMap should have same length";
193   }
194 
195   // cannot use shared_from_base() also known as this, as it will make a reference cycle on
196   // hypermap and graph generated, it will cause memory leak.
197   auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
198   constexpr size_t kPrimHoldLen = 1;
199   std::vector<AnfNodePtr> inputs;
200   inputs.reserve(size + kPrimHoldLen);
201   inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
202 
203   for (size_t i = 0; i < size; i++) {
204     MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th element of the target, reverse_: " << reverse_;
205     std::vector<AnfNodePtr> inputs2;
206     inputs2.push_back(fn_rec);
207     if (fn_arg != nullptr) {
208       inputs2.push_back(fn_arg);
209     }
210     size_t pos = (reverse_ ? (size - 1 - i) : i);
211     (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
212                          [&func_graph, &pos](std::pair<AnfNodePtr, Any> item) {
213                            return func_graph->NewCNodeInOrder(
214                              {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
215                          });
216 
217     auto call_node = func_graph->NewCNodeInOrder(inputs2);
218     if (reverse_) {
219       inputs.insert(inputs.begin() + 1, call_node);
220     } else {
221       inputs.emplace_back(call_node);
222     }
223   }
224   return func_graph->NewCNodeInOrder(inputs);
225 }
226 
FullMake(const std::shared_ptr<Class> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)227 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
228                               const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
229   MS_EXCEPTION_IF_NULL(type);
230   MS_EXCEPTION_IF_NULL(func_graph);
231 
232   std::size_t attrSize = type->GetAttributes().size();
233   constexpr size_t kPrimAndTypeLen = 2;
234   std::vector<AnfNodePtr> inputs;
235   inputs.reserve(attrSize + kPrimAndTypeLen);
236   inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
237   inputs.push_back(NewValueNode(type));
238 
239   // cannot use shared_from_base() also known as this, as it will make a reference cycle on
240   // hypermap and graph generated, it will cause memory leak.
241   auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
242   for (std::size_t i = 0; i < attrSize; i++) {
243     MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the target, reverse_: " << reverse_;
244     std::vector<AnfNodePtr> inputs2;
245     inputs2.push_back(fn_rec);
246     if (fn_arg) {
247       inputs2.push_back(fn_arg);
248     }
249 
250     size_t size = arg_map.size();
251     for (size_t j = 0; j < size; j++) {
252       size_t pos = (reverse_ ? (size - 1 - j) : j);
253       auto &item = arg_map[pos];
254       inputs2.push_back(
255         func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
256     }
257 
258     auto call_node = func_graph->NewCNodeInOrder(inputs2);
259     if (reverse_) {
260       inputs.insert(inputs.begin() + kPrimAndTypeLen, call_node);
261     } else {
262       inputs.emplace_back(call_node);
263     }
264   }
265   return func_graph->NewCNodeInOrder(inputs);
266 }
267 
Make(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)268 AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
269   bool found = false;
270   TypeId id = kObjectTypeEnd;
271   std::pair<AnfNodePtr, TypePtr> pair;
272   for (auto &item : arg_map) {
273     pair = item;
274     id = item.second->type_id();
275     if (nonleaf_.count(id)) {
276       found = true;
277       break;
278     }
279   }
280 
281   if (found) {
282     // In a nonleaf situation, all arguments must have the same generic.
283     bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
284       if (item.first != pair.first) {
285         return item.second->type_id() != pair.second->type_id();
286       }
287       return false;
288     });
289     if (is_not_same) {
290       std::ostringstream oss;
291       oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
292           << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
293       int64_t idx = 0;
294       for (auto &item : arg_map) {
295         oss << ++idx << ": " << item.second->ToString() << "\n";
296       }
297       MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
298     }
299   }
300 
301   switch (id) {
302     case kObjectTypeList: {
303       auto type = std::static_pointer_cast<List>(pair.second);
304       return FullMake(type, func_graph, fn_arg, arg_map);
305     }
306     case kObjectTypeTuple: {
307       auto type = std::static_pointer_cast<Tuple>(pair.second);
308       return FullMake(type, func_graph, fn_arg, arg_map);
309     }
310     case kObjectTypeClass: {
311       auto type = std::static_pointer_cast<Class>(pair.second);
312       return FullMake(type, func_graph, fn_arg, arg_map);
313     }
314     default:
315       return FullMake(func_graph, fn_arg, arg_map);
316   }
317 }
318 
Harmonize(const FuncGraphPtr & func_graph,const ArgsPairList & args_spec_list)319 ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
320   TypePtr type_tensor = std::make_shared<TensorType>();
321   bool flag = std::any_of(
322     args_spec_list.begin(), args_spec_list.end(),
323     [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
324   if (flag && broadcast_) {
325     ArgsPairList ret;
326     for (auto &item : args_spec_list) {
327       if (!IsSubType(item.second, type_tensor)) {
328         TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
329         ret.push_back(std::make_pair(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimScalarToArray), item.first}),
330                                      type_tensor_ele));
331       } else {
332         ret.push_back(std::make_pair(item.first, item.second));
333       }
334     }
335     return ret;
336   }
337   return args_spec_list;
338 }
339 
GenerateFromTypes(const TypePtrList & args_spec_list)340 FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
341   FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
342   ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
343   ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
344   ptr_graph->debug_info()->set_name("hyper_map");
345 
346   AnfNodePtr ptrFnArg = nullptr;
347   std::size_t i = 0;
348   ArgsPairList argmap;
349   ArgsPairList argmap2;
350   if (fn_leaf_ == nullptr) {
351     ptrFnArg = ptr_graph->add_parameter();
352     i = 1;
353   }
354 
355   std::size_t size = args_spec_list.size();
356   for (; i < size; ++i) {
357     argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
358   }
359 
360   argmap2 = Harmonize(ptr_graph, argmap);
361   ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
362   return ptr_graph;
363 }
364 
NormalizeArgs(const AbstractBasePtrList & args_spec_list) const365 abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
366   if (fn_leaf_ == nullptr) {
367     if (args_spec_list.empty()) {
368       MS_LOG(EXCEPTION) << "The args spec list is empty.";
369     }
370     MS_EXCEPTION_IF_NULL(args_spec_list[0]);
371     // Assert that hypermap's function param does not contain free variables
372     if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
373       auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
374       auto func_graph = graph_func->func_graph();
375       if (func_graph->parent() != nullptr) {
376         MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
377       }
378     }
379   }
380 
381   AbstractBasePtrList broadened;
382   (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
383                        [](const AbstractBasePtr &arg) -> AbstractBasePtr {
384                          MS_EXCEPTION_IF_NULL(arg);
385                          return arg->Broaden();
386                        });
387   return broadened;
388 }
389 
__anon6b0a12740902(const py::module *m) 390 REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
391                          (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
392                            .def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"),
393                                 py::arg("ops"))
394                            .def(py::init<bool>(), py::arg("reverse"));
395                        }));
396 
CheckSequenceAllTensor(const abstract::AbstractTuplePtr & tuple)397 bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
398   for (size_t i = 0; i < tuple->size(); ++i) {
399     if (!(*tuple)[i]->isa<abstract::AbstractUndetermined>() &&
400         !((*tuple)[i]->isa<abstract::AbstractTuple>() &&
401           CheckSequenceAllTensor((*tuple)[i]->cast<abstract::AbstractTuplePtr>()))) {
402       return false;
403     }
404   }
405   return true;
406 }
407 
CheckTailGradFristSequence(const abstract::AbstractSequeuePtr & sequeue,bool enable_tuple_grad)408 bool CheckTailGradFristSequence(const abstract::AbstractSequeuePtr &sequeue, bool enable_tuple_grad) {
409   return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
410          ((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
411           (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
412            (*sequeue)[1]->BuildType()->isa<Number>()) ||
413           ((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
414            CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
415 }
416 
GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & sequeue) const417 FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
418   MS_EXCEPTION_IF_NULL(sequeue);
419 
420   FuncGraphPtr ret = std::make_shared<FuncGraph>();
421   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
422   ret->debug_info()->set_name("tail");
423   AnfNodePtr ptrTup = ret->add_parameter();
424 
425   std::vector<AnfNodePtr> elems;
426   PrimitivePtr op = nullptr;
427   if (sequeue->isa<AbstractTuple>()) {
428     elems.push_back(NewValueNode(prim::kPrimMakeTuple));
429     op = prim::kPrimTupleGetItem;
430   } else {
431     elems.push_back(NewValueNode(prim::kPrimMakeList));
432     op = prim::kPrimListGetItem;
433   }
434 
435   if (tail_type_ == kGradFirst) {
436     if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
437       ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
438     } else {
439       ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
440     }
441 
442     return ret;
443   }
444 
445   for (size_t i = 1; i < sequeue->size(); ++i) {
446     if (tail_type_ == kGradAll) {
447       MS_EXCEPTION_IF_NULL((*sequeue)[i]);
448       if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
449           (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
450            (*sequeue)[i]->BuildType()->isa<Number>())) {
451         elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
452       }
453     } else {
454       elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
455     }
456   }
457 
458   ret->set_output(ret->NewCNodeInOrder(elems));
459   return ret;
460 }
461 
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)462 FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
463   if (args_spec_list.size() != 1) {
464     MS_LOG(EXCEPTION) << "Tail requires a non-empty tuple.";
465   }
466 
467   AbstractBasePtr a = args_spec_list[0];
468   if (a->isa<AbstractTuple>() || a->isa<AbstractList>()) {
469     return GenerateSequeueFuncGraph(a->cast<abstract::AbstractSequeuePtr>());
470   }
471 
472   MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
473 }
474 
475 REGISTER_PYBIND_DEFINE(
__anon6b0a12740a02(const py::module *m) 476   Tail_, ([](const py::module *m) {
477     (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
478   }));
479 
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)480 FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
481   int64_t tuple_size = SizeToLong(args_spec_list.size());
482 
483   std::ostringstream ss;
484   ss << "▶make_tuple_" << tuple_size;
485   FuncGraphPtr fg = std::make_shared<FuncGraph>();
486   fg->debug_info()->set_name(ss.str());
487 
488   std::vector<AnfNodePtr> params;
489   params.push_back(NewValueNode(prim::kPrimMakeTuple));
490   for (int64_t i = 0; i < tuple_size; ++i) {
491     params.push_back(fg->add_parameter());
492   }
493 
494   // make fprob first result, maketuple's forward result.
495   AnfNodePtr out = fg->NewCNodeInOrder(params);
496 
497   // make fprob second result, maketuple's backward function.
498   FuncGraphPtr b = std::make_shared<FuncGraph>();
499 
500   ss.clear();
501   ss << "◀make_tuple_" << tuple_size;
502   b->debug_info()->set_name(ss.str());
503   AnfNodePtr dout = b->add_parameter();
504 
505   std::vector<AnfNodePtr> grads;
506   grads.push_back(NewValueNode(prim::kPrimMakeTuple));
507   grads.push_back(NewValueNode(newenv));
508   for (int64_t i = 0; i < tuple_size; ++i) {
509     grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
510   }
511 
512   b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
513   b->set_output(b->NewCNodeInOrder(grads));
514 
515   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
516   fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
517   (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
518   return fg;
519 }
520 
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)521 FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
522   int64_t list_size = SizeToLong(args_spec_list.size());
523 
524   std::ostringstream ss;
525   ss << "▶make_list_" << list_size;
526   FuncGraphPtr fg = std::make_shared<FuncGraph>();
527   fg->debug_info()->set_name(ss.str());
528 
529   std::vector<AnfNodePtr> params;
530   params.push_back(NewValueNode(prim::kPrimMakeList));
531   for (int64_t i = 0; i < list_size; ++i) {
532     params.push_back(fg->add_parameter());
533   }
534 
535   // make fprob first result, maketuple's forward result.
536   AnfNodePtr out = fg->NewCNodeInOrder(params);
537 
538   // make fprob second result, maketuple's backward function.
539   FuncGraphPtr b = std::make_shared<FuncGraph>();
540 
541   ss.clear();
542   ss << "◀make_list_" << list_size;
543   b->debug_info()->set_name(ss.str());
544   AnfNodePtr dout = b->add_parameter();
545 
546   std::vector<AnfNodePtr> grads;
547   grads.push_back(NewValueNode(prim::kPrimMakeTuple));
548   grads.push_back(NewValueNode(newenv));
549   for (int64_t i = 0; i < list_size; ++i) {
550     grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
551   }
552 
553   b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
554   b->set_output(b->NewCNodeInOrder(grads));
555 
556   fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
557   fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
558   (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
559   return fg;
560 }
561 
GradOperation(const std::string & name,bool get_all,bool get_by_list,bool sens_param)562 GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
563     : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
564   if (get_by_list) {
565     signatures_ =
566       // def grad(func:read, weight_list:ref):
567       std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
568                               {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
569   }
570 }
571 
GetGrad(const AnfNodePtr & k,const AnfNodePtr & weights,const std::vector<AnfNodePtr> & forward_graph_params,bool enable_tuple_grad,const std::vector<AnfNodePtr> & weight_args)572 FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
573                                     const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
574                                     const std::vector<AnfNodePtr> &weight_args) {
575   FuncGraphPtr k_child = std::make_shared<FuncGraph>();
576   k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
577 
578   AnfNodePtr weights_node = nullptr;
579   if (weights != nullptr) {
580     weights_node = weights;
581   } else if (!weight_args.empty()) {
582     weights_node = k_child->NewCNodeInOrder(weight_args);
583   }
584 
585   std::vector<AnfNodePtr> inputs;
586   inputs.push_back(k);
587   for (size_t i = 0; i < forward_graph_params.size(); ++i) {
588     inputs.push_back(k_child->add_parameter());
589   }
590   auto k_app = k_child->NewCNodeInOrder(inputs);
591 
592   auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem);
593   auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
594   auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
595 
596   GradByParameter(k_child, f_app, bprop, weights_node, enable_tuple_grad);
597   return k_child;
598 }
599 
600 // Do grad by the parameter of GradOperation.
GradByParameter(const FuncGraphPtr & k_child,const AnfNodePtr & f_app,const AnfNodePtr & bprop,const AnfNodePtr & weights,bool enable_tuple_grad)601 void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
602                                     const AnfNodePtr &weights, bool enable_tuple_grad) {
603   MS_EXCEPTION_IF_NULL(k_child);
604 
605   AnfNodePtr bprop_arg = nullptr;
606   if (sens_param_) {
607     bprop_arg = k_child->add_parameter();
608   } else {
609     auto ones_like = prim::GetPythonOps("ones_like");
610     bprop_arg = k_child->NewCNodeInOrder({NewValueNode(ones_like), f_app});
611   }
612 
613   AnfNodePtr b_app = k_child->NewCNodeInOrder({bprop, bprop_arg});
614 
615   CNodePtr fv_bprop = nullptr;
616   if (get_by_list_) {
617     // python code: grads = hyper_map(F.partial(env_get, env), weights)
618     AnfNodePtr env =
619       k_child->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))});
620     AnfNodePtr partial_env_get =
621       k_child->NewCNodeInOrder({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
622     MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
623     fv_bprop = k_child->NewCNodeInOrder({NewValueNode(hyper_map), partial_env_get, weights});
624   }
625 
626   CNodePtr inputs_bprop = nullptr;
627   if (get_all_) {
628     TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll);
629     inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app});
630   }
631 
632   // Gradients wrt inputs and parameters
633   if (fv_bprop != nullptr && inputs_bprop != nullptr) {
634     k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
635     return;
636   }
637 
638   // Gradients wrt parameters
639   if (fv_bprop != nullptr) {
640     k_child->set_output(fv_bprop);
641     return;
642   }
643 
644   // Gradients wrt inputs
645   if (inputs_bprop != nullptr) {
646     k_child->set_output(inputs_bprop);
647     return;
648   }
649   // Gradients wrt first input.
650   // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
651   // so obtain first input grad by setting tail_type of Tail to kGradFirst.
652   TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
653   tail_grad_first->set_enable_tuple_grad(enable_tuple_grad);
654   k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
655 }
656 
657 // Generate the graph.
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)658 FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
659   if (args_spec_list.empty()) {
660     MS_LOG(EXCEPTION)
661       << "'GradOperation' requires a forward network or function as an input, while the input is empty.";
662   }
663 
664   MS_EXCEPTION_IF_NULL(args_spec_list[0]);
665   AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
666   if (fn == nullptr) {
667     MS_LOG(EXCEPTION) << "'GradOperation' arg0 must be a 'Function' or 'Cell', but got "
668                       << args_spec_list[0]->ToString();
669   }
670 
671   // Waiting for implementation.
672   auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
673   MS_EXCEPTION_IF_NULL(real_fn);
674 
675   FuncGraphPtr forward_graph = real_fn->func_graph();
676   MS_EXCEPTION_IF_NULL(forward_graph);
677   forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
678   FuncGraphPtr grad_fg = nullptr;
679   {
680     TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
681     grad_fg = std::make_shared<FuncGraph>();
682   }
683   auto nparam = forward_graph->parameters().size();
684 
685   std::ostringstream ss;
686   ss << "grad{" << nparam << "}";
687   grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
688   grad_fg->debug_info()->set_name(ss.str());
689   ParameterPtr param_graph = grad_fg->add_parameter();
690 
691   AnfNodePtr weights = nullptr;
692   if (get_by_list_) {
693     weights = grad_fg->add_parameter();
694   }
695 
696   std::vector<AnfNodePtr> inputs;
697   inputs.push_back(NewValueNode(prim::kPrimJ));
698   inputs.push_back(param_graph);
699   auto j = grad_fg->NewCNodeInOrder(inputs);
700   // df is checked in GetGrad
701   FuncGraphPtr k_child = nullptr;
702   {
703     TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
704     k_child = GetGrad(j, weights, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad"));
705   }
706   grad_fg->set_output(NewValueNode(k_child));
707 
708   return grad_fg;
709 }
710 
__anon6b0a12740b02(const py::module *m) 711 REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
712                          (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
713                            *m, "GradOperation_")
714                            .def(py::init<std::string &>(), py::arg("fn"))
715                            .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
716                                 py::arg("get_by_list"), py::arg("sens_param"));
717                        }));
718 
719 // Generate the ListMap func graph.
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)720 FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
721   size_t args_num = args_spec_list.size();
722   // args: fn, list1, list2, ...
723   if (args_num < 2) {
724     MS_LOG(EXCEPTION) << "list_map takes at least two arguments";
725   }
726 
727   for (size_t i = 1; i < args_num; ++i) {
728     if (typeid(args_spec_list[i]) != typeid(AbstractBase)) {
729       // The function currently not be use
730       MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'";
731     }
732   }
733 
734   FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
735   fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
736   fg_ptr->debug_info()->set_name("list_map");
737   AnfNodePtr fn = fg_ptr->add_parameter();
738 
739   std::vector<AnfNodePtr> lists;
740   for (size_t i = 1; i < args_num; ++i) {
741     lists.push_back(fg_ptr->add_parameter());
742   }
743 
744   std::vector<AnfNodePtr> iters;
745   (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
746     return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("list_iter")), item});
747   });
748 
749   std::vector<AnfNodePtr> nexts;
750   (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
751     return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item});
752   });
753 
754   std::vector<AnfNodePtr> values;
755   (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
756     return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item});
757   });
758 
759   (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
760     return fg_ptr->NewCNodeInOrder(
761       {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast<int64_t>(1))});
762   });
763 
764   (void)values.insert(values.begin(), fn);
765   AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values);
766   AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimMakeList), cnode_graph});
767 
768   FuncGraphPtr fgnext_ptr = std::make_shared<FuncGraph>();
769   fgnext_ptr->debug_info()->set_name("body");
770 
771   FuncGraphPtr fgcond_ptr = std::make_shared<FuncGraph>();
772   fgcond_ptr->debug_info()->set_name("cond");
773 
774   MakeCond(lists, fgnext_ptr, fgcond_ptr);
775   MakeNext(lists, fgcond_ptr, fgnext_ptr);
776 
777   CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl});
778 
779   auto inputs = output_cnode->inputs();
780   (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
781   output_cnode->set_inputs(inputs);
782 
783   fg_ptr->set_output(output_cnode);
784   return fg_ptr;
785 }
786 
MakeCond(const std::vector<AnfNodePtr> & lists,const FuncGraphPtr & fgnext_ptr,const FuncGraphPtr & fg_ptr)787 void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
788                        const FuncGraphPtr &fg_ptr) {
789   MS_EXCEPTION_IF_NULL(fg_ptr);
790 
791   AnfNodePtr fn = fg_ptr->add_parameter();
792   AnfNodePtr resl = fg_ptr->add_parameter();
793 
794   std::vector<AnfNodePtr> iters;
795   (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
796                        [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
797 
798   std::vector<AnfNodePtr> hasnexts;
799   (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) {
800     return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("hasnext")), item});
801   });
802 
803   // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
804   FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
805   fgtrue_ptr->debug_info()->set_name("ftrue");
806   fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
807 
808   CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNodeInOrder({NewValueNode(fgnext_ptr), fn, resl});
809   auto inputs = fgtrue_output_cnode->inputs();
810   (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
811   fgtrue_output_cnode->set_inputs(inputs);
812   fgtrue_ptr->set_output(fgtrue_output_cnode);
813 
814   FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
815   fgfalse_ptr->debug_info()->set_name("ffalse");
816   fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
817   fgfalse_ptr->set_output(resl);
818 
819   AnfNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
820                                                      NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)});
821   fgtrue_ptr->set_output(output_cnode);
822 }
823 
MakeNext(const std::vector<AnfNodePtr> & lists,const FuncGraphPtr & fgcond_ptr,const FuncGraphPtr & fg_ptr)824 void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
825                        const FuncGraphPtr &fg_ptr) {
826   MS_EXCEPTION_IF_NULL(fg_ptr);
827   AnfNodePtr fn = fg_ptr->add_parameter();
828 
829   std::vector<AnfNodePtr> iters;
830   (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
831                        [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
832 
833   std::vector<AnfNodePtr> nexts;
834   (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
835     return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item});
836   });
837 
838   std::vector<AnfNodePtr> values;
839   (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
840     return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item, nullptr});
841   });
842 
843   iters.clear();
844   (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
845     return fg_ptr->NewCNodeInOrder(
846       {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast<int64_t>(1))});
847   });
848 
849   (void)values.insert(values.begin(), fn);
850   AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values);
851   AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimListAppend), cnode_graph});
852   CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl});
853 
854   auto inputs = output_cnode->inputs();
855   (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
856   output_cnode->set_inputs(inputs);
857   fg_ptr->set_output(output_cnode);
858 }
859 
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)860 FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
861   // args: tuple1, tuple2
862   abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
863   AbstractBasePtr abs_a = args_spec_list[0];
864   AbstractBasePtr abs_b = args_spec_list[1];
865 
866   abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
867   abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
868   if (a_tuple == nullptr || b_tuple == nullptr) {
869     TypePtrList types;
870     (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
871                          [](const AbstractBasePtr &arg) -> TypePtr {
872                            MS_EXCEPTION_IF_NULL(arg);
873                            return arg->BuildType();
874                          });
875     auto stub = GenerateStubFunc(types);
876     if (stub != nullptr) {
877       MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
878                     << ", function: " << stub->ToString();
879       return stub;
880     }
881     MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple, but " << args_spec_list[0]->ToString() << ", "
882                       << args_spec_list[1]->ToString();
883   }
884 
885   FuncGraphPtr ret = std::make_shared<FuncGraph>();
886   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
887   AnfNodePtr p_tup_a = ret->add_parameter();
888   AnfNodePtr p_tup_b = ret->add_parameter();
889 
890   std::vector<AnfNodePtr> elems;
891   elems.push_back(NewValueNode(prim::kPrimMakeTuple));
892 
893   int64_t tuple_size = SizeToLong(a_tuple->size());
894   for (int64_t i = 0; i < tuple_size; ++i) {
895     elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
896   }
897 
898   tuple_size = SizeToLong(b_tuple->size());
899   for (int64_t i = 0; i < tuple_size; ++i) {
900     elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
901   }
902 
903   ret->set_output(ret->NewCNodeInOrder(elems));
904   return ret;
905 }
906 
GetArgScalarValue(const abstract::AbstractScalarPtr & scalar,const std::string &)907 int64_t GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
908   MS_EXCEPTION_IF_NULL(scalar);
909   return GetValue<int64_t>(scalar->BuildValue());
910 }
911 
GetPositiveIndex(int64_t index,int64_t length)912 int64_t GetPositiveIndex(int64_t index, int64_t length) {
913   if (index < 0) {
914     index += length;
915   }
916   return index;
917 }
918 
CheckSliceMember(const AbstractBasePtr & member,int64_t default_value,const std::string & member_name)919 int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, const std::string &member_name) {
920   MS_EXCEPTION_IF_NULL(member);
921 
922   if (member->isa<AbstractScalar>()) {
923     return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
924   }
925 
926   if (member->isa<AbstractNone>()) {
927     return default_value;
928   }
929 
930   MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString();
931 }
932 
GenerateTupleSliceParameter(const AbstractTuplePtr & tuple,const AbstractSlicePtr & slice,int64_t * start_index,int64_t * stop_index,int64_t * step_value)933 void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int64_t *start_index,
934                                  int64_t *stop_index, int64_t *step_value) {
935   MS_EXCEPTION_IF_NULL(tuple);
936   MS_EXCEPTION_IF_NULL(slice);
937   MS_EXCEPTION_IF_NULL(start_index);
938   MS_EXCEPTION_IF_NULL(stop_index);
939   MS_EXCEPTION_IF_NULL(step_value);
940 
941   const std::string start_name("Slice start index");
942   const std::string stop_name("Slice stop index");
943   const std::string step_name("Slice step value");
944 
945   int64_t tuple_size = SizeToLong(tuple->size());
946   int64_t start_default = 0;
947   int64_t stop_default = tuple_size;
948   int64_t step_default = 1;
949 
950   *step_value = CheckSliceMember(slice->step(), step_default, step_name);
951   if (*step_value == 0) {
952     MS_EXCEPTION(ValueError) << "TupleSlice require the step value could not be 0, but got 0.";
953   }
954 
955   if (*step_value < 0) {
956     start_default = tuple_size - 1;
957     stop_default = -1;
958   }
959 
960   *start_index = CheckSliceMember(slice->start(), start_default, start_name);
961   *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
962 
963   if (*start_index < -tuple_size) *start_index = 0;
964   if (*stop_index > tuple_size) *stop_index = tuple_size;
965   if (*start_index > tuple_size || *stop_index < -tuple_size) {
966     *start_index = 0;
967     *stop_index = 0;
968   }
969 
970   *start_index = GetPositiveIndex(*start_index, tuple_size);
971   if (!slice->stop()->isa<AbstractNone>()) {
972     *stop_index = GetPositiveIndex(*stop_index, tuple_size);
973   }
974 }
975 
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)976 FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
977   // slice a tuple
978   // args: tuple, start index, end index, step
979   const std::string op_name("TupleSlice");
980   constexpr size_t arg_size = 2;
981   abstract::CheckArgsSize(op_name, args_spec_list, arg_size);
982   AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
983   AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
984 
985   int64_t start_index;
986   int64_t stop_index;
987   int64_t step_value;
988   GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
989 
990   FuncGraphPtr ret = std::make_shared<FuncGraph>();
991   ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
992   AnfNodePtr p_tuple = ret->add_parameter();
993   (void)ret->add_parameter();
994 
995   std::vector<AnfNodePtr> elems;
996   elems.push_back(NewValueNode(prim::kPrimMakeTuple));
997   if (step_value > 0) {
998     for (int64_t index = start_index; index < stop_index; index = index + step_value) {
999       elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
1000     }
1001   } else {
1002     for (int64_t index = start_index; index > stop_index; index = index + step_value) {
1003       elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
1004     }
1005   }
1006 
1007   ret->set_output(ret->NewCNodeInOrder(elems));
1008   return ret;
1009 }
1010 
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)1011 FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
1012   // select indexed item
1013   // args: tuple of items, index
1014   const std::string op_name = std::string("TupleGetItemTensor");
1015   const size_t inputs_size = 2;
1016   abstract::CheckArgsSize(op_name, args_spec_list, inputs_size);
1017   auto ret_graph = std::make_shared<FuncGraph>();
1018   ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1019   auto functions = ret_graph->add_parameter();
1020   auto index = ret_graph->add_parameter();
1021 
1022   ret_graph->set_output(ret_graph->NewCNodeInOrder({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
1023   return ret_graph;
1024 }
1025 
__anon6b0a12741702(const py::module *m) 1026 REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
1027                          (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
1028                            .def(py::init<std::string &>());
1029                        }));
1030 
__anon6b0a12741802(const py::module *m) 1031 REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
1032                          (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
1033                            .def(py::init<std::string &>());
1034                        }));
1035 
__anon6b0a12741902(const py::module *m) 1036 REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
1037                          (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
1038                            *m, "TupleGetItemTensor_")
1039                            .def(py::init<std::string &>());
1040                        }));
1041 }  // namespace prim
1042 }  // namespace mindspore
1043