1
2 /**
3 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
4 *
5 * Copyright 2019-2021 Huawei Technologies Co., Ltd
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 #include "frontend/operator/composite/composite.h"
21 #include <algorithm>
22 #include <utility>
23 #include <sstream>
24
25 #include "ir/anf.h"
26 #include "ir/func_graph.h"
27 #include "abstract/abstract_value.h"
28 #include "abstract/abstract_function.h"
29 #include "abstract/dshape.h"
30 #include "abstract/param_validator.h"
31 #include "frontend/operator/cc_implementations.h"
32 #include "frontend/optimizer/opt.h"
33 #include "utils/symbolic.h"
34 #include "pybind_api/api_register.h"
35 #include "ir/signature.h"
36 #include "debug/trace.h"
37 #include "utils/ms_context.h"
38 #include "utils/utils.h"
39
40 namespace mindspore {
41 // namespace to support composite operators definition
42 namespace prim {
43 using AbstractTensor = mindspore::abstract::AbstractTensor;
44 using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
45
46 using mindspore::abstract::AbstractAttribute;
47 using mindspore::abstract::AbstractBase;
48 using mindspore::abstract::AbstractClass;
49 using mindspore::abstract::AbstractDictionary;
50 using mindspore::abstract::AbstractDictionaryPtr;
51 using mindspore::abstract::AbstractEllipsis;
52 using mindspore::abstract::AbstractEllipsisPtr;
53 using mindspore::abstract::AbstractFunction;
54 using mindspore::abstract::AbstractFunctionPtr;
55 using mindspore::abstract::AbstractList;
56 using mindspore::abstract::AbstractNone;
57 using mindspore::abstract::AbstractScalar;
58 using mindspore::abstract::AbstractSlice;
59 using mindspore::abstract::AbstractTuple;
60
61 ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
62 {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod},
63 {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt},
64 {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe},
65 {"__ge__", kPrimScalarGe}};
66
67 ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
68
Init()69 void HyperMap::Init() {
70 if (fn_leaf_) {
71 name_ = "hyper_map[" + fn_leaf_->name() + "]";
72 }
73 signatures_ =
74 // def hypermap(func:read, *args:ref):
75 std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
76 {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
77 }
78
HyperMap(bool reverse,const std::shared_ptr<MultitypeFuncGraph> & fn_leaf)79 HyperMap::HyperMap(bool reverse, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf)
80 : MetaFuncGraph("hyper_map"),
81 fn_leaf_(fn_leaf),
82 reverse_(reverse),
83 broadcast_(false),
84 nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
85 Init();
86 }
87
HyperMap(const HyperMap & h)88 HyperMap::HyperMap(const HyperMap &h)
89 : MetaFuncGraph("hyper_map"),
90 fn_leaf_(h.fn_leaf_),
91 reverse_(h.reverse_),
92 broadcast_(h.broadcast_),
93 nonleaf_(h.nonleaf_) {
94 Init();
95 }
96
FullMake(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)97 AnfNodePtr HyperMap::FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
98 MS_EXCEPTION_IF_NULL(func_graph);
99 std::vector<AnfNodePtr> inputs;
100 if (fn_arg != nullptr) {
101 inputs.push_back(fn_arg);
102 } else {
103 inputs.push_back(NewValueNode(fn_leaf_));
104 }
105
106 (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs),
107 [](const std::pair<AnfNodePtr, Any> &item) { return item.first; });
108 return func_graph->NewCNodeInOrder(inputs);
109 }
110
FullMake(const std::shared_ptr<List> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)111 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
112 const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
113 MS_EXCEPTION_IF_NULL(func_graph);
114 MS_EXCEPTION_IF_NULL(type);
115
116 size_t size = type->elements().size();
117 size_t num = 0;
118 bool is_not_same =
119 std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
120 num++;
121 auto lhs = std::static_pointer_cast<List>(item.second);
122 if (lhs == nullptr) {
123 MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a List, but got "
124 << item.second->ToString();
125 }
126 if (lhs->elements().size() != size) {
127 MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
128 << lhs->elements().size();
129 return true;
130 }
131 return false;
132 });
133 if (is_not_same) {
134 MS_LOG(EXCEPTION) << "List in HyperMap should have same length";
135 }
136
137 // cannot use shared_from_base() also known as this, as it will make a reference cycle on
138 // hypermap and graph generated, it will cause memory leak.
139 auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
140 constexpr size_t kPrimHoldLen = 1;
141 std::vector<AnfNodePtr> inputs;
142 inputs.reserve(size + kPrimHoldLen);
143 inputs.push_back(NewValueNode(prim::kPrimMakeList));
144
145 for (size_t i = 0; i < size; i++) {
146 MS_LOG(DEBUG) << "FullMakeList for the " << i << "th element of the target, reverse_: " << reverse_;
147 std::vector<AnfNodePtr> inputs2;
148 inputs2.push_back(fn_rec);
149 if (fn_arg != nullptr) {
150 inputs2.push_back(fn_arg);
151 }
152 size_t pos = (reverse_ ? (size - 1 - i) : i);
153 (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
154 [&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
155 return func_graph->NewCNodeInOrder(
156 {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
157 });
158
159 auto call_node = func_graph->NewCNodeInOrder(inputs2);
160 if (reverse_) {
161 inputs.insert(inputs.begin() + 1, call_node);
162 } else {
163 inputs.emplace_back(call_node);
164 }
165 }
166 return func_graph->NewCNodeInOrder(inputs);
167 }
168
FullMake(const std::shared_ptr<Tuple> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)169 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
170 const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
171 MS_EXCEPTION_IF_NULL(func_graph);
172 MS_EXCEPTION_IF_NULL(type);
173
174 size_t size = type->elements().size();
175 size_t num = 0;
176 bool is_not_same =
177 std::any_of(arg_map.begin(), arg_map.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
178 num++;
179 auto lhs = std::static_pointer_cast<Tuple>(item.second);
180 if (lhs == nullptr) {
181 MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a Tuple, but got "
182 << item.second->ToString();
183 }
184 if (lhs->elements().size() != size) {
185 MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
186 << lhs->elements().size();
187 return true;
188 }
189 return false;
190 });
191 if (is_not_same) {
192 MS_LOG(EXCEPTION) << "Tuple in HyperMap should have same length";
193 }
194
195 // cannot use shared_from_base() also known as this, as it will make a reference cycle on
196 // hypermap and graph generated, it will cause memory leak.
197 auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
198 constexpr size_t kPrimHoldLen = 1;
199 std::vector<AnfNodePtr> inputs;
200 inputs.reserve(size + kPrimHoldLen);
201 inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
202
203 for (size_t i = 0; i < size; i++) {
204 MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th element of the target, reverse_: " << reverse_;
205 std::vector<AnfNodePtr> inputs2;
206 inputs2.push_back(fn_rec);
207 if (fn_arg != nullptr) {
208 inputs2.push_back(fn_arg);
209 }
210 size_t pos = (reverse_ ? (size - 1 - i) : i);
211 (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs2),
212 [&func_graph, &pos](std::pair<AnfNodePtr, Any> item) {
213 return func_graph->NewCNodeInOrder(
214 {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
215 });
216
217 auto call_node = func_graph->NewCNodeInOrder(inputs2);
218 if (reverse_) {
219 inputs.insert(inputs.begin() + 1, call_node);
220 } else {
221 inputs.emplace_back(call_node);
222 }
223 }
224 return func_graph->NewCNodeInOrder(inputs);
225 }
226
FullMake(const std::shared_ptr<Class> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)227 AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
228 const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
229 MS_EXCEPTION_IF_NULL(type);
230 MS_EXCEPTION_IF_NULL(func_graph);
231
232 std::size_t attrSize = type->GetAttributes().size();
233 constexpr size_t kPrimAndTypeLen = 2;
234 std::vector<AnfNodePtr> inputs;
235 inputs.reserve(attrSize + kPrimAndTypeLen);
236 inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
237 inputs.push_back(NewValueNode(type));
238
239 // cannot use shared_from_base() also known as this, as it will make a reference cycle on
240 // hypermap and graph generated, it will cause memory leak.
241 auto fn_rec = NewValueNode(std::make_shared<HyperMap>(*this));
242 for (std::size_t i = 0; i < attrSize; i++) {
243 MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the target, reverse_: " << reverse_;
244 std::vector<AnfNodePtr> inputs2;
245 inputs2.push_back(fn_rec);
246 if (fn_arg) {
247 inputs2.push_back(fn_arg);
248 }
249
250 size_t size = arg_map.size();
251 for (size_t j = 0; j < size; j++) {
252 size_t pos = (reverse_ ? (size - 1 - j) : j);
253 auto &item = arg_map[pos];
254 inputs2.push_back(
255 func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
256 }
257
258 auto call_node = func_graph->NewCNodeInOrder(inputs2);
259 if (reverse_) {
260 inputs.insert(inputs.begin() + kPrimAndTypeLen, call_node);
261 } else {
262 inputs.emplace_back(call_node);
263 }
264 }
265 return func_graph->NewCNodeInOrder(inputs);
266 }
267
Make(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_map)268 AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) {
269 bool found = false;
270 TypeId id = kObjectTypeEnd;
271 std::pair<AnfNodePtr, TypePtr> pair;
272 for (auto &item : arg_map) {
273 pair = item;
274 id = item.second->type_id();
275 if (nonleaf_.count(id)) {
276 found = true;
277 break;
278 }
279 }
280
281 if (found) {
282 // In a nonleaf situation, all arguments must have the same generic.
283 bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
284 if (item.first != pair.first) {
285 return item.second->type_id() != pair.second->type_id();
286 }
287 return false;
288 });
289 if (is_not_same) {
290 std::ostringstream oss;
291 oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
292 << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
293 int64_t idx = 0;
294 for (auto &item : arg_map) {
295 oss << ++idx << ": " << item.second->ToString() << "\n";
296 }
297 MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str();
298 }
299 }
300
301 switch (id) {
302 case kObjectTypeList: {
303 auto type = std::static_pointer_cast<List>(pair.second);
304 return FullMake(type, func_graph, fn_arg, arg_map);
305 }
306 case kObjectTypeTuple: {
307 auto type = std::static_pointer_cast<Tuple>(pair.second);
308 return FullMake(type, func_graph, fn_arg, arg_map);
309 }
310 case kObjectTypeClass: {
311 auto type = std::static_pointer_cast<Class>(pair.second);
312 return FullMake(type, func_graph, fn_arg, arg_map);
313 }
314 default:
315 return FullMake(func_graph, fn_arg, arg_map);
316 }
317 }
318
Harmonize(const FuncGraphPtr & func_graph,const ArgsPairList & args_spec_list)319 ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) {
320 TypePtr type_tensor = std::make_shared<TensorType>();
321 bool flag = std::any_of(
322 args_spec_list.begin(), args_spec_list.end(),
323 [type_tensor](const std::pair<AnfNodePtr, TypePtr> &item) { return IsSubType(item.second, type_tensor); });
324 if (flag && broadcast_) {
325 ArgsPairList ret;
326 for (auto &item : args_spec_list) {
327 if (!IsSubType(item.second, type_tensor)) {
328 TypePtr type_tensor_ele = std::make_shared<TensorType>(item.second);
329 ret.push_back(std::make_pair(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimScalarToArray), item.first}),
330 type_tensor_ele));
331 } else {
332 ret.push_back(std::make_pair(item.first, item.second));
333 }
334 }
335 return ret;
336 }
337 return args_spec_list;
338 }
339
GenerateFromTypes(const TypePtrList & args_spec_list)340 FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
341 FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
342 ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
343 ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
344 ptr_graph->debug_info()->set_name("hyper_map");
345
346 AnfNodePtr ptrFnArg = nullptr;
347 std::size_t i = 0;
348 ArgsPairList argmap;
349 ArgsPairList argmap2;
350 if (fn_leaf_ == nullptr) {
351 ptrFnArg = ptr_graph->add_parameter();
352 i = 1;
353 }
354
355 std::size_t size = args_spec_list.size();
356 for (; i < size; ++i) {
357 argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
358 }
359
360 argmap2 = Harmonize(ptr_graph, argmap);
361 ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
362 return ptr_graph;
363 }
364
NormalizeArgs(const AbstractBasePtrList & args_spec_list) const365 abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
366 if (fn_leaf_ == nullptr) {
367 if (args_spec_list.empty()) {
368 MS_LOG(EXCEPTION) << "The args spec list is empty.";
369 }
370 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
371 // Assert that hypermap's function param does not contain free variables
372 if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
373 auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
374 auto func_graph = graph_func->func_graph();
375 if (func_graph->parent() != nullptr) {
376 MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet.";
377 }
378 }
379 }
380
381 AbstractBasePtrList broadened;
382 (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
383 [](const AbstractBasePtr &arg) -> AbstractBasePtr {
384 MS_EXCEPTION_IF_NULL(arg);
385 return arg->Broaden();
386 });
387 return broadened;
388 }
389
__anon6b0a12740902(const py::module *m) 390 REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
391 (void)py::class_<HyperMapPy, MetaFuncGraph, std::shared_ptr<HyperMapPy>>(*m, "HyperMap_")
392 .def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"),
393 py::arg("ops"))
394 .def(py::init<bool>(), py::arg("reverse"));
395 }));
396
CheckSequenceAllTensor(const abstract::AbstractTuplePtr & tuple)397 bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
398 for (size_t i = 0; i < tuple->size(); ++i) {
399 if (!(*tuple)[i]->isa<abstract::AbstractUndetermined>() &&
400 !((*tuple)[i]->isa<abstract::AbstractTuple>() &&
401 CheckSequenceAllTensor((*tuple)[i]->cast<abstract::AbstractTuplePtr>()))) {
402 return false;
403 }
404 }
405 return true;
406 }
407
CheckTailGradFristSequence(const abstract::AbstractSequeuePtr & sequeue,bool enable_tuple_grad)408 bool CheckTailGradFristSequence(const abstract::AbstractSequeuePtr &sequeue, bool enable_tuple_grad) {
409 return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
410 ((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
411 (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
412 (*sequeue)[1]->BuildType()->isa<Number>()) ||
413 ((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
414 CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
415 }
416
GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & sequeue) const417 FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
418 MS_EXCEPTION_IF_NULL(sequeue);
419
420 FuncGraphPtr ret = std::make_shared<FuncGraph>();
421 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
422 ret->debug_info()->set_name("tail");
423 AnfNodePtr ptrTup = ret->add_parameter();
424
425 std::vector<AnfNodePtr> elems;
426 PrimitivePtr op = nullptr;
427 if (sequeue->isa<AbstractTuple>()) {
428 elems.push_back(NewValueNode(prim::kPrimMakeTuple));
429 op = prim::kPrimTupleGetItem;
430 } else {
431 elems.push_back(NewValueNode(prim::kPrimMakeList));
432 op = prim::kPrimListGetItem;
433 }
434
435 if (tail_type_ == kGradFirst) {
436 if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
437 ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
438 } else {
439 ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
440 }
441
442 return ret;
443 }
444
445 for (size_t i = 1; i < sequeue->size(); ++i) {
446 if (tail_type_ == kGradAll) {
447 MS_EXCEPTION_IF_NULL((*sequeue)[i]);
448 if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() ||
449 (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[i]->BuildType() != nullptr &&
450 (*sequeue)[i]->BuildType()->isa<Number>())) {
451 elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
452 }
453 } else {
454 elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))}));
455 }
456 }
457
458 ret->set_output(ret->NewCNodeInOrder(elems));
459 return ret;
460 }
461
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)462 FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
463 if (args_spec_list.size() != 1) {
464 MS_LOG(EXCEPTION) << "Tail requires a non-empty tuple.";
465 }
466
467 AbstractBasePtr a = args_spec_list[0];
468 if (a->isa<AbstractTuple>() || a->isa<AbstractList>()) {
469 return GenerateSequeueFuncGraph(a->cast<abstract::AbstractSequeuePtr>());
470 }
471
472 MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString();
473 }
474
475 REGISTER_PYBIND_DEFINE(
__anon6b0a12740a02(const py::module *m) 476 Tail_, ([](const py::module *m) {
477 (void)py::class_<Tail, MetaFuncGraph, std::shared_ptr<Tail>>(*m, "Tail_").def(py::init<std::string &>());
478 }));
479
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)480 FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
481 int64_t tuple_size = SizeToLong(args_spec_list.size());
482
483 std::ostringstream ss;
484 ss << "▶make_tuple_" << tuple_size;
485 FuncGraphPtr fg = std::make_shared<FuncGraph>();
486 fg->debug_info()->set_name(ss.str());
487
488 std::vector<AnfNodePtr> params;
489 params.push_back(NewValueNode(prim::kPrimMakeTuple));
490 for (int64_t i = 0; i < tuple_size; ++i) {
491 params.push_back(fg->add_parameter());
492 }
493
494 // make fprob first result, maketuple's forward result.
495 AnfNodePtr out = fg->NewCNodeInOrder(params);
496
497 // make fprob second result, maketuple's backward function.
498 FuncGraphPtr b = std::make_shared<FuncGraph>();
499
500 ss.clear();
501 ss << "◀make_tuple_" << tuple_size;
502 b->debug_info()->set_name(ss.str());
503 AnfNodePtr dout = b->add_parameter();
504
505 std::vector<AnfNodePtr> grads;
506 grads.push_back(NewValueNode(prim::kPrimMakeTuple));
507 grads.push_back(NewValueNode(newenv));
508 for (int64_t i = 0; i < tuple_size; ++i) {
509 grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
510 }
511
512 b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
513 b->set_output(b->NewCNodeInOrder(grads));
514
515 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
516 fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
517 (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
518 return fg;
519 }
520
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)521 FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
522 int64_t list_size = SizeToLong(args_spec_list.size());
523
524 std::ostringstream ss;
525 ss << "▶make_list_" << list_size;
526 FuncGraphPtr fg = std::make_shared<FuncGraph>();
527 fg->debug_info()->set_name(ss.str());
528
529 std::vector<AnfNodePtr> params;
530 params.push_back(NewValueNode(prim::kPrimMakeList));
531 for (int64_t i = 0; i < list_size; ++i) {
532 params.push_back(fg->add_parameter());
533 }
534
535 // make fprob first result, maketuple's forward result.
536 AnfNodePtr out = fg->NewCNodeInOrder(params);
537
538 // make fprob second result, maketuple's backward function.
539 FuncGraphPtr b = std::make_shared<FuncGraph>();
540
541 ss.clear();
542 ss << "◀make_list_" << list_size;
543 b->debug_info()->set_name(ss.str());
544 AnfNodePtr dout = b->add_parameter();
545
546 std::vector<AnfNodePtr> grads;
547 grads.push_back(NewValueNode(prim::kPrimMakeTuple));
548 grads.push_back(NewValueNode(newenv));
549 for (int64_t i = 0; i < list_size; ++i) {
550 grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
551 }
552
553 b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
554 b->set_output(b->NewCNodeInOrder(grads));
555
556 fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
557 fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
558 (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
559 return fg;
560 }
561
GradOperation(const std::string & name,bool get_all,bool get_by_list,bool sens_param)562 GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
563 : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
564 if (get_by_list) {
565 signatures_ =
566 // def grad(func:read, weight_list:ref):
567 std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
568 {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}});
569 }
570 }
571
GetGrad(const AnfNodePtr & k,const AnfNodePtr & weights,const std::vector<AnfNodePtr> & forward_graph_params,bool enable_tuple_grad,const std::vector<AnfNodePtr> & weight_args)572 FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
573 const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
574 const std::vector<AnfNodePtr> &weight_args) {
575 FuncGraphPtr k_child = std::make_shared<FuncGraph>();
576 k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
577
578 AnfNodePtr weights_node = nullptr;
579 if (weights != nullptr) {
580 weights_node = weights;
581 } else if (!weight_args.empty()) {
582 weights_node = k_child->NewCNodeInOrder(weight_args);
583 }
584
585 std::vector<AnfNodePtr> inputs;
586 inputs.push_back(k);
587 for (size_t i = 0; i < forward_graph_params.size(); ++i) {
588 inputs.push_back(k_child->add_parameter());
589 }
590 auto k_app = k_child->NewCNodeInOrder(inputs);
591
592 auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem);
593 auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
594 auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
595
596 GradByParameter(k_child, f_app, bprop, weights_node, enable_tuple_grad);
597 return k_child;
598 }
599
600 // Do grad by the parameter of GradOperation.
GradByParameter(const FuncGraphPtr & k_child,const AnfNodePtr & f_app,const AnfNodePtr & bprop,const AnfNodePtr & weights,bool enable_tuple_grad)601 void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
602 const AnfNodePtr &weights, bool enable_tuple_grad) {
603 MS_EXCEPTION_IF_NULL(k_child);
604
605 AnfNodePtr bprop_arg = nullptr;
606 if (sens_param_) {
607 bprop_arg = k_child->add_parameter();
608 } else {
609 auto ones_like = prim::GetPythonOps("ones_like");
610 bprop_arg = k_child->NewCNodeInOrder({NewValueNode(ones_like), f_app});
611 }
612
613 AnfNodePtr b_app = k_child->NewCNodeInOrder({bprop, bprop_arg});
614
615 CNodePtr fv_bprop = nullptr;
616 if (get_by_list_) {
617 // python code: grads = hyper_map(F.partial(env_get, env), weights)
618 AnfNodePtr env =
619 k_child->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))});
620 AnfNodePtr partial_env_get =
621 k_child->NewCNodeInOrder({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
622 MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
623 fv_bprop = k_child->NewCNodeInOrder({NewValueNode(hyper_map), partial_env_get, weights});
624 }
625
626 CNodePtr inputs_bprop = nullptr;
627 if (get_all_) {
628 TailPtr tail_grad_all = std::make_shared<Tail>("tail_grad_all", kGradAll);
629 inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app});
630 }
631
632 // Gradients wrt inputs and parameters
633 if (fv_bprop != nullptr && inputs_bprop != nullptr) {
634 k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
635 return;
636 }
637
638 // Gradients wrt parameters
639 if (fv_bprop != nullptr) {
640 k_child->set_output(fv_bprop);
641 return;
642 }
643
644 // Gradients wrt inputs
645 if (inputs_bprop != nullptr) {
646 k_child->set_output(inputs_bprop);
647 return;
648 }
649 // Gradients wrt first input.
650 // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
651 // so obtain first input grad by setting tail_type of Tail to kGradFirst.
652 TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
653 tail_grad_first->set_enable_tuple_grad(enable_tuple_grad);
654 k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
655 }
656
657 // Generate the graph.
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)658 FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
659 if (args_spec_list.empty()) {
660 MS_LOG(EXCEPTION)
661 << "'GradOperation' requires a forward network or function as an input, while the input is empty.";
662 }
663
664 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
665 AbstractFunctionPtr fn = dyn_cast<AbstractFunction>(args_spec_list[0]);
666 if (fn == nullptr) {
667 MS_LOG(EXCEPTION) << "'GradOperation' arg0 must be a 'Function' or 'Cell', but got "
668 << args_spec_list[0]->ToString();
669 }
670
671 // Waiting for implementation.
672 auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
673 MS_EXCEPTION_IF_NULL(real_fn);
674
675 FuncGraphPtr forward_graph = real_fn->func_graph();
676 MS_EXCEPTION_IF_NULL(forward_graph);
677 forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
678 FuncGraphPtr grad_fg = nullptr;
679 {
680 TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
681 grad_fg = std::make_shared<FuncGraph>();
682 }
683 auto nparam = forward_graph->parameters().size();
684
685 std::ostringstream ss;
686 ss << "grad{" << nparam << "}";
687 grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
688 grad_fg->debug_info()->set_name(ss.str());
689 ParameterPtr param_graph = grad_fg->add_parameter();
690
691 AnfNodePtr weights = nullptr;
692 if (get_by_list_) {
693 weights = grad_fg->add_parameter();
694 }
695
696 std::vector<AnfNodePtr> inputs;
697 inputs.push_back(NewValueNode(prim::kPrimJ));
698 inputs.push_back(param_graph);
699 auto j = grad_fg->NewCNodeInOrder(inputs);
700 // df is checked in GetGrad
701 FuncGraphPtr k_child = nullptr;
702 {
703 TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
704 k_child = GetGrad(j, weights, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad"));
705 }
706 grad_fg->set_output(NewValueNode(k_child));
707
708 return grad_fg;
709 }
710
__anon6b0a12740b02(const py::module *m) 711 REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
712 (void)py::class_<GradOperation, MetaFuncGraph, std::shared_ptr<GradOperation>>(
713 *m, "GradOperation_")
714 .def(py::init<std::string &>(), py::arg("fn"))
715 .def(py::init<std::string &, bool, bool, bool>(), py::arg("fn"), py::arg("get_all"),
716 py::arg("get_by_list"), py::arg("sens_param"));
717 }));
718
719 // Generate the ListMap func graph.
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)720 FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
721 size_t args_num = args_spec_list.size();
722 // args: fn, list1, list2, ...
723 if (args_num < 2) {
724 MS_LOG(EXCEPTION) << "list_map takes at least two arguments";
725 }
726
727 for (size_t i = 1; i < args_num; ++i) {
728 if (typeid(args_spec_list[i]) != typeid(AbstractBase)) {
729 // The function currently not be use
730 MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'";
731 }
732 }
733
734 FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
735 fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
736 fg_ptr->debug_info()->set_name("list_map");
737 AnfNodePtr fn = fg_ptr->add_parameter();
738
739 std::vector<AnfNodePtr> lists;
740 for (size_t i = 1; i < args_num; ++i) {
741 lists.push_back(fg_ptr->add_parameter());
742 }
743
744 std::vector<AnfNodePtr> iters;
745 (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
746 return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("list_iter")), item});
747 });
748
749 std::vector<AnfNodePtr> nexts;
750 (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
751 return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item});
752 });
753
754 std::vector<AnfNodePtr> values;
755 (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
756 return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item});
757 });
758
759 (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
760 return fg_ptr->NewCNodeInOrder(
761 {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast<int64_t>(1))});
762 });
763
764 (void)values.insert(values.begin(), fn);
765 AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values);
766 AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimMakeList), cnode_graph});
767
768 FuncGraphPtr fgnext_ptr = std::make_shared<FuncGraph>();
769 fgnext_ptr->debug_info()->set_name("body");
770
771 FuncGraphPtr fgcond_ptr = std::make_shared<FuncGraph>();
772 fgcond_ptr->debug_info()->set_name("cond");
773
774 MakeCond(lists, fgnext_ptr, fgcond_ptr);
775 MakeNext(lists, fgcond_ptr, fgnext_ptr);
776
777 CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl});
778
779 auto inputs = output_cnode->inputs();
780 (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
781 output_cnode->set_inputs(inputs);
782
783 fg_ptr->set_output(output_cnode);
784 return fg_ptr;
785 }
786
MakeCond(const std::vector<AnfNodePtr> & lists,const FuncGraphPtr & fgnext_ptr,const FuncGraphPtr & fg_ptr)787 void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgnext_ptr,
788 const FuncGraphPtr &fg_ptr) {
789 MS_EXCEPTION_IF_NULL(fg_ptr);
790
791 AnfNodePtr fn = fg_ptr->add_parameter();
792 AnfNodePtr resl = fg_ptr->add_parameter();
793
794 std::vector<AnfNodePtr> iters;
795 (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
796 [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
797
798 std::vector<AnfNodePtr> hasnexts;
799 (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) {
800 return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("hasnext")), item});
801 });
802
803 // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
804 FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
805 fgtrue_ptr->debug_info()->set_name("ftrue");
806 fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
807
808 CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNodeInOrder({NewValueNode(fgnext_ptr), fn, resl});
809 auto inputs = fgtrue_output_cnode->inputs();
810 (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
811 fgtrue_output_cnode->set_inputs(inputs);
812 fgtrue_ptr->set_output(fgtrue_output_cnode);
813
814 FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
815 fgfalse_ptr->debug_info()->set_name("ffalse");
816 fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
817 fgfalse_ptr->set_output(resl);
818
819 AnfNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
820 NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)});
821 fgtrue_ptr->set_output(output_cnode);
822 }
823
MakeNext(const std::vector<AnfNodePtr> & lists,const FuncGraphPtr & fgcond_ptr,const FuncGraphPtr & fg_ptr)824 void ListMap::MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &fgcond_ptr,
825 const FuncGraphPtr &fg_ptr) {
826 MS_EXCEPTION_IF_NULL(fg_ptr);
827 AnfNodePtr fn = fg_ptr->add_parameter();
828
829 std::vector<AnfNodePtr> iters;
830 (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters),
831 [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); });
832
833 std::vector<AnfNodePtr> nexts;
834 (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) {
835 return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item});
836 });
837
838 std::vector<AnfNodePtr> values;
839 (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) {
840 return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item, nullptr});
841 });
842
843 iters.clear();
844 (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) {
845 return fg_ptr->NewCNodeInOrder(
846 {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast<int64_t>(1))});
847 });
848
849 (void)values.insert(values.begin(), fn);
850 AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values);
851 AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimListAppend), cnode_graph});
852 CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl});
853
854 auto inputs = output_cnode->inputs();
855 (void)inputs.insert(inputs.end(), iters.begin(), iters.end());
856 output_cnode->set_inputs(inputs);
857 fg_ptr->set_output(output_cnode);
858 }
859
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)860 FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
861 // args: tuple1, tuple2
862 abstract::CheckArgsSize("TupleAdd", args_spec_list, 2);
863 AbstractBasePtr abs_a = args_spec_list[0];
864 AbstractBasePtr abs_b = args_spec_list[1];
865
866 abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
867 abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
868 if (a_tuple == nullptr || b_tuple == nullptr) {
869 TypePtrList types;
870 (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
871 [](const AbstractBasePtr &arg) -> TypePtr {
872 MS_EXCEPTION_IF_NULL(arg);
873 return arg->BuildType();
874 });
875 auto stub = GenerateStubFunc(types);
876 if (stub != nullptr) {
877 MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
878 << ", function: " << stub->ToString();
879 return stub;
880 }
881 MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple, but " << args_spec_list[0]->ToString() << ", "
882 << args_spec_list[1]->ToString();
883 }
884
885 FuncGraphPtr ret = std::make_shared<FuncGraph>();
886 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
887 AnfNodePtr p_tup_a = ret->add_parameter();
888 AnfNodePtr p_tup_b = ret->add_parameter();
889
890 std::vector<AnfNodePtr> elems;
891 elems.push_back(NewValueNode(prim::kPrimMakeTuple));
892
893 int64_t tuple_size = SizeToLong(a_tuple->size());
894 for (int64_t i = 0; i < tuple_size; ++i) {
895 elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)}));
896 }
897
898 tuple_size = SizeToLong(b_tuple->size());
899 for (int64_t i = 0; i < tuple_size; ++i) {
900 elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)}));
901 }
902
903 ret->set_output(ret->NewCNodeInOrder(elems));
904 return ret;
905 }
906
GetArgScalarValue(const abstract::AbstractScalarPtr & scalar,const std::string &)907 int64_t GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) {
908 MS_EXCEPTION_IF_NULL(scalar);
909 return GetValue<int64_t>(scalar->BuildValue());
910 }
911
GetPositiveIndex(int64_t index,int64_t length)912 int64_t GetPositiveIndex(int64_t index, int64_t length) {
913 if (index < 0) {
914 index += length;
915 }
916 return index;
917 }
918
CheckSliceMember(const AbstractBasePtr & member,int64_t default_value,const std::string & member_name)919 int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, const std::string &member_name) {
920 MS_EXCEPTION_IF_NULL(member);
921
922 if (member->isa<AbstractScalar>()) {
923 return GetArgScalarValue(dyn_cast<AbstractScalar>(member), member_name);
924 }
925
926 if (member->isa<AbstractNone>()) {
927 return default_value;
928 }
929
930 MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString();
931 }
932
GenerateTupleSliceParameter(const AbstractTuplePtr & tuple,const AbstractSlicePtr & slice,int64_t * start_index,int64_t * stop_index,int64_t * step_value)933 void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int64_t *start_index,
934 int64_t *stop_index, int64_t *step_value) {
935 MS_EXCEPTION_IF_NULL(tuple);
936 MS_EXCEPTION_IF_NULL(slice);
937 MS_EXCEPTION_IF_NULL(start_index);
938 MS_EXCEPTION_IF_NULL(stop_index);
939 MS_EXCEPTION_IF_NULL(step_value);
940
941 const std::string start_name("Slice start index");
942 const std::string stop_name("Slice stop index");
943 const std::string step_name("Slice step value");
944
945 int64_t tuple_size = SizeToLong(tuple->size());
946 int64_t start_default = 0;
947 int64_t stop_default = tuple_size;
948 int64_t step_default = 1;
949
950 *step_value = CheckSliceMember(slice->step(), step_default, step_name);
951 if (*step_value == 0) {
952 MS_EXCEPTION(ValueError) << "TupleSlice require the step value could not be 0, but got 0.";
953 }
954
955 if (*step_value < 0) {
956 start_default = tuple_size - 1;
957 stop_default = -1;
958 }
959
960 *start_index = CheckSliceMember(slice->start(), start_default, start_name);
961 *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
962
963 if (*start_index < -tuple_size) *start_index = 0;
964 if (*stop_index > tuple_size) *stop_index = tuple_size;
965 if (*start_index > tuple_size || *stop_index < -tuple_size) {
966 *start_index = 0;
967 *stop_index = 0;
968 }
969
970 *start_index = GetPositiveIndex(*start_index, tuple_size);
971 if (!slice->stop()->isa<AbstractNone>()) {
972 *stop_index = GetPositiveIndex(*stop_index, tuple_size);
973 }
974 }
975
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)976 FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
977 // slice a tuple
978 // args: tuple, start index, end index, step
979 const std::string op_name("TupleSlice");
980 constexpr size_t arg_size = 2;
981 abstract::CheckArgsSize(op_name, args_spec_list, arg_size);
982 AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
983 AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
984
985 int64_t start_index;
986 int64_t stop_index;
987 int64_t step_value;
988 GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
989
990 FuncGraphPtr ret = std::make_shared<FuncGraph>();
991 ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
992 AnfNodePtr p_tuple = ret->add_parameter();
993 (void)ret->add_parameter();
994
995 std::vector<AnfNodePtr> elems;
996 elems.push_back(NewValueNode(prim::kPrimMakeTuple));
997 if (step_value > 0) {
998 for (int64_t index = start_index; index < stop_index; index = index + step_value) {
999 elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
1000 }
1001 } else {
1002 for (int64_t index = start_index; index > stop_index; index = index + step_value) {
1003 elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
1004 }
1005 }
1006
1007 ret->set_output(ret->NewCNodeInOrder(elems));
1008 return ret;
1009 }
1010
GenerateFuncGraph(const AbstractBasePtrList & args_spec_list)1011 FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
1012 // select indexed item
1013 // args: tuple of items, index
1014 const std::string op_name = std::string("TupleGetItemTensor");
1015 const size_t inputs_size = 2;
1016 abstract::CheckArgsSize(op_name, args_spec_list, inputs_size);
1017 auto ret_graph = std::make_shared<FuncGraph>();
1018 ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
1019 auto functions = ret_graph->add_parameter();
1020 auto index = ret_graph->add_parameter();
1021
1022 ret_graph->set_output(ret_graph->NewCNodeInOrder({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
1023 return ret_graph;
1024 }
1025
__anon6b0a12741702(const py::module *m) 1026 REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
1027 (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
1028 .def(py::init<std::string &>());
1029 }));
1030
__anon6b0a12741802(const py::module *m) 1031 REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
1032 (void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
1033 .def(py::init<std::string &>());
1034 }));
1035
__anon6b0a12741902(const py::module *m) 1036 REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
1037 (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
1038 *m, "TupleGetItemTensor_")
1039 .def(py::init<std::string &>());
1040 }));
1041 } // namespace prim
1042 } // namespace mindspore
1043