1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "frontend/optimizer/clean.h"
20 #include <string>
21 #include <vector>
22 #include <algorithm>
23 #include "debug/trace.h"
24 #include "frontend/operator/composite/composite.h"
25 #include "pipeline/jit/parse/resolve.h"
26
27 namespace mindspore {
28 /* namespace to support opt */
29 namespace opt {
30 using mindspore::abstract::AbstractAttribute;
31 using mindspore::abstract::AbstractClass;
32 using mindspore::abstract::AbstractDictionary;
33 using mindspore::abstract::AbstractJTagged;
34 using mindspore::abstract::AbstractList;
35 using mindspore::abstract::AbstractRowTensor;
36 using mindspore::abstract::AbstractScalar;
37 using mindspore::abstract::AbstractSparseTensor;
38 using mindspore::abstract::AbstractTuple;
39 using mindspore::abstract::AbstractUndetermined;
40
CheckInputsSize(size_t actual_size,size_t expect_size,const std::string & op_name)41 inline void CheckInputsSize(size_t actual_size, size_t expect_size, const std::string &op_name) {
42 if (actual_size != expect_size) {
43 MS_LOG(EXCEPTION) << op_name << " should have " << expect_size << " inputs, but got " << actual_size;
44 }
45 }
46
Reabs(const AbstractBasePtr & t)47 static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
48 if (t == nullptr) {
49 return nullptr;
50 }
51
52 if (t->isa<AbstractClass>()) {
53 auto abs_class = dyn_cast<AbstractClass>(t);
54 AbstractBasePtrList baselist;
55 auto attributes = abs_class->attributes();
56 (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist),
57 [](const AbstractAttribute &item) { return item.second; });
58 return std::make_shared<AbstractTuple>(baselist);
59 }
60 if (t->isa<AbstractDictionary>()) {
61 auto abs_dict = dyn_cast<AbstractDictionary>(t);
62 AbstractBasePtrList baselist;
63 auto elements = abs_dict->elements();
64 (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist),
65 [](const AbstractAttribute &item) { return item.second; });
66 return std::make_shared<AbstractTuple>(baselist);
67 }
68
69 return nullptr;
70 }
71
AdaptAbs(const AbstractBasePtr & t)72 static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
73 if (t == nullptr) {
74 return nullptr;
75 }
76
77 if (t->isa<AbstractList>()) {
78 auto abs_list = dyn_cast<AbstractList>(t);
79 return std::make_shared<AbstractTuple>(abs_list->elements());
80 }
81
82 if (t->isa<AbstractSparseTensor>()) {
83 auto abs_sparse = dyn_cast<AbstractSparseTensor>(t);
84 std::vector<AbstractBasePtr> abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()};
85 return std::make_shared<AbstractTuple>(abstract_list);
86 }
87
88 if (t->isa<AbstractRowTensor>()) {
89 auto abs_row_tensor = dyn_cast<AbstractRowTensor>(t);
90 std::vector<AbstractBasePtr> abstract_list{abs_row_tensor->indices(), abs_row_tensor->values(),
91 abs_row_tensor->dense_shape()};
92 return std::make_shared<AbstractTuple>(abstract_list);
93 }
94
95 return nullptr;
96 }
97
ConvertGetAttrToTupleGetItem(const CNodePtr & node)98 AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
99 MS_EXCEPTION_IF_NULL(node);
100 MS_EXCEPTION_IF_NULL(node->func_graph());
101
102 const auto &inputs = node->inputs();
103 // Inputs should be [getattr, data, attribute]
104 const size_t expect_inputs_size = 3;
105 CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
106
107 constexpr size_t data_index = 1;
108 constexpr size_t attribute_index = 2;
109 AnfNodePtr data = inputs[data_index];
110 AnfNodePtr cons = inputs[attribute_index];
111 MS_EXCEPTION_IF_NULL(data);
112 MS_EXCEPTION_IF_NULL(cons);
113
114 auto dt = data->abstract();
115 if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) {
116 return nullptr;
117 }
118
119 if (!dt->isa<AbstractClass>()) {
120 MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << ".";
121 }
122
123 auto cons_is_str = IsValueNode<StringImm>(cons);
124 auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
125
126 auto ct = dyn_cast<AbstractClass>(dt);
127 const auto &cmap = ct->attributes();
128 int64_t count = 0;
129 for (auto &item : cmap) {
130 if (cons_is_str && item.first == cons_str) {
131 break;
132 }
133 count++;
134 }
135
136 auto idx_c = NewValueNode(count);
137 AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(count));
138 idx_c->set_abstract(aptr);
139
140 return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
141 }
142
ConvertDictGetItemToTupleGetItem(const CNodePtr & node)143 AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
144 MS_EXCEPTION_IF_NULL(node);
145 MS_EXCEPTION_IF_NULL(node->func_graph());
146
147 // Inputs should be [dict_getitem, dict, item]
148 const auto &inputs = node->inputs();
149 const size_t expect_inputs_size = 3;
150 CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
151
152 constexpr size_t data_index = 1;
153 constexpr size_t cons_index = 2;
154 AnfNodePtr data = inputs[data_index];
155 AnfNodePtr cons = inputs[cons_index];
156 MS_EXCEPTION_IF_NULL(data);
157 MS_EXCEPTION_IF_NULL(cons);
158
159 auto dt = data->abstract();
160 MS_EXCEPTION_IF_NULL(dt);
161 if (!dt->isa<abstract::AbstractDictionary>()) {
162 MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name();
163 }
164 auto cons_is_str = IsValueNode<StringImm>(cons);
165 auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
166
167 auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
168 const auto &cmap = ct->elements();
169 int64_t count = 0;
170 for (auto &item : cmap) {
171 if (cons_is_str && item.first == cons_str) {
172 break;
173 }
174 count++;
175 }
176
177 auto idx_c = NewValueNode(count);
178 AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(count));
179 idx_c->set_abstract(aptr);
180 return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c});
181 }
182
ConvertDictSetItemToTupleSetItem(const CNodePtr & node)183 AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
184 MS_EXCEPTION_IF_NULL(node);
185 MS_EXCEPTION_IF_NULL(node->func_graph());
186
187 // Inputs should be [dict_setitem, dict, item, value]
188 const auto &inputs = node->inputs();
189 const size_t expect_inputs_size = 4;
190 CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
191
192 const size_t data_index = 1;
193 const size_t cons_index = 2;
194 const size_t item_value_index = 3;
195 AnfNodePtr data = inputs[data_index];
196 AnfNodePtr cons = inputs[cons_index];
197 AnfNodePtr item_value = inputs[item_value_index];
198 MS_EXCEPTION_IF_NULL(data);
199 MS_EXCEPTION_IF_NULL(cons);
200
201 auto dt = data->abstract();
202 MS_EXCEPTION_IF_NULL(dt);
203 if (!dt->isa<abstract::AbstractDictionary>()) {
204 MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name();
205 }
206 auto cons_is_str = IsValueNode<StringImm>(cons);
207 auto cons_str = cons_is_str ? GetValue<std::string>(GetValueNode(cons)) : "";
208
209 auto ct = dyn_cast<abstract::AbstractDictionary>(dt);
210 const auto &cmap = ct->elements();
211 int64_t count = 0;
212 for (auto &item : cmap) {
213 if (cons_is_str && item.first == cons_str) {
214 break;
215 }
216 count++;
217 }
218 if (LongToSize(count) >= cmap.size()) {
219 // for dictionary set, if the key does not exist, we should create a new item
220 auto tuple_add_op = std::make_shared<prim::TupleAdd>("tuple_add");
221 auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value});
222 return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item});
223 }
224 auto idx_c = NewValueNode(count);
225 AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(count));
226 idx_c->set_abstract(aptr);
227 return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value});
228 }
229
ConvertMakeRecordToMakeTuple(const CNodePtr & node)230 AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) {
231 MS_EXCEPTION_IF_NULL(node);
232 MS_EXCEPTION_IF_NULL(node->func_graph());
233
234 std::vector<AnfNodePtr> inputs;
235 inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
236 // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr;
237 (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end());
238 return node->func_graph()->NewCNode(inputs);
239 }
240
ErasePartialNode(const CNodePtr & node)241 AnfNodePtr ErasePartialNode(const CNodePtr &node) {
242 MS_EXCEPTION_IF_NULL(node);
243 MS_EXCEPTION_IF_NULL(node->func_graph());
244
245 const auto &inputs = node->inputs();
246 // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg;
247 const size_t min_inputs_size = 2;
248 if (inputs.size() < min_inputs_size) {
249 MS_LOG(EXCEPTION) << "Partial should have at least 2 inputs, but got " << inputs.size();
250 }
251
252 std::vector<AnfNodePtr> args(inputs.begin() + 2, inputs.end());
253 auto oper = inputs[1];
254 if (IsPrimitive(oper, prim::kPrimMakeRecord)) {
255 if (args.size() == 1) {
256 return NewValueNode(prim::kPrimMakeTuple);
257 }
258
259 if (args.size() > 1) {
260 std::vector<AnfNodePtr> new_inputs;
261 new_inputs.emplace_back(NewValueNode(prim::kPrimPartial));
262 new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
263 (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end());
264
265 MS_EXCEPTION_IF_NULL(node->func_graph());
266 return node->func_graph()->NewCNode(new_inputs);
267 }
268 }
269 return nullptr;
270 }
271
ConvertMakeListToMakeTuple(const CNodePtr & node)272 AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
273 MS_EXCEPTION_IF_NULL(node);
274 MS_EXCEPTION_IF_NULL(node->func_graph());
275
276 std::vector<AnfNodePtr> inputs;
277 inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
278 // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items;
279 (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
280 return node->func_graph()->NewCNode(inputs);
281 }
282
ConvertListGetItemToTupleGetItem(const CNodePtr & node)283 AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
284 MS_EXCEPTION_IF_NULL(node);
285 MS_EXCEPTION_IF_NULL(node->func_graph());
286
287 const auto &inputs = node->inputs();
288 // Inputs should be [list_getitem, list, item]
289 constexpr size_t expect_input_size = 3;
290 CheckInputsSize(inputs.size(), expect_input_size, GetCNodeFuncName(node));
291 constexpr size_t real_input_index = 1;
292 constexpr size_t index_input_index = 2;
293 AnfNodePtr data = inputs[real_input_index];
294 AnfNodePtr cons = inputs[index_input_index];
295 MS_EXCEPTION_IF_NULL(data);
296 MS_EXCEPTION_IF_NULL(cons);
297
298 auto cons_node = cons->cast<ValueNodePtr>();
299 return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node});
300 }
301
ConvertListSetItemToTupleSetItem(const CNodePtr & node)302 AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
303 MS_EXCEPTION_IF_NULL(node);
304 MS_EXCEPTION_IF_NULL(node->func_graph());
305
306 const auto &inputs = node->inputs();
307 // Inputs should be [list_setitem, list, index, item]
308 const size_t expect_inputs_size = 4;
309 CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
310
311 const size_t data_index = 1;
312 const size_t cons_index = 2;
313 const size_t value_index = 3;
314 AnfNodePtr data = inputs[data_index];
315 AnfNodePtr cons = inputs[cons_index];
316 AnfNodePtr value = inputs[value_index];
317
318 return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
319 }
320
EraseMakeDictNode(const CNodePtr & node)321 AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
322 MS_EXCEPTION_IF_NULL(node);
323 const auto &inputs = node->inputs();
324 const size_t expect_inputs_size = 3;
325 CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
326 return inputs[2];
327 }
328
EraseDictGetValues(const CNodePtr & node)329 AnfNodePtr EraseDictGetValues(const CNodePtr &node) {
330 MS_EXCEPTION_IF_NULL(node);
331 const auto &inputs = node->inputs();
332 const size_t expect_inputs_size = 2;
333 CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
334 return inputs[1];
335 }
336
EraseMakeKeywordArgNode(const CNodePtr & node)337 AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
338 MS_EXCEPTION_IF_NULL(node);
339 const auto &inputs = node->inputs();
340 // Inputs should be [make_keyword_arg, key, value]
341 constexpr size_t expect_input_size = 3;
342 constexpr size_t value_inputs_index = 2;
343 CheckInputsSize(inputs.size(), expect_input_size, GetCNodeFuncName(node));
344 return inputs[value_inputs_index];
345 }
346
EraseExtractKeywordArg(const CNodePtr & node)347 AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) {
348 MS_EXCEPTION_IF_NULL(node);
349 const auto &inputs = node->inputs();
350 // Inputs should be [extract_keyword_arg, arg, key]
351 const size_t expect_inputs_size = 3;
352 CheckInputsSize(inputs.size(), expect_inputs_size, GetCNodeFuncName(node));
353 constexpr size_t key_index = 2;
354 return inputs[key_index];
355 }
356
ConvertValueListToValueTuple(const ValueListPtr & value_list,int64_t depth)357 ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int64_t depth) {
358 const int64_t DEPTH_MAX = 5;
359 if (depth > DEPTH_MAX) {
360 MS_LOG(EXCEPTION) << "List nesting is not allowed more than 6 levels.";
361 }
362 std::vector<ValuePtr> elements;
363 for (const auto &it : value_list->value()) {
364 ValuePtr value = nullptr;
365 if (it->isa<ValueList>()) {
366 value = ConvertValueListToValueTuple(it->cast<ValueListPtr>(), depth + 1);
367 } else {
368 value = it;
369 }
370 elements.push_back(value);
371 }
372 return std::make_shared<ValueTuple>(elements);
373 }
374
ConvertValueListNodeToValueTupleNode(const ValueNodePtr & node)375 AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
376 MS_EXCEPTION_IF_NULL(node);
377 ValuePtr value = node->value();
378 auto value_list = value->cast<ValueListPtr>();
379 MS_EXCEPTION_IF_NULL(value_list);
380 int64_t depth = 0;
381 return std::make_shared<ValueNode>(ConvertValueListToValueTuple(value_list, depth));
382 }
383
384 // Convert class to Tuple
385 // Convert getattr to getitem
386 // Convert make_record to make_tuple
SimplifyDataStructures(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)387 bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
388 MS_EXCEPTION_IF_NULL(manager);
389 manager->AddFuncGraph(root);
390
391 bool changed = false;
392
393 // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
394 AnfNodeSet all_node = manager->all_nodes();
395 for (auto &node : all_node) {
396 MS_EXCEPTION_IF_NULL(node);
397 auto cnode = node->cast<CNodePtr>();
398 AnfNodePtr new_node = nullptr;
399 if (IsValueNode<parse::ClassObject>(node)) {
400 new_node = NewValueNode(prim::kPrimMakeTuple);
401 } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) {
402 new_node = ConvertGetAttrToTupleGetItem(cnode);
403 } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) {
404 new_node = ConvertMakeRecordToMakeTuple(cnode);
405 } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
406 new_node = ErasePartialNode(cnode);
407 } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
408 new_node = ConvertDictGetItemToTupleGetItem(cnode);
409 } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
410 new_node = ConvertDictSetItemToTupleSetItem(cnode);
411 } else if (IsPrimitiveCNode(node, prim::kPrimDictGetValues)) {
412 new_node = EraseDictGetValues(cnode);
413 } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
414 new_node = EraseMakeDictNode(cnode);
415 } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
416 new_node = EraseMakeKeywordArgNode(cnode);
417 } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
418 new_node = EraseExtractKeywordArg(cnode);
419 }
420
421 if (new_node != nullptr) {
422 new_node->set_abstract(node->abstract());
423 MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
424 (void)manager->Replace(node, new_node);
425 changed = true;
426 }
427 }
428
429 for (auto &node : manager->all_nodes()) {
430 auto ret = Reabs(node->abstract());
431 if (ret) {
432 MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
433 << ret->ToString();
434 node->set_abstract(ret);
435 if (ret->cast<abstract::AbstractTuplePtr>()->size() > 0) {
436 changed = true;
437 }
438 }
439 }
440 return changed;
441 }
442
ConvertMakeSparseToMakeTuple(const CNodePtr & node)443 AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) {
444 MS_EXCEPTION_IF_NULL(node);
445 MS_EXCEPTION_IF_NULL(node->func_graph());
446
447 std::vector<AnfNodePtr> inputs;
448 inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
449 // Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items;
450 (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end());
451 return node->func_graph()->NewCNode(inputs);
452 }
453
ConvertSparseGetAttrToTupleGetItem(const CNodePtr & node,const int64_t & index)454 AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node, const int64_t &index) {
455 MS_EXCEPTION_IF_NULL(node);
456 MS_EXCEPTION_IF_NULL(node->func_graph());
457
458 const auto &inputs = node->inputs();
459 // Inputs should be [sparse_getattr, sparse]
460 constexpr size_t expect_input_index = 2;
461 CheckInputsSize(inputs.size(), expect_input_index, GetCNodeFuncName(node));
462 constexpr size_t sparse_index = 1;
463 AnfNodePtr sparse = inputs[sparse_index];
464 MS_EXCEPTION_IF_NULL(sparse);
465 auto cons_node = NewValueNode(index);
466 AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(index));
467 cons_node->set_abstract(aptr);
468
469 return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, cons_node});
470 }
471
CleanAfterOptA(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)472 bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
473 MS_EXCEPTION_IF_NULL(manager);
474 manager->AddFuncGraph(root);
475
476 bool changed = false;
477
478 // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
479 auto all_node = manager->all_nodes();
480 for (auto &node : all_node) {
481 MS_EXCEPTION_IF_NULL(node);
482 auto cnode = node->cast<CNodePtr>();
483 AnfNodePtr new_node = nullptr;
484 if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
485 new_node = ConvertMakeListToMakeTuple(cnode);
486 } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
487 new_node = ConvertListGetItemToTupleGetItem(cnode);
488 } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) {
489 new_node = ConvertListSetItemToTupleSetItem(cnode);
490 } else if (IsValueNode<ValueList>(node)) {
491 new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
492 } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) ||
493 IsPrimitiveCNode(node, prim::kPrimMakeRowTensor)) {
494 new_node = ConvertMakeSparseToMakeTuple(cnode);
495 } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) ||
496 IsPrimitiveCNode(node, prim::kPrimRowTensorGetIndices)) {
497 constexpr int64_t indices_index = 0;
498 new_node = ConvertSparseGetAttrToTupleGetItem(cnode, indices_index);
499 } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) ||
500 IsPrimitiveCNode(node, prim::kPrimRowTensorGetValues)) {
501 constexpr int64_t value_index = 1;
502 new_node = ConvertSparseGetAttrToTupleGetItem(cnode, value_index);
503 } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) ||
504 IsPrimitiveCNode(node, prim::kPrimRowTensorGetDenseShape)) {
505 constexpr int64_t shape_index = 2;
506 new_node = ConvertSparseGetAttrToTupleGetItem(cnode, shape_index);
507 }
508
509 if (new_node != nullptr) {
510 new_node->set_abstract(node->abstract());
511 MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
512 (void)manager->Replace(node, new_node);
513 changed = true;
514 }
515 }
516
517 for (auto &node : manager->all_nodes()) {
518 auto ret = AdaptAbs(node->abstract());
519 if (ret) {
520 MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
521 << ret->ToString();
522 node->set_abstract(ret);
523 changed = true;
524 }
525 }
526 return changed;
527 }
528
529 // expand tuples in graph parameters
ExpandTuplesP(const FuncGraphManagerPtr & mng,const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & params)530 static std::vector<AnfNodePtr> ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph,
531 const std::vector<AnfNodePtr> ¶ms) {
532 MS_EXCEPTION_IF_NULL(mng);
533 MS_EXCEPTION_IF_NULL(func_graph);
534
535 std::vector<AnfNodePtr> new_params;
536 for (const auto ¶m : params) {
537 MS_EXCEPTION_IF_NULL(param);
538 auto param_abs = param->abstract();
539 MS_EXCEPTION_IF_NULL(param_abs);
540
541 if (param_abs->isa<AbstractJTagged>()) {
542 MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info());
543 }
544
545 if (!param_abs->isa<AbstractTuple>()) {
546 new_params.emplace_back(param);
547 continue;
548 }
549
550 std::vector<AnfNodePtr> new_param;
551 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
552 auto abs_tuple = dyn_cast<AbstractTuple>(param_abs);
553 for (auto &elem : abs_tuple->elements()) {
554 auto np = std::make_shared<Parameter>(func_graph);
555 np->set_abstract(elem);
556 new_param.emplace_back(np);
557 }
558 (void)inputs.insert(inputs.end(), new_param.begin(), new_param.end());
559 auto new_tuple = func_graph->NewCNode(inputs);
560 (void)mng->Replace(param, new_tuple);
561
562 auto expand_param = ExpandTuplesP(mng, func_graph, new_param);
563 (void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end());
564 }
565 return new_params;
566 }
567
568 // expand tuples in graph applies
ExpandTuplesC(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & inputs)569 static std::vector<AnfNodePtr> ExpandTuplesC(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs) {
570 MS_EXCEPTION_IF_NULL(graph);
571
572 std::vector<AnfNodePtr> new_inputs;
573 for (const auto &input : inputs) {
574 MS_EXCEPTION_IF_NULL(input);
575
576 auto input_abs = input->abstract();
577 MS_EXCEPTION_IF_NULL(input_abs);
578
579 if (input_abs->isa<AbstractJTagged>()) {
580 auto abstract_tag = dyn_cast<AbstractJTagged>(input_abs);
581 if (abstract_tag->element()->isa<AbstractTuple>()) {
582 MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info());
583 }
584 }
585
586 if (!input_abs->isa<AbstractTuple>()) {
587 new_inputs.emplace_back(input);
588 continue;
589 }
590
591 int64_t idx = 0;
592 std::vector<AnfNodePtr> new_input;
593 auto abs_tuple = dyn_cast<AbstractTuple>(input_abs);
594 for (auto &elem : abs_tuple->elements()) {
595 auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)});
596 AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int64Imm>(idx));
597 constexpr size_t scalar_index = 2;
598 c_node->input(scalar_index)->set_abstract(aptr);
599 c_node->set_abstract(elem);
600 new_input.emplace_back(c_node);
601 idx++;
602 }
603
604 auto expand_tuple = ExpandTuplesC(graph, new_input);
605 (void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end());
606 }
607
608 return new_inputs;
609 }
610
611 // remove most uses of tuples from the graph parameters & apply inputs
612 // tuples that are returned will be kept
613 // tuples in CNode's inputs: AbstractTuple (a, b ,c) -->
614 // CNode("tuple_getitem", (a,b,c), 0)
615 // CNode("tuple_getitem", (a,b,c), 1)
616 // CNode("tuple_getitem", (a,b,c), 2)
617 // tuples in Graph's parameters: AbstractTuple (a, b, c) -->
618 // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c))
619 // cppcheck-suppress unusedFunction
EraseTuple(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)620 void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
621 MS_EXCEPTION_IF_NULL(manager);
622 manager->AddFuncGraph(root);
623
624 // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
625 AnfNodeSet all_node = manager->all_nodes();
626 for (auto &node : all_node) {
627 auto cnode = node->cast<CNodePtr>();
628 if (cnode == nullptr) {
629 continue;
630 }
631
632 const auto &inputs = cnode->inputs();
633
634 // Bypass the first input in inputs as it's fn.
635 if (!IsValueNode<Primitive>(inputs[0])) {
636 std::vector<AnfNodePtr> expand_inputs;
637 (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end());
638
639 auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
640 if (new_inputs != expand_inputs) {
641 std::vector<AnfNodePtr> cnode_inputs{inputs[0]};
642 (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
643
644 MS_EXCEPTION_IF_NULL(node->func_graph());
645 auto new_node = node->func_graph()->NewCNode(cnode_inputs);
646 new_node->set_abstract(node->abstract());
647
648 (void)manager->Replace(node, new_node);
649 }
650 // Bypass the first 2 inputs in inputs as it's [partial, fn].
651 } else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode<Primitive>(inputs[1])) {
652 std::vector<AnfNodePtr> expand_inputs;
653 (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end());
654
655 auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs);
656 if (new_inputs != expand_inputs) {
657 std::vector<AnfNodePtr> cnode_inputs{inputs[0], inputs[1]};
658 (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end());
659
660 MS_EXCEPTION_IF_NULL(cnode->func_graph());
661 auto new_node = cnode->func_graph()->NewCNode(cnode_inputs);
662 new_node->set_abstract(cnode->abstract());
663
664 (void)manager->Replace(node, new_node);
665 }
666 }
667 }
668
669 FuncGraphSet all_graph = manager->func_graphs();
670 for (auto &func_graph : all_graph) {
671 MS_EXCEPTION_IF_NULL(func_graph);
672 auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters());
673 manager->SetParameters(func_graph, expand_p);
674 }
675 }
676 } // namespace opt
677 } // namespace mindspore
678