1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "frontend/optimizer/fallback_rewriter.h"
20 #include <iterator>
21 #include <string>
22 #include <algorithm>
23 #include <functional>
24 #include <utility>
25 #include <memory>
26 #include <vector>
27 #include <set>
28 #include <unordered_map>
29 #include "ops/structure_ops.h"
30 #include "ops/sparse_tensor_ops.h"
31 #include "ops/sequence_ops.h"
32 #include "ops/array_ops.h"
33 #include "ops/arithmetic_ops.h"
34 #include "ops/framework_ops.h"
35 #include "ops/auto_generate/gen_ops_primitive.h"
36 #include "ops/op_utils.h"
37 #include "abstract/abstract_value.h"
38 #include "base/base.h"
39 #include "pipeline/jit/ps/debug/trace.h"
40 #include "pipeline/jit/ps/action.h"
41 #include "pipeline/jit/ps/parse/parse_base.h"
42 #include "frontend/optimizer/opt.h"
43 #include "frontend/operator/composite/composite.h"
44 #include "include/common/fallback.h"
45 #include "include/common/utils/convert_utils_py.h"
46 #include "ir/anf.h"
47 #include "ir/value.h"
48 #include "pipeline/jit/ps/fallback.h"
49 #include "pipeline/jit/ps/parse/resolve.h"
50 #include "utils/hash_map.h"
51 #include "utils/anf_utils.h"
52 #include "utils/compile_config.h"
53 #include "utils/check_convert_utils.h"
54 #include "utils/tensor_construct_utils.h"
55
56 namespace mindspore {
57 /* namespace to support opt */
58 namespace opt {
59 using mindspore::abstract::AbstractBase;
60 using mindspore::abstract::AbstractBasePtr;
61 using mindspore::abstract::AbstractDictionary;
62 using mindspore::abstract::AbstractDictionaryPtr;
63 using mindspore::abstract::AbstractElementPair;
64 using mindspore::abstract::AbstractList;
65 using mindspore::abstract::AbstractListPtr;
66 using mindspore::abstract::AbstractRowTensor;
67 using mindspore::abstract::AbstractScalar;
68 using mindspore::abstract::AbstractSequence;
69 using mindspore::abstract::AbstractSequencePtr;
70 using mindspore::abstract::AbstractTuple;
71 using mindspore::abstract::AbstractTuplePtr;
72 using ClassTypePtr = std::shared_ptr<parse::ClassType>;
73 using StringSet = std::set<std::string>;
74 using StringSetPtr = std::shared_ptr<StringSet>;
75
76 constexpr auto kInternalDictSelfStr = "__internal_dict_self__";
77 constexpr auto kInternalDictKeyStr = "__internal_dict_key__";
78 constexpr auto kInternalDictValueStr = "__internal_dict_value__";
79 static const PrimitiveSet inplace_prim_set{prim::kPrimPyExecute, prim::kPrimListInplaceAppend,
80 prim::kPrimListInplaceReverse, prim::kPrimListInplaceExtend,
81 prim::kPrimListInplaceInsert, prim::kPrimListInplacePop,
82 prim::kPrimDictInplaceSetItem};
83 static const PrimitiveSet sequence_getitem_prim_set{prim::kPrimListGetItem, prim::kPrimTupleGetItem,
84 prim::kPrimDictGetItem};
85
86 namespace {
87 static constexpr size_t kMaxSeqRecursiveDepth = 6;
CheckInputsSize(const CNodePtr & cnode,size_t expect_size)88 void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) {
89 if (cnode->size() != expect_size) {
90 std::string op_name = GetCNodeFuncName(cnode);
91 MS_LOG(INTERNAL_EXCEPTION) << op_name << " should have " << expect_size << " inputs, but got " << cnode->size();
92 }
93 }
94
95 template <typename T>
GetAbstract(const AnfNodePtr & node)96 std::shared_ptr<T> GetAbstract(const AnfNodePtr &node) {
97 auto abs = node->abstract();
98 if (abs == nullptr) {
99 return nullptr;
100 }
101 return dyn_cast<T>(abs);
102 }
103
CheckContainsDict(const AbstractBasePtr & abs)104 bool CheckContainsDict(const AbstractBasePtr &abs) {
105 if (abs == nullptr) {
106 return false;
107 }
108 if (abs->isa<AbstractDictionary>()) {
109 return true;
110 }
111 auto from_dict = abs->user_data<bool>("from_dict");
112 if (from_dict != nullptr && *from_dict) {
113 return true;
114 }
115 if (abs->isa<AbstractSequence>()) {
116 auto abs_seq = abs->cast<AbstractSequencePtr>();
117 const auto &elements = abs_seq->elements();
118 if (std::any_of(elements.begin(), elements.end(),
119 [](const AbstractBasePtr &element) { return CheckContainsDict(element); })) {
120 return true;
121 }
122 }
123 return false;
124 }
125
126 // ===========================================================================
127 // BaseRewriter provides a common framework for data struct simplify.
128 // ===========================================================================
129 class BaseRewriter : protected SimpleRewriter {
130 public:
BaseRewriter(const FuncGraphPtr & root_graph,const FuncGraphManagerPtr & manager)131 BaseRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
132 : SimpleRewriter(root_graph, manager) {}
133 ~BaseRewriter() override = default;
134
need_renormalized() const135 bool need_renormalized() const { return need_renormalized_; }
136
set_need_renormalized(bool need_renormalized)137 void set_need_renormalized(bool need_renormalized) { need_renormalized_ = need_renormalized; }
138
Execute()139 virtual bool Execute() {
140 bool changed = Run();
141 if (changed) {
142 UpdateAbstracts();
143 }
144 return changed;
145 }
146
147 protected:
148 virtual AnfNodePtr ConvertPrimitiveCNode(const CNodePtr &cnode) = 0;
149 virtual AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) = 0;
150 virtual AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) = 0;
151
NodeRewrite(const AnfNodePtr & node)152 AnfNodePtr NodeRewrite(const AnfNodePtr &node) override {
153 auto new_node = ConvertNode(node);
154 if (IsPrimitiveCNode(new_node, prim::kPrimPyExecute)) {
155 need_renormalized_ = true;
156 return new_node;
157 }
158 if (new_node != nullptr) {
159 new_node->set_abstract(node->abstract());
160 }
161 return new_node;
162 }
163
ConvertNode(const AnfNodePtr & node)164 AnfNodePtr ConvertNode(const AnfNodePtr &node) {
165 auto cnode = node->cast<CNodePtr>();
166 if (cnode != nullptr) {
167 if (cnode->size() == 0) {
168 return nullptr;
169 }
170 // Call primitive cnode converter.
171 return ConvertPrimitiveCNode(cnode);
172 }
173 auto value_node = node->cast<ValueNodePtr>();
174 if (value_node != nullptr) {
175 const auto &value = value_node->value();
176 if (value == nullptr) {
177 return nullptr;
178 }
179 // Call value node converter.
180 return ConvertValueNode(value_node, value);
181 }
182 return nullptr;
183 }
184
UpdateAbstracts()185 virtual void UpdateAbstracts() {
186 const auto &nodes = manager_->all_nodes();
187 for (const auto &node : nodes) {
188 const auto &abs = node->abstract();
189 if (abs == nullptr) {
190 continue;
191 }
192 bool is_interpret_dict = false;
193 // Do not convert the abstract of Interpret node(AbstractDictionary) to AbstractSequence.
194 if (abs->isa<AbstractDictionary>()) {
195 AbstractDictionaryPtr abs_dict = abs->cast<AbstractDictionaryPtr>();
196 auto &dict_elements = abs_dict->elements();
197 for (auto &element : dict_elements) {
198 TypePtr type = element.second->GetTypeTrack();
199 MS_EXCEPTION_IF_NULL(type);
200 auto value = element.second->BuildValue();
201 MS_EXCEPTION_IF_NULL(value);
202 if (type->type_id() == kMetaTypeExternal && value->isa<parse::InterpretedObject>()) {
203 is_interpret_dict = true;
204 break;
205 }
206 }
207 }
208 if (is_interpret_dict) {
209 continue;
210 }
211 // Call abstract converter.
212 auto new_abs = ConvertAbstract(abs);
213 if (new_abs != nullptr) {
214 node->set_abstract(new_abs);
215 }
216 }
217 }
218
GetElementIndex(const std::vector<AbstractElementPair> & attrs,const AnfNodePtr & name)219 static int64_t GetElementIndex(const std::vector<AbstractElementPair> &attrs, const AnfNodePtr &name) {
220 auto n_attrs = attrs.size();
221 auto name_abstract = GetAbstract<AbstractBase>(name);
222 MS_EXCEPTION_IF_NULL(name_abstract);
223 auto name_value = name_abstract->BuildValue();
224 MS_EXCEPTION_IF_NULL(name_value);
225 for (size_t i = 0; i < n_attrs; ++i) {
226 if (*name_value == *attrs[i].first->BuildValue()) {
227 return SizeToLong(i);
228 }
229 }
230 return SizeToLong(n_attrs);
231 }
232
233 private:
234 bool need_renormalized_{false};
235 };
236
237 // ===========================================================================
238 // BeforeOptARewriter convert ObjectClass, Dictionary to Tuple.
239 // ===========================================================================
240 class BeforeOptARewriter : public BaseRewriter {
241 public:
242 using ThisClass = BeforeOptARewriter;
BeforeOptARewriter(const FuncGraphPtr & root_graph,const FuncGraphManagerPtr & manager)243 BeforeOptARewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
244 : BaseRewriter(root_graph, manager), is_dict_output_(HasDictOutput()), has_dict_inplace_(HasDictInplace()) {}
245 ~BeforeOptARewriter() override = default;
246
Execute()247 bool Execute() override {
248 bool changed = Run();
249 if (changed) {
250 UpdateAbstracts();
251 }
252 ConvertParameter();
253 return changed;
254 }
255
256 protected:
ConvertParameter()257 void ConvertParameter() {
258 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
259 for (const auto ¶ : root_graph_->parameters()) {
260 auto abs = para->abstract();
261 MS_EXCEPTION_IF_NULL(abs);
262 if (abs->isa<abstract::AbstractKeywordArg>()) {
263 auto kw_abs = abs->cast_ptr<abstract::AbstractKeywordArg>();
264 para->set_abstract(kw_abs->get_arg());
265 }
266 // If the dict input is not used in graph, convert it to tuple directly.
267 auto dict_param_not_used =
268 abs->isa<abstract::AbstractDictionary>() && manager_->node_users().find(para) == manager_->node_users().end();
269 if ((!allow_fallback_runtime || !is_dict_output_) && !dict_param_not_used) {
270 continue;
271 }
272 auto new_node_and_abs = ConvertParameterDictAbstract(para, para->abstract());
273 new_node_and_abs.first->set_abstract(new_node_and_abs.second);
274 if (new_node_and_abs.first == para) {
275 continue;
276 }
277 (void)manager_->Replace(para, new_node_and_abs.first);
278 para->set_abstract(new_node_and_abs.second);
279 }
280 }
281
ConvertParameterDictAbstract(const AnfNodePtr & cur_node,const AbstractBasePtr & cur_abs)282 std::pair<AnfNodePtr, AbstractBasePtr> ConvertParameterDictAbstract(const AnfNodePtr &cur_node,
283 const AbstractBasePtr &cur_abs) {
284 MS_EXCEPTION_IF_NULL(cur_abs);
285 auto seq_abs = cur_abs->cast_ptr<AbstractSequence>();
286 if (seq_abs != nullptr) {
287 bool is_tuple = seq_abs->isa<AbstractTuple>();
288 auto seq_prim = is_tuple ? prim::kPrimMakeTuple : prim::kPrimMakeList;
289 std::vector<AnfNodePtr> seq_inputs{NewValueNode(seq_prim)};
290 AbstractBasePtrList abs_list;
291 for (size_t i = 0; i < seq_abs->elements().size(); ++i) {
292 auto getitem_prim = is_tuple ? prim::kPrimTupleGetItem : prim::kPrimListGetItem;
293 auto next_node =
294 root_graph_->NewCNodeInOrder({NewValueNode(getitem_prim), cur_node, NewValueNode(SizeToLong(i))});
295 auto node_and_abs = ConvertParameterDictAbstract(next_node, seq_abs->elements()[i]);
296 (void)seq_inputs.emplace_back(node_and_abs.first);
297 (void)abs_list.emplace_back(node_and_abs.second);
298 }
299 if (is_tuple) {
300 return std::make_pair(root_graph_->NewCNodeInOrder(seq_inputs), std::make_shared<AbstractTuple>(abs_list));
301 }
302 return std::make_pair(root_graph_->NewCNodeInOrder(seq_inputs), std::make_shared<AbstractList>(abs_list));
303 }
304 auto dict_abs = cur_abs->cast_ptr<AbstractDictionary>();
305 if (dict_abs != nullptr) {
306 std::vector<AnfNodePtr> key_inputs{NewValueNode(prim::kPrimMakeTuple)};
307 std::vector<AnfNodePtr> value_inputs{NewValueNode(prim::kPrimMakeTuple)};
308 AbstractBasePtrList abs_list;
309 for (size_t i = 0; i < dict_abs->elements().size(); ++i) {
310 auto next_node =
311 root_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), cur_node, NewValueNode(SizeToLong(i))});
312 auto node_and_abs = ConvertParameterDictAbstract(next_node, dict_abs->elements()[i].second);
313 (void)key_inputs.emplace_back(NewValueNode(dict_abs->elements()[i].first->BuildValue()));
314 (void)value_inputs.emplace_back(node_and_abs.first);
315 (void)abs_list.emplace_back(node_and_abs.second);
316 }
317 auto make_dict =
318 root_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), root_graph_->NewCNodeInOrder(key_inputs),
319 root_graph_->NewCNodeInOrder(value_inputs)});
320 return std::make_pair(make_dict, std::make_shared<AbstractTuple>(abs_list));
321 }
322 return std::make_pair(cur_node, cur_abs);
323 }
324
GetStringValue(const AnfNodePtr & node)325 static std::string GetStringValue(const AnfNodePtr &node) {
326 auto str = GetValueNode<StringImmPtr>(node);
327 if (str == nullptr) {
328 return "";
329 }
330 return str->value();
331 }
332
NewTupleGetCNode(const AnfNodePtr & cnode,const AnfNodePtr & data_node,const std::vector<AbstractElementPair> & elements,const AnfNodePtr & name_node)333 static CNodePtr NewTupleGetCNode(const AnfNodePtr &cnode, const AnfNodePtr &data_node,
334 const std::vector<AbstractElementPair> &elements, const AnfNodePtr &name_node) {
335 int64_t index = GetElementIndex(elements, name_node);
336 auto index_node = NewValueNode(index);
337 auto prim_node = NewValueNode(prim::kPrimTupleGetItem);
338 return cnode->func_graph()->NewCNode({prim_node, data_node, index_node});
339 }
340
341 // From:
342 // DictGetItem(data:AbstractDictionary, key:AbstractBase)
343 // To:
344 // TupleGetItem(data, index:Int64Imm)
ConvertDictGetItemToTupleGetItem(const CNodePtr & node) const345 AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) const {
346 MS_EXCEPTION_IF_NULL(node);
347 MS_EXCEPTION_IF_NULL(node->func_graph());
348
349 // Inputs should be [dict_getitem, dict, item]
350 const size_t expect_inputs_size = 3;
351 CheckInputsSize(node, expect_inputs_size);
352
353 constexpr size_t data_index = 1;
354 constexpr size_t key_index = 2;
355 const auto &inputs = node->inputs();
356 auto &data = inputs[data_index];
357 auto &key = inputs[key_index];
358 MS_EXCEPTION_IF_NULL(data);
359 MS_EXCEPTION_IF_NULL(key);
360
361 auto abs_dict = GetAbstract<AbstractDictionary>(data);
362 if (abs_dict == nullptr) {
363 return nullptr;
364 }
365 return NewTupleGetCNode(node, data, abs_dict->elements(), key);
366 }
367
ConvertDictGetItem(const CNodePtr & node) const368 AnfNodePtr ConvertDictGetItem(const CNodePtr &node) const {
369 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
370 if (!allow_fallback_runtime || (!is_dict_output_ && !has_dict_inplace_)) {
371 return ConvertDictGetItemToTupleGetItem(node);
372 }
373 return nullptr;
374 }
375
376 // From:
377 // DictSetItem(data:AbstractDictionary, key:AbstractBase, value)
378 // To:
379 // TupleSetItem(data, index:Int64Imm, value)
380 // Or:
381 // tuple_add(data, value)
ConvertDictSetItemToTupleSetItem(const CNodePtr & node) const382 AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) const {
383 MS_EXCEPTION_IF_NULL(node);
384 MS_EXCEPTION_IF_NULL(node->func_graph());
385
386 // Inputs should be [dict_setitem, dict, item, value]
387 const size_t expect_inputs_size = 4;
388 CheckInputsSize(node, expect_inputs_size);
389
390 const size_t data_index = 1;
391 const size_t cons_index = 2;
392 const size_t item_value_index = 3;
393 const auto &inputs = node->inputs();
394 auto &data = inputs[data_index];
395 auto &key = inputs[cons_index];
396 auto &item_value = inputs[item_value_index];
397 MS_EXCEPTION_IF_NULL(data);
398 MS_EXCEPTION_IF_NULL(key);
399
400 auto abs_dict = GetAbstract<AbstractDictionary>(data);
401 if (abs_dict == nullptr) {
402 return nullptr;
403 }
404 int64_t index = GetElementIndex(abs_dict->elements(), key);
405 auto func_graph = node->func_graph();
406 MS_EXCEPTION_IF_NULL(func_graph);
407 if (index >= static_cast<int64_t>(abs_dict->elements().size())) {
408 // For dictionary set, if the key does not exist, we should create a new item.
409 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
410 for (size_t i = 0; i < abs_dict->elements().size(); ++i) {
411 auto tuple_getitem_i =
412 func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, NewValueNode(SizeToLong(i))});
413 (void)make_tuple_inputs.emplace_back(tuple_getitem_i);
414 }
415 (void)make_tuple_inputs.emplace_back(item_value);
416 auto new_node = func_graph->NewCNode(make_tuple_inputs);
417 new_node->set_debug_info(node->debug_info());
418 return new_node;
419 }
420 auto index_node = NewValueNode(index);
421 auto new_node = func_graph->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, index_node, item_value});
422 new_node->set_debug_info(node->debug_info());
423 return new_node;
424 }
425
HasDictOutput() const426 bool HasDictOutput() const {
427 const AnfNodePtr &output = root_graph_->output();
428 return CheckContainsDict(output->abstract());
429 }
430
HasDictInplace() const431 bool HasDictInplace() const {
432 const auto &all_nodes = manager_->all_nodes();
433 return std::any_of(all_nodes.cbegin(), all_nodes.cend(),
434 [](const auto &node) { return IsPrimitiveCNode(node, prim::kPrimDictInplaceSetItem); });
435 }
436
ConvertDictSetItem(const CNodePtr & node) const437 AnfNodePtr ConvertDictSetItem(const CNodePtr &node) const {
438 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
439 if (!allow_fallback_runtime || ConvertDictToTuple(node, node->func_graph())) {
440 return ConvertDictSetItemToTupleSetItem(node);
441 }
442 return nullptr;
443 }
444
445 // From:
446 // MakeDict(name, input)
447 // To:
448 // input
EraseMakeDictNode(const CNodePtr & node) const449 AnfNodePtr EraseMakeDictNode(const CNodePtr &node) const {
450 MS_EXCEPTION_IF_NULL(node);
451 constexpr size_t expect_inputs_size = 3;
452 constexpr size_t input_index = 2;
453 CheckInputsSize(node, expect_inputs_size);
454 return node->input(input_index);
455 }
456
CheckUserHasPyExecute(const AnfNodePtr & node,const FuncGraphPtr & func) const457 bool CheckUserHasPyExecute(const AnfNodePtr &node, const FuncGraphPtr &func) const {
458 MS_EXCEPTION_IF_NULL(node);
459 MS_EXCEPTION_IF_NULL(func);
460 auto mng = func->manager();
461 auto &users = mng->node_users()[node];
462 for (auto &user : users) {
463 if (IsPrimitiveCNode(user.first, prim::kPrimPyExecute) || IsPrimitiveCNode(user.first, prim::kPrimPyInterpret)) {
464 return true;
465 } else if (IsPrimitiveCNode(user.first, prim::kPrimMakeTuple)) {
466 if (CheckUserHasPyExecute(user.first, user.first->func_graph())) {
467 return true;
468 }
469 }
470 }
471 return false;
472 }
473
CheckDictUserHasFuncGraph(const AnfNodePtr & node,const FuncGraphPtr & func) const474 bool CheckDictUserHasFuncGraph(const AnfNodePtr &node, const FuncGraphPtr &func) const {
475 MS_EXCEPTION_IF_NULL(node);
476 MS_EXCEPTION_IF_NULL(func);
477 if (!IsValueNode<ValueDictionary>(node)) {
478 return false;
479 }
480 auto mng = func->manager();
481 auto &users = mng->node_users()[node];
482 for (auto &user : users) {
483 if (user.first->isa<CNode>()) {
484 auto cnode = user.first->cast<CNodePtr>();
485 auto input = cnode->input(0);
486 if (IsValueNode<FuncGraph>(input)) {
487 return true;
488 }
489 }
490 }
491 return false;
492 }
493
ConvertMakeDict(const CNodePtr & node) const494 AnfNodePtr ConvertMakeDict(const CNodePtr &node) const {
495 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
496 if (!allow_fallback_runtime || ConvertDictToTuple(node, node->func_graph())) {
497 auto new_node = EraseMakeDictNode(node);
498 return new_node;
499 }
500 return nullptr;
501 }
502
503 // From:
504 // DictGetValues(dict:AbstractDictionary)
505 // To:
506 // dict
EraseDictGetValues(const CNodePtr & node) const507 AnfNodePtr EraseDictGetValues(const CNodePtr &node) const {
508 MS_EXCEPTION_IF_NULL(node);
509 constexpr size_t expect_inputs_size = 2;
510 CheckInputsSize(node, expect_inputs_size);
511 auto input = node->input(1);
512 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
513 if (!allow_fallback_runtime || ConvertDictToTuple(node, node->func_graph())) {
514 return input;
515 }
516 auto abs_dict = GetAbstract<AbstractDictionary>(input);
517 if (abs_dict == nullptr) {
518 return nullptr;
519 }
520 const auto &elements = abs_dict->elements();
521 std::vector<AnfNodePtr> new_inputs;
522 new_inputs.reserve(elements.size() + 1);
523 (void)new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
524 auto fg = node->func_graph();
525 MS_EXCEPTION_IF_NULL(fg);
526 for (const auto &element : elements) {
527 MS_EXCEPTION_IF_NULL(element.first->BuildValue());
528 AnfNodePtr value_node =
529 fg->NewCNode({NewValueNode(prim::kPrimDictGetItem), input, NewValueNode(element.first->BuildValue())});
530 (void)new_inputs.emplace_back(value_node);
531 }
532 return fg->NewCNode(std::move(new_inputs));
533 }
534
535 // From:
536 // DictItems(dict:AbstractDictionary)
537 // To:
538 // kPrimMakeList(MakeTuple(key0, TupleGetItem(dict, 0)), ...)
EraseDictItems(const CNodePtr & node) const539 AnfNodePtr EraseDictItems(const CNodePtr &node) const {
540 MS_EXCEPTION_IF_NULL(node);
541 auto fg = node->func_graph();
542 MS_EXCEPTION_IF_NULL(fg);
543 constexpr size_t expect_inputs_size = 2;
544 CheckInputsSize(node, expect_inputs_size);
545
546 const auto &input = node->input(1);
547 auto abs_dict = GetAbstract<AbstractDictionary>(input);
548 if (abs_dict == nullptr) {
549 return nullptr;
550 }
551 const auto &elements = abs_dict->elements();
552 std::vector<AnfNodePtr> new_inputs;
553 new_inputs.reserve(elements.size() + 1);
554 (void)new_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
555 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
556 bool convert_to_tuple = !allow_fallback_runtime || ConvertDictToTuple(node, node->func_graph());
557 for (size_t i = 0; i < elements.size(); ++i) {
558 auto index_node = NewValueNode(static_cast<int64_t>(i));
559 MS_EXCEPTION_IF_NULL(elements[i].first->BuildValue());
560 auto key_node = NewValueNode(elements[i].first->BuildValue());
561 AnfNodePtr value_node;
562 if (convert_to_tuple) {
563 value_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, index_node});
564 } else {
565 value_node =
566 fg->NewCNode({NewValueNode(prim::kPrimDictGetItem), input, NewValueNode(elements[i].first->BuildValue())});
567 }
568 auto tuple_node = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), key_node, value_node});
569 (void)new_inputs.emplace_back(tuple_node);
570 }
571 return fg->NewCNode(std::move(new_inputs));
572 }
573
574 // From:
575 // MakeKeywordArg(key, value)
576 // To:
577 // value
EraseMakeKeywordArgNode(const CNodePtr & node) const578 AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) const {
579 MS_EXCEPTION_IF_NULL(node);
580 // Inputs should be [make_keyword_arg, key, value]
581 constexpr size_t expect_input_size = 3;
582 constexpr size_t value_inputs_index = 2;
583 CheckInputsSize(node, expect_input_size);
584 return node->input(value_inputs_index);
585 }
586
587 // From:
588 // ExtractKeywordArg(arg, key)
589 // To:
590 // key
EraseExtractKeywordArg(const CNodePtr & node) const591 AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) const {
592 MS_EXCEPTION_IF_NULL(node);
593 // Inputs should be [extract_keyword_arg, arg, key]
594 const size_t expect_inputs_size = 3;
595 // Inputs should be [extract_keyword_arg, arg, key, monad]
596 const size_t expect_inputs_has_side_effect_size = 4;
597 if (node->size() != expect_inputs_size && node->size() != expect_inputs_has_side_effect_size) {
598 MS_LOG(INTERNAL_EXCEPTION) << "The extract_keyword_arg should have 3 or 4 inputs, but got " << node->size();
599 }
600 constexpr size_t key_index = 2;
601 return node->input(key_index);
602 }
603
604 using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &) const;
605 using ConverterMap = std::unordered_map<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
606 static inline const ConverterMap converters_{
607 {prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItem},
608 {prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItem},
609 {prim::kPrimDictGetValues, &ThisClass::EraseDictGetValues},
610 {prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
611 {prim::kPrimMakeKeywordArg, &ThisClass::EraseMakeKeywordArgNode},
612 {prim::kPrimExtractKeywordArg, &ThisClass::EraseExtractKeywordArg},
613 {prim::kPrimDictItems, &ThisClass::EraseDictItems},
614 };
615
ConvertPrimitiveCNode(const CNodePtr & cnode)616 AnfNodePtr ConvertPrimitiveCNode(const CNodePtr &cnode) override {
617 // Get primitive from cnode.
618 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
619 if (prim == nullptr) {
620 return nullptr;
621 }
622 // Find cnode converter by primitive.
623 auto iter = converters_.find(prim);
624 if (iter == converters_.end()) {
625 return nullptr;
626 }
627 // Call converter.
628 return (this->*(iter->second))(cnode);
629 }
630
ConvertDictValue(const ValuePtr & value,size_t depth,bool convert_dict,bool * need_convert) const631 ValuePtr ConvertDictValue(const ValuePtr &value, size_t depth, bool convert_dict, bool *need_convert) const {
632 MS_EXCEPTION_IF_NULL(value);
633 if (depth > kMaxSeqRecursiveDepth) {
634 MS_LOG(ERROR) << "value:" << value->ToString();
635 MS_LOG(INTERNAL_EXCEPTION) << "List, tuple and dict nesting is not allowed more than " << kMaxSeqRecursiveDepth
636 << " levels.";
637 }
638 if (value->isa<ValueSequence>()) {
639 auto value_seq = value->cast<ValueSequencePtr>();
640 std::vector<ValuePtr> value_vec;
641 value_vec.reserve(value_seq->size());
642 bool new_need_convert = false;
643 for (const auto &element : value_seq->value()) {
644 (void)value_vec.emplace_back(ConvertDictValue(element, depth + 1, convert_dict, &new_need_convert));
645 }
646 if (!new_need_convert) {
647 return value;
648 }
649 *need_convert = true;
650 if (value->isa<ValueTuple>()) {
651 return std::make_shared<ValueTuple>(value_vec);
652 }
653 return std::make_shared<ValueList>(value_vec);
654 }
655 // dict(k0:v0, k1:v1, ...) --> tuple(v0, v1, ...)
656 if (value->isa<ValueDictionary>() && convert_dict) {
657 *need_convert = true;
658 const auto &keys_values = value->cast<ValueDictionaryPtr>()->value();
659 std::vector<ValuePtr> value_vec;
660 value_vec.reserve(keys_values.size());
661 for (const auto &element : keys_values) {
662 (void)value_vec.emplace_back(ConvertDictValue(element.second, depth + 1, convert_dict, need_convert));
663 }
664 return std::make_shared<ValueTuple>(value_vec);
665 }
666 return value;
667 }
668
ConvertValueNode(const ValueNodePtr & value_node,const ValuePtr & value)669 AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override {
670 // Convert Dictionary value node.
671 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
672 bool convert_dict = !allow_fallback_runtime || ConvertDictToTuple(value_node, root_graph_);
673 bool need_convert = false;
674 auto new_value = ConvertDictValue(value, 0, convert_dict, &need_convert);
675 if (need_convert) {
676 auto new_node = NewValueNode(new_value);
677 new_node->set_debug_info(value_node->debug_info());
678 return new_node;
679 }
680 return nullptr;
681 }
682
MakeAbstractTuple(const std::vector<AbstractElementPair> & attrs)683 static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractElementPair> &attrs) {
684 std::vector<AbstractBasePtr> elements;
685 elements.reserve(attrs.size());
686 (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(elements),
687 [](const auto &item) { return item.second; });
688 return std::make_shared<AbstractTuple>(std::move(elements));
689 }
690
691 // AbstractDictionary --> AbstractSequence.
ConvertToAbstractSequence(const AbstractBasePtr & abs,size_t depth)692 AbstractSequencePtr ConvertToAbstractSequence(const AbstractBasePtr &abs, size_t depth) {
693 if (depth > kMaxSeqRecursiveDepth) {
694 MS_LOG(ERROR) << "abs:" << abs->ToString();
695 MS_LOG(INTERNAL_EXCEPTION) << "List, tuple and dict nesting is not allowed more than " << kMaxSeqRecursiveDepth
696 << " levels.";
697 }
698 auto abs_seq = abs->cast<AbstractSequencePtr>();
699 if (abs_seq != nullptr) {
700 const auto &seq_elements = abs_seq->elements();
701 // First we check if elements should be converted,
702 // changed_elements maps old element to new element.
703 mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
704 for (const auto &element : seq_elements) {
705 auto new_element = ConvertToAbstractSequence(element, depth + 1);
706 if (new_element != nullptr) {
707 (void)changed_elements.emplace(element, new_element);
708 }
709 }
710 if (changed_elements.empty()) {
711 // Here the AbstractList don't need to convert to AbstractTuple.
712 return nullptr;
713 }
714 // Always make new AbstractSequence when elements changed.
715 std::vector<AbstractBasePtr> elements;
716 elements.reserve(seq_elements.size());
717 for (const auto &element : seq_elements) {
718 auto iter = changed_elements.find(element);
719 if (iter != changed_elements.end()) {
720 (void)elements.emplace_back(iter->second);
721 } else {
722 (void)elements.emplace_back(element);
723 }
724 }
725 // Here the AbstractList don't need to convert to AbstractTuple.
726 if (abs_seq->isa<AbstractList>()) {
727 return std::make_shared<AbstractList>(std::move(elements));
728 } else {
729 return std::make_shared<AbstractTuple>(std::move(elements));
730 }
731 }
732 // AbstractDictionary --> AbstractTuple.
733 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
734 bool convert_to_tuple = !allow_fallback_runtime || (!is_dict_output_ && !has_dict_inplace_);
735 auto abs_dict = abs->cast<AbstractDictionaryPtr>();
736 if (abs_dict != nullptr && convert_to_tuple) {
737 const auto &dict_elements = abs_dict->elements();
738 std::vector<AbstractBasePtr> elements;
739 elements.reserve(dict_elements.size());
740 for (const auto &element : dict_elements) {
741 auto new_element = ConvertToAbstractSequence(element.second, depth + 1);
742 if (new_element != nullptr) {
743 (void)elements.emplace_back(new_element);
744 } else {
745 (void)elements.emplace_back(element.second);
746 }
747 }
748 return std::make_shared<AbstractTuple>(elements);
749 }
750 return nullptr;
751 }
752
ConvertAbstract(const AbstractBasePtr & abs)753 AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
754 // AbstractDictionary --> AbstractSequence.
755 return ConvertToAbstractSequence(abs, 0);
756 }
757
ConvertDictToTuple(const AnfNodePtr & node,const FuncGraphPtr & fg) const758 bool ConvertDictToTuple(const AnfNodePtr &node, const FuncGraphPtr &fg) const {
759 return !is_dict_output_ && !has_dict_inplace_ && !CheckUserHasPyExecute(node, fg) &&
760 !CheckDictUserHasFuncGraph(node, fg);
761 }
762
763 private:
764 bool is_dict_output_{false};
765 bool has_dict_inplace_{false};
766 };
767
ExtractKwargsNode(const AnfNodePtr & node)768 std::pair<AnfNodePtr, AnfNodePtr> ExtractKwargsNode(const AnfNodePtr &node) {
769 MS_EXCEPTION_IF_NULL(node);
770 if (node->isa<ValueNode>()) {
771 auto kwargs = GetValueNode<KeywordArgPtr>(node);
772 if (kwargs != nullptr) {
773 auto key = MakeValue(kwargs->get_key());
774 auto arg = kwargs->get_value();
775 return std::make_pair(NewValueNode(key), NewValueNode(arg));
776 }
777 } else if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
778 auto kwarg_node = node->cast_ptr<CNode>();
779 constexpr auto kMakeKwargsKeyIndex = 1;
780 constexpr auto kMakeKwargsArgIndex = 2;
781 return std::make_pair(kwarg_node->input(kMakeKwargsKeyIndex), kwarg_node->input(kMakeKwargsArgIndex));
782 }
783 MS_LOG(EXCEPTION) << "Extract kwargs only can be used to CNode[make_keyword_arg] or ValueNode(KeywordArg), but got "
784 << node->DebugString();
785 }
786
787 // TupleGetItem/ListGetItem(sequence, index) -> PyExecute(sequence[index], ...)
ConvertSequenceGetItemInner(const CNodePtr & node)788 AnfNodePtr ConvertSequenceGetItemInner(const CNodePtr &node) {
789 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
790 if (!allow_fallback_runtime) {
791 return nullptr;
792 }
793
794 constexpr size_t prim_index = 0;
795 constexpr size_t sequence_index = 1;
796 constexpr size_t target_index = 2;
797 constexpr size_t node_inputs_size = 3;
798 const auto &node_inputs = node->inputs();
799 auto prim = GetValueNode<PrimitivePtr>(node_inputs[prim_index]);
800 MS_EXCEPTION_IF_NULL(prim);
801 const auto &prim_name = prim->name();
802 if (node_inputs.size() != node_inputs_size) {
803 MS_LOG(EXCEPTION) << "The size of input to " << prim_name << " should be " << node_inputs_size << " but got "
804 << node_inputs.size();
805 }
806
807 std::vector<AbstractBasePtr> inputs_abs;
808 for (size_t i = 1; i < node_inputs.size(); ++i) {
809 inputs_abs.push_back(node_inputs[i]->abstract());
810 }
811
812 auto output_abs = node->abstract();
813 MS_EXCEPTION_IF_NULL(output_abs);
814
815 auto sequence_node = node_inputs[sequence_index];
816 MS_EXCEPTION_IF_NULL(sequence_node);
817 auto sequence_abs = sequence_node->abstract();
818 // If the sequence is any, then the sequence getitem should be converted to PyExecute node.
819 if (sequence_abs == nullptr || !sequence_abs->isa<abstract::AbstractAny>()) {
820 if (!CheckAndConvertUtils::CheckContainNestedOrIrregularSequence(inputs_abs) &&
821 !output_abs->isa<abstract::AbstractAny>()) {
822 return nullptr;
823 }
824 if (!IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
825 auto target_node = node_inputs[target_index];
826 auto target_abs = target_node->abstract();
827 if (target_abs == nullptr || !target_abs->BuildValue()->ContainsValueAny()) {
828 return nullptr;
829 }
830 }
831 }
832
833 const auto &fg = node->func_graph();
834 MS_EXCEPTION_IF_NULL(fg);
835
836 const std::string internal_sequence_input = "__iternal_sequence_input__";
837 const std::string internal_sequence_target = "__internal_sequence_index__";
838
839 std::stringstream script_buffer;
840 script_buffer << internal_sequence_input << "[" << internal_sequence_target << "]";
841 const std::string &script = script_buffer.str();
842 const auto script_str = std::make_shared<StringImm>(script);
843
844 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
845 (void)key_value_names_list.emplace_back(NewValueNode(internal_sequence_input));
846 (void)key_value_names_list.emplace_back(NewValueNode(internal_sequence_target));
847 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
848 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
849 (void)key_value_list.emplace_back(node_inputs[sequence_index]);
850 (void)key_value_list.emplace_back(node_inputs[target_index]);
851 const auto key_value_tuple = fg->NewCNode(key_value_list);
852 auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
853
854 MS_LOG(DEBUG) << "Convert sequence getitem node to PyExecute node: " << res->DebugString();
855 return res;
856 }
857
858 // ==================================================================
859 // AfterOptARewriter converts List, Sparse, RowTensor to Tuple.
860 // ==================================================================
861 class AfterOptARewriter : public BaseRewriter {
862 public:
863 using ThisClass = AfterOptARewriter;
AfterOptARewriter(const FuncGraphPtr & root_graph,const FuncGraphManagerPtr & manager,const StringSetPtr & value_with_inplace)864 AfterOptARewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager,
865 const StringSetPtr &value_with_inplace)
866 : BaseRewriter(root_graph, manager), data_with_inplace_(value_with_inplace) {
867 auto context = MsContext::GetInstance();
868 MS_EXCEPTION_IF_NULL(context);
869 not_convert_jit_ = context->not_convert_jit();
870 }
871 ~AfterOptARewriter() override = default;
872
873 protected:
874 // From:
875 // MakeSparseTensor(indices, values, dense_shape)
876 // To:
877 // MakeTuple(indices, values, dense_shape)
ConvertMakeSparseToMakeTuple(const CNodePtr & node) const878 AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) const {
879 MS_EXCEPTION_IF_NULL(node);
880 MS_EXCEPTION_IF_NULL(node->func_graph());
881
882 AnfNodeWeakPtrList inputs;
883 inputs.reserve(node->size());
884 const auto make_tuple_node = NewValueNode(prim::kPrimMakeTuple);
885 (void)inputs.emplace_back(make_tuple_node);
886 // Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items.
887 (void)inputs.insert(inputs.cend(), node->weak_inputs().cbegin() + 1, node->weak_inputs().cend());
888 auto new_node = node->func_graph()->NewCNodeWeak(std::move(inputs));
889 new_node->set_abstract(node->abstract());
890 return new_node;
891 }
892
893 static inline const mindspore::HashMap<std::string, int64_t> sparse_attr_map = {
894 {kCSRTensorGetIndptrOpName, 0}, {kCSRTensorGetIndicesOpName, 1}, {kCSRTensorGetValuesOpName, 2},
895 {kCSRTensorGetDenseShapeOpName, 3}, {kCOOTensorGetIndicesOpName, 0}, {kCOOTensorGetValuesOpName, 1},
896 {kCOOTensorGetDenseShapeOpName, 2}, {kRowTensorGetIndicesOpName, 0}, {kRowTensorGetValuesOpName, 1},
897 {kRowTensorGetDenseShapeOpName, 2}};
898
899 // From:
900 // SparseTensorGetXXX(sparse) # index
901 // To:
902 // TupleGetItem(sparse, index)
ConvertSparseGetAttrToTupleGetItem(const CNodePtr & node) const903 AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node) const {
904 MS_EXCEPTION_IF_NULL(node);
905 MS_EXCEPTION_IF_NULL(node->func_graph());
906
907 constexpr size_t kExpectInputSize = 2;
908 constexpr size_t kSparseAttrIndex = 1;
909 CheckInputsSize(node, kExpectInputSize);
910
911 auto prim = GetValueNode<PrimitivePtr>(node->input(0));
912 if (prim != nullptr) {
913 auto iter = sparse_attr_map.find(prim->name());
914 if (iter != sparse_attr_map.end()) {
915 const auto &sparse = node->input(kSparseAttrIndex);
916 auto index_node = NewValueNode(iter->second);
917 auto new_node = node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, index_node});
918 new_node->set_abstract(node->abstract());
919 return new_node;
920 }
921 }
922 return nullptr;
923 }
924
925 // DictGetItem --> PyExecute()
ConvertDictGetItem(const CNodePtr & cnode) const926 AnfNodePtr ConvertDictGetItem(const CNodePtr &cnode) const {
927 if (not_convert_jit_) {
928 return cnode;
929 }
930 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
931 if (!allow_fallback_runtime) {
932 MS_LOG(WARNING) << "When using the DictGetItem statement with some syntaxes that is not supported in graph mode, "
933 << "it is best to set jit_syntax_level to LAX.\n";
934 return nullptr;
935 }
936 MS_EXCEPTION_IF_NULL(cnode);
937 // Inputs should be [dict_setitem, dict, item]
938 const size_t expect_inputs_size = 3;
939 CheckInputsSize(cnode, expect_inputs_size);
940
941 const size_t data_index = 1;
942 const size_t item_key_index = 2;
943 const auto &inputs = cnode->inputs();
944 auto &data = inputs[data_index];
945 auto &key = inputs[item_key_index];
946 MS_EXCEPTION_IF_NULL(data);
947 MS_EXCEPTION_IF_NULL(key);
948
949 auto func_graph = cnode->func_graph();
950 MS_EXCEPTION_IF_NULL(func_graph);
951
952 // Script
953 std::stringstream script_buffer;
954 script_buffer << kInternalDictSelfStr << "[" << kInternalDictKeyStr << "]";
955 const std::string &script = script_buffer.str();
956 const auto script_str = std::make_shared<StringImm>(script);
957
958 // Pack local parameters keys.
959 const auto script_dict_self_name = std::make_shared<StringImm>(kInternalDictSelfStr);
960 const auto script_dict_key_name = std::make_shared<StringImm>(kInternalDictKeyStr);
961 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
962 (void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
963 (void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
964 const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
965
966 // Pack the local parameters values, not support list, tuple, or dict.
967 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
968 (void)key_value_list.emplace_back(data);
969 (void)key_value_list.emplace_back(key);
970 const auto key_value_tuple = func_graph->NewCNode(key_value_list);
971
972 // Build the new dict node.
973 const auto dict_getitem_node =
974 fallback::CreatePyExecuteCNodeInOrder(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
975 auto abs_dict = GetAbstract<AbstractDictionary>(data);
976 if (abs_dict != nullptr) {
977 size_t index = GetElementIndex(abs_dict->elements(), key);
978 const auto &elements = abs_dict->elements();
979 if (elements.size() > index) {
980 const auto &val = elements[index].second;
981 const auto &tensor_val = dyn_cast<abstract::AbstractTensor>(val);
982 if (tensor_val != nullptr) {
983 const auto &tensor_type = tensor_val->element()->BuildType();
984 fallback::SetRealType<AnfNode, Type>(dict_getitem_node, tensor_val->BuildType());
985 const auto &tensor_shape = dyn_cast<abstract::Shape>(tensor_val->BuildShape());
986 MS_EXCEPTION_IF_NULL(tensor_shape);
987 fallback::SetRealShape<AnfNode, abstract::BaseShape>(dict_getitem_node, tensor_shape);
988 MS_LOG(DEBUG) << "key: " << key->abstract()->BuildValue()->ToString() << ", type: " << tensor_type->ToString()
989 << ", shape: " << tensor_shape->ToString() << ", val: " << tensor_val->ToString();
990 }
991 }
992 }
993 MS_LOG(DEBUG) << "Made dict getitem node: " << dict_getitem_node->DebugString();
994 return dict_getitem_node;
995 }
996
997 // DictSetItem --> PyExecute()
ConvertDictSetItem(const CNodePtr & cnode) const998 AnfNodePtr ConvertDictSetItem(const CNodePtr &cnode) const {
999 if (not_convert_jit_) {
1000 return cnode;
1001 }
1002 MS_EXCEPTION_IF_NULL(cnode);
1003 // Inputs should be [dict_setitem, dict, item, value]
1004 const size_t expect_inputs_size = 4;
1005 CheckInputsSize(cnode, expect_inputs_size);
1006
1007 const size_t data_index = 1;
1008 const size_t item_key_index = 2;
1009 const size_t item_value_index = 3;
1010 const auto &inputs = cnode->inputs();
1011 auto &data = inputs[data_index];
1012 auto &key = inputs[item_key_index];
1013 auto &item_value = inputs[item_value_index];
1014 MS_EXCEPTION_IF_NULL(data);
1015 MS_EXCEPTION_IF_NULL(key);
1016
1017 auto abs_dict = GetAbstract<AbstractDictionary>(data);
1018 if (abs_dict == nullptr) {
1019 return nullptr;
1020 }
1021 auto func_graph = cnode->func_graph();
1022 MS_EXCEPTION_IF_NULL(func_graph);
1023
1024 // Script
1025 std::stringstream script_buffer;
1026 script_buffer << "__import__('mindspore').common._jit_fallback_utils.dict_setitem(" << kInternalDictSelfStr << ", "
1027 << kInternalDictKeyStr << ", " << kInternalDictValueStr << ")";
1028 const std::string &script = script_buffer.str();
1029 const auto script_str = std::make_shared<StringImm>(script);
1030
1031 // Pack local parameters keys.
1032 const auto script_dict_self_name = std::make_shared<StringImm>(kInternalDictSelfStr);
1033 const auto script_dict_key_name = std::make_shared<StringImm>(kInternalDictKeyStr);
1034 const auto script_dict_value_name = std::make_shared<StringImm>(kInternalDictValueStr);
1035 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1036 (void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
1037 (void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
1038 (void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
1039 const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
1040
1041 // Pack the local parameters values, not support list, tuple, or dict.
1042 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1043 (void)key_value_list.emplace_back(data);
1044 (void)key_value_list.emplace_back(key);
1045 (void)key_value_list.emplace_back(item_value);
1046 const auto key_value_tuple = func_graph->NewCNode(key_value_list);
1047
1048 // Build the new dict node.
1049 const auto dict_setitem_node =
1050 fallback::CreatePyExecuteCNodeInOrder(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1051 MS_LOG(DEBUG) << "Made dict setitem node: " << dict_setitem_node->DebugString();
1052 return dict_setitem_node;
1053 }
1054
ConstructInternalTupleKeysNode(const FuncGraphPtr & fg,const AnfNodePtr & keys_node) const1055 AnfNodePtr ConstructInternalTupleKeysNode(const FuncGraphPtr &fg, const AnfNodePtr &keys_node) const {
1056 constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
1057 MS_EXCEPTION_IF_NULL(fg);
1058 const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
1059 auto dict_py_exec_key = std::make_shared<ValueTuple>(std::vector<ValuePtr>{script_key_tuple_str});
1060 auto dict_tuple_key_value = fg->NewCNode({std::make_shared<ValueNode>(prim::kPrimMakeTuple), keys_node});
1061 const auto make_key_tuple_node =
1062 fallback::CreatePyExecuteCNode(fg, NewValueNode(script_key_tuple_str), NewValueNode(dict_py_exec_key),
1063 dict_tuple_key_value, keys_node->debug_info());
1064 return make_key_tuple_node;
1065 }
1066
ConstructInternalTupleValueNode(const FuncGraphPtr & fg,const AnfNodePtr & values_node) const1067 AnfNodePtr ConstructInternalTupleValueNode(const FuncGraphPtr &fg, const AnfNodePtr &values_node) const {
1068 constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
1069 MS_EXCEPTION_IF_NULL(fg);
1070 const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
1071 auto dict_py_exec_value = std::make_shared<ValueTuple>(std::vector<ValuePtr>{script_value_tuple_str});
1072 auto dict_tuple_node = fg->NewCNode({std::make_shared<ValueNode>(prim::kPrimMakeTuple), values_node});
1073 const auto make_value_tuple_node =
1074 fallback::CreatePyExecuteCNode(fg, NewValueNode(script_value_tuple_str), NewValueNode(dict_py_exec_value),
1075 dict_tuple_node, values_node->debug_info());
1076 return make_value_tuple_node;
1077 }
1078
ConstructNewDictNode(const FuncGraphPtr & fg,const AnfNodePtr & make_key_tuple_node,const AnfNodePtr & make_value_tuple_node) const1079 AnfNodePtr ConstructNewDictNode(const FuncGraphPtr &fg, const AnfNodePtr &make_key_tuple_node,
1080 const AnfNodePtr &make_value_tuple_node) const {
1081 constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
1082 constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
1083 // Pack the local parameters values
1084 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1085 (void)key_value_list.emplace_back(make_key_tuple_node);
1086 (void)key_value_list.emplace_back(make_value_tuple_node);
1087 const auto key_value_tuple = fg->NewCNode(key_value_list);
1088
1089 // Pack local parameters keys.
1090 const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
1091 const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
1092 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1093 (void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
1094 (void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
1095 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1096
1097 // Construct Script Node
1098 std::stringstream script_buffer;
1099 script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
1100 const std::string &script = script_buffer.str();
1101 const auto script_str = std::make_shared<StringImm>(script);
1102
1103 // Build the new dict node.
1104 const auto make_dict_node = fallback::CreatePyExecuteCNodeInOrder(
1105 fg, NewValueNode(script_str), key_value_name_tuple, key_value_tuple, make_key_tuple_node->debug_info());
1106 MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
1107 return make_dict_node;
1108 }
1109
1110 // MakeDict(keys, values) --> PyExecute('dict(zip(keys, values))', ...)
ConvertMakeDict(const CNodePtr & node) const1111 AnfNodePtr ConvertMakeDict(const CNodePtr &node) const {
1112 if (not_convert_jit_) {
1113 return node;
1114 }
1115 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1116 if (!allow_fallback_runtime) {
1117 MS_LOG(WARNING) << "When using the MakeDict statement with some syntaxes that is not supported in graph mode, "
1118 << "it is best to set jit_syntax_level to LAX.\n";
1119 return nullptr;
1120 }
1121 const auto &fg = node->func_graph();
1122 MS_EXCEPTION_IF_NULL(fg);
1123 // Local parameters values.
1124 // Get the key tuple.
1125 constexpr size_t keys_input_index = 1;
1126 auto keys_node = node->input(keys_input_index);
1127 const auto make_key_tuple_node = ConstructInternalTupleKeysNode(fg, keys_node);
1128 make_key_tuple_node->set_debug_info(node->input(keys_input_index)->debug_info());
1129 // Get the value tuple.
1130 constexpr size_t values_input_index = 2;
1131 auto values_node = node->input(values_input_index);
1132 const auto make_value_tuple_node = ConstructInternalTupleValueNode(fg, values_node);
1133 make_value_tuple_node->set_debug_info(node->input(values_input_index)->debug_info());
1134
1135 auto new_dict_node = ConstructNewDictNode(fg, make_key_tuple_node, make_value_tuple_node);
1136 new_dict_node->set_debug_info(node->debug_info());
1137 return new_dict_node;
1138 }
1139
GenerateTupleInput(const CNodePtr & node) const1140 AnfNodePtr GenerateTupleInput(const CNodePtr &node) const {
1141 const auto &fg = node->func_graph();
1142 MS_EXCEPTION_IF_NULL(fg);
1143 const auto &inputs = node->inputs();
1144 constexpr auto internal_element_str_prefix = "__internal_list_element_";
1145 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1146 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1147 std::stringstream script_buffer;
1148 script_buffer << "(";
1149 for (size_t i = 1; i < inputs.size(); ++i) {
1150 if (IsValueNode<None>(inputs[i])) {
1151 script_buffer << "None, ";
1152 continue;
1153 }
1154 std::string cur_element = internal_element_str_prefix + std::to_string(i) + "_";
1155 (void)key_value_names_list.emplace_back(NewValueNode(cur_element));
1156 (void)key_value_list.emplace_back(inputs[i]);
1157 script_buffer << cur_element << ", ";
1158 }
1159 script_buffer << ")";
1160 const std::string &script = script_buffer.str();
1161 const auto script_str = std::make_shared<StringImm>(script);
1162 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1163 const auto key_value_tuple = fg->NewCNode(key_value_list);
1164 auto list_node =
1165 fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1166 return list_node;
1167 }
1168
1169 // MakeList(x1, x2, ...) --> PyExecute('[x1, x2, ...]', ...)
ConvertMakeList(const CNodePtr & node) const1170 AnfNodePtr ConvertMakeList(const CNodePtr &node) const {
1171 if (!fallback::EnableFallbackListDictInplace()) {
1172 return nullptr;
1173 }
1174
1175 const auto &fg = node->func_graph();
1176 MS_EXCEPTION_IF_NULL(fg);
1177
1178 auto list_node_input = GenerateTupleInput(node);
1179
1180 if (!fallback::HasObjInExtraInfoHolder(node->abstract())) {
1181 MS_LOG(EXCEPTION) << "MakeList node: " << node->DebugString() << " do not have python list object.";
1182 }
1183 auto object = fallback::GetObjFromExtraInfoHolder(node->abstract());
1184 if (!py::isinstance<py::list>(object)) {
1185 MS_INTERNAL_EXCEPTION(TypeError) << "For MakeList node: " << node->DebugString()
1186 << ", the corresponding python object should be list but got: " << object;
1187 }
1188 py::list list_object = py::list(object);
1189 const std::string list_obj_str_prefix = "__list_py_object_";
1190 auto list_obj_id = fallback::GetPyObjectPtrStr(list_object);
1191 MS_LOG(DEBUG) << "Current python object id: " << list_obj_id;
1192 auto list_obj_str = list_obj_str_prefix + list_obj_id + "_";
1193 fallback::SetPyObjectToLocalVariable(list_obj_str, list_object);
1194
1195 const auto list_key_input = "__internal_list_key__";
1196 const auto list_value_input = "__internal_list_value__";
1197 std::stringstream script_buffer;
1198 script_buffer << "__import__('mindspore').common._jit_fallback_utils.generate_list(" << list_key_input << ", "
1199 << list_value_input << ")";
1200 const std::string &script = script_buffer.str();
1201 const auto script_str = std::make_shared<StringImm>(script);
1202
1203 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1204 (void)key_value_names_list.emplace_back(NewValueNode(list_key_input));
1205 (void)key_value_names_list.emplace_back(NewValueNode(list_value_input));
1206 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1207 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1208 (void)key_value_list.emplace_back(NewValueNode(list_obj_str));
1209 (void)key_value_list.emplace_back(list_node_input);
1210 const auto key_value_tuple = fg->NewCNode(key_value_list);
1211 auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1212
1213 auto abs = node->abstract();
1214 MS_EXCEPTION_IF_NULL(abs);
1215 auto list_abs = abs->cast<abstract::AbstractListPtr>();
1216 MS_EXCEPTION_IF_NULL(list_abs);
1217
1218 res->set_debug_info(node->debug_info());
1219
1220 MS_LOG(DEBUG) << "Convert make_list node to PyExecute node: " << res->DebugString();
1221 return res;
1222 }
1223
1224 // x.extend(y) --> PyExecute(_jit_fallback_list_inplace_extend(x, y))
ConvertListInplaceExtend(const CNodePtr & node) const1225 AnfNodePtr ConvertListInplaceExtend(const CNodePtr &node) const {
1226 if (!fallback::EnableFallbackListDictInplace()) {
1227 return nullptr;
1228 }
1229
1230 const auto &fg = node->func_graph();
1231 MS_EXCEPTION_IF_NULL(fg);
1232 constexpr auto internal_list_input = "__internal_list_input__";
1233 constexpr auto internal_target_input = "__internal_target_input__";
1234 std::stringstream script_buffer;
1235 script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_extend(" << internal_list_input
1236 << ", " << internal_target_input << ")";
1237 const std::string &script = script_buffer.str();
1238 const auto script_str = std::make_shared<StringImm>(script);
1239 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1240 (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1241 (void)key_value_names_list.emplace_back(NewValueNode(internal_target_input));
1242 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1243
1244 const auto &node_inputs = node->inputs();
1245 constexpr size_t min_node_inputs_size = 3;
1246 constexpr size_t max_node_inputs_size = 4;
1247 size_t inputs_size = node_inputs.size();
1248 if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1249 MS_LOG(EXCEPTION) << "The size of input to ListInplaceExtend should be " << min_node_inputs_size << " or "
1250 << max_node_inputs_size << " but got " << inputs_size;
1251 }
1252 constexpr size_t node_list_index = 1;
1253 constexpr size_t node_target_index = 2;
1254 auto list_input_node = node_inputs[node_list_index];
1255 if (IsPrimitiveCNode(list_input_node, prim::kPrimMakeList)) {
1256 TraceGuard trace_guard(std::make_shared<TraceCopy>(list_input_node->debug_info()));
1257 auto new_node = ConvertMakeList(list_input_node->cast<CNodePtr>());
1258 (void)manager_->Replace(list_input_node, new_node);
1259 list_input_node = new_node;
1260 }
1261 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1262 (void)key_value_list.emplace_back(list_input_node);
1263 (void)key_value_list.emplace_back(node_inputs[node_target_index]);
1264 const auto key_value_tuple = fg->NewCNode(key_value_list);
1265
1266 auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1267
1268 if (inputs_size == max_node_inputs_size) {
1269 res->add_input(node_inputs[max_node_inputs_size - 1]);
1270 }
1271 res->set_debug_info(node->debug_info());
1272
1273 MS_LOG(DEBUG) << "Convert list inplace append node to PyExecute node: " << res->DebugString();
1274 return res;
1275 }
1276
1277 // x.insert(index, y) --> PyExecute(_jit_fallback_list_inplace_insert(x, index, y))
ConvertDictInplaceSetItem(const CNodePtr & node) const1278 AnfNodePtr ConvertDictInplaceSetItem(const CNodePtr &node) const {
1279 if (!fallback::EnableFallbackListDictInplace()) {
1280 return nullptr;
1281 }
1282
1283 const auto &fg = node->func_graph();
1284 MS_EXCEPTION_IF_NULL(fg);
1285 constexpr auto internal_dict_input = "__internal_dict_input__";
1286 constexpr auto internal_key_input = "__internal_key_input__";
1287 constexpr auto internal_target_input = "__internal_target_input__";
1288 std::stringstream script_buffer;
1289 script_buffer << "__import__('mindspore').common._jit_fallback_utils.dict_inplace_setitem(" << internal_dict_input
1290 << ", " << internal_key_input << ", " << internal_target_input << ")";
1291 const std::string &script = script_buffer.str();
1292 const auto script_str = std::make_shared<StringImm>(script);
1293 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1294 (void)key_value_names_list.emplace_back(NewValueNode(internal_dict_input));
1295 (void)key_value_names_list.emplace_back(NewValueNode(internal_key_input));
1296 (void)key_value_names_list.emplace_back(NewValueNode(internal_target_input));
1297 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1298
1299 const auto &node_inputs = node->inputs();
1300 constexpr size_t min_node_inputs_size = 4;
1301 constexpr size_t max_node_inputs_size = 5;
1302 size_t inputs_size = node_inputs.size();
1303 if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1304 MS_LOG(EXCEPTION) << "The size of input to DictInplaceSetItem should be " << min_node_inputs_size << " or "
1305 << max_node_inputs_size << " but got " << inputs_size;
1306 }
1307 constexpr size_t node_list_index = 1;
1308 constexpr size_t node_index_index = 2;
1309 constexpr size_t node_target_index = 3;
1310 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1311 (void)key_value_list.emplace_back(node_inputs[node_list_index]);
1312 (void)key_value_list.emplace_back(node_inputs[node_index_index]);
1313 (void)key_value_list.emplace_back(node_inputs[node_target_index]);
1314 const auto key_value_tuple = fg->NewCNode(key_value_list);
1315
1316 auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1317
1318 if (inputs_size == max_node_inputs_size) {
1319 res->add_input(node_inputs[max_node_inputs_size - 1]);
1320 }
1321
1322 res->set_debug_info(node->debug_info());
1323
1324 MS_LOG(DEBUG) << "Convert dict inplace setitem node to PyExecute node: " << res->DebugString();
1325 return res;
1326 }
1327
1328 // x.pop(index) --> PyExecute(_jit_fallback_list_inplace_pop(x, index, y))
ConvertListInplacePop(const CNodePtr & node) const1329 AnfNodePtr ConvertListInplacePop(const CNodePtr &node) const {
1330 if (!fallback::EnableFallbackListDictInplace()) {
1331 return nullptr;
1332 }
1333
1334 const auto &fg = node->func_graph();
1335 MS_EXCEPTION_IF_NULL(fg);
1336 constexpr auto internal_list_input = "__internal_list_input__";
1337 constexpr auto internal_index_input = "__internal_index_input__";
1338 std::stringstream script_buffer;
1339 script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_pop(" << internal_list_input
1340 << ", " << internal_index_input << ")";
1341 const std::string &script = script_buffer.str();
1342 const auto script_str = std::make_shared<StringImm>(script);
1343 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1344 (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1345 (void)key_value_names_list.emplace_back(NewValueNode(internal_index_input));
1346 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1347
1348 const auto &node_inputs = node->inputs();
1349 constexpr size_t min_node_inputs_size = 3;
1350 constexpr size_t max_node_inputs_size = 4;
1351 size_t inputs_size = node_inputs.size();
1352 if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1353 MS_LOG(EXCEPTION) << "The size of input to ListInplacePop should be " << min_node_inputs_size << " or "
1354 << max_node_inputs_size << " but got " << inputs_size;
1355 }
1356 constexpr size_t node_list_index = 1;
1357 constexpr size_t node_index_index = 2;
1358 auto list_input_node = node_inputs[node_list_index];
1359 if (IsPrimitiveCNode(list_input_node, prim::kPrimMakeList)) {
1360 TraceGuard trace_guard(std::make_shared<TraceCopy>(list_input_node->debug_info()));
1361 auto new_node = ConvertMakeList(list_input_node->cast<CNodePtr>());
1362 (void)manager_->Replace(list_input_node, new_node);
1363 list_input_node = new_node;
1364 }
1365 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1366 (void)key_value_list.emplace_back(list_input_node);
1367 (void)key_value_list.emplace_back(node_inputs[node_index_index]);
1368 const auto key_value_tuple = fg->NewCNode(key_value_list);
1369
1370 auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1371
1372 if (inputs_size == max_node_inputs_size) {
1373 res->add_input(node_inputs[max_node_inputs_size - 1]);
1374 }
1375 res->set_debug_info(node->debug_info());
1376
1377 MS_LOG(DEBUG) << "Convert list inplace pop node to PyExecute node: " << res->DebugString();
1378 return res;
1379 }
1380
1381 // x.reverse() --> PyExecute(_jit_fallback_list_inplace_reverse(x))
ConvertListInplaceReverse(const CNodePtr & node) const1382 AnfNodePtr ConvertListInplaceReverse(const CNodePtr &node) const {
1383 if (!fallback::EnableFallbackListDictInplace()) {
1384 return nullptr;
1385 }
1386
1387 const auto &fg = node->func_graph();
1388 MS_EXCEPTION_IF_NULL(fg);
1389 constexpr auto internal_list_input = "__internal_list_input__";
1390 std::stringstream script_buffer;
1391 script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_reverse(" << internal_list_input
1392 << ")";
1393 const std::string &script = script_buffer.str();
1394 const auto script_str = std::make_shared<StringImm>(script);
1395 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1396 (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1397 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1398
1399 const auto &node_inputs = node->inputs();
1400 constexpr size_t min_node_inputs_size = 2;
1401 constexpr size_t max_node_inputs_size = 3;
1402 size_t inputs_size = node_inputs.size();
1403 if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1404 MS_LOG(EXCEPTION) << "The size of input to ListInplaceAppend should be " << min_node_inputs_size << " or "
1405 << max_node_inputs_size << " but got " << inputs_size;
1406 }
1407 constexpr size_t node_list_index = 1;
1408 auto list_input_node = node_inputs[node_list_index];
1409 if (IsPrimitiveCNode(list_input_node, prim::kPrimMakeList)) {
1410 TraceGuard trace_guard(std::make_shared<TraceCopy>(list_input_node->debug_info()));
1411 auto new_node = ConvertMakeList(list_input_node->cast<CNodePtr>());
1412 (void)manager_->Replace(list_input_node, new_node);
1413 list_input_node = new_node;
1414 }
1415 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1416 (void)key_value_list.emplace_back(list_input_node);
1417 const auto key_value_tuple = fg->NewCNode(key_value_list);
1418 auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1419
1420 if (inputs_size == max_node_inputs_size) {
1421 res->add_input(node_inputs[max_node_inputs_size - 1]);
1422 }
1423 res->set_debug_info(node->debug_info());
1424
1425 MS_LOG(DEBUG) << "Convert list inplace reverse node to PyExecute node: " << res->DebugString();
1426 return res;
1427 }
1428
1429 // x.clear() --> PyExecute(_jit_fallback_list_inplace_clear(x))
ConvertListInplaceClear(const CNodePtr & node) const1430 AnfNodePtr ConvertListInplaceClear(const CNodePtr &node) const {
1431 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1432 if (!allow_fallback_runtime) {
1433 return nullptr;
1434 }
1435 static const auto allow_inplace_ops = common::GetCompileConfig("FALLBACK_SUPPORT_LIST_DICT_INPLACE") == "1";
1436 if (!allow_inplace_ops) {
1437 return nullptr;
1438 }
1439
1440 const auto &fg = node->func_graph();
1441 MS_EXCEPTION_IF_NULL(fg);
1442 constexpr auto internal_list_input = "__internal_list_input__";
1443 std::stringstream script_buffer;
1444 script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_clear(" << internal_list_input
1445 << ")";
1446 const std::string &script = script_buffer.str();
1447 const auto script_str = std::make_shared<StringImm>(script);
1448 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1449 (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1450 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1451
1452 const auto &node_inputs = node->inputs();
1453 constexpr size_t node_inputs_size = 2;
1454 if (node_inputs.size() != node_inputs_size) {
1455 MS_LOG(EXCEPTION) << "The size of input to ListInplaceClear should be " << node_inputs_size << " but got "
1456 << node_inputs.size();
1457 }
1458 constexpr size_t node_list_index = 1;
1459 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1460 (void)key_value_list.emplace_back(node_inputs[node_list_index]);
1461 const auto key_value_tuple = fg->NewCNode(key_value_list);
1462
1463 auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1464 res->set_debug_info(node->debug_info());
1465
1466 MS_LOG(DEBUG) << "Convert list inplace clear node to PyExecute node: " << res->DebugString();
1467 return res;
1468 }
1469
1470 // data[key] = target --> PyExecute(_jit_fallback_dict_inplace_setitem(data, key, target))
ConvertListInplaceInsert(const CNodePtr & node) const1471 AnfNodePtr ConvertListInplaceInsert(const CNodePtr &node) const {
1472 if (!fallback::EnableFallbackListDictInplace()) {
1473 return nullptr;
1474 }
1475
1476 const auto &fg = node->func_graph();
1477 MS_EXCEPTION_IF_NULL(fg);
1478 constexpr auto internal_list_input = "__internal_list_input__";
1479 constexpr auto internal_index_input = "__internal_index_input__";
1480 constexpr auto internal_target_input = "__internal_target_input__";
1481 std::stringstream script_buffer;
1482 script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_insert(" << internal_list_input
1483 << ", " << internal_index_input << ", " << internal_target_input << ")";
1484 const std::string &script = script_buffer.str();
1485 const auto script_str = std::make_shared<StringImm>(script);
1486 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1487 (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1488 (void)key_value_names_list.emplace_back(NewValueNode(internal_index_input));
1489 (void)key_value_names_list.emplace_back(NewValueNode(internal_target_input));
1490 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1491
1492 const auto &node_inputs = node->inputs();
1493 constexpr size_t min_node_inputs_size = 4;
1494 constexpr size_t max_node_inputs_size = 5;
1495 size_t inputs_size = node_inputs.size();
1496 if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1497 MS_LOG(EXCEPTION) << "The size of input to ListInplaceInsert should be " << min_node_inputs_size << " or "
1498 << max_node_inputs_size << " but got " << inputs_size;
1499 }
1500 constexpr size_t node_list_index = 1;
1501 constexpr size_t node_index_index = 2;
1502 constexpr size_t node_target_index = 3;
1503 auto list_input_node = node_inputs[node_list_index];
1504 if (IsPrimitiveCNode(list_input_node, prim::kPrimMakeList)) {
1505 TraceGuard trace_guard(std::make_shared<TraceCopy>(list_input_node->debug_info()));
1506 auto new_node = ConvertMakeList(list_input_node->cast<CNodePtr>());
1507 (void)manager_->Replace(list_input_node, new_node);
1508 list_input_node = new_node;
1509 }
1510 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1511 (void)key_value_list.emplace_back(list_input_node);
1512 (void)key_value_list.emplace_back(node_inputs[node_index_index]);
1513 (void)key_value_list.emplace_back(node_inputs[node_target_index]);
1514 const auto key_value_tuple = fg->NewCNode(key_value_list);
1515
1516 auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1517
1518 if (inputs_size == max_node_inputs_size) {
1519 res->add_input(node_inputs[max_node_inputs_size - 1]);
1520 }
1521 res->set_debug_info(node->debug_info());
1522
1523 MS_LOG(DEBUG) << "Convert list inplace insert node to PyExecute node: " << res->DebugString();
1524 return res;
1525 }
1526
1527 // TupleGetItem/ListGetItem(sequence, index) -> PyExecute(sequence[index], ...)
ConvertSequenceGetItem(const CNodePtr & node) const1528 AnfNodePtr ConvertSequenceGetItem(const CNodePtr &node) const { return ConvertSequenceGetItemInner(node); }
1529
1530 // raise(string, keys, values, io) --> PyExecute(string, keys, values, io)
ConvertRaise(const CNodePtr & cnode) const1531 AnfNodePtr ConvertRaise(const CNodePtr &cnode) const {
1532 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1533 if (!allow_fallback_runtime) {
1534 MS_LOG(WARNING) << "When using the raise statement, it is best to set jit_syntax_level to LAX, "
1535 << "because there is no the real raise operator.\n";
1536 return nullptr;
1537 }
1538 MS_EXCEPTION_IF_NULL(cnode);
1539 const auto &fg = cnode->func_graph();
1540 MS_EXCEPTION_IF_NULL(fg);
1541 MS_LOG(DEBUG) << "Raise node: " << cnode->DebugString();
1542 const auto &inputs = cnode->inputs();
1543 std::shared_ptr<raiseutils::KeyValueInfo> key_value = std::make_shared<raiseutils::KeyValueInfo>();
1544 key_value->keys = {NewValueNode(prim::kPrimMakeTuple)};
1545 key_value->values = {NewValueNode(prim::kPrimMakeTuple)};
1546 size_t index_begin = 2;
1547 constexpr auto end_num = 2;
1548 size_t index_end = inputs.size() - end_num;
1549 size_t size_if_empty = 4;
1550 std::string exception_type = raiseutils::GetExceptionType(inputs[1]->abstract(), inputs[index_end], key_value);
1551 std::string exception_string;
1552 // Process raise ValueError()
1553 if (inputs.size() == size_if_empty) {
1554 std::string key = raiseutils::MakeRaiseKey(key_value->num_str);
1555 (void)key_value->keys.emplace_back(NewValueNode(std::make_shared<StringImm>(key)));
1556 (void)key_value->values.emplace_back(NewValueNode(std::make_shared<StringImm>("")));
1557 exception_string = key;
1558 }
1559 // Processed in units of nodes. Raise ValueError(xxxx)
1560 for (size_t index = index_begin; index < index_end; ++index) {
1561 const auto input = inputs[index];
1562 auto input_abs = input->abstract();
1563 MS_EXCEPTION_IF_NULL(input_abs);
1564 const bool need_symbol = raiseutils::CheckNeedSymbol(input_abs);
1565 if (need_symbol) {
1566 exception_string += "'";
1567 }
1568 bool need_comma = !IsPrimitiveCNode(input, prim::kPrimMakeTuple);
1569 exception_string += raiseutils::GetExceptionString(input_abs, input, key_value, need_symbol, need_comma);
1570 if (need_symbol) {
1571 exception_string += "'";
1572 }
1573 if (index != inputs.size() - 1) {
1574 exception_string += ", ";
1575 }
1576 }
1577 bool need_out_symbol = inputs.size() > 5;
1578 if (need_out_symbol) {
1579 exception_string = "(" + exception_string + ")";
1580 }
1581 // Condition has variable but script does not.
1582 if (key_value->keys.size() <= 1) {
1583 std::string key = raiseutils::MakeRaiseKey(key_value->num_str);
1584 (void)key_value->keys.emplace_back(NewValueNode(std::make_shared<StringImm>(key)));
1585 (void)key_value->values.emplace_back(NewValueNode(std::make_shared<StringImm>(exception_string)));
1586 exception_string = key;
1587 }
1588 // Build PyExecute node for raise
1589 const std::string error_msg =
1590 "__import__('mindspore').common._utils._jit_fallback_raise_func(" + exception_type + "," + exception_string + ")";
1591 const auto script_str = std::make_shared<StringImm>(error_msg);
1592 // Pack local parameter keys
1593 const auto key_value_name_tuple = fg->NewCNodeInOrder(key_value->keys);
1594 // Pack local parameter values
1595 const auto key_value_tuple = fg->NewCNodeInOrder(key_value->values);
1596 // Build the PyExecute node for raise error.
1597 const auto raise_pyexecute_node = fallback::CreatePyExecuteCNodeInOrder(
1598 fg, NewValueNode(script_str), key_value_name_tuple, key_value_tuple, cnode->debug_info());
1599 raise_pyexecute_node->add_input(inputs[inputs.size() - 1]);
1600 auto old_abs = cnode->abstract();
1601 MS_EXCEPTION_IF_NULL(old_abs);
1602 const auto &type = old_abs->BuildType();
1603 MS_EXCEPTION_IF_NULL(type);
1604 fallback::SetRealType(raise_pyexecute_node, type);
1605 MS_LOG(DEBUG) << "Raise convert to PyExecute node: " << raise_pyexecute_node->DebugString();
1606 return raise_pyexecute_node;
1607 }
1608
1609 // ScalarCast(x, dtype) --> PyExecute(string, keys, values)
ConvertScalarCast(const CNodePtr & cnode) const1610 AnfNodePtr ConvertScalarCast(const CNodePtr &cnode) const {
1611 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1612 if (!allow_fallback_runtime) {
1613 MS_LOG(WARNING) << "When using the ScalarCast statement with some syntaxes that is not supported in graph mode, "
1614 << "it is best to set jit_syntax_level to LAX.\n";
1615 return nullptr;
1616 }
1617 constexpr size_t x_index = 1;
1618 constexpr size_t dtype_index = 2;
1619 auto x_node = cnode->input(x_index);
1620 auto dtype_node = cnode->input(dtype_index);
1621 auto x_abs = GetAbstract<abstract::AbstractAny>(x_node);
1622 if (x_abs == nullptr) {
1623 return nullptr;
1624 }
1625 auto dtype_abs = GetAbstract<abstract::AbstractScalar>(dtype_node);
1626 MS_EXCEPTION_IF_NULL(dtype_abs);
1627 auto dtype_val = dtype_abs->GetValue();
1628 MS_EXCEPTION_IF_NULL(dtype_val);
1629 auto type_id_opt = ops::GetScalarValue<int64_t>(dtype_val);
1630 if (!type_id_opt.has_value()) {
1631 MS_LOG(EXCEPTION) << "the dtype input is invalid!";
1632 }
1633 std::string target_type_str;
1634 auto type_id = type_id_opt.value();
1635 if (type_id == kNumberTypeInt) {
1636 target_type_str = "int";
1637 } else if (type_id == kNumberTypeFloat) {
1638 target_type_str = "float";
1639 } else if (type_id == kNumberTypeBool) {
1640 target_type_str = "bool";
1641 } else {
1642 MS_LOG(EXCEPTION) << "Unsupported type: " << type_id;
1643 }
1644
1645 const auto &fg = cnode->func_graph();
1646 MS_EXCEPTION_IF_NULL(fg);
1647 std::string internal_scalar_arg_str = "__internal_scalar_arg__";
1648 std::string script = target_type_str + "(" + internal_scalar_arg_str + ")";
1649 auto script_node = NewValueNode(std::make_shared<StringImm>(script));
1650 auto arg_name_node = NewValueNode(std::make_shared<StringImm>(internal_scalar_arg_str));
1651 auto keys_tuple_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), arg_name_node});
1652 auto values_tuple_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), x_node});
1653 keys_tuple_node->set_debug_info(cnode->debug_info());
1654 values_tuple_node->set_debug_info(cnode->debug_info());
1655 auto scalar_cast_node =
1656 fallback::CreatePyExecuteCNodeInOrder(cnode, script_node, keys_tuple_node, values_tuple_node);
1657 MS_LOG(DEBUG) << "Convert CastToScalar: " << cnode->DebugString() << " -> " << scalar_cast_node->DebugString();
1658 return scalar_cast_node;
1659 }
1660
ConvertMakeSlice(const CNodePtr & cnode) const1661 AnfNodePtr ConvertMakeSlice(const CNodePtr &cnode) const {
1662 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1663 if (!allow_fallback_runtime) {
1664 MS_LOG(WARNING) << "When using the MakeSlice statement with some syntaxes that is not supported in graph mode, "
1665 << "it is best to set jit_syntax_level to LAX.\n";
1666 return nullptr;
1667 }
1668 MS_EXCEPTION_IF_NULL(cnode);
1669 const auto &fg = cnode->func_graph();
1670 MS_EXCEPTION_IF_NULL(fg);
1671 MS_LOG(DEBUG) << " make_slice node: " << cnode->DebugString();
1672 constexpr size_t slice_size = 4;
1673 if (cnode->size() != slice_size) {
1674 MS_LOG(INTERNAL_EXCEPTION) << "The size of input to make_slice should be " << slice_size << ", but got "
1675 << cnode->size();
1676 }
1677 constexpr size_t start_index = 1;
1678 constexpr size_t stop_index = 2;
1679 constexpr size_t step_index = 3;
1680 bool is_start_none = IsValueNode<None>(cnode->input(start_index));
1681 bool is_stop_none = IsValueNode<None>(cnode->input(stop_index));
1682 bool is_step_none = IsValueNode<None>(cnode->input(step_index));
1683 auto start_str = is_start_none ? "None" : "__start__";
1684 auto stop_str = is_stop_none ? "None" : "__stop__";
1685 auto step_str = is_step_none ? "None" : "__step__";
1686 // Script
1687 std::stringstream script_buffer;
1688 script_buffer << "slice(" << start_str << ", " << stop_str << ", " << step_str << ")";
1689 const std::string &script = script_buffer.str();
1690 const auto script_str = std::make_shared<StringImm>(script);
1691
1692 // Pack local parameters keys and values.
1693 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1694 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1695 if (!is_start_none) {
1696 (void)key_value_names_list.emplace_back(NewValueNode(start_str));
1697 (void)key_value_list.emplace_back(cnode->input(start_index));
1698 }
1699 if (!is_stop_none) {
1700 (void)key_value_names_list.emplace_back(NewValueNode(stop_str));
1701 (void)key_value_list.emplace_back(cnode->input(stop_index));
1702 }
1703 if (!is_step_none) {
1704 (void)key_value_names_list.emplace_back(NewValueNode(step_str));
1705 (void)key_value_list.emplace_back(cnode->input(step_index));
1706 }
1707 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1708 const auto key_value_tuple = fg->NewCNode(key_value_list);
1709
1710 // Build the new slice node.
1711 const auto slice_node =
1712 fallback::CreatePyExecuteCNodeInOrder(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1713 MS_LOG(DEBUG) << "Made slice node: " << slice_node->DebugString();
1714 return slice_node;
1715 }
1716
1717 // Only process the node that have a PyExecute node(the abstract is AbstractAny).
CheckInputsHasAnyType(const CNodePtr & cnode) const1718 bool CheckInputsHasAnyType(const CNodePtr &cnode) const {
1719 bool exist_any_type = false;
1720 for (const auto &weak_input : cnode->weak_inputs()) {
1721 auto input = weak_input.lock();
1722 auto input_abs = input->abstract();
1723 if (fallback::ContainsSequenceAnyType(input_abs)) {
1724 exist_any_type = true;
1725 break;
1726 }
1727 }
1728 return exist_any_type;
1729 }
1730
ConvertIsInstance(const CNodePtr & cnode) const1731 AnfNodePtr ConvertIsInstance(const CNodePtr &cnode) const {
1732 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1733 if (!allow_fallback_runtime) {
1734 MS_LOG(WARNING) << "When using the isinstance statement, it is best to set jit_syntax_level to LAX, "
1735 << "because there is no the real isinstance operator.\n";
1736 return nullptr;
1737 }
1738 const auto &fg = cnode->func_graph();
1739 MS_EXCEPTION_IF_NULL(fg);
1740 if (!CheckInputsHasAnyType(cnode)) {
1741 return nullptr;
1742 }
1743 const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1744 MS_EXCEPTION_IF_NULL(prim);
1745 string name = prim->name();
1746 auto pyexecute_node = fallback::ConvertCNodeToPyExecuteForPrim(cnode, name);
1747 MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << pyexecute_node->DebugString();
1748 return pyexecute_node;
1749 }
1750
1751 // JoinedStr(XXXXXX)
1752 // TO
1753 // A = PyExecute("list(map(str, __inner_convert_object__), ("__inner_convert_object__",), ((XXXXXX,),)")
1754 // B = PyExecute("".join(__inner_str_list__)", ("__inner_str_list__",), (A,)).
1755 // replace(B --> JoinedStr)
ConvertJoinedStr(const CNodePtr & cnode) const1756 AnfNodePtr ConvertJoinedStr(const CNodePtr &cnode) const {
1757 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1758 if (!allow_fallback_runtime) {
1759 MS_LOG(WARNING) << "When using the JoinedStr statement, it is best to set jit_syntax_level to LAX, "
1760 << "because there is no the real JoinedStr operator.\n";
1761 return nullptr;
1762 }
1763 MS_EXCEPTION_IF_NULL(cnode);
1764 const auto &fg = cnode->func_graph();
1765 MS_EXCEPTION_IF_NULL(fg);
1766 MS_LOG(DEBUG) << " make_slice node: " << cnode->DebugString();
1767 // Convert all node to list[str]
1768 constexpr auto kConvertToListString = "list(map(str, __inner_convert_object__))";
1769 constexpr auto kConvertToListKey = "__inner_convert_object__";
1770 const auto make_tuple_value_node = NewValueNode(prim::kPrimMakeTuple);
1771 AnfNodeWeakPtrList list_str_value_list = {make_tuple_value_node};
1772 (void)std::copy(cnode->weak_inputs().cbegin() + 1, cnode->weak_inputs().cend(),
1773 std::back_inserter(list_str_value_list));
1774
1775 const auto make_tuple_key_node = NewValueNode(prim::kPrimMakeTuple);
1776 const auto key_node = NewValueNode(kConvertToListKey);
1777 AnfNodeWeakPtrList list_str_key_list = {make_tuple_key_node, key_node};
1778 auto list_str_key_node = fg->NewCNodeWeak(list_str_key_list);
1779 auto list_str_value_node = fg->NewCNodeWeak(list_str_value_list);
1780 auto convet_list_str_node = fallback::CreatePyExecuteCNodeInOrder(
1781 fg, NewValueNode(kConvertToListString), list_str_key_node,
1782 fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), list_str_value_node}), cnode->debug_info());
1783
1784 // change to string.
1785 constexpr auto eval_string_script = "\"\".join(__inner_str_list__)";
1786 constexpr auto eval_key_string = "__inner_str_list__";
1787 auto eval_key_node = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(eval_key_string)});
1788 auto eval_value_node = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), convet_list_str_node});
1789
1790 auto joined_result_node = fallback::CreatePyExecuteCNode(fg, NewValueNode(eval_string_script), eval_key_node,
1791 eval_value_node, cnode->debug_info());
1792 return joined_result_node;
1793 }
1794
HasPyExecuteInput(const CNodePtr & cnode) const1795 bool HasPyExecuteInput(const CNodePtr &cnode) const {
1796 MS_EXCEPTION_IF_NULL(cnode);
1797 const auto &inputs = cnode->inputs();
1798 for (auto &input : inputs) {
1799 if (IsPrimitiveCNode(input, prim::kPrimPyExecute)) {
1800 return true;
1801 }
1802 }
1803 return false;
1804 }
1805
ConvertPrint(const CNodePtr & cnode) const1806 AnfNodePtr ConvertPrint(const CNodePtr &cnode) const {
1807 const auto &fg = cnode->func_graph();
1808 MS_EXCEPTION_IF_NULL(fg);
1809 if (!CheckInputsHasAnyType(cnode) && !HasPyExecuteInput(cnode)) {
1810 return nullptr;
1811 }
1812 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1813 if (!allow_fallback_runtime) {
1814 MS_LOG(WARNING) << "When using the print statement with some syntaxes that is not supported in graph mode, "
1815 << "it is best to set jit_syntax_level to LAX.\n";
1816 return nullptr;
1817 }
1818 // Skip the io_monad input
1819 auto inputs = cnode->inputs();
1820 if (!HasAbstractMonad(inputs.back())) {
1821 MS_LOG(EXCEPTION) << "The print node has no monad input:" << cnode->DebugString();
1822 }
1823 inputs.pop_back();
1824 auto no_io_print = fg->NewCNode(inputs);
1825 auto pyexecute_node = fallback::ConvertCNodeToPyExecuteForPrim(no_io_print, "print");
1826
1827 // Add io_monad input
1828 auto new_pyexecute_inputs = pyexecute_node->cast<CNodePtr>()->inputs();
1829 (void)new_pyexecute_inputs.emplace_back(cnode->inputs().back());
1830 auto new_pyexecute_node = fg->NewCNode(new_pyexecute_inputs);
1831 MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << new_pyexecute_node->DebugString();
1832 return new_pyexecute_node;
1833 }
1834 // Format(str, XXXX) Convert to PyExecute
1835 // First Spilt XXXX to dict input when the args is KWargs, otherwise push it to a list.And Then Convert To PyExecute
1836 // A = MakeDict(XXXX[KWargs]->keys(), XXXX[KWargs]->values()) --> This Dict will convert to PyExecute use function
1837 // ConvertMakeDict. B = Tuple(XXXX - XXXX[KWargs]) ps: this sub operator is set sub. C =
1838 // PyExecute("__inner_str__.format(*__format_list_str__, **__format_kwargs__str__)"
1839 // , (__inner_str__, __format_list_str__, __format_kwargs__str__), (str, B, A));
1840 // Replace(C -> Format).
ConvertFormat(const CNodePtr & cnode) const1841 AnfNodePtr ConvertFormat(const CNodePtr &cnode) const {
1842 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1843 if (!allow_fallback_runtime) {
1844 MS_LOG(WARNING) << "When using the format statement with some syntaxes that is not supported in graph mode, "
1845 << "it is best to set jit_syntax_level to LAX.\n";
1846 return nullptr;
1847 }
1848 auto fg = cnode->func_graph();
1849 MS_EXCEPTION_IF_NULL(fg);
1850
1851 std::vector<AnfNodePtr> format_list = {NewValueNode(prim::kPrimMakeTuple)};
1852
1853 std::vector<AnfNodePtr> kwargs_keys_node = {NewValueNode(prim::kPrimMakeTuple)};
1854 std::vector<AnfNodePtr> kwargs_values_node = {NewValueNode(prim::kPrimMakeTuple)};
1855 auto inputs = cnode->inputs();
1856 constexpr auto kFormatArgsIndex = 2;
1857 constexpr auto kStringArgsIndex = 1;
1858 for (size_t i = kFormatArgsIndex; i < inputs.size(); ++i) {
1859 auto input = inputs[i];
1860 MS_EXCEPTION_IF_NULL(input);
1861 auto abs = input->abstract();
1862 if (abs != nullptr && abs->isa<abstract::AbstractKeywordArg>()) {
1863 auto [key, arg] = ExtractKwargsNode(input);
1864 (void)kwargs_keys_node.emplace_back(key);
1865 (void)kwargs_values_node.emplace_back(arg);
1866 } else {
1867 format_list.emplace_back(inputs[i]);
1868 }
1869 }
1870 // Construct kwargs node
1871 auto dict_key_node = fg->NewCNode(kwargs_keys_node);
1872 dict_key_node->set_debug_info(cnode->debug_info());
1873 auto dict_value_node = fg->NewCNode(kwargs_values_node);
1874 dict_value_node->set_debug_info(cnode->debug_info());
1875 auto dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), dict_key_node, dict_value_node});
1876 dict_node->set_debug_info(cnode->debug_info());
1877 auto py_exec_dict_node = ConvertMakeDict(dict_node);
1878 // Construct list args node
1879 auto list_node = fg->NewCNode(format_list);
1880 list_node->set_debug_info(cnode->debug_info());
1881 // Construct PyExecute node
1882 constexpr auto inner_str = "__inner_str__";
1883 constexpr auto format_list_str = "__format_list_str__";
1884 constexpr auto format_kwargs_str = "__format_kwargs__str__";
1885 std::stringstream script_buffer;
1886 script_buffer << inner_str << ".format(*" << format_list_str << ", **" << format_kwargs_str << ")";
1887
1888 std::vector<ValuePtr> key_values = {MakeValue(inner_str), MakeValue(format_list_str), MakeValue(format_kwargs_str)};
1889 auto intrepret_node_keys = NewValueNode(std::make_shared<ValueTuple>(key_values));
1890 auto intrepert_node_values =
1891 fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), inputs.at(kStringArgsIndex), list_node, py_exec_dict_node});
1892 intrepert_node_values->set_debug_info(cnode->debug_info());
1893 auto convert_node = fallback::CreatePyExecuteCNode(fg, NewValueNode(MakeValue(script_buffer.str())),
1894 intrepret_node_keys, intrepert_node_values, cnode->debug_info());
1895 return convert_node;
1896 }
1897
ConvertMakeRange(const CNodePtr & cnode) const1898 AnfNodePtr ConvertMakeRange(const CNodePtr &cnode) const {
1899 const auto &fg = cnode->func_graph();
1900 MS_EXCEPTION_IF_NULL(fg);
1901 if (!CheckInputsHasAnyType(cnode) && !HasPyExecuteInput(cnode)) {
1902 return nullptr;
1903 }
1904 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1905 if (!allow_fallback_runtime) {
1906 MS_LOG(WARNING) << "When using the range statement with some syntaxes that is not supported in graph mode, "
1907 << "it is best to set jit_syntax_level to LAX.\n";
1908 return nullptr;
1909 }
1910 auto pyexecute_node = fallback::ConvertCNodeToPyExecuteForPrim(cnode, "range");
1911 MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << pyexecute_node->DebugString();
1912 return pyexecute_node;
1913 }
1914
ConvertIsAndIsNot(const CNodePtr & cnode,bool is) const1915 AnfNodePtr ConvertIsAndIsNot(const CNodePtr &cnode, bool is) const {
1916 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1917 if (!allow_fallback_runtime) {
1918 MS_LOG(WARNING) << "When using the is/is_not statement with some syntaxes that is not supported in graph mode, "
1919 << "it is best to set jit_syntax_level to LAX.\n";
1920 return nullptr;
1921 }
1922 const auto &fg = cnode->func_graph();
1923 MS_EXCEPTION_IF_NULL(fg);
1924
1925 constexpr auto data_str = "__data__";
1926 constexpr auto target_str = "__target__";
1927 std::stringstream script_buffer;
1928 script_buffer << data_str;
1929 if (is) {
1930 script_buffer << " is ";
1931 } else {
1932 script_buffer << " is not ";
1933 }
1934 script_buffer << target_str;
1935 const std::string &script = script_buffer.str();
1936 const auto script_str = std::make_shared<StringImm>(script);
1937 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1938 (void)key_value_names_list.emplace_back(NewValueNode(data_str));
1939 (void)key_value_names_list.emplace_back(NewValueNode(target_str));
1940 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1941
1942 const auto &node_inputs = cnode->inputs();
1943 constexpr size_t inputs_size = 3;
1944 if (node_inputs.size() != inputs_size) {
1945 MS_LOG(INTERNAL_EXCEPTION) << "The size of input to kPrimIs should be " << inputs_size << "but got "
1946 << node_inputs.size();
1947 }
1948 constexpr size_t data_index = 1;
1949 constexpr size_t target_index = 2;
1950 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1951 (void)key_value_list.emplace_back(node_inputs[data_index]);
1952 (void)key_value_list.emplace_back(node_inputs[target_index]);
1953 const auto key_value_tuple = fg->NewCNode(key_value_list);
1954
1955 auto res = fallback::CreatePyExecuteCNode(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1956 res->set_debug_info(cnode->debug_info());
1957 return res;
1958 }
1959
ConvertIs_(const CNodePtr & cnode) const1960 AnfNodePtr ConvertIs_(const CNodePtr &cnode) const {
1961 auto res = ConvertIsAndIsNot(cnode, true);
1962 MS_LOG(DEBUG) << "Convert primitive Is_ to PyExecute node: " << res->DebugString();
1963 return res;
1964 }
1965
ConvertIsNot(const CNodePtr & cnode) const1966 AnfNodePtr ConvertIsNot(const CNodePtr &cnode) const {
1967 auto res = ConvertIsAndIsNot(cnode, false);
1968 MS_LOG(DEBUG) << "Convert primitive IsNot to PyExecute node: " << res->DebugString();
1969 return res;
1970 }
1971
1972 using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &) const;
1973 using ConverterMap = std::unordered_map<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
1974 static inline const ConverterMap converters_{
1975 // SparseProcess: 1.MakeSparse->MakeTuple 2.SparseGetAttr->TupleGetItem
1976 {prim::kPrimMakeRowTensor, &ThisClass::ConvertMakeSparseToMakeTuple},
1977 {prim::kPrimRowTensorGetIndices, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1978 {prim::kPrimRowTensorGetValues, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1979 {prim::kPrimRowTensorGetDenseShape, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1980 {prim::kPrimMakeCSRTensor, &ThisClass::ConvertMakeSparseToMakeTuple},
1981 {prim::kPrimCSRTensorGetIndptr, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1982 {prim::kPrimCSRTensorGetIndices, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1983 {prim::kPrimCSRTensorGetValues, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1984 {prim::kPrimCSRTensorGetDenseShape, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1985 {prim::kPrimMakeCOOTensor, &ThisClass::ConvertMakeSparseToMakeTuple},
1986 {prim::kPrimCOOTensorGetIndices, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1987 {prim::kPrimCOOTensorGetValues, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1988 {prim::kPrimCOOTensorGetDenseShape, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1989 {prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItem},
1990 {prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItem},
1991 {prim::kPrimListInplaceExtend, &ThisClass::ConvertListInplaceExtend},
1992 {prim::kPrimListInplaceInsert, &ThisClass::ConvertListInplaceInsert},
1993 {prim::kPrimListInplacePop, &ThisClass::ConvertListInplacePop},
1994 {prim::kPrimListInplaceReverse, &ThisClass::ConvertListInplaceReverse},
1995 {prim::kPrimListInplaceClear, &ThisClass::ConvertListInplaceClear},
1996 {prim::kPrimDictInplaceSetItem, &ThisClass::ConvertDictInplaceSetItem},
1997 {prim::kPrimListGetItem, &ThisClass::ConvertSequenceGetItem},
1998 {prim::kPrimTupleGetItem, &ThisClass::ConvertSequenceGetItem},
1999 {prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
2000 {prim::kPrimRaise, &ThisClass::ConvertRaise},
2001 {prim::kPrimScalarCast, &ThisClass::ConvertScalarCast},
2002 {prim::kPrimMakeSlice, &ThisClass::ConvertMakeSlice},
2003 {prim::kPrimIsInstance, &ThisClass::ConvertIsInstance},
2004 {prim::kPrimJoinedStr, &ThisClass::ConvertJoinedStr},
2005 {prim::kPrimPrint, &ThisClass::ConvertPrint},
2006 {prim::kPrimFormat, &ThisClass::ConvertFormat},
2007 {prim::kPrimMakeRange, &ThisClass::ConvertMakeRange},
2008 {prim::kPrimIs_, &ThisClass::ConvertIs_},
2009 {prim::kPrimIsNot, &ThisClass::ConvertIsNot}};
2010
2011 static inline const PrimitiveSet seq_prim_set_{
2012 prim::kPrimInSequence, prim::kPrimSequenceMul, prim::kPrimSequenceCount, prim::kPrimSequenceIndex,
2013 prim::kPrimSequenceLen, prim::kPrimListEqual, prim::kPrimTupleEqual, prim::kPrimTupleGreaterThan,
2014 prim::kPrimListLessEqual, prim::kPrimTupleLessThan, prim::kPrimListLessThan, prim::kPrimTupleLessEqual,
2015 prim::kPrimListGreaterThan, prim::kPrimTupleGreaterEqual, prim::kPrimListGreaterEqual, prim::kPrimSequenceSlice};
2016
2017 // Convert ValueNode<None> to PyExecute("None", ("None"), ("None")).
ConvertNoneToPyExecute(const FuncGraphPtr & func_graph)2018 AnfNodePtr ConvertNoneToPyExecute(const FuncGraphPtr &func_graph) {
2019 MS_EXCEPTION_IF_NULL(func_graph);
2020 auto str_value = std::make_shared<StringImm>("None");
2021 auto script_node = NewValueNode(str_value);
2022
2023 std::vector<ValuePtr> none_value{str_value};
2024 const auto none_tuple = std::make_shared<ValueTuple>(none_value);
2025 auto none_tuple_node = NewValueNode(none_tuple);
2026 AbstractBasePtrList abs_list{std::make_shared<abstract::AbstractScalar>(MakeValue("None"))};
2027 none_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
2028
2029 AnfNodePtr none_execute_node = fallback::CreatePyExecuteCNodeInOrder(
2030 func_graph, script_node, none_tuple_node, none_tuple_node, none_tuple_node->debug_info());
2031 MS_LOG(DEBUG) << "none_execute_node:" << none_execute_node->DebugString();
2032
2033 set_need_renormalized(true);
2034 return none_execute_node;
2035 }
2036
GetPyExecuteFromValueSequence(const FuncGraphPtr & fg,const ValueNodePtr & value_node,const ValueSequencePtr & value_sequence,const PrimitivePtr & prim,bool py_execute_input)2037 AnfNodePtr GetPyExecuteFromValueSequence(const FuncGraphPtr &fg, const ValueNodePtr &value_node,
2038 const ValueSequencePtr &value_sequence, const PrimitivePtr &prim,
2039 bool py_execute_input) {
2040 std::vector<AnfNodePtr> new_inputs;
2041 new_inputs.reserve(value_sequence->size());
2042 (void)new_inputs.emplace_back(NewValueNode(prim));
2043 bool changed = false;
2044 auto abs = value_node->abstract();
2045 if (abs == nullptr) {
2046 for (const auto &v : value_sequence->value()) {
2047 auto v_node = NewValueNode(v);
2048 v_node->set_debug_info(value_node->debug_info());
2049 auto new_node = GetPyExecuteFromValue(fg, v_node, v, py_execute_input);
2050 new_node->set_debug_info(value_node->debug_info());
2051 (void)new_inputs.emplace_back(new_node);
2052 if (new_node != v_node) {
2053 changed = true;
2054 }
2055 }
2056 } else {
2057 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
2058 MS_EXCEPTION_IF_NULL(abs_seq);
2059 const auto &abs_seq_elements = abs_seq->elements();
2060 const auto &value_sequence_values = value_sequence->value();
2061 if (abs_seq_elements.size() != value_sequence_values.size()) {
2062 MS_LOG(EXCEPTION) << "The size of value sequence should be same as the size of abstract sequence.";
2063 }
2064 for (size_t i = 0; i < value_sequence_values.size(); ++i) {
2065 auto v = value_sequence_values[i];
2066 auto v_node = NewValueNode(v);
2067 v_node->set_debug_info(value_node->debug_info());
2068 v_node->set_abstract(abs_seq_elements[i]);
2069 auto new_node = GetPyExecuteFromValue(fg, v_node, v, py_execute_input);
2070 new_node->set_debug_info(value_node->debug_info());
2071 (void)new_inputs.emplace_back(new_node);
2072 if (new_node != v_node) {
2073 changed = true;
2074 }
2075 }
2076 }
2077 if (changed) {
2078 auto ret = fg->NewCNode(new_inputs);
2079 ret->set_abstract(value_node->abstract());
2080 return ret;
2081 }
2082 return value_node;
2083 }
2084
ConvertTypeToPyExecute(const FuncGraphPtr & fg,const ValueNodePtr & node,const TypePtr & type) const2085 AnfNodePtr ConvertTypeToPyExecute(const FuncGraphPtr &fg, const ValueNodePtr &node, const TypePtr &type) const {
2086 // Support convert type to PyExecute.
2087 const auto py_type = ValueToPyData(type);
2088 MS_LOG(DEBUG) << "py_type: " << py_type;
2089 auto res = fallback::ConvertPyObjectToPyExecute(fg, py::str(py_type).cast<std::string>(), py_type, node, false);
2090 fallback::SetRealType(res, type);
2091 return res;
2092 }
2093
ConvertClassTypeToPyExecute(const FuncGraphPtr & fg,const ValueNodePtr & node,const ClassTypePtr & class_type) const2094 AnfNodePtr ConvertClassTypeToPyExecute(const FuncGraphPtr &fg, const ValueNodePtr &node,
2095 const ClassTypePtr &class_type) const {
2096 // Support convert class type to PyExecute.
2097 const auto py_type = ValueToPyData(class_type);
2098 MS_LOG(DEBUG) << "py_type: " << py_type;
2099 auto res = fallback::ConvertPyObjectToPyExecute(fg, py::str(py_type).cast<std::string>(), py_type, node, true);
2100 fallback::SetRealType(res, class_type);
2101 MS_LOG(DEBUG) << "res: " << res->DebugString();
2102 return res;
2103 }
2104
ConvertNameSpaceToPyExecute(const FuncGraphPtr & fg,const ValueNodePtr & node,const parse::NameSpacePtr & name_space) const2105 AnfNodePtr ConvertNameSpaceToPyExecute(const FuncGraphPtr &fg, const ValueNodePtr &node,
2106 const parse::NameSpacePtr &name_space) const {
2107 // Support convert namespace to PyExecute.
2108 const auto name_space_type = ValueToPyData(name_space);
2109 MS_LOG(DEBUG) << "name_space_type: " << name_space_type;
2110 auto res = fallback::ConvertPyObjectToPyExecute(fg, py::str(name_space_type).cast<std::string>(), name_space_type,
2111 node, true);
2112 fallback::SetRealType(res, name_space);
2113 MS_LOG(DEBUG) << "res: " << res->DebugString();
2114 return res;
2115 }
2116
IsValueListWithInplace(const ValueNodePtr & value_node) const2117 bool IsValueListWithInplace(const ValueNodePtr &value_node) const {
2118 if (!fallback::EnableFallbackListDictInplace()) {
2119 return false;
2120 }
2121
2122 MS_EXCEPTION_IF_NULL(value_node);
2123 auto abs = value_node->abstract();
2124 MS_EXCEPTION_IF_NULL(abs);
2125 auto list_abs = abs->cast<abstract::AbstractListPtr>();
2126 MS_EXCEPTION_IF_NULL(list_abs);
2127 if (!fallback::HasObjInExtraInfoHolder(list_abs)) {
2128 return false;
2129 }
2130 py::list list_object = fallback::GetObjFromExtraInfoHolder(list_abs);
2131 // The value list do not need to convert to PyExecute if:
2132 // 1. The list is created within graph.
2133 // 2. The list and its elements do not perform any inplace operation.
2134 if (fallback::GetCreateInGraphFromExtraInfoHolder(list_abs) && !CheckSeqWithInplace(list_object)) {
2135 return false;
2136 }
2137 return true;
2138 }
2139
ConvertValueSlice(const FuncGraphPtr & func_graph,const AnfNodePtr & slice_node,const ValueSlicePtr & value_slice)2140 AnfNodePtr ConvertValueSlice(const FuncGraphPtr &func_graph, const AnfNodePtr &slice_node,
2141 const ValueSlicePtr &value_slice) {
2142 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
2143 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
2144 bool is_start_none = value_slice->start()->isa<None>();
2145 bool is_stop_none = value_slice->stop()->isa<None>();
2146 bool is_step_none = value_slice->step()->isa<None>();
2147 auto start_str = is_start_none ? "None" : "__start__";
2148 auto stop_str = is_stop_none ? "None" : "__stop__";
2149 auto step_str = is_step_none ? "None" : "__step__";
2150 // Script
2151 std::stringstream script_buffer;
2152 script_buffer << "slice(" << start_str << ", " << stop_str << ", " << step_str << ")";
2153 const std::string &script = script_buffer.str();
2154 const auto script_str = std::make_shared<StringImm>(script);
2155
2156 // Pack local parameters keys and values.
2157 (void)key_value_names_list.emplace_back(NewValueNode(start_str));
2158 (void)key_value_names_list.emplace_back(NewValueNode(stop_str));
2159 (void)key_value_names_list.emplace_back(NewValueNode(step_str));
2160 AnfNodePtr start_node;
2161 AnfNodePtr end_node;
2162 AnfNodePtr step_node;
2163 if (!is_start_none) {
2164 start_node = func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), slice_node, NewValueNode("start")});
2165 } else {
2166 start_node = NewValueNode(start_str);
2167 }
2168 if (!is_stop_none) {
2169 end_node = func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), slice_node, NewValueNode("stop")});
2170 } else {
2171 end_node = NewValueNode(stop_str);
2172 }
2173 if (!is_step_none) {
2174 step_node = func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), slice_node, NewValueNode("step")});
2175 } else {
2176 step_node = NewValueNode(stop_str);
2177 }
2178 (void)key_value_list.emplace_back(start_node);
2179 (void)key_value_list.emplace_back(end_node);
2180 (void)key_value_list.emplace_back(step_node);
2181 const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
2182 const auto key_value_tuple = func_graph->NewCNode(key_value_list);
2183 return fallback::CreatePyExecuteCNodeInOrder(func_graph, NewValueNode(script_str), key_value_name_tuple,
2184 key_value_tuple, key_value_tuple->debug_info());
2185 }
2186
GetPyExecuteFromValue(const FuncGraphPtr & fg,const ValueNodePtr & value_node,const ValuePtr & value,bool py_execute_input)2187 AnfNodePtr GetPyExecuteFromValue(const FuncGraphPtr &fg, const ValueNodePtr &value_node, const ValuePtr &value,
2188 bool py_execute_input) {
2189 MS_EXCEPTION_IF_NULL(fg);
2190 MS_EXCEPTION_IF_NULL(value_node);
2191 MS_EXCEPTION_IF_NULL(value);
2192 if (value->isa<None>()) {
2193 constexpr auto vmap_prefix = "VmapRule";
2194 if (value_node->scope() != nullptr &&
2195 value_node->scope()->name().compare(0, strlen(vmap_prefix), vmap_prefix) == 0) {
2196 return value_node;
2197 }
2198 return ConvertNoneToPyExecute(fg);
2199 }
2200 if (fallback::GetJitSyntaxLevel() == kLax) {
2201 if (value->isa<Type>()) {
2202 return ConvertTypeToPyExecute(fg, value_node, value->cast<TypePtr>());
2203 } else if (value->isa<parse::ClassType>()) {
2204 auto class_type = GetValueNode<ClassTypePtr>(value_node);
2205 MS_EXCEPTION_IF_NULL(class_type);
2206 return ConvertClassTypeToPyExecute(fg, value_node, class_type);
2207 } else if (value->isa<parse::NameSpace>()) {
2208 auto name_space = GetValueNode<parse::NameSpacePtr>(value_node);
2209 MS_EXCEPTION_IF_NULL(name_space);
2210 return ConvertNameSpaceToPyExecute(fg, value_node, name_space);
2211 }
2212 }
2213 if (value->isa<parse::MsClassObject>()) {
2214 return fallback::ConvertMsClassObjectToPyExecute(fg, value, value_node);
2215 }
2216 if (value->isa<parse::InterpretedObject>()) {
2217 const auto interpreted_value = dyn_cast<parse::InterpretedObject>(value);
2218 const std::string &key = interpreted_value->name();
2219 return fallback::ConvertPyObjectToPyExecute(fg, key, interpreted_value->obj(), value_node, true);
2220 }
2221 if (value->isa<ValueTuple>()) {
2222 return GetPyExecuteFromValueSequence(fg, value_node, value->cast<ValueSequencePtr>(), prim::kPrimMakeTuple,
2223 py_execute_input);
2224 }
2225 if (value->isa<ValueList>()) {
2226 if (!IsValueListWithInplace(value_node) && !py_execute_input) {
2227 return GetPyExecuteFromValueSequence(fg, value_node, value->cast<ValueSequencePtr>(), prim::kPrimMakeList,
2228 py_execute_input);
2229 }
2230 return RebuildValueList(fg, value_node);
2231 }
2232 if (value->isa<ValueDictionary>()) {
2233 return RebuildValueDict(fg, value_node, value->cast<ValueDictionaryPtr>());
2234 }
2235 if (value->isa<ValueSlice>()) {
2236 return ConvertValueSlice(fg, value_node, value->cast<ValueSlicePtr>());
2237 }
2238 return value_node;
2239 }
2240
ConvertValueInputToPyExecute(const CNodePtr & cnode)2241 void ConvertValueInputToPyExecute(const CNodePtr &cnode) {
2242 MS_EXCEPTION_IF_NULL(cnode);
2243 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
2244 if (!allow_fallback_runtime) {
2245 return;
2246 }
2247 if (AnfUtils::IsRealKernel(cnode) && !IsOneOfPrimitiveCNode(cnode, inplace_prim_set) &&
2248 !IsOneOfPrimitiveCNode(cnode, seq_prim_set_)) {
2249 return;
2250 }
2251 if (IsOneOfPrimitiveCNode(cnode, seq_prim_set_)) {
2252 const auto &inputs = cnode->inputs();
2253 std::vector<AbstractBasePtr> inputs_abs;
2254 for (size_t i = 1; i < inputs.size(); ++i) {
2255 inputs_abs.push_back(inputs[i]->abstract());
2256 }
2257 auto output_abs = cnode->abstract();
2258 MS_EXCEPTION_IF_NULL(output_abs);
2259 // Only sequence ops with nested sequence input or irregular input (element with different shape/type)
2260 // or the output abstract of sequence node is AbstractAny should be converted to PyExecute node later and
2261 // their sequence input should be converted to PyExecute.
2262 if (!CheckAndConvertUtils::CheckContainNestedOrIrregularSequence(inputs_abs) &&
2263 !output_abs->isa<abstract::AbstractAny>()) {
2264 return;
2265 }
2266 }
2267 const auto &inputs = cnode->inputs();
2268 auto cur_func = cnode->func_graph();
2269 MS_EXCEPTION_IF_NULL(cur_func);
2270 for (const auto &input : inputs) {
2271 auto value_node = dyn_cast<ValueNode>(input);
2272 if (value_node == nullptr) {
2273 continue;
2274 }
2275 const auto &value = value_node->value();
2276 if (fallback::GetJitSyntaxLevel() == kLax) {
2277 // Not convert the 'type' used by Cast primitive.
2278 if (value->isa<Type>() && IsPrimitiveCNode(cnode, prim::kPrimCast)) {
2279 continue;
2280 }
2281 }
2282 auto debug_info = value_node->debug_info();
2283 auto location_info = trace::GetDebugInfoStr(debug_info);
2284 if (location_info.empty()) {
2285 value_node->set_debug_info(cnode->debug_info());
2286 }
2287 auto new_input = GetPyExecuteFromValue(cur_func, value_node, value, false);
2288 if (new_input == input) {
2289 continue;
2290 }
2291 new_input->set_debug_info(value_node->debug_info());
2292 (void)manager_->Replace(input, new_input);
2293 set_need_renormalized(true);
2294 }
2295 }
2296
ConvertSequenceOps(const CNodePtr & cnode) const2297 AnfNodePtr ConvertSequenceOps(const CNodePtr &cnode) const {
2298 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
2299 if (!allow_fallback_runtime) {
2300 return nullptr;
2301 }
2302 const auto &inputs = cnode->inputs();
2303 std::vector<AbstractBasePtr> inputs_abs;
2304 for (size_t i = 1; i < inputs.size(); ++i) {
2305 inputs_abs.push_back(inputs[i]->abstract());
2306 }
2307 auto output_abs = cnode->abstract();
2308 MS_EXCEPTION_IF_NULL(output_abs);
2309 // Only sequence ops with nested sequence input or irregular input (element with different shape/type)
2310 // or the output abstract of sequence node is AbstractAny should be converted to PyExecute node.
2311 if (!CheckAndConvertUtils::CheckContainNestedOrIrregularSequence(inputs_abs) &&
2312 !output_abs->isa<abstract::AbstractAny>()) {
2313 return nullptr;
2314 }
2315
2316 auto prim = GetValueNode<PrimitivePtr>(inputs[0]);
2317 MS_EXCEPTION_IF_NULL(prim);
2318 const auto &prim_name = prim->name();
2319
2320 const auto &fg = cnode->func_graph();
2321 MS_EXCEPTION_IF_NULL(fg);
2322 const std::string seq_ops_dir = "__import__('mindspore').ops.operations._sequence_ops.";
2323 const std::string input_prefix = "__internal_input_";
2324
2325 std::stringstream script_buffer;
2326 script_buffer << seq_ops_dir << prim_name << "()(";
2327 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
2328 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
2329 for (size_t i = 1; i < inputs.size(); ++i) {
2330 auto cur_input_str = input_prefix + std::to_string(i - 1) + "__";
2331 script_buffer << cur_input_str << ",";
2332 (void)key_value_names_list.emplace_back(NewValueNode(cur_input_str));
2333 (void)key_value_list.emplace_back(inputs[i]);
2334 }
2335 script_buffer << ")";
2336 const std::string &script = script_buffer.str();
2337 auto script_node = NewValueNode(std::make_shared<StringImm>(script));
2338 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
2339 const auto key_value_tuple = fg->NewCNode(key_value_list);
2340
2341 auto res =
2342 fallback::CreatePyExecuteCNode(fg, script_node, key_value_name_tuple, key_value_tuple, cnode->debug_info());
2343 MS_LOG(DEBUG) << "Convert sequence node: " << cnode->DebugString() << " to " << res->DebugString();
2344 return res;
2345 }
2346
ConvertPrimitiveCNode(const CNodePtr & cnode)2347 AnfNodePtr ConvertPrimitiveCNode(const CNodePtr &cnode) override {
2348 // Get primitive from cnode.
2349 const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2350 if (prim == nullptr) {
2351 return nullptr;
2352 }
2353 ConvertValueInputToPyExecute(cnode);
2354
2355 // Find cnode converter by primitive.
2356 auto iter = converters_.find(prim);
2357 if (iter != converters_.end()) {
2358 // Call converter.
2359 return (this->*(iter->second))(cnode);
2360 }
2361 if (seq_prim_set_.find(prim) != seq_prim_set_.end()) {
2362 return ConvertSequenceOps(cnode);
2363 }
2364 return nullptr;
2365 }
2366
PackDictValue(const FuncGraphPtr & fg,const ValueNodePtr & value_node,const ValueDictionaryPtr & dict)2367 AnfNodePtr PackDictValue(const FuncGraphPtr &fg, const ValueNodePtr &value_node, const ValueDictionaryPtr &dict) {
2368 const auto &keys_values = dict->value();
2369 auto abs_dict = dyn_cast<abstract::AbstractDictionary>(value_node->abstract());
2370 const auto &abs_keys_values = abs_dict->elements();
2371 if (keys_values.size() != abs_keys_values.size()) {
2372 MS_LOG(INTERNAL_EXCEPTION) << "The size of value dict should be same as the size of abstract dict.";
2373 }
2374 std::vector<AnfNodePtr> value_list{NewValueNode(prim::kPrimMakeTuple)};
2375 for (size_t i = 0; i < keys_values.size(); ++i) {
2376 auto key_value = keys_values[i];
2377 auto new_vnode = NewValueNode(key_value.second);
2378 new_vnode->set_debug_info(value_node->debug_info());
2379 new_vnode->set_abstract(abs_keys_values[i].second);
2380 auto iter_value = GetPyExecuteFromValue(fg, new_vnode, key_value.second, true);
2381 iter_value->set_debug_info(value_node->debug_info());
2382 (void)value_list.emplace_back(iter_value);
2383 }
2384 auto value_tuple_node = fg->NewCNode(value_list);
2385 return value_tuple_node;
2386 }
2387
2388 // If the value dict has attached object:
2389 // dict(k0:v0, k1:v1, ...) --> PyExecute('get_local_variable(dict_key)', ...)
2390 // otherwise:
2391 // dict(k0:v0, k1:v1, ...) --> PyExecute('dict(zip(keys, values))', ...)
RebuildValueDict(const FuncGraphPtr & fg,const ValueNodePtr & value_node,const ValueDictionaryPtr & dict)2392 AnfNodePtr RebuildValueDict(const FuncGraphPtr &fg, const ValueNodePtr &value_node, const ValueDictionaryPtr &dict) {
2393 if (not_convert_jit_) {
2394 return value_node;
2395 }
2396 auto abs = value_node->abstract();
2397 MS_EXCEPTION_IF_NULL(abs);
2398 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
2399 MS_EXCEPTION_IF_NULL(abs_dict);
2400 if (fallback::HasObjInExtraInfoHolder(abs_dict) && !fallback::GetCreateInGraphFromExtraInfoHolder(abs_dict)) {
2401 // If the abstract of value dict has python object and the python object is created outside the graph,
2402 // the we use the python object to generate pyexecute node.
2403 py::dict dict_object = fallback::GetObjFromExtraInfoHolder(abs_dict);
2404 const std::string dict_obj_str_prefix = "__dict_py_object_";
2405 auto dict_obj_id = fallback::GetPyObjectPtrStr(dict_object);
2406 MS_LOG(DEBUG) << "Current python object id: " << dict_obj_id;
2407 auto dict_obj_str = dict_obj_str_prefix + dict_obj_id + "_";
2408 auto res = fallback::ConvertPyObjectToPyExecute(fg, dict_obj_str, dict_object, value_node, false);
2409 MS_LOG(DEBUG) << "Convert value dict node: " << value_node->DebugString()
2410 << " to inplace pyexecute node: " << res->DebugString();
2411 return res;
2412 }
2413
2414 const auto &keys_values = dict->value();
2415
2416 // Local parameters values.
2417 // Pack the key tuple.
2418 std::vector<ValuePtr> key_list;
2419 key_list.reserve(keys_values.size());
2420 for (const auto &key_value : keys_values) {
2421 (void)key_list.emplace_back(key_value.first);
2422 }
2423 const auto key_tuple = std::make_shared<ValueTuple>(key_list);
2424 auto key_tuple_node = NewValueNode(key_tuple);
2425 key_tuple_node->set_debug_info(value_node->debug_info());
2426 // Pack the value tuple.
2427 auto value_tuple_node = PackDictValue(fg, value_node, dict);
2428
2429 // Generate Make Dict PyExecute Node value
2430 auto make_key_tuple_node = ConstructInternalTupleKeysNode(fg, key_tuple_node);
2431 auto make_value_tuple_node = ConstructInternalTupleValueNode(fg, value_tuple_node);
2432
2433 auto make_dict_node = ConstructNewDictNode(fg, make_key_tuple_node, make_value_tuple_node);
2434 make_dict_node->set_debug_info(value_node->debug_info());
2435 MS_LOG(DEBUG) << "Convert value dict node: " << value_node->DebugString()
2436 << " to non-inplace pyexecute node: " << make_dict_node->DebugString();
2437 return make_dict_node;
2438 }
2439
CheckSeqWithInplace(const py::sequence & seq) const2440 bool CheckSeqWithInplace(const py::sequence &seq) const {
2441 if (py::isinstance<py::list>(seq)) {
2442 const auto &seq_str = fallback::GetPyObjectPtrStr(seq);
2443 if (data_with_inplace_->find(seq_str) != data_with_inplace_->end()) {
2444 return true;
2445 }
2446 }
2447 for (const auto &obj : seq) {
2448 if (py::isinstance<py::list>(obj) && CheckSeqWithInplace(py::list(obj))) {
2449 return true;
2450 }
2451 if (py::isinstance<py::tuple>(obj) && CheckSeqWithInplace(py::tuple(obj))) {
2452 return true;
2453 }
2454 }
2455 return false;
2456 }
2457
RebuildValueList(const FuncGraphPtr & fg,const ValueNodePtr & value_node) const2458 AnfNodePtr RebuildValueList(const FuncGraphPtr &fg, const ValueNodePtr &value_node) const {
2459 MS_EXCEPTION_IF_NULL(value_node);
2460 MS_EXCEPTION_IF_NULL(fg);
2461
2462 auto value = value_node->value();
2463 MS_EXCEPTION_IF_NULL(value);
2464 auto value_list = value->cast<ValueListPtr>();
2465 MS_EXCEPTION_IF_NULL(value_list);
2466
2467 auto abs = value_node->abstract();
2468 MS_EXCEPTION_IF_NULL(abs);
2469 auto list_abs = abs->cast<abstract::AbstractListPtr>();
2470 MS_EXCEPTION_IF_NULL(list_abs);
2471
2472 if (list_abs->dynamic_len()) {
2473 return value_node;
2474 }
2475
2476 bool has_object = fallback::HasObjInExtraInfoHolder(list_abs);
2477 py::list list_object = has_object ? fallback::GetObjFromExtraInfoHolder(list_abs) : ValueToPyData(value);
2478
2479 // Generate PyExecute node: __list_object__
2480 const std::string list_obj_str_prefix = "__list_py_object_";
2481 auto list_obj_id = fallback::GetPyObjectPtrStr(list_object);
2482 MS_LOG(DEBUG) << "Current python object id: " << list_obj_id;
2483 auto list_obj_str = list_obj_str_prefix + list_obj_id + "_";
2484 auto res = fallback::ConvertPyObjectToPyExecute(fg, list_obj_str, list_object, value_node, false);
2485
2486 return res;
2487 }
2488
ConvertInterpretedObjectValue(const ValueNodePtr & node,const parse::InterpretedObjectPtr & value) const2489 AnfNodePtr ConvertInterpretedObjectValue(const ValueNodePtr &node, const parse::InterpretedObjectPtr &value) const {
2490 // Convert InterpretedObject value node to PyExecute CNode.
2491 const auto interpreted_value = dyn_cast<parse::InterpretedObject>(value);
2492 const std::string &key = interpreted_value->name();
2493 return fallback::ConvertPyObjectToPyExecute(root_graph_, key, interpreted_value->obj(), node, true);
2494 }
2495
ConvertValueNode(const ValueNodePtr & value_node,const ValuePtr & value)2496 AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override {
2497 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
2498 if (allow_fallback_runtime) {
2499 if (value->ContainsValueAny()) {
2500 return nullptr;
2501 }
2502 if (value->isa<ValueDictionary>()) {
2503 return RebuildValueDict(root_graph_, value_node, value->cast<ValueDictionaryPtr>());
2504 } else if (value->isa<parse::InterpretedObject>()) {
2505 return ConvertInterpretedObjectValue(value_node, value->cast<parse::InterpretedObjectPtr>());
2506 } else if (value->isa<parse::MsClassObject>()) {
2507 return fallback::ConvertMsClassObjectToPyExecute(root_graph_, value, value_node);
2508 }
2509 }
2510 return nullptr;
2511 }
2512
2513 // AbstractRowTensor --> AbstractTuple.
ConvertToAbstractTuple(const AbstractBasePtr & abs,size_t depth)2514 static AbstractBasePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) {
2515 if (depth > kMaxSeqRecursiveDepth) {
2516 MS_LOG(ERROR) << "abs:" << abs->ToString();
2517 MS_LOG(INTERNAL_EXCEPTION) << "List, tuple and dict nesting is not allowed more than " << kMaxSeqRecursiveDepth
2518 << " levels.";
2519 }
2520 // Convert RowTensor in AbstractSequence to AbstractTuple.
2521 auto abs_seq = abs->cast<AbstractSequencePtr>();
2522 if (abs_seq != nullptr) {
2523 // Dynamic length sequence do not convert.
2524 if (abs_seq->dynamic_len()) {
2525 return nullptr;
2526 }
2527 const auto &seq_elements = abs_seq->elements();
2528 // First we check if elements should be converted,
2529 // changed_elements maps old element to new element.
2530 mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
2531 for (const auto &element : seq_elements) {
2532 auto new_element = ConvertToAbstractTuple(element, depth + 1);
2533 if (new_element != nullptr) {
2534 (void)changed_elements.emplace(element, new_element);
2535 }
2536 }
2537 if (changed_elements.empty()) {
2538 // If no RowTensor in sequence is changed, do not convert.
2539 return nullptr;
2540 }
2541 // Make new abstract sequence.
2542 std::vector<AbstractBasePtr> elements;
2543 elements.reserve(seq_elements.size());
2544 for (const auto &element : seq_elements) {
2545 auto iter = changed_elements.find(element);
2546 if (iter != changed_elements.end()) {
2547 (void)elements.emplace_back(iter->second);
2548 } else {
2549 (void)elements.emplace_back(element);
2550 }
2551 }
2552 if (abs_seq->isa<AbstractList>()) {
2553 return std::make_shared<AbstractList>(std::move(elements));
2554 }
2555 return std::make_shared<AbstractTuple>(std::move(elements));
2556 }
2557 // AbstractRowTensor --> AbstractTuple.
2558 auto abs_row_tensor = abs->cast<std::shared_ptr<AbstractRowTensor>>();
2559 if (abs_row_tensor != nullptr) {
2560 std::vector<AbstractBasePtr> elements{abs_row_tensor->indices(), abs_row_tensor->values(),
2561 abs_row_tensor->dense_shape()};
2562 return std::make_shared<AbstractTuple>(std::move(elements));
2563 }
2564 return nullptr;
2565 }
2566
ConvertAbstract(const AbstractBasePtr & abs)2567 AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
2568 // AbstractSequence, AbstractDict, AbstractRowTensor --> AbstractTuple.
2569 return ConvertToAbstractTuple(abs, 0);
2570 }
2571
2572 private:
2573 StringSetPtr data_with_inplace_;
2574 bool not_convert_jit_{false};
2575 };
2576
FindValueWithInplaceInner(const FuncGraphPtr & graph,const StringSetPtr & value_with_inplace)2577 void FindValueWithInplaceInner(const FuncGraphPtr &graph, const StringSetPtr &value_with_inplace) {
2578 MS_EXCEPTION_IF_NULL(graph);
2579 AnfNodePtr return_node = graph->get_return();
2580 MS_EXCEPTION_IF_NULL(return_node);
2581 std::vector<AnfNodePtr> all_nodes = TopoSort(return_node);
2582 constexpr size_t sequence_index = 1;
2583 for (auto &node : all_nodes) {
2584 MS_EXCEPTION_IF_NULL(node);
2585 if (!IsOneOfPrimitiveCNode(node, inplace_prim_set)) {
2586 continue;
2587 }
2588 auto cnode = node->cast<CNodePtr>();
2589 auto sequence_node = cnode->input(sequence_index);
2590 MS_EXCEPTION_IF_NULL(sequence_node);
2591 if (!IsValueNode<ValueList>(sequence_node)) {
2592 continue;
2593 }
2594 auto abs = sequence_node->abstract();
2595 if (abs == nullptr || !abs->isa<abstract::AbstractList>()) {
2596 continue;
2597 }
2598 auto abs_list = abs->cast<abstract::AbstractListPtr>();
2599 auto list_py_object = fallback::GetObjFromExtraInfoHolder(abs_list);
2600 MS_LOG(DEBUG) << "Found list python object in inplace: " << py::str(list_py_object);
2601 const auto &list_py_object_str = fallback::GetPyObjectPtrStr(list_py_object);
2602 (void)value_with_inplace->insert(list_py_object_str);
2603 }
2604 }
2605
FindValueWithInplace(const FuncGraphPtr & root,const pipeline::ResourcePtr & resource,const StringSetPtr & value_with_inplace)2606 void FindValueWithInplace(const FuncGraphPtr &root, const pipeline::ResourcePtr &resource,
2607 const StringSetPtr &value_with_inplace) {
2608 const auto func_graphs_used_total = root->func_graphs_used_total();
2609 for (const auto &fg : func_graphs_used_total) {
2610 FindValueWithInplaceInner(fg, value_with_inplace);
2611 }
2612 FindValueWithInplaceInner(root, value_with_inplace);
2613 }
2614
ConvertToPyExecuteGetItem(const AnfNodePtr & node)2615 AnfNodePtr ConvertToPyExecuteGetItem(const AnfNodePtr &node) {
2616 MS_EXCEPTION_IF_NULL(node);
2617 if (!IsOneOfPrimitiveCNode(node, sequence_getitem_prim_set)) {
2618 return nullptr;
2619 }
2620 auto abs = node->abstract();
2621 MS_EXCEPTION_IF_NULL(abs);
2622 if (!abs->isa<abstract::AbstractAny>()) {
2623 return nullptr;
2624 }
2625 return ConvertSequenceGetItemInner(node->cast<CNodePtr>());
2626 }
2627
CheckNeedConvertList(const AbstractBasePtr & abs)2628 bool CheckNeedConvertList(const AbstractBasePtr &abs) {
2629 if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
2630 return false;
2631 }
2632 // If abstract has real type/shape, it means the corresponding node is PyExecute.
2633 // Do not covert PyExecute node.
2634 if (fallback::HasRealType(abs) || fallback::HasRealShape(abs)) {
2635 return false;
2636 }
2637 auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
2638 if (seq_abs->dynamic_len()) {
2639 return false;
2640 }
2641 if (seq_abs->isa<abstract::AbstractList>()) {
2642 return true;
2643 }
2644 const auto &elements = seq_abs->elements();
2645 return std::any_of(elements.begin(), elements.end(),
2646 [](const AbstractBasePtr &abs) { return CheckNeedConvertList(abs); });
2647 }
2648
ConvertToPyExecuteListInner(const AnfNodePtr & node,const FuncGraphPtr & fg)2649 AnfNodePtr ConvertToPyExecuteListInner(const AnfNodePtr &node, const FuncGraphPtr &fg) {
2650 MS_EXCEPTION_IF_NULL(node);
2651 auto abs = node->abstract();
2652 if (abs == nullptr || !CheckNeedConvertList(abs)) {
2653 return nullptr;
2654 }
2655 auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
2656 MS_EXCEPTION_IF_NULL(seq_abs);
2657 const auto &elements = seq_abs->elements();
2658 if (abs->isa<abstract::AbstractList>()) {
2659 const std::string element_prefix = "__list_element_";
2660 std::stringstream script_buffer;
2661 script_buffer << "[";
2662 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
2663 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
2664 for (size_t i = 0; i < elements.size(); ++i) {
2665 auto element_abs = elements[i];
2666 auto element_node =
2667 fg->NewCNode({NewValueNode(prim::kPrimListGetItem), node, NewValueNode(MakeValue<int64_t>(i))});
2668 element_node->set_abstract(element_abs);
2669 auto new_element_node = ConvertToPyExecuteListInner(element_node, fg);
2670 if (new_element_node == nullptr) {
2671 new_element_node = element_node;
2672 }
2673 std::string element_name = element_prefix + std::to_string(i) + "__";
2674 script_buffer << element_name << ",";
2675 (void)key_value_names_list.emplace_back(NewValueNode(element_name));
2676 (void)key_value_list.emplace_back(new_element_node);
2677 }
2678 script_buffer << "]";
2679 const std::string &script = script_buffer.str();
2680 const auto script_str = std::make_shared<StringImm>(script);
2681 const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
2682 const auto key_value_tuple = fg->NewCNode(key_value_list);
2683 return fallback::CreatePyExecuteCNode(fg, NewValueNode(script_str), key_value_name_tuple, key_value_tuple,
2684 node->debug_info());
2685 }
2686 std::vector<AnfNodePtr> new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
2687 for (size_t i = 0; i < elements.size(); ++i) {
2688 auto element_abs = elements[i];
2689 auto element_node =
2690 fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(MakeValue<int64_t>(i))});
2691 element_node->set_abstract(element_abs);
2692 auto new_element_node = ConvertToPyExecuteListInner(element_node, fg);
2693 if (new_element_node == nullptr) {
2694 new_element_node = element_node;
2695 }
2696 (void)new_make_tuple_inputs.emplace_back(new_element_node);
2697 }
2698 return fg->NewCNode(new_make_tuple_inputs);
2699 }
2700
ConvertToPyExecuteList(const AnfNodePtr & node)2701 AnfNodePtr ConvertToPyExecuteList(const AnfNodePtr &node) {
2702 MS_EXCEPTION_IF_NULL(node);
2703 if (!IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
2704 return nullptr;
2705 }
2706 constexpr size_t pyexecute_min_len = 4;
2707 auto cnode = node->cast<CNodePtr>();
2708 if (cnode->size() < pyexecute_min_len) {
2709 MS_LOG(INTERNAL_EXCEPTION) << "The minimum len of input to PyExecute should " << pyexecute_min_len << " but got "
2710 << cnode->size() << " for node: " << cnode->DebugString();
2711 }
2712 constexpr size_t pyexecute_value_index = 3;
2713 const auto &fg = cnode->func_graph();
2714 return ConvertToPyExecuteListInner(cnode->input(pyexecute_value_index), fg);
2715 }
2716
ConvertPyExecuteAfterRewriter(const FuncGraphPtr & graph,const FuncGraphManagerPtr & manager)2717 bool ConvertPyExecuteAfterRewriter(const FuncGraphPtr &graph, const FuncGraphManagerPtr &manager) {
2718 MS_EXCEPTION_IF_NULL(graph);
2719 AnfNodePtr return_node = graph->get_return();
2720 MS_EXCEPTION_IF_NULL(return_node);
2721 std::vector<AnfNodePtr> all_nodes = TopoSort(return_node);
2722 bool change = false;
2723 constexpr size_t pyexecute_value_index = 3;
2724 for (auto &node : all_nodes) {
2725 MS_EXCEPTION_IF_NULL(node);
2726 auto tr = manager->Transact();
2727 auto new_node = ConvertToPyExecuteGetItem(node);
2728 if (new_node != nullptr) {
2729 tr.Replace(node, new_node);
2730 tr.Commit();
2731 change = true;
2732 continue;
2733 }
2734 auto new_value_input = ConvertToPyExecuteList(node);
2735 if (new_value_input != nullptr) {
2736 tr.SetEdge(node, pyexecute_value_index, new_value_input);
2737 tr.Commit();
2738 change = true;
2739 continue;
2740 }
2741 }
2742 return change;
2743 }
2744
OrderPyExecuteCNode(const FuncGraphPtr & graph,const FuncGraphManagerPtr & manager)2745 static inline bool OrderPyExecuteCNode(const FuncGraphPtr &graph, const FuncGraphManagerPtr &manager) {
2746 MS_EXCEPTION_IF_NULL(graph);
2747 AnfNodePtr return_node = graph->get_return();
2748 MS_EXCEPTION_IF_NULL(return_node);
2749 std::vector<AnfNodePtr> all_nodes = TopoSort(return_node);
2750 CNodePtr former_node = nullptr;
2751 CNodePtr latter_node = nullptr;
2752 bool change = false;
2753 for (auto &node : all_nodes) {
2754 MS_EXCEPTION_IF_NULL(node);
2755 if (!IsPrimitiveCNode(node, prim::kPrimPyExecute) || node->func_graph() != graph) {
2756 continue;
2757 }
2758 if (former_node == nullptr) {
2759 former_node = dyn_cast<CNode>(node);
2760 continue;
2761 } else {
2762 latter_node = dyn_cast<CNode>(node);
2763 }
2764 MS_EXCEPTION_IF_NULL(former_node);
2765 MS_EXCEPTION_IF_NULL(latter_node);
2766
2767 // Make former node as latter node's input.
2768 auto tr = manager->Transact();
2769 size_t latest_index = latter_node->size() - 1;
2770 const auto &last_input_abs = latter_node->input(latest_index)->abstract();
2771 if (last_input_abs != nullptr && last_input_abs->isa<abstract::AbstractMonad>()) { // Should be IO monad.
2772 const auto &monad_node = latter_node->input(latest_index);
2773 tr.SetEdge(latter_node, latest_index, former_node);
2774 tr.AddEdge(latter_node, monad_node);
2775 } else {
2776 tr.AddEdge(latter_node, former_node);
2777 }
2778 tr.Commit();
2779
2780 former_node = latter_node;
2781 change = true;
2782 }
2783 return change;
2784 }
2785 } // namespace
2786
RewriterBeforeOptA(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)2787 bool RewriterBeforeOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
2788 MS_EXCEPTION_IF_NULL(manager);
2789 manager->AddFuncGraph(root);
2790 BeforeOptARewriter rewriter(root, manager);
2791 return rewriter.Execute();
2792 }
2793
RewriterAfterOptA(const FuncGraphPtr & root,const pipeline::ResourcePtr & resource)2794 bool RewriterAfterOptA(const FuncGraphPtr &root, const pipeline::ResourcePtr &resource) {
2795 MS_EXCEPTION_IF_NULL(root);
2796 MS_EXCEPTION_IF_NULL(resource);
2797 auto manager = resource->manager();
2798 MS_EXCEPTION_IF_NULL(manager);
2799 manager->AddFuncGraph(root);
2800 StringSetPtr value_with_inplace = std::make_shared<StringSet>();
2801 FindValueWithInplace(root, resource, value_with_inplace);
2802 AfterOptARewriter rewriter(root, manager, value_with_inplace);
2803 bool change = rewriter.Execute();
2804 if (rewriter.need_renormalized()) {
2805 abstract::AbstractBasePtrList new_args_spec;
2806 (void)std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),
2807 [](const AnfNodePtr ¶m) -> AbstractBasePtr { return param->abstract(); });
2808 (void)pipeline::Renormalize(resource, root, new_args_spec);
2809 }
2810 return change;
2811 }
2812
ConvertAfterRewriter(const FuncGraphPtr & root,const pipeline::ResourcePtr & resource)2813 bool ConvertAfterRewriter(const FuncGraphPtr &root, const pipeline::ResourcePtr &resource) {
2814 auto manager = resource->manager();
2815 const auto func_graphs_used_total = root->func_graphs_used_total();
2816 bool change = false;
2817 for (const auto &fg : func_graphs_used_total) {
2818 auto cur_change = ConvertPyExecuteAfterRewriter(fg, manager);
2819 change = change || cur_change;
2820 }
2821 bool root_change = ConvertPyExecuteAfterRewriter(root, manager);
2822 change = change || root_change;
2823 if (change) {
2824 abstract::AbstractBasePtrList new_args_spec;
2825 (void)std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),
2826 [](const AnfNodePtr ¶m) -> AbstractBasePtr { return param->abstract(); });
2827 (void)pipeline::Renormalize(resource, root, new_args_spec);
2828 }
2829 return change;
2830 }
2831
OrderPyExecuteAfterRewriter(const FuncGraphPtr & root,const pipeline::ResourcePtr & resource)2832 bool OrderPyExecuteAfterRewriter(const FuncGraphPtr &root, const pipeline::ResourcePtr &resource) {
2833 auto manager = resource->manager();
2834 const auto func_graphs_used_total = root->func_graphs_used_total();
2835 bool change = false;
2836 for (const auto &fg : func_graphs_used_total) {
2837 auto cur_change = OrderPyExecuteCNode(fg, manager);
2838 change = change || cur_change;
2839 }
2840 bool root_change = OrderPyExecuteCNode(root, manager);
2841 change = change || root_change;
2842 if (change) {
2843 abstract::AbstractBasePtrList new_args_spec;
2844 (void)std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),
2845 [](const AnfNodePtr ¶m) -> AbstractBasePtr { return param->abstract(); });
2846 (void)pipeline::Renormalize(resource, root, new_args_spec);
2847 }
2848 return change;
2849 }
2850 } // namespace opt
2851 } // namespace mindspore
2852