• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/py_interpret_to_execute.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <unordered_map>
23 
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "abstract/abstract_function.h"
27 #include "include/common/utils/convert_utils_py.h"
28 #include "include/common/utils/utils.h"
29 #include "utils/anf_utils.h"
30 #include "utils/interpret_node_recorder.h"
31 #include "utils/symbolic.h"
32 #include "pipeline/jit/ps/parse/resolve.h"
33 #include "pipeline/jit/ps/fallback.h"
34 
35 namespace mindspore {
36 /* namespace to support opt */
37 namespace opt {
38 namespace {
39 CNodePtr Transform(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
40                    std::map<AnfNodePtr, AnfNodePtr> *has_converted_nodes);
NewValueNodeWithAbstract(const ValuePtr & value,const AbstractBasePtr & abs=nullptr)41 AnfNodePtr NewValueNodeWithAbstract(const ValuePtr &value, const AbstractBasePtr &abs = nullptr) {
42   auto value_node = NewValueNode(value);
43   if (abs != nullptr) {
44     value_node->set_abstract(abs->Clone());
45   } else {
46     value_node->set_abstract(value->ToAbstract());
47   }
48   return value_node;
49 }
50 
FuncGraphToPyData(const AnfNodePtr & node)51 AnfNodePtr FuncGraphToPyData(const AnfNodePtr &node) {
52   MS_EXCEPTION_IF_NULL(node);
53   if (!node->isa<ValueNode>()) {
54     return node;
55   }
56   auto value_node = node->cast_ptr<ValueNode>();
57   auto value = value_node->value();
58   if (value->IsFromTypeId(FuncGraph::kTypeId)) {
59     auto fg = value->cast_ptr<FuncGraph>();
60     MS_EXCEPTION_IF_NULL(fg);
61     auto wrapper_obj = fg->python_obj();
62     if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
63       return NewValueNode(
64         std::make_shared<parse::InterpretedObject>(wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj()));
65     }
66   }
67   return node;
68 }
69 
ConvertValueTupleToList(const AnfNodePtr & node)70 std::vector<AnfNodePtr> ConvertValueTupleToList(const AnfNodePtr &node) {
71   if ((!IsValueNode<ValueTuple>(node) && !IsPrimitiveCNode(node, prim::kPrimMakeTuple))) {
72     MS_LOG(INTERNAL_EXCEPTION) << "The dictionary's keys and values should be a tuple, but got " << node->DebugString();
73   }
74   std::vector<AnfNodePtr> node_list;
75   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
76     auto cnode = node->cast_ptr<CNode>();
77     auto inputs = cnode->inputs();
78     std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(node_list));
79     return node_list;
80   }
81   auto tuple_value = GetValueNode<ValueTuplePtr>(node);
82   auto value_list = tuple_value->value();
83   std::transform(value_list.begin(), value_list.end(), std::back_inserter(node_list),
84                  [](const ValuePtr &value) -> AnfNodePtr { return NewValueNodeWithAbstract(value); });
85   return node_list;
86 }
87 
UnzipGlobalDict(const AnfNodePtr & dict_node)88 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> UnzipGlobalDict(const AnfNodePtr &dict_node) {
89   MS_EXCEPTION_IF_NULL(dict_node);
90   std::vector<AnfNodePtr> keys;
91   std::vector<AnfNodePtr> values;
92   if (!dict_node->isa<ValueNode>()) {
93     MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret global dict should be a InterpretedObject value node, but got "
94                                << dict_node->DebugString();
95   }
96   // Process the PyInterpret operator defined by the frontend and its global information is empty.
97   if (IsValueNode<ValueDictionary>(dict_node)) {
98     auto dict_input = GetValueNode<ValueDictionaryPtr>(dict_node);
99     if (dict_input->value().empty()) {
100       return std::make_pair(keys, values);
101     }
102   }
103   auto interpreted_object = GetValueNode<parse::InterpretedObjectPtr>(dict_node);
104   MS_EXCEPTION_IF_NULL(interpreted_object);
105   ValuePtr converted_value = nullptr;
106   if (!parse::ConvertData(interpreted_object->obj(), &converted_value)) {
107     MS_LOG(INTERNAL_EXCEPTION) << "Convert data failed";
108   }
109   MS_EXCEPTION_IF_NULL(converted_value);
110   auto dict_value = dyn_cast<ValueDictionary>(converted_value);
111   if (dict_value == nullptr) {
112     MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict or global dict should be a dictionary, but got "
113                                << converted_value->ToString();
114   }
115   for (auto item : dict_value->value()) {
116     (void)keys.emplace_back(NewValueNodeWithAbstract(item.first));
117     (void)values.emplace_back(NewValueNodeWithAbstract(item.second));
118   }
119   return std::make_pair(keys, values);
120 }
121 
UnzipLocalDict(const AnfNodePtr & dict_node)122 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> UnzipLocalDict(const AnfNodePtr &dict_node) {
123   MS_EXCEPTION_IF_NULL(dict_node);
124   if (dict_node->isa<ValueNode>()) {
125     auto dict_value = GetValueNode<ValueDictionaryPtr>(dict_node);
126     if (dict_value == nullptr) {
127       MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict should be a dictionary, but got "
128                                  << dict_node->DebugString();
129     }
130 
131     auto abs = dict_node->abstract();
132     MS_EXCEPTION_IF_NULL(abs);
133     auto dict_abs = abs->cast<abstract::AbstractDictionaryPtr>();
134     MS_EXCEPTION_IF_NULL(dict_abs);
135     const auto &elements_pair = dict_abs->elements();
136     const auto &dict_value_value = dict_value->value();
137     if (elements_pair.size() != dict_value_value.size()) {
138       MS_LOG(INTERNAL_EXCEPTION) << "For node: " << dict_node->DebugString()
139                                  << ", the abstract elements size is: " << elements_pair.size()
140                                  << " and the value elements size is: " << dict_value_value.size()
141                                  << ". Size not matched.";
142     }
143 
144     std::vector<AnfNodePtr> keys;
145     std::vector<AnfNodePtr> values;
146     for (size_t i = 0; i < dict_value_value.size(); ++i) {
147       auto item = dict_value_value[i];
148       (void)keys.emplace_back(NewValueNodeWithAbstract(item.first));
149       // Key element may contain ExtraInfoHolder, need to clone the abstract.
150       (void)values.emplace_back(NewValueNodeWithAbstract(item.second, elements_pair[i].second));
151     }
152     return std::make_pair(keys, values);
153   }
154 
155   if (!IsPrimitiveCNode(dict_node, prim::kPrimMakeDict)) {
156     MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict should be a dictionary, but got "
157                                << dict_node->DebugString();
158   }
159   auto make_dict_node = dict_node->cast_ptr<CNode>();
160   constexpr auto kMakeDictKeysInputIndex = 1;
161   constexpr auto kMakeDictValueInputIndex = 2;
162   auto keys_input = make_dict_node->input(kMakeDictKeysInputIndex);
163   auto values_input = make_dict_node->input(kMakeDictValueInputIndex);
164 
165   auto keys_list = ConvertValueTupleToList(keys_input);
166   auto values_list = ConvertValueTupleToList(values_input);
167   return std::make_pair(keys_list, values_list);
168 }
169 
GetLocalKeySet(const std::vector<AnfNodePtr> & key_node_list)170 std::set<std::string> GetLocalKeySet(const std::vector<AnfNodePtr> &key_node_list) {
171   std::set<std::string> key_set;
172   std::transform(key_node_list.begin(), key_node_list.end(), std::inserter(key_set, key_set.begin()),
173                  [](const AnfNodePtr &node) -> std::string {
174                    auto abs = node->abstract();
175                    MS_EXCEPTION_IF_NULL(abs);
176                    auto value = abs->BuildValue();
177                    MS_EXCEPTION_IF_NULL(value);
178                    return GetValue<std::string>(value);
179                  });
180   return key_set;
181 }
182 
183 // Merge global dict to local dict and return merged key and value
MergeGlobalDictToLocal(const AnfNodePtr & global_dict_node,const AnfNodePtr & local_dict_node,const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,std::map<AnfNodePtr,AnfNodePtr> * has_converted_nodes)184 std::pair<AnfNodePtr, AnfNodePtr> MergeGlobalDictToLocal(const AnfNodePtr &global_dict_node,
185                                                          const AnfNodePtr &local_dict_node,
186                                                          const FuncGraphPtr &func_graph,
187                                                          const FuncGraphManagerPtr &manager,
188                                                          std::map<AnfNodePtr, AnfNodePtr> *has_converted_nodes) {
189   MS_EXCEPTION_IF_NULL(global_dict_node);
190   MS_EXCEPTION_IF_NULL(local_dict_node);
191   auto [global_keys, global_values] = UnzipGlobalDict(global_dict_node);
192   auto [local_keys, local_values] = UnzipLocalDict(local_dict_node);
193 
194   auto local_dict_keys_set = GetLocalKeySet(local_keys);
195 
196   std::vector<AnfNodePtr> local_keys_inputs{NewValueNode(prim::kPrimMakeTuple)};
197   std::vector<AnfNodePtr> local_value_inputs{NewValueNode(prim::kPrimMakeTuple)};
198   for (size_t index = 0; index < global_keys.size(); ++index) {
199     auto global_key = global_keys.at(index);
200     MS_EXCEPTION_IF_NULL(global_key);
201     auto key = GetValueNode<StringImmPtr>(global_key);
202     if (local_dict_keys_set.find(GetValue<std::string>(key)) != local_dict_keys_set.end()) {
203       MS_LOG(INFO) << "The global dict has the same name with local dict.:" << key->ToString();
204       continue;
205     }
206     MS_LOG(DEBUG) << "The global key " << global_key->DebugString() << ", value "
207                   << global_values.at(index)->DebugString() << ". merged in local dict.";
208     (void)local_keys_inputs.emplace_back(global_key);
209     (void)local_value_inputs.emplace_back(FuncGraphToPyData(global_values.at(index)));
210   }
211   std::copy(local_keys.begin(), local_keys.end(), std::back_inserter(local_keys_inputs));
212 
213   for (size_t i = 0; i < local_values.size(); ++i) {
214     auto local_value_node = local_values[i];
215     if (!IsPrimitiveCNode(local_value_node, prim::kPrimPyInterpret)) {
216       (void)local_value_inputs.emplace_back(local_value_node);
217     } else if (has_converted_nodes->find(local_value_node) != has_converted_nodes->end()) {
218       (void)local_value_inputs.emplace_back((*has_converted_nodes)[local_value_node]);
219     } else {
220       auto trans_node = Transform(local_value_node->cast<CNodePtr>(), manager, has_converted_nodes);
221       (void)manager->Replace(local_value_node, trans_node);
222       (void)local_value_inputs.emplace_back(trans_node);
223     }
224   }
225   return std::make_pair(func_graph->NewCNode(local_keys_inputs), func_graph->NewCNode(local_value_inputs));
226 }
227 
Transform(const CNodePtr & cnode,const FuncGraphManagerPtr & manager,std::map<AnfNodePtr,AnfNodePtr> * has_converted_nodes)228 CNodePtr Transform(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
229                    std::map<AnfNodePtr, AnfNodePtr> *has_converted_nodes) {
230   constexpr auto input_index_one = 1;
231   constexpr auto input_index_two = 2;
232   constexpr auto input_index_three = 3;
233   auto new_cnode = std::make_shared<CNode>(*cnode);
234   new_cnode->CloneUserData(cnode);
235   new_cnode->set_input(0, NewValueNode(prim::kPrimPyExecute));
236   auto &first_input = cnode->input(input_index_one);
237   if (IsValueNode<parse::Script>(first_input)) {
238     const auto &script = GetValueNode<std::shared_ptr<parse::Script>>(first_input);
239     const auto &script_str = script->script();
240     const auto &script_strimm_node = NewValueNode(std::make_shared<StringImm>(script_str));
241     new_cnode->set_input(input_index_one, script_strimm_node);
242   } else if (!IsValueNode<StringImm>(first_input)) {
243     MS_LOG(INTERNAL_EXCEPTION) << "The first input should be a Script or string, but got "
244                                << cnode->input(input_index_one)->DebugString();
245   }
246   auto global_dict_node = cnode->input(input_index_two);
247   auto local_dict_node = cnode->input(input_index_three);
248 
249   auto [local_dict_keys, local_dict_values] =
250     MergeGlobalDictToLocal(global_dict_node, local_dict_node, cnode->func_graph(), manager, has_converted_nodes);
251 
252   new_cnode->set_input(input_index_two, local_dict_keys);
253   new_cnode->set_input(input_index_three, local_dict_values);
254 
255   // Record the PyExecute node.
256   InterpretNodeRecorder::GetInstance().PushPyExecuteNode(new_cnode);
257   (void)has_converted_nodes->emplace(cnode, new_cnode);
258   return new_cnode;
259 }
260 }  // namespace
261 
262 // Convert PyInterpret into PyExecute:
263 //   PyInterpret(script, global_dict, local_dict)
264 //   -->
265 //   PyExecute(script, local_dict_keys, local_dict_values),
266 //   with side-effect operation:
267 //     Merge global_dict to local dict.
268 //     If there are arguments in global dict and local dict use local dict argument instead of global dict.
PyInterpretToExecute(const pipeline::ResourcePtr & resource)269 bool PyInterpretToExecute(const pipeline::ResourcePtr &resource) {
270   MS_EXCEPTION_IF_NULL(resource);
271   auto manager = resource->manager();
272   MS_EXCEPTION_IF_NULL(manager);
273   auto transact = manager->Transact();
274   const auto all_nodes = manager->all_nodes();
275   std::map<AnfNodePtr, AnfNodePtr> has_converted_nodes;
276   for (const auto &node : all_nodes) {
277     if (IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
278       auto trans_node = Transform(node->cast<CNodePtr>(), manager, &has_converted_nodes);
279       (void)transact.Replace(node, trans_node);
280     }
281   }
282   transact.Commit();
283   return true;
284 }
285 }  // namespace opt
286 }  // namespace mindspore
287