• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "frontend/optimizer/clean.h"
20 #include <string>
21 #include <vector>
22 #include <algorithm>
23 #include "debug/trace.h"
24 #include "frontend/operator/composite/composite.h"
25 #include "pipeline/jit/parse/resolve.h"
26 
27 namespace mindspore {
28 /* namespace to support opt */
29 namespace opt {
30 using mindspore::abstract::AbstractAttribute;
31 using mindspore::abstract::AbstractClass;
32 using mindspore::abstract::AbstractDictionary;
33 using mindspore::abstract::AbstractJTagged;
34 using mindspore::abstract::AbstractList;
35 using mindspore::abstract::AbstractRowTensor;
36 using mindspore::abstract::AbstractScalar;
37 using mindspore::abstract::AbstractSparseTensor;
38 using mindspore::abstract::AbstractTuple;
39 using mindspore::abstract::AbstractUndetermined;
40 
CheckInputsSize(size_t actual_size,size_t expect_size,const std::string & op_name)41 inline void CheckInputsSize(size_t actual_size, size_t expect_size, const std::string &op_name) {
42   if (actual_size != expect_size) {
43     MS_LOG(EXCEPTION) << op_name << " should have " << expect_size << " inputs, but got " << actual_size;
44   }
45 }
46 
Reabs(const AbstractBasePtr & t)47 static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
48   if (t == nullptr) {
49     return nullptr;
50   }
51 
52   if (t->isa<AbstractClass>()) {
53     auto abs_class = dyn_cast<AbstractClass>(t);
54     AbstractBasePtrList baselist;
55     auto attributes = abs_class->attributes();
56     (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
57                          [](const AbstractAttribute &item) { return item.second; });
58     return std::make_shared<AbstractTuple>(baselist);
59   }
60   if (t->isa<AbstractDictionary>()) {
61     auto abs_dict = dyn_cast<AbstractDictionary>(t);
62     AbstractBasePtrList baselist;
63     auto elements = abs_dict->elements();
64     (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
65                          [](const AbstractAttribute &item) { return item.second; });
66     return std::make_shared<AbstractTuple>(baselist);
67   }
68 
69   return nullptr;
70 }
71 
AdaptAbs(const AbstractBasePtr & t)72 static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
73   if (t == nullptr) {
74     return nullptr;
75   }
76 
77   if (t->isa<AbstractList>()) {
78     auto abs_list = dyn_cast<AbstractList>(t);
79     return std::make_shared<AbstractTuple>(abs_list->elements());
80   }
81 
82   if (t->isa<AbstractSparseTensor>()) {
83     auto abs_sparse = dyn_cast<AbstractSparseTensor>(t);
84     std::vector<AbstractBasePtr> abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()};
85     return std::make_shared<AbstractTuple>(abstract_list);
86   }
87 
88   if (t->isa<AbstractRowTensor>()) {
89     auto abs_row_tensor = dyn_cast<AbstractRowTensor>(t);
90     std::vector<AbstractBasePtr> abstract_list{abs_row_tensor->indices(), abs_row_tensor->values(),
91                                                abs_row_tensor->dense_shape()};
92     return std::make_shared<AbstractTuple>(abstract_list);
93   }
94 
95   return nullptr;
96 }
97 
ConvertGetAttrToTupleGetItem(const CNodePtr & node)98 AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
99   MS_EXCEPTION_IF_NULL(node);
100   MS_EXCEPTION_IF_NULL(node->func_graph());
101 
102   const auto &inputs = node->inputs();
103   // Inputs should be [getattr, data, attribute]
104   const size_t expect_inputs_size = 3;
105   CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
106 
107   constexpr size_t data_index = 1;
108   constexpr size_t attribute_index = 2;
109   AnfNodePtr data = inputs[data_index];
110   AnfNodePtr cons = inputs[attribute_index];
111   MS_EXCEPTION_IF_NULL(data);
112   MS_EXCEPTION_IF_NULL(cons);
113 
114   auto dt = data->abstract();
115   if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) {
116     return nullptr;
117   }
118 
119   if (!dt->isa<AbstractClass>()) {
120     MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << ".";
121   }
122 
123   auto cons_is_str = IsValueNode<StringImm>(cons);
124   auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
125 
126   auto ct = dyn_cast<AbstractClass>(dt);
127   const auto &cmap = ct->attributes();
128   int64_t count = 0;
129   for (auto &item : cmap) {
130     if (cons_is_str && item.first == cons_str) {
131       break;
132     }
133     count++;
134   }
135 
136   auto idx_c = NewValueNode(count);
137   AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(count));
138   idx_c->set_abstract(aptr);
139 
140   return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
141 }
142 
ConvertDictGetItemToTupleGetItem(const CNodePtr & node)143 AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
144   MS_EXCEPTION_IF_NULL(node);
145   MS_EXCEPTION_IF_NULL(node->func_graph());
146 
147   // Inputs should be [dict_getitem, dict, item]
148   const auto &inputs = node->inputs();
149   const size_t expect_inputs_size = 3;
150   CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
151 
152   constexpr size_t data_index = 1;
153   constexpr size_t cons_index = 2;
154   AnfNodePtr data = inputs[data_index];
155   AnfNodePtr cons = inputs[cons_index];
156   MS_EXCEPTION_IF_NULL(data);
157   MS_EXCEPTION_IF_NULL(cons);
158 
159   auto dt = data->abstract();
160   MS_EXCEPTION_IF_NULL(dt);
161   if (!dt->isa<abstract::AbstractDictionary>()) {
162     MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name();
163   }
164   auto cons_is_str = IsValueNode<StringImm>(cons);
165   auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
166 
167   auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
168   const auto &cmap = ct->elements();
169   int64_t count = 0;
170   for (auto &item : cmap) {
171     if (cons_is_str && item.first == cons_str) {
172       break;
173     }
174     count++;
175   }
176 
177   auto idx_c = NewValueNode(count);
178   AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(count));
179   idx_c->set_abstract(aptr);
180   return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
181 }
182 
ConvertDictSetItemToTupleSetItem(const CNodePtr & node)183 AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
184   MS_EXCEPTION_IF_NULL(node);
185   MS_EXCEPTION_IF_NULL(node->func_graph());
186 
187   // Inputs should be [dict_setitem, dict, item, value]
188   const auto &inputs = node->inputs();
189   const size_t expect_inputs_size = 4;
190   CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
191 
192   const size_t data_index = 1;
193   const size_t cons_index = 2;
194   const size_t item_value_index = 3;
195   AnfNodePtr data = inputs[data_index];
196   AnfNodePtr cons = inputs[cons_index];
197   AnfNodePtr item_value = inputs[item_value_index];
198   MS_EXCEPTION_IF_NULL(data);
199   MS_EXCEPTION_IF_NULL(cons);
200 
201   auto dt = data->abstract();
202   MS_EXCEPTION_IF_NULL(dt);
203   if (!dt->isa<abstract::AbstractDictionary>()) {
204     MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name();
205   }
206   auto cons_is_str = IsValueNode<StringImm>(cons);
207   auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
208 
209   auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
210   const auto &cmap = ct->elements();
211   int64_t count = 0;
212   for (auto &item : cmap) {
213     if (cons_is_str && item.first == cons_str) {
214       break;
215     }
216     count++;
217   }
218   if (LongToSize(count) >= cmap.size()) {
219     // for dictionary set, if the key does not exist, we should create a new item
220     auto tuple_add_op = std::make_shared<prim::TupleAdd>("tuple_add");
221     auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value});
222     return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item});
223   }
224   auto idx_c = NewValueNode(count);
225   AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(count));
226   idx_c->set_abstract(aptr);
227   return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value});
228 }
229 
ConvertMakeRecordToMakeTuple(const CNodePtr & node)230 AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
231   MS_EXCEPTION_IF_NULL(node);
232   MS_EXCEPTION_IF_NULL(node->func_graph());
233 
234   std::vector<AnfNodePtr> inputs;
235   inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
236   // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr;
237   (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end());
238   return node->func_graph()->NewCNode(inputs);
239 }
240 
ErasePartialNode(const CNodePtr & node)241 AnfNodePtr ErasePartialNode(const CNodePtr &node) {
242   MS_EXCEPTION_IF_NULL(node);
243   MS_EXCEPTION_IF_NULL(node->func_graph());
244 
245   const auto &inputs = node->inputs();
246   // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
247   const size_t min_inputs_size = 2;
248   if (inputs.size() < min_inputs_size) {
249     MS_LOG(EXCEPTION) << "Partial should have at least 2 inputs, but got " << inputs.size();
250   }
251 
252   std::vector<AnfNodePtr> args(inputs.begin() + 2, inputs.end());
253   auto oper = inputs[1];
254   if (IsPrimitive(oper, prim::kPrimMakeRecord)) {
255     if (args.size() == 1) {
256       return NewValueNode(prim::kPrimMakeTuple);
257     }
258 
259     if (args.size() > 1) {
260       std::vector<AnfNodePtr> new_inputs;
261       new_inputs.emplace_back(NewValueNode(prim::kPrimPartial));
262       new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
263       (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end());
264 
265       MS_EXCEPTION_IF_NULL(node->func_graph());
266       return node->func_graph()->NewCNode(new_inputs);
267     }
268   }
269   return nullptr;
270 }
271 
ConvertMakeListToMakeTuple(const CNodePtr & node)272 AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
273   MS_EXCEPTION_IF_NULL(node);
274   MS_EXCEPTION_IF_NULL(node->func_graph());
275 
276   std::vector<AnfNodePtr> inputs;
277   inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
278   // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items;
279   (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
280   return node->func_graph()->NewCNode(inputs);
281 }
282 
ConvertListGetItemToTupleGetItem(const CNodePtr & node)283 AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
284   MS_EXCEPTION_IF_NULL(node);
285   MS_EXCEPTION_IF_NULL(node->func_graph());
286 
287   const auto &inputs = node->inputs();
288   // Inputs should be [list_getitem, list, item]
289   constexpr size_t expect_input_size = 3;
290   CheckInputsSize(inputs.size(), expect_input_size, GetCNodeFuncName(node));
291   constexpr size_t real_input_index = 1;
292   constexpr size_t index_input_index = 2;
293   AnfNodePtr data = inputs[real_input_index];
294   AnfNodePtr cons = inputs[index_input_index];
295   MS_EXCEPTION_IF_NULL(data);
296   MS_EXCEPTION_IF_NULL(cons);
297 
298   auto cons_node = cons->cast<ValueNodePtr>();
299   return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
300 }
301 
ConvertListSetItemToTupleSetItem(const CNodePtr & node)302 AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
303   MS_EXCEPTION_IF_NULL(node);
304   MS_EXCEPTION_IF_NULL(node->func_graph());
305 
306   const auto &inputs = node->inputs();
307   // Inputs should be [list_setitem, list, index, item]
308   const size_t expect_inputs_size = 4;
309   CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
310 
311   const size_t data_index = 1;
312   const size_t cons_index = 2;
313   const size_t value_index = 3;
314   AnfNodePtr data = inputs[data_index];
315   AnfNodePtr cons = inputs[cons_index];
316   AnfNodePtr value = inputs[value_index];
317 
318   return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
319 }
320 
EraseMakeDictNode(const CNodePtr & node)321 AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
322   MS_EXCEPTION_IF_NULL(node);
323   const auto &inputs = node->inputs();
324   const size_t expect_inputs_size = 3;
325   CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
326   return inputs[2];
327 }
328 
EraseDictGetValues(const CNodePtr & node)329 AnfNodePtr EraseDictGetValues(const CNodePtr &node) {
330   MS_EXCEPTION_IF_NULL(node);
331   const auto &inputs = node->inputs();
332   const size_t expect_inputs_size = 2;
333   CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
334   return inputs[1];
335 }
336 
EraseMakeKeywordArgNode(const CNodePtr & node)337 AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
338   MS_EXCEPTION_IF_NULL(node);
339   const auto &inputs = node->inputs();
340   // Inputs should be [make_keyword_arg, key, value]
341   constexpr size_t expect_input_size = 3;
342   constexpr size_t value_inputs_index = 2;
343   CheckInputsSize(inputs.size(), expect_input_size, GetCNodeFuncName(node));
344   return inputs[value_inputs_index];
345 }
346 
EraseExtractKeywordArg(const CNodePtr & node)347 AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
348   MS_EXCEPTION_IF_NULL(node);
349   const auto &inputs = node->inputs();
350   // Inputs should be [extract_keyword_arg, arg, key]
351   const size_t expect_inputs_size = 3;
352   CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
353   constexpr size_t key_index = 2;
354   return inputs[key_index];
355 }
356 
ConvertValueListToValueTuple(const ValueListPtr & value_list,int64_t depth)357 ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) {
358   const int64_t DEPTH_MAX = 5;
359   if (depth > DEPTH_MAX) {
360     MS_LOG(EXCEPTION) << "List nesting is not allowed more than 6 levels.";
361   }
362   std::vector<ValuePtr> elements;
363   for (const auto &it : value_list->value()) {
364     ValuePtr value = nullptr;
365     if (it->isa<ValueList>()) {
366       value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
367     } else {
368       value = it;
369     }
370     elements.push_back(value);
371   }
372   return std::make_shared<ValueTuple>(elements);
373 }
374 
ConvertValueListNodeToValueTupleNode(const ValueNodePtr & node)375 AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
376   MS_EXCEPTION_IF_NULL(node);
377   ValuePtr value = node->value();
378   auto value_list = value->cast<ValueListPtr>();
379   MS_EXCEPTION_IF_NULL(value_list);
380   int64_t depth = 0;
381   return std::make_shared<ValueNode>(ConvertValueListToValueTuple(value_list, depth));
382 }
383 
384 // Convert class to Tuple
385 // Convert getattr to getitem
386 // Convert make_record to make_tuple
SimplifyDataStructures(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)387 bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
388   MS_EXCEPTION_IF_NULL(manager);
389   manager->AddFuncGraph(root);
390 
391   bool changed = false;
392 
393   // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
394   AnfNodeSet all_node = manager->all_nodes();
395   for (auto &node : all_node) {
396     MS_EXCEPTION_IF_NULL(node);
397     auto cnode = node->cast<CNodePtr>();
398     AnfNodePtr new_node = nullptr;
399     if (IsValueNode<parse::ClassObject>(node)) {
400       new_node = NewValueNode(prim::kPrimMakeTuple);
401     } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) {
402       new_node = ConvertGetAttrToTupleGetItem(cnode);
403     } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) {
404       new_node = ConvertMakeRecordToMakeTuple(cnode);
405     } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
406       new_node = ErasePartialNode(cnode);
407     } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
408       new_node = ConvertDictGetItemToTupleGetItem(cnode);
409     } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
410       new_node = ConvertDictSetItemToTupleSetItem(cnode);
411     } else if (IsPrimitiveCNode(node, prim::kPrimDictGetValues)) {
412       new_node = EraseDictGetValues(cnode);
413     } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
414       new_node = EraseMakeDictNode(cnode);
415     } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
416       new_node = EraseMakeKeywordArgNode(cnode);
417     } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
418       new_node = EraseExtractKeywordArg(cnode);
419     }
420 
421     if (new_node != nullptr) {
422       new_node->set_abstract(node->abstract());
423       MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
424       (void)manager->Replace(node, new_node);
425       changed = true;
426     }
427   }
428 
429   for (auto &node : manager->all_nodes()) {
430     auto ret = Reabs(node->abstract());
431     if (ret) {
432       MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
433                     << ret->ToString();
434       node->set_abstract(ret);
435       if (ret->cast<abstract::AbstractTuplePtr>()->size() > 0) {
436         changed = true;
437       }
438     }
439   }
440   return changed;
441 }
442 
ConvertMakeSparseToMakeTuple(const CNodePtr & node)443 AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) {
444   MS_EXCEPTION_IF_NULL(node);
445   MS_EXCEPTION_IF_NULL(node->func_graph());
446 
447   std::vector<AnfNodePtr> inputs;
448   inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
449   // Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items;
450   (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
451   return node->func_graph()->NewCNode(inputs);
452 }
453 
ConvertSparseGetAttrToTupleGetItem(const CNodePtr & node,const int64_t & index)454 AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node, const int64_t &index) {
455   MS_EXCEPTION_IF_NULL(node);
456   MS_EXCEPTION_IF_NULL(node->func_graph());
457 
458   const auto &inputs = node->inputs();
459   // Inputs should be [sparse_getattr, sparse]
460   constexpr size_t expect_input_index = 2;
461   CheckInputsSize(inputs.size(), expect_input_index, GetCNodeFuncName(node));
462   constexpr size_t sparse_index = 1;
463   AnfNodePtr sparse = inputs[sparse_index];
464   MS_EXCEPTION_IF_NULL(sparse);
465   auto cons_node = NewValueNode(index);
466   AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(index));
467   cons_node->set_abstract(aptr);
468 
469   return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node});
470 }
471 
CleanAfterOptA(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)472 bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
473   MS_EXCEPTION_IF_NULL(manager);
474   manager->AddFuncGraph(root);
475 
476   bool changed = false;
477 
478   // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
479   auto all_node = manager->all_nodes();
480   for (auto &node : all_node) {
481     MS_EXCEPTION_IF_NULL(node);
482     auto cnode = node->cast<CNodePtr>();
483     AnfNodePtr new_node = nullptr;
484     if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
485       new_node = ConvertMakeListToMakeTuple(cnode);
486     } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
487       new_node = ConvertListGetItemToTupleGetItem(cnode);
488     } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) {
489       new_node = ConvertListSetItemToTupleSetItem(cnode);
490     } else if (IsValueNode<ValueList>(node)) {
491       new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
492     } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) ||
493                IsPrimitiveCNode(node, prim::kPrimMakeRowTensor)) {
494       new_node = ConvertMakeSparseToMakeTuple(cnode);
495     } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) ||
496                IsPrimitiveCNode(node, prim::kPrimRowTensorGetIndices)) {
497       constexpr int64_t indices_index = 0;
498       new_node = ConvertSparseGetAttrToTupleGetItem(cnode, indices_index);
499     } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) ||
500                IsPrimitiveCNode(node, prim::kPrimRowTensorGetValues)) {
501       constexpr int64_t value_index = 1;
502       new_node = ConvertSparseGetAttrToTupleGetItem(cnode, value_index);
503     } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) ||
504                IsPrimitiveCNode(node, prim::kPrimRowTensorGetDenseShape)) {
505       constexpr int64_t shape_index = 2;
506       new_node = ConvertSparseGetAttrToTupleGetItem(cnode, shape_index);
507     }
508 
509     if (new_node != nullptr) {
510       new_node->set_abstract(node->abstract());
511       MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
512       (void)manager->Replace(node, new_node);
513       changed = true;
514     }
515   }
516 
517   for (auto &node : manager->all_nodes()) {
518     auto ret = AdaptAbs(node->abstract());
519     if (ret) {
520       MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
521                     << ret->ToString();
522       node->set_abstract(ret);
523       changed = true;
524     }
525   }
526   return changed;
527 }
528 
529 // expand tuples in graph parameters
ExpandTuplesP(const FuncGraphManagerPtr & mng,const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & params)530 static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
531                                              const std::vector<AnfNodePtr> &params) {
532   MS_EXCEPTION_IF_NULL(mng);
533   MS_EXCEPTION_IF_NULL(func_graph);
534 
535   std::vector<AnfNodePtr> new_params;
536   for (const auto &param : params) {
537     MS_EXCEPTION_IF_NULL(param);
538     auto param_abs = param->abstract();
539     MS_EXCEPTION_IF_NULL(param_abs);
540 
541     if (param_abs->isa<AbstractJTagged>()) {
542       MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info());
543     }
544 
545     if (!param_abs->isa<AbstractTuple>()) {
546       new_params.emplace_back(param);
547       continue;
548     }
549 
550     std::vector<AnfNodePtr> new_param;
551     std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
552     auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
553     for (auto &elem : abs_tuple->elements()) {
554       auto np = std::make_shared<Parameter>(func_graph);
555       np->set_abstract(elem);
556       new_param.emplace_back(np);
557     }
558     (void)inputs.insert(inputs.end(), new_param.begin(), new_param.end());
559     auto new_tuple = func_graph->NewCNode(inputs);
560     (void)mng->Replace(param, new_tuple);
561 
562     auto expand_param = ExpandTuplesP(mng, func_graph, new_param);
563     (void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end());
564   }
565   return new_params;
566 }
567 
568 // expand tuples in graph applies
ExpandTuplesC(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & inputs)569 static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
570   MS_EXCEPTION_IF_NULL(graph);
571 
572   std::vector<AnfNodePtr> new_inputs;
573   for (const auto &input : inputs) {
574     MS_EXCEPTION_IF_NULL(input);
575 
576     auto input_abs = input->abstract();
577     MS_EXCEPTION_IF_NULL(input_abs);
578 
579     if (input_abs->isa<AbstractJTagged>()) {
580       auto abstract_tag = dyn_cast<AbstractJTagged>(input_abs);
581       if (abstract_tag->element()->isa<AbstractTuple>()) {
582         MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info());
583       }
584     }
585 
586     if (!input_abs->isa<AbstractTuple>()) {
587       new_inputs.emplace_back(input);
588       continue;
589     }
590 
591     int64_t idx = 0;
592     std::vector<AnfNodePtr> new_input;
593     auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
594     for (auto &elem : abs_tuple->elements()) {
595       auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
596       AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(idx));
597       constexpr size_t scalar_index = 2;
598       c_node->input(scalar_index)->set_abstract(aptr);
599       c_node->set_abstract(elem);
600       new_input.emplace_back(c_node);
601       idx++;
602     }
603 
604     auto expand_tuple = ExpandTuplesC(graph, new_input);
605     (void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end());
606   }
607 
608   return new_inputs;
609 }
610 
611 // remove most uses of tuples from the graph parameters & apply inputs
612 // tuples that are returned will be kept
613 // tuples in CNode's inputs: AbstractTuple (a, b ,c) -->
614 //         CNode("tuple_getitem", (a,b,c), 0)
615 //         CNode("tuple_getitem", (a,b,c), 1)
616 //         CNode("tuple_getitem", (a,b,c), 2)
617 // tuples in Graph's parameters: AbstractTuple (a, b, c) -->
618 //         CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
619 // cppcheck-suppress unusedFunction
EraseTuple(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)620 void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
621   MS_EXCEPTION_IF_NULL(manager);
622   manager->AddFuncGraph(root);
623 
624   // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
625   AnfNodeSet all_node = manager->all_nodes();
626   for (auto &node : all_node) {
627     auto cnode = node->cast<CNodePtr>();
628     if (cnode == nullptr) {
629       continue;
630     }
631 
632     const auto &inputs = cnode->inputs();
633 
634     // Bypass the first input in inputs as it's fn.
635     if (!IsValueNode<Primitive>(inputs[0])) {
636       std::vector<AnfNodePtr> expand_inputs;
637       (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end());
638 
639       auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
640       if (new_inputs != expand_inputs) {
641         std::vector<AnfNodePtr> cnode_inputs{inputs[0]};
642         (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
643 
644         MS_EXCEPTION_IF_NULL(node->func_graph());
645         auto new_node = node->func_graph()->NewCNode(cnode_inputs);
646         new_node->set_abstract(node->abstract());
647 
648         (void)manager->Replace(node, new_node);
649       }
650       // Bypass the first 2 inputs in inputs as it's [partial, fn].
651     } else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode<Primitive>(inputs[1])) {
652       std::vector<AnfNodePtr> expand_inputs;
653       (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end());
654 
655       auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
656       if (new_inputs != expand_inputs) {
657         std::vector<AnfNodePtr> cnode_inputs{inputs[0], inputs[1]};
658         (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
659 
660         MS_EXCEPTION_IF_NULL(cnode->func_graph());
661         auto new_node = cnode->func_graph()->NewCNode(cnode_inputs);
662         new_node->set_abstract(cnode->abstract());
663 
664         (void)manager->Replace(node, new_node);
665       }
666     }
667   }
668 
669   FuncGraphSet all_graph = manager->func_graphs();
670   for (auto &func_graph : all_graph) {
671     MS_EXCEPTION_IF_NULL(func_graph);
672     auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
673     manager->SetParameters(func_graph, expand_p);
674   }
675 }
676 }  // namespace opt
677 }  // namespace mindspore
678