• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "frontend/optimizer/fallback_rewriter.h"
20 #include <iterator>
21 #include <string>
22 #include <algorithm>
23 #include <functional>
24 #include <utility>
25 #include <memory>
26 #include <vector>
27 #include <set>
28 #include <unordered_map>
29 #include "ops/structure_ops.h"
30 #include "ops/sparse_tensor_ops.h"
31 #include "ops/sequence_ops.h"
32 #include "ops/array_ops.h"
33 #include "ops/arithmetic_ops.h"
34 #include "ops/framework_ops.h"
35 #include "ops/auto_generate/gen_ops_primitive.h"
36 #include "ops/op_utils.h"
37 #include "abstract/abstract_value.h"
38 #include "base/base.h"
39 #include "pipeline/jit/ps/debug/trace.h"
40 #include "pipeline/jit/ps/action.h"
41 #include "pipeline/jit/ps/parse/parse_base.h"
42 #include "frontend/optimizer/opt.h"
43 #include "frontend/operator/composite/composite.h"
44 #include "include/common/fallback.h"
45 #include "include/common/utils/convert_utils_py.h"
46 #include "ir/anf.h"
47 #include "ir/value.h"
48 #include "pipeline/jit/ps/fallback.h"
49 #include "pipeline/jit/ps/parse/resolve.h"
50 #include "utils/hash_map.h"
51 #include "utils/anf_utils.h"
52 #include "utils/compile_config.h"
53 #include "utils/check_convert_utils.h"
54 #include "utils/tensor_construct_utils.h"
55 
56 namespace mindspore {
57 /* namespace to support opt */
58 namespace opt {
59 using mindspore::abstract::AbstractBase;
60 using mindspore::abstract::AbstractBasePtr;
61 using mindspore::abstract::AbstractDictionary;
62 using mindspore::abstract::AbstractDictionaryPtr;
63 using mindspore::abstract::AbstractElementPair;
64 using mindspore::abstract::AbstractList;
65 using mindspore::abstract::AbstractListPtr;
66 using mindspore::abstract::AbstractRowTensor;
67 using mindspore::abstract::AbstractScalar;
68 using mindspore::abstract::AbstractSequence;
69 using mindspore::abstract::AbstractSequencePtr;
70 using mindspore::abstract::AbstractTuple;
71 using mindspore::abstract::AbstractTuplePtr;
72 using ClassTypePtr = std::shared_ptr<parse::ClassType>;
73 using StringSet = std::set<std::string>;
74 using StringSetPtr = std::shared_ptr<StringSet>;
75 
76 constexpr auto kInternalDictSelfStr = "__internal_dict_self__";
77 constexpr auto kInternalDictKeyStr = "__internal_dict_key__";
78 constexpr auto kInternalDictValueStr = "__internal_dict_value__";
79 static const PrimitiveSet inplace_prim_set{prim::kPrimPyExecute,          prim::kPrimListInplaceAppend,
80                                            prim::kPrimListInplaceReverse, prim::kPrimListInplaceExtend,
81                                            prim::kPrimListInplaceInsert,  prim::kPrimListInplacePop,
82                                            prim::kPrimDictInplaceSetItem};
83 static const PrimitiveSet sequence_getitem_prim_set{prim::kPrimListGetItem, prim::kPrimTupleGetItem,
84                                                     prim::kPrimDictGetItem};
85 
86 namespace {
87 static constexpr size_t kMaxSeqRecursiveDepth = 6;
CheckInputsSize(const CNodePtr & cnode,size_t expect_size)88 void CheckInputsSize(const CNodePtr &cnode, size_t expect_size) {
89   if (cnode->size() != expect_size) {
90     std::string op_name = GetCNodeFuncName(cnode);
91     MS_LOG(INTERNAL_EXCEPTION) << op_name << " should have " << expect_size << " inputs, but got " << cnode->size();
92   }
93 }
94 
95 template <typename T>
GetAbstract(const AnfNodePtr & node)96 std::shared_ptr<T> GetAbstract(const AnfNodePtr &node) {
97   auto abs = node->abstract();
98   if (abs == nullptr) {
99     return nullptr;
100   }
101   return dyn_cast<T>(abs);
102 }
103 
CheckContainsDict(const AbstractBasePtr & abs)104 bool CheckContainsDict(const AbstractBasePtr &abs) {
105   if (abs == nullptr) {
106     return false;
107   }
108   if (abs->isa<AbstractDictionary>()) {
109     return true;
110   }
111   auto from_dict = abs->user_data<bool>("from_dict");
112   if (from_dict != nullptr && *from_dict) {
113     return true;
114   }
115   if (abs->isa<AbstractSequence>()) {
116     auto abs_seq = abs->cast<AbstractSequencePtr>();
117     const auto &elements = abs_seq->elements();
118     if (std::any_of(elements.begin(), elements.end(),
119                     [](const AbstractBasePtr &element) { return CheckContainsDict(element); })) {
120       return true;
121     }
122   }
123   return false;
124 }
125 
126 // ===========================================================================
127 // BaseRewriter provides a common framework for data struct simplify.
128 // ===========================================================================
129 class BaseRewriter : protected SimpleRewriter {
130  public:
BaseRewriter(const FuncGraphPtr & root_graph,const FuncGraphManagerPtr & manager)131   BaseRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
132       : SimpleRewriter(root_graph, manager) {}
133   ~BaseRewriter() override = default;
134 
need_renormalized() const135   bool need_renormalized() const { return need_renormalized_; }
136 
set_need_renormalized(bool need_renormalized)137   void set_need_renormalized(bool need_renormalized) { need_renormalized_ = need_renormalized; }
138 
Execute()139   virtual bool Execute() {
140     bool changed = Run();
141     if (changed) {
142       UpdateAbstracts();
143     }
144     return changed;
145   }
146 
147  protected:
148   virtual AnfNodePtr ConvertPrimitiveCNode(const CNodePtr &cnode) = 0;
149   virtual AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) = 0;
150   virtual AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) = 0;
151 
NodeRewrite(const AnfNodePtr & node)152   AnfNodePtr NodeRewrite(const AnfNodePtr &node) override {
153     auto new_node = ConvertNode(node);
154     if (IsPrimitiveCNode(new_node, prim::kPrimPyExecute)) {
155       need_renormalized_ = true;
156       return new_node;
157     }
158     if (new_node != nullptr) {
159       new_node->set_abstract(node->abstract());
160     }
161     return new_node;
162   }
163 
ConvertNode(const AnfNodePtr & node)164   AnfNodePtr ConvertNode(const AnfNodePtr &node) {
165     auto cnode = node->cast<CNodePtr>();
166     if (cnode != nullptr) {
167       if (cnode->size() == 0) {
168         return nullptr;
169       }
170       // Call primitive cnode converter.
171       return ConvertPrimitiveCNode(cnode);
172     }
173     auto value_node = node->cast<ValueNodePtr>();
174     if (value_node != nullptr) {
175       const auto &value = value_node->value();
176       if (value == nullptr) {
177         return nullptr;
178       }
179       // Call value node converter.
180       return ConvertValueNode(value_node, value);
181     }
182     return nullptr;
183   }
184 
UpdateAbstracts()185   virtual void UpdateAbstracts() {
186     const auto &nodes = manager_->all_nodes();
187     for (const auto &node : nodes) {
188       const auto &abs = node->abstract();
189       if (abs == nullptr) {
190         continue;
191       }
192       bool is_interpret_dict = false;
193       // Do not convert the abstract of Interpret node(AbstractDictionary) to AbstractSequence.
194       if (abs->isa<AbstractDictionary>()) {
195         AbstractDictionaryPtr abs_dict = abs->cast<AbstractDictionaryPtr>();
196         auto &dict_elements = abs_dict->elements();
197         for (auto &element : dict_elements) {
198           TypePtr type = element.second->GetTypeTrack();
199           MS_EXCEPTION_IF_NULL(type);
200           auto value = element.second->BuildValue();
201           MS_EXCEPTION_IF_NULL(value);
202           if (type->type_id() == kMetaTypeExternal && value->isa<parse::InterpretedObject>()) {
203             is_interpret_dict = true;
204             break;
205           }
206         }
207       }
208       if (is_interpret_dict) {
209         continue;
210       }
211       // Call abstract converter.
212       auto new_abs = ConvertAbstract(abs);
213       if (new_abs != nullptr) {
214         node->set_abstract(new_abs);
215       }
216     }
217   }
218 
GetElementIndex(const std::vector<AbstractElementPair> & attrs,const AnfNodePtr & name)219   static int64_t GetElementIndex(const std::vector<AbstractElementPair> &attrs, const AnfNodePtr &name) {
220     auto n_attrs = attrs.size();
221     auto name_abstract = GetAbstract<AbstractBase>(name);
222     MS_EXCEPTION_IF_NULL(name_abstract);
223     auto name_value = name_abstract->BuildValue();
224     MS_EXCEPTION_IF_NULL(name_value);
225     for (size_t i = 0; i < n_attrs; ++i) {
226       if (*name_value == *attrs[i].first->BuildValue()) {
227         return SizeToLong(i);
228       }
229     }
230     return SizeToLong(n_attrs);
231   }
232 
233  private:
234   bool need_renormalized_{false};
235 };
236 
237 // ===========================================================================
238 // BeforeOptARewriter convert ObjectClass, Dictionary to Tuple.
239 // ===========================================================================
240 class BeforeOptARewriter : public BaseRewriter {
241  public:
242   using ThisClass = BeforeOptARewriter;
BeforeOptARewriter(const FuncGraphPtr & root_graph,const FuncGraphManagerPtr & manager)243   BeforeOptARewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
244       : BaseRewriter(root_graph, manager), is_dict_output_(HasDictOutput()), has_dict_inplace_(HasDictInplace()) {}
245   ~BeforeOptARewriter() override = default;
246 
Execute()247   bool Execute() override {
248     bool changed = Run();
249     if (changed) {
250       UpdateAbstracts();
251     }
252     ConvertParameter();
253     return changed;
254   }
255 
256  protected:
ConvertParameter()257   void ConvertParameter() {
258     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
259     for (const auto &para : root_graph_->parameters()) {
260       auto abs = para->abstract();
261       MS_EXCEPTION_IF_NULL(abs);
262       if (abs->isa<abstract::AbstractKeywordArg>()) {
263         auto kw_abs = abs->cast_ptr<abstract::AbstractKeywordArg>();
264         para->set_abstract(kw_abs->get_arg());
265       }
266       // If the dict input is not used in graph, convert it to tuple directly.
267       auto dict_param_not_used =
268         abs->isa<abstract::AbstractDictionary>() && manager_->node_users().find(para) == manager_->node_users().end();
269       if ((!allow_fallback_runtime || !is_dict_output_) && !dict_param_not_used) {
270         continue;
271       }
272       auto new_node_and_abs = ConvertParameterDictAbstract(para, para->abstract());
273       new_node_and_abs.first->set_abstract(new_node_and_abs.second);
274       if (new_node_and_abs.first == para) {
275         continue;
276       }
277       (void)manager_->Replace(para, new_node_and_abs.first);
278       para->set_abstract(new_node_and_abs.second);
279     }
280   }
281 
ConvertParameterDictAbstract(const AnfNodePtr & cur_node,const AbstractBasePtr & cur_abs)282   std::pair<AnfNodePtr, AbstractBasePtr> ConvertParameterDictAbstract(const AnfNodePtr &cur_node,
283                                                                       const AbstractBasePtr &cur_abs) {
284     MS_EXCEPTION_IF_NULL(cur_abs);
285     auto seq_abs = cur_abs->cast_ptr<AbstractSequence>();
286     if (seq_abs != nullptr) {
287       bool is_tuple = seq_abs->isa<AbstractTuple>();
288       auto seq_prim = is_tuple ? prim::kPrimMakeTuple : prim::kPrimMakeList;
289       std::vector<AnfNodePtr> seq_inputs{NewValueNode(seq_prim)};
290       AbstractBasePtrList abs_list;
291       for (size_t i = 0; i < seq_abs->elements().size(); ++i) {
292         auto getitem_prim = is_tuple ? prim::kPrimTupleGetItem : prim::kPrimListGetItem;
293         auto next_node =
294           root_graph_->NewCNodeInOrder({NewValueNode(getitem_prim), cur_node, NewValueNode(SizeToLong(i))});
295         auto node_and_abs = ConvertParameterDictAbstract(next_node, seq_abs->elements()[i]);
296         (void)seq_inputs.emplace_back(node_and_abs.first);
297         (void)abs_list.emplace_back(node_and_abs.second);
298       }
299       if (is_tuple) {
300         return std::make_pair(root_graph_->NewCNodeInOrder(seq_inputs), std::make_shared<AbstractTuple>(abs_list));
301       }
302       return std::make_pair(root_graph_->NewCNodeInOrder(seq_inputs), std::make_shared<AbstractList>(abs_list));
303     }
304     auto dict_abs = cur_abs->cast_ptr<AbstractDictionary>();
305     if (dict_abs != nullptr) {
306       std::vector<AnfNodePtr> key_inputs{NewValueNode(prim::kPrimMakeTuple)};
307       std::vector<AnfNodePtr> value_inputs{NewValueNode(prim::kPrimMakeTuple)};
308       AbstractBasePtrList abs_list;
309       for (size_t i = 0; i < dict_abs->elements().size(); ++i) {
310         auto next_node =
311           root_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), cur_node, NewValueNode(SizeToLong(i))});
312         auto node_and_abs = ConvertParameterDictAbstract(next_node, dict_abs->elements()[i].second);
313         (void)key_inputs.emplace_back(NewValueNode(dict_abs->elements()[i].first->BuildValue()));
314         (void)value_inputs.emplace_back(node_and_abs.first);
315         (void)abs_list.emplace_back(node_and_abs.second);
316       }
317       auto make_dict =
318         root_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), root_graph_->NewCNodeInOrder(key_inputs),
319                                       root_graph_->NewCNodeInOrder(value_inputs)});
320       return std::make_pair(make_dict, std::make_shared<AbstractTuple>(abs_list));
321     }
322     return std::make_pair(cur_node, cur_abs);
323   }
324 
GetStringValue(const AnfNodePtr & node)325   static std::string GetStringValue(const AnfNodePtr &node) {
326     auto str = GetValueNode<StringImmPtr>(node);
327     if (str == nullptr) {
328       return "";
329     }
330     return str->value();
331   }
332 
NewTupleGetCNode(const AnfNodePtr & cnode,const AnfNodePtr & data_node,const std::vector<AbstractElementPair> & elements,const AnfNodePtr & name_node)333   static CNodePtr NewTupleGetCNode(const AnfNodePtr &cnode, const AnfNodePtr &data_node,
334                                    const std::vector<AbstractElementPair> &elements, const AnfNodePtr &name_node) {
335     int64_t index = GetElementIndex(elements, name_node);
336     auto index_node = NewValueNode(index);
337     auto prim_node = NewValueNode(prim::kPrimTupleGetItem);
338     return cnode->func_graph()->NewCNode({prim_node, data_node, index_node});
339   }
340 
341   // From:
342   //   DictGetItem(data:AbstractDictionary, key:AbstractBase)
343   // To:
344   //   TupleGetItem(data, index:Int64Imm)
ConvertDictGetItemToTupleGetItem(const CNodePtr & node) const345   AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) const {
346     MS_EXCEPTION_IF_NULL(node);
347     MS_EXCEPTION_IF_NULL(node->func_graph());
348 
349     // Inputs should be [dict_getitem, dict, item]
350     const size_t expect_inputs_size = 3;
351     CheckInputsSize(node, expect_inputs_size);
352 
353     constexpr size_t data_index = 1;
354     constexpr size_t key_index = 2;
355     const auto &inputs = node->inputs();
356     auto &data = inputs[data_index];
357     auto &key = inputs[key_index];
358     MS_EXCEPTION_IF_NULL(data);
359     MS_EXCEPTION_IF_NULL(key);
360 
361     auto abs_dict = GetAbstract<AbstractDictionary>(data);
362     if (abs_dict == nullptr) {
363       return nullptr;
364     }
365     return NewTupleGetCNode(node, data, abs_dict->elements(), key);
366   }
367 
ConvertDictGetItem(const CNodePtr & node) const368   AnfNodePtr ConvertDictGetItem(const CNodePtr &node) const {
369     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
370     if (!allow_fallback_runtime || (!is_dict_output_ && !has_dict_inplace_)) {
371       return ConvertDictGetItemToTupleGetItem(node);
372     }
373     return nullptr;
374   }
375 
376   // From:
377   //   DictSetItem(data:AbstractDictionary, key:AbstractBase, value)
378   // To:
379   //   TupleSetItem(data, index:Int64Imm, value)
380   // Or:
381   //   tuple_add(data, value)
ConvertDictSetItemToTupleSetItem(const CNodePtr & node) const382   AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) const {
383     MS_EXCEPTION_IF_NULL(node);
384     MS_EXCEPTION_IF_NULL(node->func_graph());
385 
386     // Inputs should be [dict_setitem, dict, item, value]
387     const size_t expect_inputs_size = 4;
388     CheckInputsSize(node, expect_inputs_size);
389 
390     const size_t data_index = 1;
391     const size_t cons_index = 2;
392     const size_t item_value_index = 3;
393     const auto &inputs = node->inputs();
394     auto &data = inputs[data_index];
395     auto &key = inputs[cons_index];
396     auto &item_value = inputs[item_value_index];
397     MS_EXCEPTION_IF_NULL(data);
398     MS_EXCEPTION_IF_NULL(key);
399 
400     auto abs_dict = GetAbstract<AbstractDictionary>(data);
401     if (abs_dict == nullptr) {
402       return nullptr;
403     }
404     int64_t index = GetElementIndex(abs_dict->elements(), key);
405     auto func_graph = node->func_graph();
406     MS_EXCEPTION_IF_NULL(func_graph);
407     if (index >= static_cast<int64_t>(abs_dict->elements().size())) {
408       // For dictionary set, if the key does not exist, we should create a new item.
409       std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
410       for (size_t i = 0; i < abs_dict->elements().size(); ++i) {
411         auto tuple_getitem_i =
412           func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, NewValueNode(SizeToLong(i))});
413         (void)make_tuple_inputs.emplace_back(tuple_getitem_i);
414       }
415       (void)make_tuple_inputs.emplace_back(item_value);
416       auto new_node = func_graph->NewCNode(make_tuple_inputs);
417       new_node->set_debug_info(node->debug_info());
418       return new_node;
419     }
420     auto index_node = NewValueNode(index);
421     auto new_node = func_graph->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, index_node, item_value});
422     new_node->set_debug_info(node->debug_info());
423     return new_node;
424   }
425 
HasDictOutput() const426   bool HasDictOutput() const {
427     const AnfNodePtr &output = root_graph_->output();
428     return CheckContainsDict(output->abstract());
429   }
430 
HasDictInplace() const431   bool HasDictInplace() const {
432     const auto &all_nodes = manager_->all_nodes();
433     return std::any_of(all_nodes.cbegin(), all_nodes.cend(),
434                        [](const auto &node) { return IsPrimitiveCNode(node, prim::kPrimDictInplaceSetItem); });
435   }
436 
ConvertDictSetItem(const CNodePtr & node) const437   AnfNodePtr ConvertDictSetItem(const CNodePtr &node) const {
438     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
439     if (!allow_fallback_runtime || ConvertDictToTuple(node, node->func_graph())) {
440       return ConvertDictSetItemToTupleSetItem(node);
441     }
442     return nullptr;
443   }
444 
445   // From:
446   //   MakeDict(name, input)
447   // To:
448   //   input
EraseMakeDictNode(const CNodePtr & node) const449   AnfNodePtr EraseMakeDictNode(const CNodePtr &node) const {
450     MS_EXCEPTION_IF_NULL(node);
451     constexpr size_t expect_inputs_size = 3;
452     constexpr size_t input_index = 2;
453     CheckInputsSize(node, expect_inputs_size);
454     return node->input(input_index);
455   }
456 
CheckUserHasPyExecute(const AnfNodePtr & node,const FuncGraphPtr & func) const457   bool CheckUserHasPyExecute(const AnfNodePtr &node, const FuncGraphPtr &func) const {
458     MS_EXCEPTION_IF_NULL(node);
459     MS_EXCEPTION_IF_NULL(func);
460     auto mng = func->manager();
461     auto &users = mng->node_users()[node];
462     for (auto &user : users) {
463       if (IsPrimitiveCNode(user.first, prim::kPrimPyExecute) || IsPrimitiveCNode(user.first, prim::kPrimPyInterpret)) {
464         return true;
465       } else if (IsPrimitiveCNode(user.first, prim::kPrimMakeTuple)) {
466         if (CheckUserHasPyExecute(user.first, user.first->func_graph())) {
467           return true;
468         }
469       }
470     }
471     return false;
472   }
473 
CheckDictUserHasFuncGraph(const AnfNodePtr & node,const FuncGraphPtr & func) const474   bool CheckDictUserHasFuncGraph(const AnfNodePtr &node, const FuncGraphPtr &func) const {
475     MS_EXCEPTION_IF_NULL(node);
476     MS_EXCEPTION_IF_NULL(func);
477     if (!IsValueNode<ValueDictionary>(node)) {
478       return false;
479     }
480     auto mng = func->manager();
481     auto &users = mng->node_users()[node];
482     for (auto &user : users) {
483       if (user.first->isa<CNode>()) {
484         auto cnode = user.first->cast<CNodePtr>();
485         auto input = cnode->input(0);
486         if (IsValueNode<FuncGraph>(input)) {
487           return true;
488         }
489       }
490     }
491     return false;
492   }
493 
ConvertMakeDict(const CNodePtr & node) const494   AnfNodePtr ConvertMakeDict(const CNodePtr &node) const {
495     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
496     if (!allow_fallback_runtime || ConvertDictToTuple(node, node->func_graph())) {
497       auto new_node = EraseMakeDictNode(node);
498       return new_node;
499     }
500     return nullptr;
501   }
502 
503   // From:
504   //   DictGetValues(dict:AbstractDictionary)
505   // To:
506   //   dict
EraseDictGetValues(const CNodePtr & node) const507   AnfNodePtr EraseDictGetValues(const CNodePtr &node) const {
508     MS_EXCEPTION_IF_NULL(node);
509     constexpr size_t expect_inputs_size = 2;
510     CheckInputsSize(node, expect_inputs_size);
511     auto input = node->input(1);
512     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
513     if (!allow_fallback_runtime || ConvertDictToTuple(node, node->func_graph())) {
514       return input;
515     }
516     auto abs_dict = GetAbstract<AbstractDictionary>(input);
517     if (abs_dict == nullptr) {
518       return nullptr;
519     }
520     const auto &elements = abs_dict->elements();
521     std::vector<AnfNodePtr> new_inputs;
522     new_inputs.reserve(elements.size() + 1);
523     (void)new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
524     auto fg = node->func_graph();
525     MS_EXCEPTION_IF_NULL(fg);
526     for (const auto &element : elements) {
527       MS_EXCEPTION_IF_NULL(element.first->BuildValue());
528       AnfNodePtr value_node =
529         fg->NewCNode({NewValueNode(prim::kPrimDictGetItem), input, NewValueNode(element.first->BuildValue())});
530       (void)new_inputs.emplace_back(value_node);
531     }
532     return fg->NewCNode(std::move(new_inputs));
533   }
534 
535   // From:
536   //   DictItems(dict:AbstractDictionary)
537   // To:
538   //   kPrimMakeList(MakeTuple(key0, TupleGetItem(dict, 0)), ...)
EraseDictItems(const CNodePtr & node) const539   AnfNodePtr EraseDictItems(const CNodePtr &node) const {
540     MS_EXCEPTION_IF_NULL(node);
541     auto fg = node->func_graph();
542     MS_EXCEPTION_IF_NULL(fg);
543     constexpr size_t expect_inputs_size = 2;
544     CheckInputsSize(node, expect_inputs_size);
545 
546     const auto &input = node->input(1);
547     auto abs_dict = GetAbstract<AbstractDictionary>(input);
548     if (abs_dict == nullptr) {
549       return nullptr;
550     }
551     const auto &elements = abs_dict->elements();
552     std::vector<AnfNodePtr> new_inputs;
553     new_inputs.reserve(elements.size() + 1);
554     (void)new_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
555     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
556     bool convert_to_tuple = !allow_fallback_runtime || ConvertDictToTuple(node, node->func_graph());
557     for (size_t i = 0; i < elements.size(); ++i) {
558       auto index_node = NewValueNode(static_cast<int64_t>(i));
559       MS_EXCEPTION_IF_NULL(elements[i].first->BuildValue());
560       auto key_node = NewValueNode(elements[i].first->BuildValue());
561       AnfNodePtr value_node;
562       if (convert_to_tuple) {
563         value_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, index_node});
564       } else {
565         value_node =
566           fg->NewCNode({NewValueNode(prim::kPrimDictGetItem), input, NewValueNode(elements[i].first->BuildValue())});
567       }
568       auto tuple_node = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), key_node, value_node});
569       (void)new_inputs.emplace_back(tuple_node);
570     }
571     return fg->NewCNode(std::move(new_inputs));
572   }
573 
574   // From:
575   //   MakeKeywordArg(key, value)
576   // To:
577   //   value
EraseMakeKeywordArgNode(const CNodePtr & node) const578   AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) const {
579     MS_EXCEPTION_IF_NULL(node);
580     // Inputs should be [make_keyword_arg, key, value]
581     constexpr size_t expect_input_size = 3;
582     constexpr size_t value_inputs_index = 2;
583     CheckInputsSize(node, expect_input_size);
584     return node->input(value_inputs_index);
585   }
586 
587   // From:
588   //   ExtractKeywordArg(arg, key)
589   // To:
590   //   key
EraseExtractKeywordArg(const CNodePtr & node) const591   AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) const {
592     MS_EXCEPTION_IF_NULL(node);
593     // Inputs should be [extract_keyword_arg, arg, key]
594     const size_t expect_inputs_size = 3;
595     // Inputs should be [extract_keyword_arg, arg, key, monad]
596     const size_t expect_inputs_has_side_effect_size = 4;
597     if (node->size() != expect_inputs_size && node->size() != expect_inputs_has_side_effect_size) {
598       MS_LOG(INTERNAL_EXCEPTION) << "The extract_keyword_arg should have 3 or 4 inputs, but got " << node->size();
599     }
600     constexpr size_t key_index = 2;
601     return node->input(key_index);
602   }
603 
604   using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &) const;
605   using ConverterMap = std::unordered_map<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
606   static inline const ConverterMap converters_{
607     {prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItem},
608     {prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItem},
609     {prim::kPrimDictGetValues, &ThisClass::EraseDictGetValues},
610     {prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
611     {prim::kPrimMakeKeywordArg, &ThisClass::EraseMakeKeywordArgNode},
612     {prim::kPrimExtractKeywordArg, &ThisClass::EraseExtractKeywordArg},
613     {prim::kPrimDictItems, &ThisClass::EraseDictItems},
614   };
615 
ConvertPrimitiveCNode(const CNodePtr & cnode)616   AnfNodePtr ConvertPrimitiveCNode(const CNodePtr &cnode) override {
617     // Get primitive from cnode.
618     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
619     if (prim == nullptr) {
620       return nullptr;
621     }
622     // Find cnode converter by primitive.
623     auto iter = converters_.find(prim);
624     if (iter == converters_.end()) {
625       return nullptr;
626     }
627     // Call converter.
628     return (this->*(iter->second))(cnode);
629   }
630 
ConvertDictValue(const ValuePtr & value,size_t depth,bool convert_dict,bool * need_convert) const631   ValuePtr ConvertDictValue(const ValuePtr &value, size_t depth, bool convert_dict, bool *need_convert) const {
632     MS_EXCEPTION_IF_NULL(value);
633     if (depth > kMaxSeqRecursiveDepth) {
634       MS_LOG(ERROR) << "value:" << value->ToString();
635       MS_LOG(INTERNAL_EXCEPTION) << "List, tuple and dict nesting is not allowed more than " << kMaxSeqRecursiveDepth
636                                  << " levels.";
637     }
638     if (value->isa<ValueSequence>()) {
639       auto value_seq = value->cast<ValueSequencePtr>();
640       std::vector<ValuePtr> value_vec;
641       value_vec.reserve(value_seq->size());
642       bool new_need_convert = false;
643       for (const auto &element : value_seq->value()) {
644         (void)value_vec.emplace_back(ConvertDictValue(element, depth + 1, convert_dict, &new_need_convert));
645       }
646       if (!new_need_convert) {
647         return value;
648       }
649       *need_convert = true;
650       if (value->isa<ValueTuple>()) {
651         return std::make_shared<ValueTuple>(value_vec);
652       }
653       return std::make_shared<ValueList>(value_vec);
654     }
655     // dict(k0:v0, k1:v1, ...) --> tuple(v0, v1, ...)
656     if (value->isa<ValueDictionary>() && convert_dict) {
657       *need_convert = true;
658       const auto &keys_values = value->cast<ValueDictionaryPtr>()->value();
659       std::vector<ValuePtr> value_vec;
660       value_vec.reserve(keys_values.size());
661       for (const auto &element : keys_values) {
662         (void)value_vec.emplace_back(ConvertDictValue(element.second, depth + 1, convert_dict, need_convert));
663       }
664       return std::make_shared<ValueTuple>(value_vec);
665     }
666     return value;
667   }
668 
ConvertValueNode(const ValueNodePtr & value_node,const ValuePtr & value)669   AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override {
670     // Convert Dictionary value node.
671     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
672     bool convert_dict = !allow_fallback_runtime || ConvertDictToTuple(value_node, root_graph_);
673     bool need_convert = false;
674     auto new_value = ConvertDictValue(value, 0, convert_dict, &need_convert);
675     if (need_convert) {
676       auto new_node = NewValueNode(new_value);
677       new_node->set_debug_info(value_node->debug_info());
678       return new_node;
679     }
680     return nullptr;
681   }
682 
MakeAbstractTuple(const std::vector<AbstractElementPair> & attrs)683   static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractElementPair> &attrs) {
684     std::vector<AbstractBasePtr> elements;
685     elements.reserve(attrs.size());
686     (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(elements),
687                          [](const auto &item) { return item.second; });
688     return std::make_shared<AbstractTuple>(std::move(elements));
689   }
690 
691   // AbstractDictionary --> AbstractSequence.
ConvertToAbstractSequence(const AbstractBasePtr & abs,size_t depth)692   AbstractSequencePtr ConvertToAbstractSequence(const AbstractBasePtr &abs, size_t depth) {
693     if (depth > kMaxSeqRecursiveDepth) {
694       MS_LOG(ERROR) << "abs:" << abs->ToString();
695       MS_LOG(INTERNAL_EXCEPTION) << "List, tuple and dict nesting is not allowed more than " << kMaxSeqRecursiveDepth
696                                  << " levels.";
697     }
698     auto abs_seq = abs->cast<AbstractSequencePtr>();
699     if (abs_seq != nullptr) {
700       const auto &seq_elements = abs_seq->elements();
701       // First we check if elements should be converted,
702       // changed_elements maps old element to new element.
703       mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
704       for (const auto &element : seq_elements) {
705         auto new_element = ConvertToAbstractSequence(element, depth + 1);
706         if (new_element != nullptr) {
707           (void)changed_elements.emplace(element, new_element);
708         }
709       }
710       if (changed_elements.empty()) {
711         // Here the AbstractList don't need to convert to AbstractTuple.
712         return nullptr;
713       }
714       // Always make new AbstractSequence when elements changed.
715       std::vector<AbstractBasePtr> elements;
716       elements.reserve(seq_elements.size());
717       for (const auto &element : seq_elements) {
718         auto iter = changed_elements.find(element);
719         if (iter != changed_elements.end()) {
720           (void)elements.emplace_back(iter->second);
721         } else {
722           (void)elements.emplace_back(element);
723         }
724       }
725       // Here the AbstractList don't need to convert to AbstractTuple.
726       if (abs_seq->isa<AbstractList>()) {
727         return std::make_shared<AbstractList>(std::move(elements));
728       } else {
729         return std::make_shared<AbstractTuple>(std::move(elements));
730       }
731     }
732     // AbstractDictionary --> AbstractTuple.
733     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
734     bool convert_to_tuple = !allow_fallback_runtime || (!is_dict_output_ && !has_dict_inplace_);
735     auto abs_dict = abs->cast<AbstractDictionaryPtr>();
736     if (abs_dict != nullptr && convert_to_tuple) {
737       const auto &dict_elements = abs_dict->elements();
738       std::vector<AbstractBasePtr> elements;
739       elements.reserve(dict_elements.size());
740       for (const auto &element : dict_elements) {
741         auto new_element = ConvertToAbstractSequence(element.second, depth + 1);
742         if (new_element != nullptr) {
743           (void)elements.emplace_back(new_element);
744         } else {
745           (void)elements.emplace_back(element.second);
746         }
747       }
748       return std::make_shared<AbstractTuple>(elements);
749     }
750     return nullptr;
751   }
752 
ConvertAbstract(const AbstractBasePtr & abs)753   AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
754     // AbstractDictionary --> AbstractSequence.
755     return ConvertToAbstractSequence(abs, 0);
756   }
757 
ConvertDictToTuple(const AnfNodePtr & node,const FuncGraphPtr & fg) const758   bool ConvertDictToTuple(const AnfNodePtr &node, const FuncGraphPtr &fg) const {
759     return !is_dict_output_ && !has_dict_inplace_ && !CheckUserHasPyExecute(node, fg) &&
760            !CheckDictUserHasFuncGraph(node, fg);
761   }
762 
763  private:
764   bool is_dict_output_{false};
765   bool has_dict_inplace_{false};
766 };
767 
ExtractKwargsNode(const AnfNodePtr & node)768 std::pair<AnfNodePtr, AnfNodePtr> ExtractKwargsNode(const AnfNodePtr &node) {
769   MS_EXCEPTION_IF_NULL(node);
770   if (node->isa<ValueNode>()) {
771     auto kwargs = GetValueNode<KeywordArgPtr>(node);
772     if (kwargs != nullptr) {
773       auto key = MakeValue(kwargs->get_key());
774       auto arg = kwargs->get_value();
775       return std::make_pair(NewValueNode(key), NewValueNode(arg));
776     }
777   } else if (node->isa<CNode>() && IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {
778     auto kwarg_node = node->cast_ptr<CNode>();
779     constexpr auto kMakeKwargsKeyIndex = 1;
780     constexpr auto kMakeKwargsArgIndex = 2;
781     return std::make_pair(kwarg_node->input(kMakeKwargsKeyIndex), kwarg_node->input(kMakeKwargsArgIndex));
782   }
783   MS_LOG(EXCEPTION) << "Extract kwargs only can be used to CNode[make_keyword_arg] or ValueNode(KeywordArg), but got "
784                     << node->DebugString();
785 }
786 
787 // TupleGetItem/ListGetItem(sequence, index) -> PyExecute(sequence[index], ...)
ConvertSequenceGetItemInner(const CNodePtr & node)788 AnfNodePtr ConvertSequenceGetItemInner(const CNodePtr &node) {
789   const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
790   if (!allow_fallback_runtime) {
791     return nullptr;
792   }
793 
794   constexpr size_t prim_index = 0;
795   constexpr size_t sequence_index = 1;
796   constexpr size_t target_index = 2;
797   constexpr size_t node_inputs_size = 3;
798   const auto &node_inputs = node->inputs();
799   auto prim = GetValueNode<PrimitivePtr>(node_inputs[prim_index]);
800   MS_EXCEPTION_IF_NULL(prim);
801   const auto &prim_name = prim->name();
802   if (node_inputs.size() != node_inputs_size) {
803     MS_LOG(EXCEPTION) << "The size of input to " << prim_name << " should be " << node_inputs_size << " but got "
804                       << node_inputs.size();
805   }
806 
807   std::vector<AbstractBasePtr> inputs_abs;
808   for (size_t i = 1; i < node_inputs.size(); ++i) {
809     inputs_abs.push_back(node_inputs[i]->abstract());
810   }
811 
812   auto output_abs = node->abstract();
813   MS_EXCEPTION_IF_NULL(output_abs);
814 
815   auto sequence_node = node_inputs[sequence_index];
816   MS_EXCEPTION_IF_NULL(sequence_node);
817   auto sequence_abs = sequence_node->abstract();
818   // If the sequence is any, then the sequence getitem should be converted to PyExecute node.
819   if (sequence_abs == nullptr || !sequence_abs->isa<abstract::AbstractAny>()) {
820     if (!CheckAndConvertUtils::CheckContainNestedOrIrregularSequence(inputs_abs) &&
821         !output_abs->isa<abstract::AbstractAny>()) {
822       return nullptr;
823     }
824     if (!IsPrimitiveCNode(node, prim::kPrimDictGetItem)) {
825       auto target_node = node_inputs[target_index];
826       auto target_abs = target_node->abstract();
827       if (target_abs == nullptr || !target_abs->BuildValue()->ContainsValueAny()) {
828         return nullptr;
829       }
830     }
831   }
832 
833   const auto &fg = node->func_graph();
834   MS_EXCEPTION_IF_NULL(fg);
835 
836   const std::string internal_sequence_input = "__iternal_sequence_input__";
837   const std::string internal_sequence_target = "__internal_sequence_index__";
838 
839   std::stringstream script_buffer;
840   script_buffer << internal_sequence_input << "[" << internal_sequence_target << "]";
841   const std::string &script = script_buffer.str();
842   const auto script_str = std::make_shared<StringImm>(script);
843 
844   std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
845   (void)key_value_names_list.emplace_back(NewValueNode(internal_sequence_input));
846   (void)key_value_names_list.emplace_back(NewValueNode(internal_sequence_target));
847   const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
848   std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
849   (void)key_value_list.emplace_back(node_inputs[sequence_index]);
850   (void)key_value_list.emplace_back(node_inputs[target_index]);
851   const auto key_value_tuple = fg->NewCNode(key_value_list);
852   auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
853 
854   MS_LOG(DEBUG) << "Convert sequence getitem node to PyExecute node: " << res->DebugString();
855   return res;
856 }
857 
858 // ==================================================================
859 // AfterOptARewriter converts List, Sparse, RowTensor to Tuple.
860 // ==================================================================
861 class AfterOptARewriter : public BaseRewriter {
862  public:
863   using ThisClass = AfterOptARewriter;
AfterOptARewriter(const FuncGraphPtr & root_graph,const FuncGraphManagerPtr & manager,const StringSetPtr & value_with_inplace)864   AfterOptARewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager,
865                     const StringSetPtr &value_with_inplace)
866       : BaseRewriter(root_graph, manager), data_with_inplace_(value_with_inplace) {
867     auto context = MsContext::GetInstance();
868     MS_EXCEPTION_IF_NULL(context);
869     not_convert_jit_ = context->not_convert_jit();
870   }
871   ~AfterOptARewriter() override = default;
872 
873  protected:
874   // From:
875   //   MakeSparseTensor(indices, values, dense_shape)
876   // To:
877   //   MakeTuple(indices, values, dense_shape)
ConvertMakeSparseToMakeTuple(const CNodePtr & node) const878   AnfNodePtr ConvertMakeSparseToMakeTuple(const CNodePtr &node) const {
879     MS_EXCEPTION_IF_NULL(node);
880     MS_EXCEPTION_IF_NULL(node->func_graph());
881 
882     AnfNodeWeakPtrList inputs;
883     inputs.reserve(node->size());
884     const auto make_tuple_node = NewValueNode(prim::kPrimMakeTuple);
885     (void)inputs.emplace_back(make_tuple_node);
886     // Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items.
887     (void)inputs.insert(inputs.cend(), node->weak_inputs().cbegin() + 1, node->weak_inputs().cend());
888     auto new_node = node->func_graph()->NewCNodeWeak(std::move(inputs));
889     new_node->set_abstract(node->abstract());
890     return new_node;
891   }
892 
893   static inline const mindspore::HashMap<std::string, int64_t> sparse_attr_map = {
894     {kCSRTensorGetIndptrOpName, 0},     {kCSRTensorGetIndicesOpName, 1}, {kCSRTensorGetValuesOpName, 2},
895     {kCSRTensorGetDenseShapeOpName, 3}, {kCOOTensorGetIndicesOpName, 0}, {kCOOTensorGetValuesOpName, 1},
896     {kCOOTensorGetDenseShapeOpName, 2}, {kRowTensorGetIndicesOpName, 0}, {kRowTensorGetValuesOpName, 1},
897     {kRowTensorGetDenseShapeOpName, 2}};
898 
899   // From:
900   //   SparseTensorGetXXX(sparse) # index
901   // To:
902   //   TupleGetItem(sparse, index)
ConvertSparseGetAttrToTupleGetItem(const CNodePtr & node) const903   AnfNodePtr ConvertSparseGetAttrToTupleGetItem(const CNodePtr &node) const {
904     MS_EXCEPTION_IF_NULL(node);
905     MS_EXCEPTION_IF_NULL(node->func_graph());
906 
907     constexpr size_t kExpectInputSize = 2;
908     constexpr size_t kSparseAttrIndex = 1;
909     CheckInputsSize(node, kExpectInputSize);
910 
911     auto prim = GetValueNode<PrimitivePtr>(node->input(0));
912     if (prim != nullptr) {
913       auto iter = sparse_attr_map.find(prim->name());
914       if (iter != sparse_attr_map.end()) {
915         const auto &sparse = node->input(kSparseAttrIndex);
916         auto index_node = NewValueNode(iter->second);
917         auto new_node = node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), sparse, index_node});
918         new_node->set_abstract(node->abstract());
919         return new_node;
920       }
921     }
922     return nullptr;
923   }
924 
925   // DictGetItem --> PyExecute()
ConvertDictGetItem(const CNodePtr & cnode) const926   AnfNodePtr ConvertDictGetItem(const CNodePtr &cnode) const {
927     if (not_convert_jit_) {
928       return cnode;
929     }
930     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
931     if (!allow_fallback_runtime) {
932       MS_LOG(WARNING) << "When using the DictGetItem statement with some syntaxes that is not supported in graph mode, "
933                       << "it is best to set jit_syntax_level to LAX.\n";
934       return nullptr;
935     }
936     MS_EXCEPTION_IF_NULL(cnode);
937     // Inputs should be [dict_setitem, dict, item]
938     const size_t expect_inputs_size = 3;
939     CheckInputsSize(cnode, expect_inputs_size);
940 
941     const size_t data_index = 1;
942     const size_t item_key_index = 2;
943     const auto &inputs = cnode->inputs();
944     auto &data = inputs[data_index];
945     auto &key = inputs[item_key_index];
946     MS_EXCEPTION_IF_NULL(data);
947     MS_EXCEPTION_IF_NULL(key);
948 
949     auto func_graph = cnode->func_graph();
950     MS_EXCEPTION_IF_NULL(func_graph);
951 
952     // Script
953     std::stringstream script_buffer;
954     script_buffer << kInternalDictSelfStr << "[" << kInternalDictKeyStr << "]";
955     const std::string &script = script_buffer.str();
956     const auto script_str = std::make_shared<StringImm>(script);
957 
958     // Pack local parameters keys.
959     const auto script_dict_self_name = std::make_shared<StringImm>(kInternalDictSelfStr);
960     const auto script_dict_key_name = std::make_shared<StringImm>(kInternalDictKeyStr);
961     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
962     (void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
963     (void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
964     const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
965 
966     // Pack the local parameters values, not support list, tuple, or dict.
967     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
968     (void)key_value_list.emplace_back(data);
969     (void)key_value_list.emplace_back(key);
970     const auto key_value_tuple = func_graph->NewCNode(key_value_list);
971 
972     // Build the new dict node.
973     const auto dict_getitem_node =
974       fallback::CreatePyExecuteCNodeInOrder(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
975     auto abs_dict = GetAbstract<AbstractDictionary>(data);
976     if (abs_dict != nullptr) {
977       size_t index = GetElementIndex(abs_dict->elements(), key);
978       const auto &elements = abs_dict->elements();
979       if (elements.size() > index) {
980         const auto &val = elements[index].second;
981         const auto &tensor_val = dyn_cast<abstract::AbstractTensor>(val);
982         if (tensor_val != nullptr) {
983           const auto &tensor_type = tensor_val->element()->BuildType();
984           fallback::SetRealType<AnfNode, Type>(dict_getitem_node, tensor_val->BuildType());
985           const auto &tensor_shape = dyn_cast<abstract::Shape>(tensor_val->BuildShape());
986           MS_EXCEPTION_IF_NULL(tensor_shape);
987           fallback::SetRealShape<AnfNode, abstract::BaseShape>(dict_getitem_node, tensor_shape);
988           MS_LOG(DEBUG) << "key: " << key->abstract()->BuildValue()->ToString() << ", type: " << tensor_type->ToString()
989                         << ", shape: " << tensor_shape->ToString() << ", val: " << tensor_val->ToString();
990         }
991       }
992     }
993     MS_LOG(DEBUG) << "Made dict getitem node: " << dict_getitem_node->DebugString();
994     return dict_getitem_node;
995   }
996 
997   // DictSetItem --> PyExecute()
ConvertDictSetItem(const CNodePtr & cnode) const998   AnfNodePtr ConvertDictSetItem(const CNodePtr &cnode) const {
999     if (not_convert_jit_) {
1000       return cnode;
1001     }
1002     MS_EXCEPTION_IF_NULL(cnode);
1003     // Inputs should be [dict_setitem, dict, item, value]
1004     const size_t expect_inputs_size = 4;
1005     CheckInputsSize(cnode, expect_inputs_size);
1006 
1007     const size_t data_index = 1;
1008     const size_t item_key_index = 2;
1009     const size_t item_value_index = 3;
1010     const auto &inputs = cnode->inputs();
1011     auto &data = inputs[data_index];
1012     auto &key = inputs[item_key_index];
1013     auto &item_value = inputs[item_value_index];
1014     MS_EXCEPTION_IF_NULL(data);
1015     MS_EXCEPTION_IF_NULL(key);
1016 
1017     auto abs_dict = GetAbstract<AbstractDictionary>(data);
1018     if (abs_dict == nullptr) {
1019       return nullptr;
1020     }
1021     auto func_graph = cnode->func_graph();
1022     MS_EXCEPTION_IF_NULL(func_graph);
1023 
1024     // Script
1025     std::stringstream script_buffer;
1026     script_buffer << "__import__('mindspore').common._jit_fallback_utils.dict_setitem(" << kInternalDictSelfStr << ", "
1027                   << kInternalDictKeyStr << ", " << kInternalDictValueStr << ")";
1028     const std::string &script = script_buffer.str();
1029     const auto script_str = std::make_shared<StringImm>(script);
1030 
1031     // Pack local parameters keys.
1032     const auto script_dict_self_name = std::make_shared<StringImm>(kInternalDictSelfStr);
1033     const auto script_dict_key_name = std::make_shared<StringImm>(kInternalDictKeyStr);
1034     const auto script_dict_value_name = std::make_shared<StringImm>(kInternalDictValueStr);
1035     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1036     (void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
1037     (void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
1038     (void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
1039     const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
1040 
1041     // Pack the local parameters values, not support list, tuple, or dict.
1042     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1043     (void)key_value_list.emplace_back(data);
1044     (void)key_value_list.emplace_back(key);
1045     (void)key_value_list.emplace_back(item_value);
1046     const auto key_value_tuple = func_graph->NewCNode(key_value_list);
1047 
1048     // Build the new dict node.
1049     const auto dict_setitem_node =
1050       fallback::CreatePyExecuteCNodeInOrder(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1051     MS_LOG(DEBUG) << "Made dict setitem node: " << dict_setitem_node->DebugString();
1052     return dict_setitem_node;
1053   }
1054 
ConstructInternalTupleKeysNode(const FuncGraphPtr & fg,const AnfNodePtr & keys_node) const1055   AnfNodePtr ConstructInternalTupleKeysNode(const FuncGraphPtr &fg, const AnfNodePtr &keys_node) const {
1056     constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
1057     MS_EXCEPTION_IF_NULL(fg);
1058     const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
1059     auto dict_py_exec_key = std::make_shared<ValueTuple>(std::vector<ValuePtr>{script_key_tuple_str});
1060     auto dict_tuple_key_value = fg->NewCNode({std::make_shared<ValueNode>(prim::kPrimMakeTuple), keys_node});
1061     const auto make_key_tuple_node =
1062       fallback::CreatePyExecuteCNode(fg, NewValueNode(script_key_tuple_str), NewValueNode(dict_py_exec_key),
1063                                      dict_tuple_key_value, keys_node->debug_info());
1064     return make_key_tuple_node;
1065   }
1066 
ConstructInternalTupleValueNode(const FuncGraphPtr & fg,const AnfNodePtr & values_node) const1067   AnfNodePtr ConstructInternalTupleValueNode(const FuncGraphPtr &fg, const AnfNodePtr &values_node) const {
1068     constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
1069     MS_EXCEPTION_IF_NULL(fg);
1070     const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
1071     auto dict_py_exec_value = std::make_shared<ValueTuple>(std::vector<ValuePtr>{script_value_tuple_str});
1072     auto dict_tuple_node = fg->NewCNode({std::make_shared<ValueNode>(prim::kPrimMakeTuple), values_node});
1073     const auto make_value_tuple_node =
1074       fallback::CreatePyExecuteCNode(fg, NewValueNode(script_value_tuple_str), NewValueNode(dict_py_exec_value),
1075                                      dict_tuple_node, values_node->debug_info());
1076     return make_value_tuple_node;
1077   }
1078 
ConstructNewDictNode(const FuncGraphPtr & fg,const AnfNodePtr & make_key_tuple_node,const AnfNodePtr & make_value_tuple_node) const1079   AnfNodePtr ConstructNewDictNode(const FuncGraphPtr &fg, const AnfNodePtr &make_key_tuple_node,
1080                                   const AnfNodePtr &make_value_tuple_node) const {
1081     constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
1082     constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
1083     // Pack the local parameters values
1084     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1085     (void)key_value_list.emplace_back(make_key_tuple_node);
1086     (void)key_value_list.emplace_back(make_value_tuple_node);
1087     const auto key_value_tuple = fg->NewCNode(key_value_list);
1088 
1089     // Pack local parameters keys.
1090     const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
1091     const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
1092     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1093     (void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
1094     (void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
1095     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1096 
1097     // Construct Script Node
1098     std::stringstream script_buffer;
1099     script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
1100     const std::string &script = script_buffer.str();
1101     const auto script_str = std::make_shared<StringImm>(script);
1102 
1103     // Build the new dict node.
1104     const auto make_dict_node = fallback::CreatePyExecuteCNodeInOrder(
1105       fg, NewValueNode(script_str), key_value_name_tuple, key_value_tuple, make_key_tuple_node->debug_info());
1106     MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
1107     return make_dict_node;
1108   }
1109 
1110   // MakeDict(keys, values) --> PyExecute('dict(zip(keys, values))', ...)
ConvertMakeDict(const CNodePtr & node) const1111   AnfNodePtr ConvertMakeDict(const CNodePtr &node) const {
1112     if (not_convert_jit_) {
1113       return node;
1114     }
1115     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1116     if (!allow_fallback_runtime) {
1117       MS_LOG(WARNING) << "When using the MakeDict statement with some syntaxes that is not supported in graph mode, "
1118                       << "it is best to set jit_syntax_level to LAX.\n";
1119       return nullptr;
1120     }
1121     const auto &fg = node->func_graph();
1122     MS_EXCEPTION_IF_NULL(fg);
1123     // Local parameters values.
1124     // Get the key tuple.
1125     constexpr size_t keys_input_index = 1;
1126     auto keys_node = node->input(keys_input_index);
1127     const auto make_key_tuple_node = ConstructInternalTupleKeysNode(fg, keys_node);
1128     make_key_tuple_node->set_debug_info(node->input(keys_input_index)->debug_info());
1129     // Get the value tuple.
1130     constexpr size_t values_input_index = 2;
1131     auto values_node = node->input(values_input_index);
1132     const auto make_value_tuple_node = ConstructInternalTupleValueNode(fg, values_node);
1133     make_value_tuple_node->set_debug_info(node->input(values_input_index)->debug_info());
1134 
1135     auto new_dict_node = ConstructNewDictNode(fg, make_key_tuple_node, make_value_tuple_node);
1136     new_dict_node->set_debug_info(node->debug_info());
1137     return new_dict_node;
1138   }
1139 
GenerateTupleInput(const CNodePtr & node) const1140   AnfNodePtr GenerateTupleInput(const CNodePtr &node) const {
1141     const auto &fg = node->func_graph();
1142     MS_EXCEPTION_IF_NULL(fg);
1143     const auto &inputs = node->inputs();
1144     constexpr auto internal_element_str_prefix = "__internal_list_element_";
1145     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1146     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1147     std::stringstream script_buffer;
1148     script_buffer << "(";
1149     for (size_t i = 1; i < inputs.size(); ++i) {
1150       if (IsValueNode<None>(inputs[i])) {
1151         script_buffer << "None, ";
1152         continue;
1153       }
1154       std::string cur_element = internal_element_str_prefix + std::to_string(i) + "_";
1155       (void)key_value_names_list.emplace_back(NewValueNode(cur_element));
1156       (void)key_value_list.emplace_back(inputs[i]);
1157       script_buffer << cur_element << ", ";
1158     }
1159     script_buffer << ")";
1160     const std::string &script = script_buffer.str();
1161     const auto script_str = std::make_shared<StringImm>(script);
1162     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1163     const auto key_value_tuple = fg->NewCNode(key_value_list);
1164     auto list_node =
1165       fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1166     return list_node;
1167   }
1168 
1169   // MakeList(x1, x2, ...) --> PyExecute('[x1, x2, ...]', ...)
ConvertMakeList(const CNodePtr & node) const1170   AnfNodePtr ConvertMakeList(const CNodePtr &node) const {
1171     if (!fallback::EnableFallbackListDictInplace()) {
1172       return nullptr;
1173     }
1174 
1175     const auto &fg = node->func_graph();
1176     MS_EXCEPTION_IF_NULL(fg);
1177 
1178     auto list_node_input = GenerateTupleInput(node);
1179 
1180     if (!fallback::HasObjInExtraInfoHolder(node->abstract())) {
1181       MS_LOG(EXCEPTION) << "MakeList node: " << node->DebugString() << " do not have python list object.";
1182     }
1183     auto object = fallback::GetObjFromExtraInfoHolder(node->abstract());
1184     if (!py::isinstance<py::list>(object)) {
1185       MS_INTERNAL_EXCEPTION(TypeError) << "For MakeList node: " << node->DebugString()
1186                                        << ", the corresponding python object should be list but got: " << object;
1187     }
1188     py::list list_object = py::list(object);
1189     const std::string list_obj_str_prefix = "__list_py_object_";
1190     auto list_obj_id = fallback::GetPyObjectPtrStr(list_object);
1191     MS_LOG(DEBUG) << "Current python object id: " << list_obj_id;
1192     auto list_obj_str = list_obj_str_prefix + list_obj_id + "_";
1193     fallback::SetPyObjectToLocalVariable(list_obj_str, list_object);
1194 
1195     const auto list_key_input = "__internal_list_key__";
1196     const auto list_value_input = "__internal_list_value__";
1197     std::stringstream script_buffer;
1198     script_buffer << "__import__('mindspore').common._jit_fallback_utils.generate_list(" << list_key_input << ", "
1199                   << list_value_input << ")";
1200     const std::string &script = script_buffer.str();
1201     const auto script_str = std::make_shared<StringImm>(script);
1202 
1203     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1204     (void)key_value_names_list.emplace_back(NewValueNode(list_key_input));
1205     (void)key_value_names_list.emplace_back(NewValueNode(list_value_input));
1206     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1207     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1208     (void)key_value_list.emplace_back(NewValueNode(list_obj_str));
1209     (void)key_value_list.emplace_back(list_node_input);
1210     const auto key_value_tuple = fg->NewCNode(key_value_list);
1211     auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1212 
1213     auto abs = node->abstract();
1214     MS_EXCEPTION_IF_NULL(abs);
1215     auto list_abs = abs->cast<abstract::AbstractListPtr>();
1216     MS_EXCEPTION_IF_NULL(list_abs);
1217 
1218     res->set_debug_info(node->debug_info());
1219 
1220     MS_LOG(DEBUG) << "Convert make_list node to PyExecute node: " << res->DebugString();
1221     return res;
1222   }
1223 
1224   // x.extend(y) --> PyExecute(_jit_fallback_list_inplace_extend(x, y))
ConvertListInplaceExtend(const CNodePtr & node) const1225   AnfNodePtr ConvertListInplaceExtend(const CNodePtr &node) const {
1226     if (!fallback::EnableFallbackListDictInplace()) {
1227       return nullptr;
1228     }
1229 
1230     const auto &fg = node->func_graph();
1231     MS_EXCEPTION_IF_NULL(fg);
1232     constexpr auto internal_list_input = "__internal_list_input__";
1233     constexpr auto internal_target_input = "__internal_target_input__";
1234     std::stringstream script_buffer;
1235     script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_extend(" << internal_list_input
1236                   << ", " << internal_target_input << ")";
1237     const std::string &script = script_buffer.str();
1238     const auto script_str = std::make_shared<StringImm>(script);
1239     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1240     (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1241     (void)key_value_names_list.emplace_back(NewValueNode(internal_target_input));
1242     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1243 
1244     const auto &node_inputs = node->inputs();
1245     constexpr size_t min_node_inputs_size = 3;
1246     constexpr size_t max_node_inputs_size = 4;
1247     size_t inputs_size = node_inputs.size();
1248     if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1249       MS_LOG(EXCEPTION) << "The size of input to ListInplaceExtend should be " << min_node_inputs_size << " or "
1250                         << max_node_inputs_size << " but got " << inputs_size;
1251     }
1252     constexpr size_t node_list_index = 1;
1253     constexpr size_t node_target_index = 2;
1254     auto list_input_node = node_inputs[node_list_index];
1255     if (IsPrimitiveCNode(list_input_node, prim::kPrimMakeList)) {
1256       TraceGuard trace_guard(std::make_shared<TraceCopy>(list_input_node->debug_info()));
1257       auto new_node = ConvertMakeList(list_input_node->cast<CNodePtr>());
1258       (void)manager_->Replace(list_input_node, new_node);
1259       list_input_node = new_node;
1260     }
1261     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1262     (void)key_value_list.emplace_back(list_input_node);
1263     (void)key_value_list.emplace_back(node_inputs[node_target_index]);
1264     const auto key_value_tuple = fg->NewCNode(key_value_list);
1265 
1266     auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1267 
1268     if (inputs_size == max_node_inputs_size) {
1269       res->add_input(node_inputs[max_node_inputs_size - 1]);
1270     }
1271     res->set_debug_info(node->debug_info());
1272 
1273     MS_LOG(DEBUG) << "Convert list inplace append node to PyExecute node: " << res->DebugString();
1274     return res;
1275   }
1276 
1277   // x.insert(index, y) --> PyExecute(_jit_fallback_list_inplace_insert(x, index, y))
ConvertDictInplaceSetItem(const CNodePtr & node) const1278   AnfNodePtr ConvertDictInplaceSetItem(const CNodePtr &node) const {
1279     if (!fallback::EnableFallbackListDictInplace()) {
1280       return nullptr;
1281     }
1282 
1283     const auto &fg = node->func_graph();
1284     MS_EXCEPTION_IF_NULL(fg);
1285     constexpr auto internal_dict_input = "__internal_dict_input__";
1286     constexpr auto internal_key_input = "__internal_key_input__";
1287     constexpr auto internal_target_input = "__internal_target_input__";
1288     std::stringstream script_buffer;
1289     script_buffer << "__import__('mindspore').common._jit_fallback_utils.dict_inplace_setitem(" << internal_dict_input
1290                   << ", " << internal_key_input << ", " << internal_target_input << ")";
1291     const std::string &script = script_buffer.str();
1292     const auto script_str = std::make_shared<StringImm>(script);
1293     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1294     (void)key_value_names_list.emplace_back(NewValueNode(internal_dict_input));
1295     (void)key_value_names_list.emplace_back(NewValueNode(internal_key_input));
1296     (void)key_value_names_list.emplace_back(NewValueNode(internal_target_input));
1297     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1298 
1299     const auto &node_inputs = node->inputs();
1300     constexpr size_t min_node_inputs_size = 4;
1301     constexpr size_t max_node_inputs_size = 5;
1302     size_t inputs_size = node_inputs.size();
1303     if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1304       MS_LOG(EXCEPTION) << "The size of input to DictInplaceSetItem should be " << min_node_inputs_size << " or "
1305                         << max_node_inputs_size << " but got " << inputs_size;
1306     }
1307     constexpr size_t node_list_index = 1;
1308     constexpr size_t node_index_index = 2;
1309     constexpr size_t node_target_index = 3;
1310     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1311     (void)key_value_list.emplace_back(node_inputs[node_list_index]);
1312     (void)key_value_list.emplace_back(node_inputs[node_index_index]);
1313     (void)key_value_list.emplace_back(node_inputs[node_target_index]);
1314     const auto key_value_tuple = fg->NewCNode(key_value_list);
1315 
1316     auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1317 
1318     if (inputs_size == max_node_inputs_size) {
1319       res->add_input(node_inputs[max_node_inputs_size - 1]);
1320     }
1321 
1322     res->set_debug_info(node->debug_info());
1323 
1324     MS_LOG(DEBUG) << "Convert dict inplace setitem node to PyExecute node: " << res->DebugString();
1325     return res;
1326   }
1327 
1328   // x.pop(index) --> PyExecute(_jit_fallback_list_inplace_pop(x, index, y))
ConvertListInplacePop(const CNodePtr & node) const1329   AnfNodePtr ConvertListInplacePop(const CNodePtr &node) const {
1330     if (!fallback::EnableFallbackListDictInplace()) {
1331       return nullptr;
1332     }
1333 
1334     const auto &fg = node->func_graph();
1335     MS_EXCEPTION_IF_NULL(fg);
1336     constexpr auto internal_list_input = "__internal_list_input__";
1337     constexpr auto internal_index_input = "__internal_index_input__";
1338     std::stringstream script_buffer;
1339     script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_pop(" << internal_list_input
1340                   << ", " << internal_index_input << ")";
1341     const std::string &script = script_buffer.str();
1342     const auto script_str = std::make_shared<StringImm>(script);
1343     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1344     (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1345     (void)key_value_names_list.emplace_back(NewValueNode(internal_index_input));
1346     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1347 
1348     const auto &node_inputs = node->inputs();
1349     constexpr size_t min_node_inputs_size = 3;
1350     constexpr size_t max_node_inputs_size = 4;
1351     size_t inputs_size = node_inputs.size();
1352     if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1353       MS_LOG(EXCEPTION) << "The size of input to ListInplacePop should be " << min_node_inputs_size << " or "
1354                         << max_node_inputs_size << " but got " << inputs_size;
1355     }
1356     constexpr size_t node_list_index = 1;
1357     constexpr size_t node_index_index = 2;
1358     auto list_input_node = node_inputs[node_list_index];
1359     if (IsPrimitiveCNode(list_input_node, prim::kPrimMakeList)) {
1360       TraceGuard trace_guard(std::make_shared<TraceCopy>(list_input_node->debug_info()));
1361       auto new_node = ConvertMakeList(list_input_node->cast<CNodePtr>());
1362       (void)manager_->Replace(list_input_node, new_node);
1363       list_input_node = new_node;
1364     }
1365     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1366     (void)key_value_list.emplace_back(list_input_node);
1367     (void)key_value_list.emplace_back(node_inputs[node_index_index]);
1368     const auto key_value_tuple = fg->NewCNode(key_value_list);
1369 
1370     auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1371 
1372     if (inputs_size == max_node_inputs_size) {
1373       res->add_input(node_inputs[max_node_inputs_size - 1]);
1374     }
1375     res->set_debug_info(node->debug_info());
1376 
1377     MS_LOG(DEBUG) << "Convert list inplace pop node to PyExecute node: " << res->DebugString();
1378     return res;
1379   }
1380 
1381   // x.reverse() --> PyExecute(_jit_fallback_list_inplace_reverse(x))
ConvertListInplaceReverse(const CNodePtr & node) const1382   AnfNodePtr ConvertListInplaceReverse(const CNodePtr &node) const {
1383     if (!fallback::EnableFallbackListDictInplace()) {
1384       return nullptr;
1385     }
1386 
1387     const auto &fg = node->func_graph();
1388     MS_EXCEPTION_IF_NULL(fg);
1389     constexpr auto internal_list_input = "__internal_list_input__";
1390     std::stringstream script_buffer;
1391     script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_reverse(" << internal_list_input
1392                   << ")";
1393     const std::string &script = script_buffer.str();
1394     const auto script_str = std::make_shared<StringImm>(script);
1395     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1396     (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1397     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1398 
1399     const auto &node_inputs = node->inputs();
1400     constexpr size_t min_node_inputs_size = 2;
1401     constexpr size_t max_node_inputs_size = 3;
1402     size_t inputs_size = node_inputs.size();
1403     if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1404       MS_LOG(EXCEPTION) << "The size of input to ListInplaceAppend should be " << min_node_inputs_size << " or "
1405                         << max_node_inputs_size << " but got " << inputs_size;
1406     }
1407     constexpr size_t node_list_index = 1;
1408     auto list_input_node = node_inputs[node_list_index];
1409     if (IsPrimitiveCNode(list_input_node, prim::kPrimMakeList)) {
1410       TraceGuard trace_guard(std::make_shared<TraceCopy>(list_input_node->debug_info()));
1411       auto new_node = ConvertMakeList(list_input_node->cast<CNodePtr>());
1412       (void)manager_->Replace(list_input_node, new_node);
1413       list_input_node = new_node;
1414     }
1415     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1416     (void)key_value_list.emplace_back(list_input_node);
1417     const auto key_value_tuple = fg->NewCNode(key_value_list);
1418     auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1419 
1420     if (inputs_size == max_node_inputs_size) {
1421       res->add_input(node_inputs[max_node_inputs_size - 1]);
1422     }
1423     res->set_debug_info(node->debug_info());
1424 
1425     MS_LOG(DEBUG) << "Convert list inplace reverse node to PyExecute node: " << res->DebugString();
1426     return res;
1427   }
1428 
1429   // x.clear() --> PyExecute(_jit_fallback_list_inplace_clear(x))
ConvertListInplaceClear(const CNodePtr & node) const1430   AnfNodePtr ConvertListInplaceClear(const CNodePtr &node) const {
1431     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1432     if (!allow_fallback_runtime) {
1433       return nullptr;
1434     }
1435     static const auto allow_inplace_ops = common::GetCompileConfig("FALLBACK_SUPPORT_LIST_DICT_INPLACE") == "1";
1436     if (!allow_inplace_ops) {
1437       return nullptr;
1438     }
1439 
1440     const auto &fg = node->func_graph();
1441     MS_EXCEPTION_IF_NULL(fg);
1442     constexpr auto internal_list_input = "__internal_list_input__";
1443     std::stringstream script_buffer;
1444     script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_clear(" << internal_list_input
1445                   << ")";
1446     const std::string &script = script_buffer.str();
1447     const auto script_str = std::make_shared<StringImm>(script);
1448     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1449     (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1450     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1451 
1452     const auto &node_inputs = node->inputs();
1453     constexpr size_t node_inputs_size = 2;
1454     if (node_inputs.size() != node_inputs_size) {
1455       MS_LOG(EXCEPTION) << "The size of input to ListInplaceClear should be " << node_inputs_size << " but got "
1456                         << node_inputs.size();
1457     }
1458     constexpr size_t node_list_index = 1;
1459     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1460     (void)key_value_list.emplace_back(node_inputs[node_list_index]);
1461     const auto key_value_tuple = fg->NewCNode(key_value_list);
1462 
1463     auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1464     res->set_debug_info(node->debug_info());
1465 
1466     MS_LOG(DEBUG) << "Convert list inplace clear node to PyExecute node: " << res->DebugString();
1467     return res;
1468   }
1469 
1470   // data[key] = target --> PyExecute(_jit_fallback_dict_inplace_setitem(data, key, target))
ConvertListInplaceInsert(const CNodePtr & node) const1471   AnfNodePtr ConvertListInplaceInsert(const CNodePtr &node) const {
1472     if (!fallback::EnableFallbackListDictInplace()) {
1473       return nullptr;
1474     }
1475 
1476     const auto &fg = node->func_graph();
1477     MS_EXCEPTION_IF_NULL(fg);
1478     constexpr auto internal_list_input = "__internal_list_input__";
1479     constexpr auto internal_index_input = "__internal_index_input__";
1480     constexpr auto internal_target_input = "__internal_target_input__";
1481     std::stringstream script_buffer;
1482     script_buffer << "__import__('mindspore').common._jit_fallback_utils.list_inplace_insert(" << internal_list_input
1483                   << ", " << internal_index_input << ", " << internal_target_input << ")";
1484     const std::string &script = script_buffer.str();
1485     const auto script_str = std::make_shared<StringImm>(script);
1486     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1487     (void)key_value_names_list.emplace_back(NewValueNode(internal_list_input));
1488     (void)key_value_names_list.emplace_back(NewValueNode(internal_index_input));
1489     (void)key_value_names_list.emplace_back(NewValueNode(internal_target_input));
1490     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1491 
1492     const auto &node_inputs = node->inputs();
1493     constexpr size_t min_node_inputs_size = 4;
1494     constexpr size_t max_node_inputs_size = 5;
1495     size_t inputs_size = node_inputs.size();
1496     if (inputs_size != min_node_inputs_size && inputs_size != max_node_inputs_size) {
1497       MS_LOG(EXCEPTION) << "The size of input to ListInplaceInsert should be " << min_node_inputs_size << " or "
1498                         << max_node_inputs_size << " but got " << inputs_size;
1499     }
1500     constexpr size_t node_list_index = 1;
1501     constexpr size_t node_index_index = 2;
1502     constexpr size_t node_target_index = 3;
1503     auto list_input_node = node_inputs[node_list_index];
1504     if (IsPrimitiveCNode(list_input_node, prim::kPrimMakeList)) {
1505       TraceGuard trace_guard(std::make_shared<TraceCopy>(list_input_node->debug_info()));
1506       auto new_node = ConvertMakeList(list_input_node->cast<CNodePtr>());
1507       (void)manager_->Replace(list_input_node, new_node);
1508       list_input_node = new_node;
1509     }
1510     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1511     (void)key_value_list.emplace_back(list_input_node);
1512     (void)key_value_list.emplace_back(node_inputs[node_index_index]);
1513     (void)key_value_list.emplace_back(node_inputs[node_target_index]);
1514     const auto key_value_tuple = fg->NewCNode(key_value_list);
1515 
1516     auto res = fallback::CreatePyExecuteCNode(node, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1517 
1518     if (inputs_size == max_node_inputs_size) {
1519       res->add_input(node_inputs[max_node_inputs_size - 1]);
1520     }
1521     res->set_debug_info(node->debug_info());
1522 
1523     MS_LOG(DEBUG) << "Convert list inplace insert node to PyExecute node: " << res->DebugString();
1524     return res;
1525   }
1526 
1527   // TupleGetItem/ListGetItem(sequence, index) -> PyExecute(sequence[index], ...)
ConvertSequenceGetItem(const CNodePtr & node) const1528   AnfNodePtr ConvertSequenceGetItem(const CNodePtr &node) const { return ConvertSequenceGetItemInner(node); }
1529 
1530   // raise(string, keys, values, io) --> PyExecute(string, keys, values, io)
ConvertRaise(const CNodePtr & cnode) const1531   AnfNodePtr ConvertRaise(const CNodePtr &cnode) const {
1532     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1533     if (!allow_fallback_runtime) {
1534       MS_LOG(WARNING) << "When using the raise statement, it is best to set jit_syntax_level to LAX, "
1535                       << "because there is no the real raise operator.\n";
1536       return nullptr;
1537     }
1538     MS_EXCEPTION_IF_NULL(cnode);
1539     const auto &fg = cnode->func_graph();
1540     MS_EXCEPTION_IF_NULL(fg);
1541     MS_LOG(DEBUG) << "Raise node: " << cnode->DebugString();
1542     const auto &inputs = cnode->inputs();
1543     std::shared_ptr<raiseutils::KeyValueInfo> key_value = std::make_shared<raiseutils::KeyValueInfo>();
1544     key_value->keys = {NewValueNode(prim::kPrimMakeTuple)};
1545     key_value->values = {NewValueNode(prim::kPrimMakeTuple)};
1546     size_t index_begin = 2;
1547     constexpr auto end_num = 2;
1548     size_t index_end = inputs.size() - end_num;
1549     size_t size_if_empty = 4;
1550     std::string exception_type = raiseutils::GetExceptionType(inputs[1]->abstract(), inputs[index_end], key_value);
1551     std::string exception_string;
1552     // Process raise ValueError()
1553     if (inputs.size() == size_if_empty) {
1554       std::string key = raiseutils::MakeRaiseKey(key_value->num_str);
1555       (void)key_value->keys.emplace_back(NewValueNode(std::make_shared<StringImm>(key)));
1556       (void)key_value->values.emplace_back(NewValueNode(std::make_shared<StringImm>("")));
1557       exception_string = key;
1558     }
1559     // Processed in units of nodes. Raise ValueError(xxxx)
1560     for (size_t index = index_begin; index < index_end; ++index) {
1561       const auto input = inputs[index];
1562       auto input_abs = input->abstract();
1563       MS_EXCEPTION_IF_NULL(input_abs);
1564       const bool need_symbol = raiseutils::CheckNeedSymbol(input_abs);
1565       if (need_symbol) {
1566         exception_string += "'";
1567       }
1568       bool need_comma = !IsPrimitiveCNode(input, prim::kPrimMakeTuple);
1569       exception_string += raiseutils::GetExceptionString(input_abs, input, key_value, need_symbol, need_comma);
1570       if (need_symbol) {
1571         exception_string += "'";
1572       }
1573       if (index != inputs.size() - 1) {
1574         exception_string += ", ";
1575       }
1576     }
1577     bool need_out_symbol = inputs.size() > 5;
1578     if (need_out_symbol) {
1579       exception_string = "(" + exception_string + ")";
1580     }
1581     // Condition has variable but script does not.
1582     if (key_value->keys.size() <= 1) {
1583       std::string key = raiseutils::MakeRaiseKey(key_value->num_str);
1584       (void)key_value->keys.emplace_back(NewValueNode(std::make_shared<StringImm>(key)));
1585       (void)key_value->values.emplace_back(NewValueNode(std::make_shared<StringImm>(exception_string)));
1586       exception_string = key;
1587     }
1588     // Build PyExecute node for raise
1589     const std::string error_msg =
1590       "__import__('mindspore').common._utils._jit_fallback_raise_func(" + exception_type + "," + exception_string + ")";
1591     const auto script_str = std::make_shared<StringImm>(error_msg);
1592     // Pack local parameter keys
1593     const auto key_value_name_tuple = fg->NewCNodeInOrder(key_value->keys);
1594     // Pack local parameter values
1595     const auto key_value_tuple = fg->NewCNodeInOrder(key_value->values);
1596     // Build the PyExecute node for raise error.
1597     const auto raise_pyexecute_node = fallback::CreatePyExecuteCNodeInOrder(
1598       fg, NewValueNode(script_str), key_value_name_tuple, key_value_tuple, cnode->debug_info());
1599     raise_pyexecute_node->add_input(inputs[inputs.size() - 1]);
1600     auto old_abs = cnode->abstract();
1601     MS_EXCEPTION_IF_NULL(old_abs);
1602     const auto &type = old_abs->BuildType();
1603     MS_EXCEPTION_IF_NULL(type);
1604     fallback::SetRealType(raise_pyexecute_node, type);
1605     MS_LOG(DEBUG) << "Raise convert to PyExecute node: " << raise_pyexecute_node->DebugString();
1606     return raise_pyexecute_node;
1607   }
1608 
1609   // ScalarCast(x, dtype) --> PyExecute(string, keys, values)
ConvertScalarCast(const CNodePtr & cnode) const1610   AnfNodePtr ConvertScalarCast(const CNodePtr &cnode) const {
1611     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1612     if (!allow_fallback_runtime) {
1613       MS_LOG(WARNING) << "When using the ScalarCast statement with some syntaxes that is not supported in graph mode, "
1614                       << "it is best to set jit_syntax_level to LAX.\n";
1615       return nullptr;
1616     }
1617     constexpr size_t x_index = 1;
1618     constexpr size_t dtype_index = 2;
1619     auto x_node = cnode->input(x_index);
1620     auto dtype_node = cnode->input(dtype_index);
1621     auto x_abs = GetAbstract<abstract::AbstractAny>(x_node);
1622     if (x_abs == nullptr) {
1623       return nullptr;
1624     }
1625     auto dtype_abs = GetAbstract<abstract::AbstractScalar>(dtype_node);
1626     MS_EXCEPTION_IF_NULL(dtype_abs);
1627     auto dtype_val = dtype_abs->GetValue();
1628     MS_EXCEPTION_IF_NULL(dtype_val);
1629     auto type_id_opt = ops::GetScalarValue<int64_t>(dtype_val);
1630     if (!type_id_opt.has_value()) {
1631       MS_LOG(EXCEPTION) << "the dtype input is invalid!";
1632     }
1633     std::string target_type_str;
1634     auto type_id = type_id_opt.value();
1635     if (type_id == kNumberTypeInt) {
1636       target_type_str = "int";
1637     } else if (type_id == kNumberTypeFloat) {
1638       target_type_str = "float";
1639     } else if (type_id == kNumberTypeBool) {
1640       target_type_str = "bool";
1641     } else {
1642       MS_LOG(EXCEPTION) << "Unsupported type: " << type_id;
1643     }
1644 
1645     const auto &fg = cnode->func_graph();
1646     MS_EXCEPTION_IF_NULL(fg);
1647     std::string internal_scalar_arg_str = "__internal_scalar_arg__";
1648     std::string script = target_type_str + "(" + internal_scalar_arg_str + ")";
1649     auto script_node = NewValueNode(std::make_shared<StringImm>(script));
1650     auto arg_name_node = NewValueNode(std::make_shared<StringImm>(internal_scalar_arg_str));
1651     auto keys_tuple_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), arg_name_node});
1652     auto values_tuple_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), x_node});
1653     keys_tuple_node->set_debug_info(cnode->debug_info());
1654     values_tuple_node->set_debug_info(cnode->debug_info());
1655     auto scalar_cast_node =
1656       fallback::CreatePyExecuteCNodeInOrder(cnode, script_node, keys_tuple_node, values_tuple_node);
1657     MS_LOG(DEBUG) << "Convert CastToScalar: " << cnode->DebugString() << " -> " << scalar_cast_node->DebugString();
1658     return scalar_cast_node;
1659   }
1660 
ConvertMakeSlice(const CNodePtr & cnode) const1661   AnfNodePtr ConvertMakeSlice(const CNodePtr &cnode) const {
1662     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1663     if (!allow_fallback_runtime) {
1664       MS_LOG(WARNING) << "When using the MakeSlice statement with some syntaxes that is not supported in graph mode, "
1665                       << "it is best to set jit_syntax_level to LAX.\n";
1666       return nullptr;
1667     }
1668     MS_EXCEPTION_IF_NULL(cnode);
1669     const auto &fg = cnode->func_graph();
1670     MS_EXCEPTION_IF_NULL(fg);
1671     MS_LOG(DEBUG) << " make_slice node: " << cnode->DebugString();
1672     constexpr size_t slice_size = 4;
1673     if (cnode->size() != slice_size) {
1674       MS_LOG(INTERNAL_EXCEPTION) << "The size of input to make_slice should be " << slice_size << ", but got "
1675                                  << cnode->size();
1676     }
1677     constexpr size_t start_index = 1;
1678     constexpr size_t stop_index = 2;
1679     constexpr size_t step_index = 3;
1680     bool is_start_none = IsValueNode<None>(cnode->input(start_index));
1681     bool is_stop_none = IsValueNode<None>(cnode->input(stop_index));
1682     bool is_step_none = IsValueNode<None>(cnode->input(step_index));
1683     auto start_str = is_start_none ? "None" : "__start__";
1684     auto stop_str = is_stop_none ? "None" : "__stop__";
1685     auto step_str = is_step_none ? "None" : "__step__";
1686     // Script
1687     std::stringstream script_buffer;
1688     script_buffer << "slice(" << start_str << ", " << stop_str << ", " << step_str << ")";
1689     const std::string &script = script_buffer.str();
1690     const auto script_str = std::make_shared<StringImm>(script);
1691 
1692     // Pack local parameters keys and values.
1693     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1694     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1695     if (!is_start_none) {
1696       (void)key_value_names_list.emplace_back(NewValueNode(start_str));
1697       (void)key_value_list.emplace_back(cnode->input(start_index));
1698     }
1699     if (!is_stop_none) {
1700       (void)key_value_names_list.emplace_back(NewValueNode(stop_str));
1701       (void)key_value_list.emplace_back(cnode->input(stop_index));
1702     }
1703     if (!is_step_none) {
1704       (void)key_value_names_list.emplace_back(NewValueNode(step_str));
1705       (void)key_value_list.emplace_back(cnode->input(step_index));
1706     }
1707     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1708     const auto key_value_tuple = fg->NewCNode(key_value_list);
1709 
1710     // Build the new slice node.
1711     const auto slice_node =
1712       fallback::CreatePyExecuteCNodeInOrder(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1713     MS_LOG(DEBUG) << "Made slice node: " << slice_node->DebugString();
1714     return slice_node;
1715   }
1716 
1717   // Only process the node that have a PyExecute node(the abstract is AbstractAny).
CheckInputsHasAnyType(const CNodePtr & cnode) const1718   bool CheckInputsHasAnyType(const CNodePtr &cnode) const {
1719     bool exist_any_type = false;
1720     for (const auto &weak_input : cnode->weak_inputs()) {
1721       auto input = weak_input.lock();
1722       auto input_abs = input->abstract();
1723       if (fallback::ContainsSequenceAnyType(input_abs)) {
1724         exist_any_type = true;
1725         break;
1726       }
1727     }
1728     return exist_any_type;
1729   }
1730 
ConvertIsInstance(const CNodePtr & cnode) const1731   AnfNodePtr ConvertIsInstance(const CNodePtr &cnode) const {
1732     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1733     if (!allow_fallback_runtime) {
1734       MS_LOG(WARNING) << "When using the isinstance statement, it is best to set jit_syntax_level to LAX, "
1735                       << "because there is no the real isinstance operator.\n";
1736       return nullptr;
1737     }
1738     const auto &fg = cnode->func_graph();
1739     MS_EXCEPTION_IF_NULL(fg);
1740     if (!CheckInputsHasAnyType(cnode)) {
1741       return nullptr;
1742     }
1743     const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1744     MS_EXCEPTION_IF_NULL(prim);
1745     string name = prim->name();
1746     auto pyexecute_node = fallback::ConvertCNodeToPyExecuteForPrim(cnode, name);
1747     MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << pyexecute_node->DebugString();
1748     return pyexecute_node;
1749   }
1750 
1751   // JoinedStr(XXXXXX)
1752   // TO
1753   // A = PyExecute("list(map(str, __inner_convert_object__), ("__inner_convert_object__",), ((XXXXXX,),)")
1754   // B = PyExecute("".join(__inner_str_list__)", ("__inner_str_list__",), (A,)).
1755   // replace(B --> JoinedStr)
ConvertJoinedStr(const CNodePtr & cnode) const1756   AnfNodePtr ConvertJoinedStr(const CNodePtr &cnode) const {
1757     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1758     if (!allow_fallback_runtime) {
1759       MS_LOG(WARNING) << "When using the JoinedStr statement, it is best to set jit_syntax_level to LAX, "
1760                       << "because there is no the real JoinedStr operator.\n";
1761       return nullptr;
1762     }
1763     MS_EXCEPTION_IF_NULL(cnode);
1764     const auto &fg = cnode->func_graph();
1765     MS_EXCEPTION_IF_NULL(fg);
1766     MS_LOG(DEBUG) << " make_slice node: " << cnode->DebugString();
1767     // Convert all node to list[str]
1768     constexpr auto kConvertToListString = "list(map(str, __inner_convert_object__))";
1769     constexpr auto kConvertToListKey = "__inner_convert_object__";
1770     const auto make_tuple_value_node = NewValueNode(prim::kPrimMakeTuple);
1771     AnfNodeWeakPtrList list_str_value_list = {make_tuple_value_node};
1772     (void)std::copy(cnode->weak_inputs().cbegin() + 1, cnode->weak_inputs().cend(),
1773                     std::back_inserter(list_str_value_list));
1774 
1775     const auto make_tuple_key_node = NewValueNode(prim::kPrimMakeTuple);
1776     const auto key_node = NewValueNode(kConvertToListKey);
1777     AnfNodeWeakPtrList list_str_key_list = {make_tuple_key_node, key_node};
1778     auto list_str_key_node = fg->NewCNodeWeak(list_str_key_list);
1779     auto list_str_value_node = fg->NewCNodeWeak(list_str_value_list);
1780     auto convet_list_str_node = fallback::CreatePyExecuteCNodeInOrder(
1781       fg, NewValueNode(kConvertToListString), list_str_key_node,
1782       fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), list_str_value_node}), cnode->debug_info());
1783 
1784     // change to string.
1785     constexpr auto eval_string_script = "\"\".join(__inner_str_list__)";
1786     constexpr auto eval_key_string = "__inner_str_list__";
1787     auto eval_key_node = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(eval_key_string)});
1788     auto eval_value_node = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), convet_list_str_node});
1789 
1790     auto joined_result_node = fallback::CreatePyExecuteCNode(fg, NewValueNode(eval_string_script), eval_key_node,
1791                                                              eval_value_node, cnode->debug_info());
1792     return joined_result_node;
1793   }
1794 
HasPyExecuteInput(const CNodePtr & cnode) const1795   bool HasPyExecuteInput(const CNodePtr &cnode) const {
1796     MS_EXCEPTION_IF_NULL(cnode);
1797     const auto &inputs = cnode->inputs();
1798     for (auto &input : inputs) {
1799       if (IsPrimitiveCNode(input, prim::kPrimPyExecute)) {
1800         return true;
1801       }
1802     }
1803     return false;
1804   }
1805 
ConvertPrint(const CNodePtr & cnode) const1806   AnfNodePtr ConvertPrint(const CNodePtr &cnode) const {
1807     const auto &fg = cnode->func_graph();
1808     MS_EXCEPTION_IF_NULL(fg);
1809     if (!CheckInputsHasAnyType(cnode) && !HasPyExecuteInput(cnode)) {
1810       return nullptr;
1811     }
1812     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1813     if (!allow_fallback_runtime) {
1814       MS_LOG(WARNING) << "When using the print statement with some syntaxes that is not supported in graph mode, "
1815                       << "it is best to set jit_syntax_level to LAX.\n";
1816       return nullptr;
1817     }
1818     // Skip the io_monad input
1819     auto inputs = cnode->inputs();
1820     if (!HasAbstractMonad(inputs.back())) {
1821       MS_LOG(EXCEPTION) << "The print node has no monad input:" << cnode->DebugString();
1822     }
1823     inputs.pop_back();
1824     auto no_io_print = fg->NewCNode(inputs);
1825     auto pyexecute_node = fallback::ConvertCNodeToPyExecuteForPrim(no_io_print, "print");
1826 
1827     // Add io_monad input
1828     auto new_pyexecute_inputs = pyexecute_node->cast<CNodePtr>()->inputs();
1829     (void)new_pyexecute_inputs.emplace_back(cnode->inputs().back());
1830     auto new_pyexecute_node = fg->NewCNode(new_pyexecute_inputs);
1831     MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << new_pyexecute_node->DebugString();
1832     return new_pyexecute_node;
1833   }
1834   // Format(str, XXXX) Convert to PyExecute
1835   // First Spilt XXXX to dict input when the args is KWargs, otherwise push it to a list.And Then Convert To PyExecute
1836   // A = MakeDict(XXXX[KWargs]->keys(), XXXX[KWargs]->values()) --> This Dict will convert to PyExecute use function
1837   // ConvertMakeDict. B = Tuple(XXXX - XXXX[KWargs]) ps: this sub operator is set sub. C =
1838   // PyExecute("__inner_str__.format(*__format_list_str__, **__format_kwargs__str__)"
1839   //        , (__inner_str__, __format_list_str__, __format_kwargs__str__), (str, B, A));
1840   // Replace(C -> Format).
ConvertFormat(const CNodePtr & cnode) const1841   AnfNodePtr ConvertFormat(const CNodePtr &cnode) const {
1842     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1843     if (!allow_fallback_runtime) {
1844       MS_LOG(WARNING) << "When using the format statement with some syntaxes that is not supported in graph mode, "
1845                       << "it is best to set jit_syntax_level to LAX.\n";
1846       return nullptr;
1847     }
1848     auto fg = cnode->func_graph();
1849     MS_EXCEPTION_IF_NULL(fg);
1850 
1851     std::vector<AnfNodePtr> format_list = {NewValueNode(prim::kPrimMakeTuple)};
1852 
1853     std::vector<AnfNodePtr> kwargs_keys_node = {NewValueNode(prim::kPrimMakeTuple)};
1854     std::vector<AnfNodePtr> kwargs_values_node = {NewValueNode(prim::kPrimMakeTuple)};
1855     auto inputs = cnode->inputs();
1856     constexpr auto kFormatArgsIndex = 2;
1857     constexpr auto kStringArgsIndex = 1;
1858     for (size_t i = kFormatArgsIndex; i < inputs.size(); ++i) {
1859       auto input = inputs[i];
1860       MS_EXCEPTION_IF_NULL(input);
1861       auto abs = input->abstract();
1862       if (abs != nullptr && abs->isa<abstract::AbstractKeywordArg>()) {
1863         auto [key, arg] = ExtractKwargsNode(input);
1864         (void)kwargs_keys_node.emplace_back(key);
1865         (void)kwargs_values_node.emplace_back(arg);
1866       } else {
1867         format_list.emplace_back(inputs[i]);
1868       }
1869     }
1870     // Construct kwargs node
1871     auto dict_key_node = fg->NewCNode(kwargs_keys_node);
1872     dict_key_node->set_debug_info(cnode->debug_info());
1873     auto dict_value_node = fg->NewCNode(kwargs_values_node);
1874     dict_value_node->set_debug_info(cnode->debug_info());
1875     auto dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), dict_key_node, dict_value_node});
1876     dict_node->set_debug_info(cnode->debug_info());
1877     auto py_exec_dict_node = ConvertMakeDict(dict_node);
1878     // Construct list args node
1879     auto list_node = fg->NewCNode(format_list);
1880     list_node->set_debug_info(cnode->debug_info());
1881     // Construct PyExecute node
1882     constexpr auto inner_str = "__inner_str__";
1883     constexpr auto format_list_str = "__format_list_str__";
1884     constexpr auto format_kwargs_str = "__format_kwargs__str__";
1885     std::stringstream script_buffer;
1886     script_buffer << inner_str << ".format(*" << format_list_str << ", **" << format_kwargs_str << ")";
1887 
1888     std::vector<ValuePtr> key_values = {MakeValue(inner_str), MakeValue(format_list_str), MakeValue(format_kwargs_str)};
1889     auto intrepret_node_keys = NewValueNode(std::make_shared<ValueTuple>(key_values));
1890     auto intrepert_node_values =
1891       fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), inputs.at(kStringArgsIndex), list_node, py_exec_dict_node});
1892     intrepert_node_values->set_debug_info(cnode->debug_info());
1893     auto convert_node = fallback::CreatePyExecuteCNode(fg, NewValueNode(MakeValue(script_buffer.str())),
1894                                                        intrepret_node_keys, intrepert_node_values, cnode->debug_info());
1895     return convert_node;
1896   }
1897 
ConvertMakeRange(const CNodePtr & cnode) const1898   AnfNodePtr ConvertMakeRange(const CNodePtr &cnode) const {
1899     const auto &fg = cnode->func_graph();
1900     MS_EXCEPTION_IF_NULL(fg);
1901     if (!CheckInputsHasAnyType(cnode) && !HasPyExecuteInput(cnode)) {
1902       return nullptr;
1903     }
1904     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1905     if (!allow_fallback_runtime) {
1906       MS_LOG(WARNING) << "When using the range statement with some syntaxes that is not supported in graph mode, "
1907                       << "it is best to set jit_syntax_level to LAX.\n";
1908       return nullptr;
1909     }
1910     auto pyexecute_node = fallback::ConvertCNodeToPyExecuteForPrim(cnode, "range");
1911     MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << pyexecute_node->DebugString();
1912     return pyexecute_node;
1913   }
1914 
ConvertIsAndIsNot(const CNodePtr & cnode,bool is) const1915   AnfNodePtr ConvertIsAndIsNot(const CNodePtr &cnode, bool is) const {
1916     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
1917     if (!allow_fallback_runtime) {
1918       MS_LOG(WARNING) << "When using the is/is_not statement with some syntaxes that is not supported in graph mode, "
1919                       << "it is best to set jit_syntax_level to LAX.\n";
1920       return nullptr;
1921     }
1922     const auto &fg = cnode->func_graph();
1923     MS_EXCEPTION_IF_NULL(fg);
1924 
1925     constexpr auto data_str = "__data__";
1926     constexpr auto target_str = "__target__";
1927     std::stringstream script_buffer;
1928     script_buffer << data_str;
1929     if (is) {
1930       script_buffer << " is ";
1931     } else {
1932       script_buffer << " is not ";
1933     }
1934     script_buffer << target_str;
1935     const std::string &script = script_buffer.str();
1936     const auto script_str = std::make_shared<StringImm>(script);
1937     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
1938     (void)key_value_names_list.emplace_back(NewValueNode(data_str));
1939     (void)key_value_names_list.emplace_back(NewValueNode(target_str));
1940     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
1941 
1942     const auto &node_inputs = cnode->inputs();
1943     constexpr size_t inputs_size = 3;
1944     if (node_inputs.size() != inputs_size) {
1945       MS_LOG(INTERNAL_EXCEPTION) << "The size of input to kPrimIs should be " << inputs_size << "but got "
1946                                  << node_inputs.size();
1947     }
1948     constexpr size_t data_index = 1;
1949     constexpr size_t target_index = 2;
1950     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
1951     (void)key_value_list.emplace_back(node_inputs[data_index]);
1952     (void)key_value_list.emplace_back(node_inputs[target_index]);
1953     const auto key_value_tuple = fg->NewCNode(key_value_list);
1954 
1955     auto res = fallback::CreatePyExecuteCNode(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
1956     res->set_debug_info(cnode->debug_info());
1957     return res;
1958   }
1959 
ConvertIs_(const CNodePtr & cnode) const1960   AnfNodePtr ConvertIs_(const CNodePtr &cnode) const {
1961     auto res = ConvertIsAndIsNot(cnode, true);
1962     MS_LOG(DEBUG) << "Convert primitive Is_ to PyExecute node: " << res->DebugString();
1963     return res;
1964   }
1965 
ConvertIsNot(const CNodePtr & cnode) const1966   AnfNodePtr ConvertIsNot(const CNodePtr &cnode) const {
1967     auto res = ConvertIsAndIsNot(cnode, false);
1968     MS_LOG(DEBUG) << "Convert primitive IsNot to PyExecute node: " << res->DebugString();
1969     return res;
1970   }
1971 
1972   using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &) const;
1973   using ConverterMap = std::unordered_map<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
1974   static inline const ConverterMap converters_{
1975     // SparseProcess: 1.MakeSparse->MakeTuple 2.SparseGetAttr->TupleGetItem
1976     {prim::kPrimMakeRowTensor, &ThisClass::ConvertMakeSparseToMakeTuple},
1977     {prim::kPrimRowTensorGetIndices, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1978     {prim::kPrimRowTensorGetValues, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1979     {prim::kPrimRowTensorGetDenseShape, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1980     {prim::kPrimMakeCSRTensor, &ThisClass::ConvertMakeSparseToMakeTuple},
1981     {prim::kPrimCSRTensorGetIndptr, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1982     {prim::kPrimCSRTensorGetIndices, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1983     {prim::kPrimCSRTensorGetValues, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1984     {prim::kPrimCSRTensorGetDenseShape, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1985     {prim::kPrimMakeCOOTensor, &ThisClass::ConvertMakeSparseToMakeTuple},
1986     {prim::kPrimCOOTensorGetIndices, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1987     {prim::kPrimCOOTensorGetValues, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1988     {prim::kPrimCOOTensorGetDenseShape, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
1989     {prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItem},
1990     {prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItem},
1991     {prim::kPrimListInplaceExtend, &ThisClass::ConvertListInplaceExtend},
1992     {prim::kPrimListInplaceInsert, &ThisClass::ConvertListInplaceInsert},
1993     {prim::kPrimListInplacePop, &ThisClass::ConvertListInplacePop},
1994     {prim::kPrimListInplaceReverse, &ThisClass::ConvertListInplaceReverse},
1995     {prim::kPrimListInplaceClear, &ThisClass::ConvertListInplaceClear},
1996     {prim::kPrimDictInplaceSetItem, &ThisClass::ConvertDictInplaceSetItem},
1997     {prim::kPrimListGetItem, &ThisClass::ConvertSequenceGetItem},
1998     {prim::kPrimTupleGetItem, &ThisClass::ConvertSequenceGetItem},
1999     {prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
2000     {prim::kPrimRaise, &ThisClass::ConvertRaise},
2001     {prim::kPrimScalarCast, &ThisClass::ConvertScalarCast},
2002     {prim::kPrimMakeSlice, &ThisClass::ConvertMakeSlice},
2003     {prim::kPrimIsInstance, &ThisClass::ConvertIsInstance},
2004     {prim::kPrimJoinedStr, &ThisClass::ConvertJoinedStr},
2005     {prim::kPrimPrint, &ThisClass::ConvertPrint},
2006     {prim::kPrimFormat, &ThisClass::ConvertFormat},
2007     {prim::kPrimMakeRange, &ThisClass::ConvertMakeRange},
2008     {prim::kPrimIs_, &ThisClass::ConvertIs_},
2009     {prim::kPrimIsNot, &ThisClass::ConvertIsNot}};
2010 
2011   static inline const PrimitiveSet seq_prim_set_{
2012     prim::kPrimInSequence,      prim::kPrimSequenceMul,       prim::kPrimSequenceCount,    prim::kPrimSequenceIndex,
2013     prim::kPrimSequenceLen,     prim::kPrimListEqual,         prim::kPrimTupleEqual,       prim::kPrimTupleGreaterThan,
2014     prim::kPrimListLessEqual,   prim::kPrimTupleLessThan,     prim::kPrimListLessThan,     prim::kPrimTupleLessEqual,
2015     prim::kPrimListGreaterThan, prim::kPrimTupleGreaterEqual, prim::kPrimListGreaterEqual, prim::kPrimSequenceSlice};
2016 
2017   // Convert ValueNode<None> to PyExecute("None", ("None"), ("None")).
ConvertNoneToPyExecute(const FuncGraphPtr & func_graph)2018   AnfNodePtr ConvertNoneToPyExecute(const FuncGraphPtr &func_graph) {
2019     MS_EXCEPTION_IF_NULL(func_graph);
2020     auto str_value = std::make_shared<StringImm>("None");
2021     auto script_node = NewValueNode(str_value);
2022 
2023     std::vector<ValuePtr> none_value{str_value};
2024     const auto none_tuple = std::make_shared<ValueTuple>(none_value);
2025     auto none_tuple_node = NewValueNode(none_tuple);
2026     AbstractBasePtrList abs_list{std::make_shared<abstract::AbstractScalar>(MakeValue("None"))};
2027     none_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
2028 
2029     AnfNodePtr none_execute_node = fallback::CreatePyExecuteCNodeInOrder(
2030       func_graph, script_node, none_tuple_node, none_tuple_node, none_tuple_node->debug_info());
2031     MS_LOG(DEBUG) << "none_execute_node:" << none_execute_node->DebugString();
2032 
2033     set_need_renormalized(true);
2034     return none_execute_node;
2035   }
2036 
GetPyExecuteFromValueSequence(const FuncGraphPtr & fg,const ValueNodePtr & value_node,const ValueSequencePtr & value_sequence,const PrimitivePtr & prim,bool py_execute_input)2037   AnfNodePtr GetPyExecuteFromValueSequence(const FuncGraphPtr &fg, const ValueNodePtr &value_node,
2038                                            const ValueSequencePtr &value_sequence, const PrimitivePtr &prim,
2039                                            bool py_execute_input) {
2040     std::vector<AnfNodePtr> new_inputs;
2041     new_inputs.reserve(value_sequence->size());
2042     (void)new_inputs.emplace_back(NewValueNode(prim));
2043     bool changed = false;
2044     auto abs = value_node->abstract();
2045     if (abs == nullptr) {
2046       for (const auto &v : value_sequence->value()) {
2047         auto v_node = NewValueNode(v);
2048         v_node->set_debug_info(value_node->debug_info());
2049         auto new_node = GetPyExecuteFromValue(fg, v_node, v, py_execute_input);
2050         new_node->set_debug_info(value_node->debug_info());
2051         (void)new_inputs.emplace_back(new_node);
2052         if (new_node != v_node) {
2053           changed = true;
2054         }
2055       }
2056     } else {
2057       auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
2058       MS_EXCEPTION_IF_NULL(abs_seq);
2059       const auto &abs_seq_elements = abs_seq->elements();
2060       const auto &value_sequence_values = value_sequence->value();
2061       if (abs_seq_elements.size() != value_sequence_values.size()) {
2062         MS_LOG(EXCEPTION) << "The size of value sequence should be same as the size of abstract sequence.";
2063       }
2064       for (size_t i = 0; i < value_sequence_values.size(); ++i) {
2065         auto v = value_sequence_values[i];
2066         auto v_node = NewValueNode(v);
2067         v_node->set_debug_info(value_node->debug_info());
2068         v_node->set_abstract(abs_seq_elements[i]);
2069         auto new_node = GetPyExecuteFromValue(fg, v_node, v, py_execute_input);
2070         new_node->set_debug_info(value_node->debug_info());
2071         (void)new_inputs.emplace_back(new_node);
2072         if (new_node != v_node) {
2073           changed = true;
2074         }
2075       }
2076     }
2077     if (changed) {
2078       auto ret = fg->NewCNode(new_inputs);
2079       ret->set_abstract(value_node->abstract());
2080       return ret;
2081     }
2082     return value_node;
2083   }
2084 
ConvertTypeToPyExecute(const FuncGraphPtr & fg,const ValueNodePtr & node,const TypePtr & type) const2085   AnfNodePtr ConvertTypeToPyExecute(const FuncGraphPtr &fg, const ValueNodePtr &node, const TypePtr &type) const {
2086     // Support convert type to PyExecute.
2087     const auto py_type = ValueToPyData(type);
2088     MS_LOG(DEBUG) << "py_type: " << py_type;
2089     auto res = fallback::ConvertPyObjectToPyExecute(fg, py::str(py_type).cast<std::string>(), py_type, node, false);
2090     fallback::SetRealType(res, type);
2091     return res;
2092   }
2093 
ConvertClassTypeToPyExecute(const FuncGraphPtr & fg,const ValueNodePtr & node,const ClassTypePtr & class_type) const2094   AnfNodePtr ConvertClassTypeToPyExecute(const FuncGraphPtr &fg, const ValueNodePtr &node,
2095                                          const ClassTypePtr &class_type) const {
2096     // Support convert class type to PyExecute.
2097     const auto py_type = ValueToPyData(class_type);
2098     MS_LOG(DEBUG) << "py_type: " << py_type;
2099     auto res = fallback::ConvertPyObjectToPyExecute(fg, py::str(py_type).cast<std::string>(), py_type, node, true);
2100     fallback::SetRealType(res, class_type);
2101     MS_LOG(DEBUG) << "res: " << res->DebugString();
2102     return res;
2103   }
2104 
ConvertNameSpaceToPyExecute(const FuncGraphPtr & fg,const ValueNodePtr & node,const parse::NameSpacePtr & name_space) const2105   AnfNodePtr ConvertNameSpaceToPyExecute(const FuncGraphPtr &fg, const ValueNodePtr &node,
2106                                          const parse::NameSpacePtr &name_space) const {
2107     // Support convert namespace to PyExecute.
2108     const auto name_space_type = ValueToPyData(name_space);
2109     MS_LOG(DEBUG) << "name_space_type: " << name_space_type;
2110     auto res = fallback::ConvertPyObjectToPyExecute(fg, py::str(name_space_type).cast<std::string>(), name_space_type,
2111                                                     node, true);
2112     fallback::SetRealType(res, name_space);
2113     MS_LOG(DEBUG) << "res: " << res->DebugString();
2114     return res;
2115   }
2116 
IsValueListWithInplace(const ValueNodePtr & value_node) const2117   bool IsValueListWithInplace(const ValueNodePtr &value_node) const {
2118     if (!fallback::EnableFallbackListDictInplace()) {
2119       return false;
2120     }
2121 
2122     MS_EXCEPTION_IF_NULL(value_node);
2123     auto abs = value_node->abstract();
2124     MS_EXCEPTION_IF_NULL(abs);
2125     auto list_abs = abs->cast<abstract::AbstractListPtr>();
2126     MS_EXCEPTION_IF_NULL(list_abs);
2127     if (!fallback::HasObjInExtraInfoHolder(list_abs)) {
2128       return false;
2129     }
2130     py::list list_object = fallback::GetObjFromExtraInfoHolder(list_abs);
2131     // The value list  do not need to convert to PyExecute if:
2132     //   1. The list is created within graph.
2133     //   2. The list and its elements do not perform any inplace operation.
2134     if (fallback::GetCreateInGraphFromExtraInfoHolder(list_abs) && !CheckSeqWithInplace(list_object)) {
2135       return false;
2136     }
2137     return true;
2138   }
2139 
ConvertValueSlice(const FuncGraphPtr & func_graph,const AnfNodePtr & slice_node,const ValueSlicePtr & value_slice)2140   AnfNodePtr ConvertValueSlice(const FuncGraphPtr &func_graph, const AnfNodePtr &slice_node,
2141                                const ValueSlicePtr &value_slice) {
2142     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
2143     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
2144     bool is_start_none = value_slice->start()->isa<None>();
2145     bool is_stop_none = value_slice->stop()->isa<None>();
2146     bool is_step_none = value_slice->step()->isa<None>();
2147     auto start_str = is_start_none ? "None" : "__start__";
2148     auto stop_str = is_stop_none ? "None" : "__stop__";
2149     auto step_str = is_step_none ? "None" : "__step__";
2150     // Script
2151     std::stringstream script_buffer;
2152     script_buffer << "slice(" << start_str << ", " << stop_str << ", " << step_str << ")";
2153     const std::string &script = script_buffer.str();
2154     const auto script_str = std::make_shared<StringImm>(script);
2155 
2156     // Pack local parameters keys and values.
2157     (void)key_value_names_list.emplace_back(NewValueNode(start_str));
2158     (void)key_value_names_list.emplace_back(NewValueNode(stop_str));
2159     (void)key_value_names_list.emplace_back(NewValueNode(step_str));
2160     AnfNodePtr start_node;
2161     AnfNodePtr end_node;
2162     AnfNodePtr step_node;
2163     if (!is_start_none) {
2164       start_node = func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), slice_node, NewValueNode("start")});
2165     } else {
2166       start_node = NewValueNode(start_str);
2167     }
2168     if (!is_stop_none) {
2169       end_node = func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), slice_node, NewValueNode("stop")});
2170     } else {
2171       end_node = NewValueNode(stop_str);
2172     }
2173     if (!is_step_none) {
2174       step_node = func_graph->NewCNode({NewValueNode(prim::kPrimSliceGetItem), slice_node, NewValueNode("step")});
2175     } else {
2176       step_node = NewValueNode(stop_str);
2177     }
2178     (void)key_value_list.emplace_back(start_node);
2179     (void)key_value_list.emplace_back(end_node);
2180     (void)key_value_list.emplace_back(step_node);
2181     const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
2182     const auto key_value_tuple = func_graph->NewCNode(key_value_list);
2183     return fallback::CreatePyExecuteCNodeInOrder(func_graph, NewValueNode(script_str), key_value_name_tuple,
2184                                                  key_value_tuple, key_value_tuple->debug_info());
2185   }
2186 
GetPyExecuteFromValue(const FuncGraphPtr & fg,const ValueNodePtr & value_node,const ValuePtr & value,bool py_execute_input)2187   AnfNodePtr GetPyExecuteFromValue(const FuncGraphPtr &fg, const ValueNodePtr &value_node, const ValuePtr &value,
2188                                    bool py_execute_input) {
2189     MS_EXCEPTION_IF_NULL(fg);
2190     MS_EXCEPTION_IF_NULL(value_node);
2191     MS_EXCEPTION_IF_NULL(value);
2192     if (value->isa<None>()) {
2193       constexpr auto vmap_prefix = "VmapRule";
2194       if (value_node->scope() != nullptr &&
2195           value_node->scope()->name().compare(0, strlen(vmap_prefix), vmap_prefix) == 0) {
2196         return value_node;
2197       }
2198       return ConvertNoneToPyExecute(fg);
2199     }
2200     if (fallback::GetJitSyntaxLevel() == kLax) {
2201       if (value->isa<Type>()) {
2202         return ConvertTypeToPyExecute(fg, value_node, value->cast<TypePtr>());
2203       } else if (value->isa<parse::ClassType>()) {
2204         auto class_type = GetValueNode<ClassTypePtr>(value_node);
2205         MS_EXCEPTION_IF_NULL(class_type);
2206         return ConvertClassTypeToPyExecute(fg, value_node, class_type);
2207       } else if (value->isa<parse::NameSpace>()) {
2208         auto name_space = GetValueNode<parse::NameSpacePtr>(value_node);
2209         MS_EXCEPTION_IF_NULL(name_space);
2210         return ConvertNameSpaceToPyExecute(fg, value_node, name_space);
2211       }
2212     }
2213     if (value->isa<parse::MsClassObject>()) {
2214       return fallback::ConvertMsClassObjectToPyExecute(fg, value, value_node);
2215     }
2216     if (value->isa<parse::InterpretedObject>()) {
2217       const auto interpreted_value = dyn_cast<parse::InterpretedObject>(value);
2218       const std::string &key = interpreted_value->name();
2219       return fallback::ConvertPyObjectToPyExecute(fg, key, interpreted_value->obj(), value_node, true);
2220     }
2221     if (value->isa<ValueTuple>()) {
2222       return GetPyExecuteFromValueSequence(fg, value_node, value->cast<ValueSequencePtr>(), prim::kPrimMakeTuple,
2223                                            py_execute_input);
2224     }
2225     if (value->isa<ValueList>()) {
2226       if (!IsValueListWithInplace(value_node) && !py_execute_input) {
2227         return GetPyExecuteFromValueSequence(fg, value_node, value->cast<ValueSequencePtr>(), prim::kPrimMakeList,
2228                                              py_execute_input);
2229       }
2230       return RebuildValueList(fg, value_node);
2231     }
2232     if (value->isa<ValueDictionary>()) {
2233       return RebuildValueDict(fg, value_node, value->cast<ValueDictionaryPtr>());
2234     }
2235     if (value->isa<ValueSlice>()) {
2236       return ConvertValueSlice(fg, value_node, value->cast<ValueSlicePtr>());
2237     }
2238     return value_node;
2239   }
2240 
ConvertValueInputToPyExecute(const CNodePtr & cnode)2241   void ConvertValueInputToPyExecute(const CNodePtr &cnode) {
2242     MS_EXCEPTION_IF_NULL(cnode);
2243     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
2244     if (!allow_fallback_runtime) {
2245       return;
2246     }
2247     if (AnfUtils::IsRealKernel(cnode) && !IsOneOfPrimitiveCNode(cnode, inplace_prim_set) &&
2248         !IsOneOfPrimitiveCNode(cnode, seq_prim_set_)) {
2249       return;
2250     }
2251     if (IsOneOfPrimitiveCNode(cnode, seq_prim_set_)) {
2252       const auto &inputs = cnode->inputs();
2253       std::vector<AbstractBasePtr> inputs_abs;
2254       for (size_t i = 1; i < inputs.size(); ++i) {
2255         inputs_abs.push_back(inputs[i]->abstract());
2256       }
2257       auto output_abs = cnode->abstract();
2258       MS_EXCEPTION_IF_NULL(output_abs);
2259       // Only sequence ops with nested sequence input or irregular input (element with different shape/type)
2260       // or the output abstract of sequence node is AbstractAny should be converted to PyExecute node later and
2261       // their sequence input should be converted to PyExecute.
2262       if (!CheckAndConvertUtils::CheckContainNestedOrIrregularSequence(inputs_abs) &&
2263           !output_abs->isa<abstract::AbstractAny>()) {
2264         return;
2265       }
2266     }
2267     const auto &inputs = cnode->inputs();
2268     auto cur_func = cnode->func_graph();
2269     MS_EXCEPTION_IF_NULL(cur_func);
2270     for (const auto &input : inputs) {
2271       auto value_node = dyn_cast<ValueNode>(input);
2272       if (value_node == nullptr) {
2273         continue;
2274       }
2275       const auto &value = value_node->value();
2276       if (fallback::GetJitSyntaxLevel() == kLax) {
2277         // Not convert the 'type' used by Cast primitive.
2278         if (value->isa<Type>() && IsPrimitiveCNode(cnode, prim::kPrimCast)) {
2279           continue;
2280         }
2281       }
2282       auto debug_info = value_node->debug_info();
2283       auto location_info = trace::GetDebugInfoStr(debug_info);
2284       if (location_info.empty()) {
2285         value_node->set_debug_info(cnode->debug_info());
2286       }
2287       auto new_input = GetPyExecuteFromValue(cur_func, value_node, value, false);
2288       if (new_input == input) {
2289         continue;
2290       }
2291       new_input->set_debug_info(value_node->debug_info());
2292       (void)manager_->Replace(input, new_input);
2293       set_need_renormalized(true);
2294     }
2295   }
2296 
ConvertSequenceOps(const CNodePtr & cnode) const2297   AnfNodePtr ConvertSequenceOps(const CNodePtr &cnode) const {
2298     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
2299     if (!allow_fallback_runtime) {
2300       return nullptr;
2301     }
2302     const auto &inputs = cnode->inputs();
2303     std::vector<AbstractBasePtr> inputs_abs;
2304     for (size_t i = 1; i < inputs.size(); ++i) {
2305       inputs_abs.push_back(inputs[i]->abstract());
2306     }
2307     auto output_abs = cnode->abstract();
2308     MS_EXCEPTION_IF_NULL(output_abs);
2309     // Only sequence ops with nested sequence input or irregular input (element with different shape/type)
2310     // or the output abstract of sequence node is AbstractAny should be converted to PyExecute node.
2311     if (!CheckAndConvertUtils::CheckContainNestedOrIrregularSequence(inputs_abs) &&
2312         !output_abs->isa<abstract::AbstractAny>()) {
2313       return nullptr;
2314     }
2315 
2316     auto prim = GetValueNode<PrimitivePtr>(inputs[0]);
2317     MS_EXCEPTION_IF_NULL(prim);
2318     const auto &prim_name = prim->name();
2319 
2320     const auto &fg = cnode->func_graph();
2321     MS_EXCEPTION_IF_NULL(fg);
2322     const std::string seq_ops_dir = "__import__('mindspore').ops.operations._sequence_ops.";
2323     const std::string input_prefix = "__internal_input_";
2324 
2325     std::stringstream script_buffer;
2326     script_buffer << seq_ops_dir << prim_name << "()(";
2327     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
2328     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
2329     for (size_t i = 1; i < inputs.size(); ++i) {
2330       auto cur_input_str = input_prefix + std::to_string(i - 1) + "__";
2331       script_buffer << cur_input_str << ",";
2332       (void)key_value_names_list.emplace_back(NewValueNode(cur_input_str));
2333       (void)key_value_list.emplace_back(inputs[i]);
2334     }
2335     script_buffer << ")";
2336     const std::string &script = script_buffer.str();
2337     auto script_node = NewValueNode(std::make_shared<StringImm>(script));
2338     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
2339     const auto key_value_tuple = fg->NewCNode(key_value_list);
2340 
2341     auto res =
2342       fallback::CreatePyExecuteCNode(fg, script_node, key_value_name_tuple, key_value_tuple, cnode->debug_info());
2343     MS_LOG(DEBUG) << "Convert sequence node: " << cnode->DebugString() << " to " << res->DebugString();
2344     return res;
2345   }
2346 
ConvertPrimitiveCNode(const CNodePtr & cnode)2347   AnfNodePtr ConvertPrimitiveCNode(const CNodePtr &cnode) override {
2348     // Get primitive from cnode.
2349     const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0));
2350     if (prim == nullptr) {
2351       return nullptr;
2352     }
2353     ConvertValueInputToPyExecute(cnode);
2354 
2355     // Find cnode converter by primitive.
2356     auto iter = converters_.find(prim);
2357     if (iter != converters_.end()) {
2358       // Call converter.
2359       return (this->*(iter->second))(cnode);
2360     }
2361     if (seq_prim_set_.find(prim) != seq_prim_set_.end()) {
2362       return ConvertSequenceOps(cnode);
2363     }
2364     return nullptr;
2365   }
2366 
PackDictValue(const FuncGraphPtr & fg,const ValueNodePtr & value_node,const ValueDictionaryPtr & dict)2367   AnfNodePtr PackDictValue(const FuncGraphPtr &fg, const ValueNodePtr &value_node, const ValueDictionaryPtr &dict) {
2368     const auto &keys_values = dict->value();
2369     auto abs_dict = dyn_cast<abstract::AbstractDictionary>(value_node->abstract());
2370     const auto &abs_keys_values = abs_dict->elements();
2371     if (keys_values.size() != abs_keys_values.size()) {
2372       MS_LOG(INTERNAL_EXCEPTION) << "The size of value dict should be same as the size of abstract dict.";
2373     }
2374     std::vector<AnfNodePtr> value_list{NewValueNode(prim::kPrimMakeTuple)};
2375     for (size_t i = 0; i < keys_values.size(); ++i) {
2376       auto key_value = keys_values[i];
2377       auto new_vnode = NewValueNode(key_value.second);
2378       new_vnode->set_debug_info(value_node->debug_info());
2379       new_vnode->set_abstract(abs_keys_values[i].second);
2380       auto iter_value = GetPyExecuteFromValue(fg, new_vnode, key_value.second, true);
2381       iter_value->set_debug_info(value_node->debug_info());
2382       (void)value_list.emplace_back(iter_value);
2383     }
2384     auto value_tuple_node = fg->NewCNode(value_list);
2385     return value_tuple_node;
2386   }
2387 
2388   // If the value dict has attached object:
2389   //   dict(k0:v0, k1:v1, ...) --> PyExecute('get_local_variable(dict_key)', ...)
2390   // otherwise:
2391   //   dict(k0:v0, k1:v1, ...) --> PyExecute('dict(zip(keys, values))', ...)
RebuildValueDict(const FuncGraphPtr & fg,const ValueNodePtr & value_node,const ValueDictionaryPtr & dict)2392   AnfNodePtr RebuildValueDict(const FuncGraphPtr &fg, const ValueNodePtr &value_node, const ValueDictionaryPtr &dict) {
2393     if (not_convert_jit_) {
2394       return value_node;
2395     }
2396     auto abs = value_node->abstract();
2397     MS_EXCEPTION_IF_NULL(abs);
2398     auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
2399     MS_EXCEPTION_IF_NULL(abs_dict);
2400     if (fallback::HasObjInExtraInfoHolder(abs_dict) && !fallback::GetCreateInGraphFromExtraInfoHolder(abs_dict)) {
2401       // If the abstract of value dict has python object and the python object is created outside the graph,
2402       // the we use the python object to generate pyexecute node.
2403       py::dict dict_object = fallback::GetObjFromExtraInfoHolder(abs_dict);
2404       const std::string dict_obj_str_prefix = "__dict_py_object_";
2405       auto dict_obj_id = fallback::GetPyObjectPtrStr(dict_object);
2406       MS_LOG(DEBUG) << "Current python object id: " << dict_obj_id;
2407       auto dict_obj_str = dict_obj_str_prefix + dict_obj_id + "_";
2408       auto res = fallback::ConvertPyObjectToPyExecute(fg, dict_obj_str, dict_object, value_node, false);
2409       MS_LOG(DEBUG) << "Convert value dict node: " << value_node->DebugString()
2410                     << " to inplace pyexecute node: " << res->DebugString();
2411       return res;
2412     }
2413 
2414     const auto &keys_values = dict->value();
2415 
2416     // Local parameters values.
2417     // Pack the key tuple.
2418     std::vector<ValuePtr> key_list;
2419     key_list.reserve(keys_values.size());
2420     for (const auto &key_value : keys_values) {
2421       (void)key_list.emplace_back(key_value.first);
2422     }
2423     const auto key_tuple = std::make_shared<ValueTuple>(key_list);
2424     auto key_tuple_node = NewValueNode(key_tuple);
2425     key_tuple_node->set_debug_info(value_node->debug_info());
2426     // Pack the value tuple.
2427     auto value_tuple_node = PackDictValue(fg, value_node, dict);
2428 
2429     // Generate Make Dict PyExecute Node value
2430     auto make_key_tuple_node = ConstructInternalTupleKeysNode(fg, key_tuple_node);
2431     auto make_value_tuple_node = ConstructInternalTupleValueNode(fg, value_tuple_node);
2432 
2433     auto make_dict_node = ConstructNewDictNode(fg, make_key_tuple_node, make_value_tuple_node);
2434     make_dict_node->set_debug_info(value_node->debug_info());
2435     MS_LOG(DEBUG) << "Convert value dict node: " << value_node->DebugString()
2436                   << " to non-inplace pyexecute node: " << make_dict_node->DebugString();
2437     return make_dict_node;
2438   }
2439 
CheckSeqWithInplace(const py::sequence & seq) const2440   bool CheckSeqWithInplace(const py::sequence &seq) const {
2441     if (py::isinstance<py::list>(seq)) {
2442       const auto &seq_str = fallback::GetPyObjectPtrStr(seq);
2443       if (data_with_inplace_->find(seq_str) != data_with_inplace_->end()) {
2444         return true;
2445       }
2446     }
2447     for (const auto &obj : seq) {
2448       if (py::isinstance<py::list>(obj) && CheckSeqWithInplace(py::list(obj))) {
2449         return true;
2450       }
2451       if (py::isinstance<py::tuple>(obj) && CheckSeqWithInplace(py::tuple(obj))) {
2452         return true;
2453       }
2454     }
2455     return false;
2456   }
2457 
RebuildValueList(const FuncGraphPtr & fg,const ValueNodePtr & value_node) const2458   AnfNodePtr RebuildValueList(const FuncGraphPtr &fg, const ValueNodePtr &value_node) const {
2459     MS_EXCEPTION_IF_NULL(value_node);
2460     MS_EXCEPTION_IF_NULL(fg);
2461 
2462     auto value = value_node->value();
2463     MS_EXCEPTION_IF_NULL(value);
2464     auto value_list = value->cast<ValueListPtr>();
2465     MS_EXCEPTION_IF_NULL(value_list);
2466 
2467     auto abs = value_node->abstract();
2468     MS_EXCEPTION_IF_NULL(abs);
2469     auto list_abs = abs->cast<abstract::AbstractListPtr>();
2470     MS_EXCEPTION_IF_NULL(list_abs);
2471 
2472     if (list_abs->dynamic_len()) {
2473       return value_node;
2474     }
2475 
2476     bool has_object = fallback::HasObjInExtraInfoHolder(list_abs);
2477     py::list list_object = has_object ? fallback::GetObjFromExtraInfoHolder(list_abs) : ValueToPyData(value);
2478 
2479     // Generate PyExecute node: __list_object__
2480     const std::string list_obj_str_prefix = "__list_py_object_";
2481     auto list_obj_id = fallback::GetPyObjectPtrStr(list_object);
2482     MS_LOG(DEBUG) << "Current python object id: " << list_obj_id;
2483     auto list_obj_str = list_obj_str_prefix + list_obj_id + "_";
2484     auto res = fallback::ConvertPyObjectToPyExecute(fg, list_obj_str, list_object, value_node, false);
2485 
2486     return res;
2487   }
2488 
ConvertInterpretedObjectValue(const ValueNodePtr & node,const parse::InterpretedObjectPtr & value) const2489   AnfNodePtr ConvertInterpretedObjectValue(const ValueNodePtr &node, const parse::InterpretedObjectPtr &value) const {
2490     // Convert InterpretedObject value node to PyExecute CNode.
2491     const auto interpreted_value = dyn_cast<parse::InterpretedObject>(value);
2492     const std::string &key = interpreted_value->name();
2493     return fallback::ConvertPyObjectToPyExecute(root_graph_, key, interpreted_value->obj(), node, true);
2494   }
2495 
ConvertValueNode(const ValueNodePtr & value_node,const ValuePtr & value)2496   AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override {
2497     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
2498     if (allow_fallback_runtime) {
2499       if (value->ContainsValueAny()) {
2500         return nullptr;
2501       }
2502       if (value->isa<ValueDictionary>()) {
2503         return RebuildValueDict(root_graph_, value_node, value->cast<ValueDictionaryPtr>());
2504       } else if (value->isa<parse::InterpretedObject>()) {
2505         return ConvertInterpretedObjectValue(value_node, value->cast<parse::InterpretedObjectPtr>());
2506       } else if (value->isa<parse::MsClassObject>()) {
2507         return fallback::ConvertMsClassObjectToPyExecute(root_graph_, value, value_node);
2508       }
2509     }
2510     return nullptr;
2511   }
2512 
2513   // AbstractRowTensor --> AbstractTuple.
ConvertToAbstractTuple(const AbstractBasePtr & abs,size_t depth)2514   static AbstractBasePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) {
2515     if (depth > kMaxSeqRecursiveDepth) {
2516       MS_LOG(ERROR) << "abs:" << abs->ToString();
2517       MS_LOG(INTERNAL_EXCEPTION) << "List, tuple and dict nesting is not allowed more than " << kMaxSeqRecursiveDepth
2518                                  << " levels.";
2519     }
2520     // Convert RowTensor in AbstractSequence to AbstractTuple.
2521     auto abs_seq = abs->cast<AbstractSequencePtr>();
2522     if (abs_seq != nullptr) {
2523       // Dynamic length sequence do not convert.
2524       if (abs_seq->dynamic_len()) {
2525         return nullptr;
2526       }
2527       const auto &seq_elements = abs_seq->elements();
2528       // First we check if elements should be converted,
2529       // changed_elements maps old element to new element.
2530       mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
2531       for (const auto &element : seq_elements) {
2532         auto new_element = ConvertToAbstractTuple(element, depth + 1);
2533         if (new_element != nullptr) {
2534           (void)changed_elements.emplace(element, new_element);
2535         }
2536       }
2537       if (changed_elements.empty()) {
2538         // If no RowTensor in sequence is changed, do not convert.
2539         return nullptr;
2540       }
2541       // Make new abstract sequence.
2542       std::vector<AbstractBasePtr> elements;
2543       elements.reserve(seq_elements.size());
2544       for (const auto &element : seq_elements) {
2545         auto iter = changed_elements.find(element);
2546         if (iter != changed_elements.end()) {
2547           (void)elements.emplace_back(iter->second);
2548         } else {
2549           (void)elements.emplace_back(element);
2550         }
2551       }
2552       if (abs_seq->isa<AbstractList>()) {
2553         return std::make_shared<AbstractList>(std::move(elements));
2554       }
2555       return std::make_shared<AbstractTuple>(std::move(elements));
2556     }
2557     // AbstractRowTensor --> AbstractTuple.
2558     auto abs_row_tensor = abs->cast<std::shared_ptr<AbstractRowTensor>>();
2559     if (abs_row_tensor != nullptr) {
2560       std::vector<AbstractBasePtr> elements{abs_row_tensor->indices(), abs_row_tensor->values(),
2561                                             abs_row_tensor->dense_shape()};
2562       return std::make_shared<AbstractTuple>(std::move(elements));
2563     }
2564     return nullptr;
2565   }
2566 
ConvertAbstract(const AbstractBasePtr & abs)2567   AbstractBasePtr ConvertAbstract(const AbstractBasePtr &abs) override {
2568     // AbstractSequence, AbstractDict, AbstractRowTensor --> AbstractTuple.
2569     return ConvertToAbstractTuple(abs, 0);
2570   }
2571 
2572  private:
2573   StringSetPtr data_with_inplace_;
2574   bool not_convert_jit_{false};
2575 };
2576 
FindValueWithInplaceInner(const FuncGraphPtr & graph,const StringSetPtr & value_with_inplace)2577 void FindValueWithInplaceInner(const FuncGraphPtr &graph, const StringSetPtr &value_with_inplace) {
2578   MS_EXCEPTION_IF_NULL(graph);
2579   AnfNodePtr return_node = graph->get_return();
2580   MS_EXCEPTION_IF_NULL(return_node);
2581   std::vector<AnfNodePtr> all_nodes = TopoSort(return_node);
2582   constexpr size_t sequence_index = 1;
2583   for (auto &node : all_nodes) {
2584     MS_EXCEPTION_IF_NULL(node);
2585     if (!IsOneOfPrimitiveCNode(node, inplace_prim_set)) {
2586       continue;
2587     }
2588     auto cnode = node->cast<CNodePtr>();
2589     auto sequence_node = cnode->input(sequence_index);
2590     MS_EXCEPTION_IF_NULL(sequence_node);
2591     if (!IsValueNode<ValueList>(sequence_node)) {
2592       continue;
2593     }
2594     auto abs = sequence_node->abstract();
2595     if (abs == nullptr || !abs->isa<abstract::AbstractList>()) {
2596       continue;
2597     }
2598     auto abs_list = abs->cast<abstract::AbstractListPtr>();
2599     auto list_py_object = fallback::GetObjFromExtraInfoHolder(abs_list);
2600     MS_LOG(DEBUG) << "Found list python object in inplace: " << py::str(list_py_object);
2601     const auto &list_py_object_str = fallback::GetPyObjectPtrStr(list_py_object);
2602     (void)value_with_inplace->insert(list_py_object_str);
2603   }
2604 }
2605 
FindValueWithInplace(const FuncGraphPtr & root,const pipeline::ResourcePtr & resource,const StringSetPtr & value_with_inplace)2606 void FindValueWithInplace(const FuncGraphPtr &root, const pipeline::ResourcePtr &resource,
2607                           const StringSetPtr &value_with_inplace) {
2608   const auto func_graphs_used_total = root->func_graphs_used_total();
2609   for (const auto &fg : func_graphs_used_total) {
2610     FindValueWithInplaceInner(fg, value_with_inplace);
2611   }
2612   FindValueWithInplaceInner(root, value_with_inplace);
2613 }
2614 
ConvertToPyExecuteGetItem(const AnfNodePtr & node)2615 AnfNodePtr ConvertToPyExecuteGetItem(const AnfNodePtr &node) {
2616   MS_EXCEPTION_IF_NULL(node);
2617   if (!IsOneOfPrimitiveCNode(node, sequence_getitem_prim_set)) {
2618     return nullptr;
2619   }
2620   auto abs = node->abstract();
2621   MS_EXCEPTION_IF_NULL(abs);
2622   if (!abs->isa<abstract::AbstractAny>()) {
2623     return nullptr;
2624   }
2625   return ConvertSequenceGetItemInner(node->cast<CNodePtr>());
2626 }
2627 
CheckNeedConvertList(const AbstractBasePtr & abs)2628 bool CheckNeedConvertList(const AbstractBasePtr &abs) {
2629   if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
2630     return false;
2631   }
2632   // If abstract has real type/shape, it means the corresponding node is PyExecute.
2633   // Do not covert PyExecute node.
2634   if (fallback::HasRealType(abs) || fallback::HasRealShape(abs)) {
2635     return false;
2636   }
2637   auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
2638   if (seq_abs->dynamic_len()) {
2639     return false;
2640   }
2641   if (seq_abs->isa<abstract::AbstractList>()) {
2642     return true;
2643   }
2644   const auto &elements = seq_abs->elements();
2645   return std::any_of(elements.begin(), elements.end(),
2646                      [](const AbstractBasePtr &abs) { return CheckNeedConvertList(abs); });
2647 }
2648 
ConvertToPyExecuteListInner(const AnfNodePtr & node,const FuncGraphPtr & fg)2649 AnfNodePtr ConvertToPyExecuteListInner(const AnfNodePtr &node, const FuncGraphPtr &fg) {
2650   MS_EXCEPTION_IF_NULL(node);
2651   auto abs = node->abstract();
2652   if (abs == nullptr || !CheckNeedConvertList(abs)) {
2653     return nullptr;
2654   }
2655   auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
2656   MS_EXCEPTION_IF_NULL(seq_abs);
2657   const auto &elements = seq_abs->elements();
2658   if (abs->isa<abstract::AbstractList>()) {
2659     const std::string element_prefix = "__list_element_";
2660     std::stringstream script_buffer;
2661     script_buffer << "[";
2662     std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
2663     std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
2664     for (size_t i = 0; i < elements.size(); ++i) {
2665       auto element_abs = elements[i];
2666       auto element_node =
2667         fg->NewCNode({NewValueNode(prim::kPrimListGetItem), node, NewValueNode(MakeValue<int64_t>(i))});
2668       element_node->set_abstract(element_abs);
2669       auto new_element_node = ConvertToPyExecuteListInner(element_node, fg);
2670       if (new_element_node == nullptr) {
2671         new_element_node = element_node;
2672       }
2673       std::string element_name = element_prefix + std::to_string(i) + "__";
2674       script_buffer << element_name << ",";
2675       (void)key_value_names_list.emplace_back(NewValueNode(element_name));
2676       (void)key_value_list.emplace_back(new_element_node);
2677     }
2678     script_buffer << "]";
2679     const std::string &script = script_buffer.str();
2680     const auto script_str = std::make_shared<StringImm>(script);
2681     const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
2682     const auto key_value_tuple = fg->NewCNode(key_value_list);
2683     return fallback::CreatePyExecuteCNode(fg, NewValueNode(script_str), key_value_name_tuple, key_value_tuple,
2684                                           node->debug_info());
2685   }
2686   std::vector<AnfNodePtr> new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
2687   for (size_t i = 0; i < elements.size(); ++i) {
2688     auto element_abs = elements[i];
2689     auto element_node =
2690       fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(MakeValue<int64_t>(i))});
2691     element_node->set_abstract(element_abs);
2692     auto new_element_node = ConvertToPyExecuteListInner(element_node, fg);
2693     if (new_element_node == nullptr) {
2694       new_element_node = element_node;
2695     }
2696     (void)new_make_tuple_inputs.emplace_back(new_element_node);
2697   }
2698   return fg->NewCNode(new_make_tuple_inputs);
2699 }
2700 
ConvertToPyExecuteList(const AnfNodePtr & node)2701 AnfNodePtr ConvertToPyExecuteList(const AnfNodePtr &node) {
2702   MS_EXCEPTION_IF_NULL(node);
2703   if (!IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
2704     return nullptr;
2705   }
2706   constexpr size_t pyexecute_min_len = 4;
2707   auto cnode = node->cast<CNodePtr>();
2708   if (cnode->size() < pyexecute_min_len) {
2709     MS_LOG(INTERNAL_EXCEPTION) << "The minimum len of input to PyExecute should " << pyexecute_min_len << " but got "
2710                                << cnode->size() << " for node: " << cnode->DebugString();
2711   }
2712   constexpr size_t pyexecute_value_index = 3;
2713   const auto &fg = cnode->func_graph();
2714   return ConvertToPyExecuteListInner(cnode->input(pyexecute_value_index), fg);
2715 }
2716 
ConvertPyExecuteAfterRewriter(const FuncGraphPtr & graph,const FuncGraphManagerPtr & manager)2717 bool ConvertPyExecuteAfterRewriter(const FuncGraphPtr &graph, const FuncGraphManagerPtr &manager) {
2718   MS_EXCEPTION_IF_NULL(graph);
2719   AnfNodePtr return_node = graph->get_return();
2720   MS_EXCEPTION_IF_NULL(return_node);
2721   std::vector<AnfNodePtr> all_nodes = TopoSort(return_node);
2722   bool change = false;
2723   constexpr size_t pyexecute_value_index = 3;
2724   for (auto &node : all_nodes) {
2725     MS_EXCEPTION_IF_NULL(node);
2726     auto tr = manager->Transact();
2727     auto new_node = ConvertToPyExecuteGetItem(node);
2728     if (new_node != nullptr) {
2729       tr.Replace(node, new_node);
2730       tr.Commit();
2731       change = true;
2732       continue;
2733     }
2734     auto new_value_input = ConvertToPyExecuteList(node);
2735     if (new_value_input != nullptr) {
2736       tr.SetEdge(node, pyexecute_value_index, new_value_input);
2737       tr.Commit();
2738       change = true;
2739       continue;
2740     }
2741   }
2742   return change;
2743 }
2744 
OrderPyExecuteCNode(const FuncGraphPtr & graph,const FuncGraphManagerPtr & manager)2745 static inline bool OrderPyExecuteCNode(const FuncGraphPtr &graph, const FuncGraphManagerPtr &manager) {
2746   MS_EXCEPTION_IF_NULL(graph);
2747   AnfNodePtr return_node = graph->get_return();
2748   MS_EXCEPTION_IF_NULL(return_node);
2749   std::vector<AnfNodePtr> all_nodes = TopoSort(return_node);
2750   CNodePtr former_node = nullptr;
2751   CNodePtr latter_node = nullptr;
2752   bool change = false;
2753   for (auto &node : all_nodes) {
2754     MS_EXCEPTION_IF_NULL(node);
2755     if (!IsPrimitiveCNode(node, prim::kPrimPyExecute) || node->func_graph() != graph) {
2756       continue;
2757     }
2758     if (former_node == nullptr) {
2759       former_node = dyn_cast<CNode>(node);
2760       continue;
2761     } else {
2762       latter_node = dyn_cast<CNode>(node);
2763     }
2764     MS_EXCEPTION_IF_NULL(former_node);
2765     MS_EXCEPTION_IF_NULL(latter_node);
2766 
2767     // Make former node as latter node's input.
2768     auto tr = manager->Transact();
2769     size_t latest_index = latter_node->size() - 1;
2770     const auto &last_input_abs = latter_node->input(latest_index)->abstract();
2771     if (last_input_abs != nullptr && last_input_abs->isa<abstract::AbstractMonad>()) {  // Should be IO monad.
2772       const auto &monad_node = latter_node->input(latest_index);
2773       tr.SetEdge(latter_node, latest_index, former_node);
2774       tr.AddEdge(latter_node, monad_node);
2775     } else {
2776       tr.AddEdge(latter_node, former_node);
2777     }
2778     tr.Commit();
2779 
2780     former_node = latter_node;
2781     change = true;
2782   }
2783   return change;
2784 }
2785 }  // namespace
2786 
RewriterBeforeOptA(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)2787 bool RewriterBeforeOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
2788   MS_EXCEPTION_IF_NULL(manager);
2789   manager->AddFuncGraph(root);
2790   BeforeOptARewriter rewriter(root, manager);
2791   return rewriter.Execute();
2792 }
2793 
RewriterAfterOptA(const FuncGraphPtr & root,const pipeline::ResourcePtr & resource)2794 bool RewriterAfterOptA(const FuncGraphPtr &root, const pipeline::ResourcePtr &resource) {
2795   MS_EXCEPTION_IF_NULL(root);
2796   MS_EXCEPTION_IF_NULL(resource);
2797   auto manager = resource->manager();
2798   MS_EXCEPTION_IF_NULL(manager);
2799   manager->AddFuncGraph(root);
2800   StringSetPtr value_with_inplace = std::make_shared<StringSet>();
2801   FindValueWithInplace(root, resource, value_with_inplace);
2802   AfterOptARewriter rewriter(root, manager, value_with_inplace);
2803   bool change = rewriter.Execute();
2804   if (rewriter.need_renormalized()) {
2805     abstract::AbstractBasePtrList new_args_spec;
2806     (void)std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),
2807                          [](const AnfNodePtr &param) -> AbstractBasePtr { return param->abstract(); });
2808     (void)pipeline::Renormalize(resource, root, new_args_spec);
2809   }
2810   return change;
2811 }
2812 
ConvertAfterRewriter(const FuncGraphPtr & root,const pipeline::ResourcePtr & resource)2813 bool ConvertAfterRewriter(const FuncGraphPtr &root, const pipeline::ResourcePtr &resource) {
2814   auto manager = resource->manager();
2815   const auto func_graphs_used_total = root->func_graphs_used_total();
2816   bool change = false;
2817   for (const auto &fg : func_graphs_used_total) {
2818     auto cur_change = ConvertPyExecuteAfterRewriter(fg, manager);
2819     change = change || cur_change;
2820   }
2821   bool root_change = ConvertPyExecuteAfterRewriter(root, manager);
2822   change = change || root_change;
2823   if (change) {
2824     abstract::AbstractBasePtrList new_args_spec;
2825     (void)std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),
2826                          [](const AnfNodePtr &param) -> AbstractBasePtr { return param->abstract(); });
2827     (void)pipeline::Renormalize(resource, root, new_args_spec);
2828   }
2829   return change;
2830 }
2831 
OrderPyExecuteAfterRewriter(const FuncGraphPtr & root,const pipeline::ResourcePtr & resource)2832 bool OrderPyExecuteAfterRewriter(const FuncGraphPtr &root, const pipeline::ResourcePtr &resource) {
2833   auto manager = resource->manager();
2834   const auto func_graphs_used_total = root->func_graphs_used_total();
2835   bool change = false;
2836   for (const auto &fg : func_graphs_used_total) {
2837     auto cur_change = OrderPyExecuteCNode(fg, manager);
2838     change = change || cur_change;
2839   }
2840   bool root_change = OrderPyExecuteCNode(root, manager);
2841   change = change || root_change;
2842   if (change) {
2843     abstract::AbstractBasePtrList new_args_spec;
2844     (void)std::transform(root->parameters().begin(), root->parameters().end(), std::back_inserter(new_args_spec),
2845                          [](const AnfNodePtr &param) -> AbstractBasePtr { return param->abstract(); });
2846     (void)pipeline::Renormalize(resource, root, new_args_spec);
2847   }
2848   return change;
2849 }
2850 }  // namespace opt
2851 }  // namespace mindspore
2852