1 /**
2 * Copyright 2022-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/py_interpret_to_execute.h"
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <unordered_map>
23
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "abstract/abstract_function.h"
27 #include "include/common/utils/convert_utils_py.h"
28 #include "include/common/utils/utils.h"
29 #include "utils/anf_utils.h"
30 #include "utils/interpret_node_recorder.h"
31 #include "utils/symbolic.h"
32 #include "pipeline/jit/ps/parse/resolve.h"
33 #include "pipeline/jit/ps/fallback.h"
34
35 namespace mindspore {
36 /* namespace to support opt */
37 namespace opt {
38 namespace {
39 CNodePtr Transform(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
40 std::map<AnfNodePtr, AnfNodePtr> *has_converted_nodes);
NewValueNodeWithAbstract(const ValuePtr & value,const AbstractBasePtr & abs=nullptr)41 AnfNodePtr NewValueNodeWithAbstract(const ValuePtr &value, const AbstractBasePtr &abs = nullptr) {
42 auto value_node = NewValueNode(value);
43 if (abs != nullptr) {
44 value_node->set_abstract(abs->Clone());
45 } else {
46 value_node->set_abstract(value->ToAbstract());
47 }
48 return value_node;
49 }
50
FuncGraphToPyData(const AnfNodePtr & node)51 AnfNodePtr FuncGraphToPyData(const AnfNodePtr &node) {
52 MS_EXCEPTION_IF_NULL(node);
53 if (!node->isa<ValueNode>()) {
54 return node;
55 }
56 auto value_node = node->cast_ptr<ValueNode>();
57 auto value = value_node->value();
58 if (value->IsFromTypeId(FuncGraph::kTypeId)) {
59 auto fg = value->cast_ptr<FuncGraph>();
60 MS_EXCEPTION_IF_NULL(fg);
61 auto wrapper_obj = fg->python_obj();
62 if (wrapper_obj != nullptr && wrapper_obj->isa<parse::PyObjectWrapper>()) {
63 return NewValueNode(
64 std::make_shared<parse::InterpretedObject>(wrapper_obj->cast_ptr<parse::PyObjectWrapper>()->obj()));
65 }
66 }
67 return node;
68 }
69
ConvertValueTupleToList(const AnfNodePtr & node)70 std::vector<AnfNodePtr> ConvertValueTupleToList(const AnfNodePtr &node) {
71 if ((!IsValueNode<ValueTuple>(node) && !IsPrimitiveCNode(node, prim::kPrimMakeTuple))) {
72 MS_LOG(INTERNAL_EXCEPTION) << "The dictionary's keys and values should be a tuple, but got " << node->DebugString();
73 }
74 std::vector<AnfNodePtr> node_list;
75 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
76 auto cnode = node->cast_ptr<CNode>();
77 auto inputs = cnode->inputs();
78 std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(node_list));
79 return node_list;
80 }
81 auto tuple_value = GetValueNode<ValueTuplePtr>(node);
82 auto value_list = tuple_value->value();
83 std::transform(value_list.begin(), value_list.end(), std::back_inserter(node_list),
84 [](const ValuePtr &value) -> AnfNodePtr { return NewValueNodeWithAbstract(value); });
85 return node_list;
86 }
87
UnzipGlobalDict(const AnfNodePtr & dict_node)88 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> UnzipGlobalDict(const AnfNodePtr &dict_node) {
89 MS_EXCEPTION_IF_NULL(dict_node);
90 std::vector<AnfNodePtr> keys;
91 std::vector<AnfNodePtr> values;
92 if (!dict_node->isa<ValueNode>()) {
93 MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret global dict should be a InterpretedObject value node, but got "
94 << dict_node->DebugString();
95 }
96 // Process the PyInterpret operator defined by the frontend and its global information is empty.
97 if (IsValueNode<ValueDictionary>(dict_node)) {
98 auto dict_input = GetValueNode<ValueDictionaryPtr>(dict_node);
99 if (dict_input->value().empty()) {
100 return std::make_pair(keys, values);
101 }
102 }
103 auto interpreted_object = GetValueNode<parse::InterpretedObjectPtr>(dict_node);
104 MS_EXCEPTION_IF_NULL(interpreted_object);
105 ValuePtr converted_value = nullptr;
106 if (!parse::ConvertData(interpreted_object->obj(), &converted_value)) {
107 MS_LOG(INTERNAL_EXCEPTION) << "Convert data failed";
108 }
109 MS_EXCEPTION_IF_NULL(converted_value);
110 auto dict_value = dyn_cast<ValueDictionary>(converted_value);
111 if (dict_value == nullptr) {
112 MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict or global dict should be a dictionary, but got "
113 << converted_value->ToString();
114 }
115 for (auto item : dict_value->value()) {
116 (void)keys.emplace_back(NewValueNodeWithAbstract(item.first));
117 (void)values.emplace_back(NewValueNodeWithAbstract(item.second));
118 }
119 return std::make_pair(keys, values);
120 }
121
UnzipLocalDict(const AnfNodePtr & dict_node)122 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> UnzipLocalDict(const AnfNodePtr &dict_node) {
123 MS_EXCEPTION_IF_NULL(dict_node);
124 if (dict_node->isa<ValueNode>()) {
125 auto dict_value = GetValueNode<ValueDictionaryPtr>(dict_node);
126 if (dict_value == nullptr) {
127 MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict should be a dictionary, but got "
128 << dict_node->DebugString();
129 }
130
131 auto abs = dict_node->abstract();
132 MS_EXCEPTION_IF_NULL(abs);
133 auto dict_abs = abs->cast<abstract::AbstractDictionaryPtr>();
134 MS_EXCEPTION_IF_NULL(dict_abs);
135 const auto &elements_pair = dict_abs->elements();
136 const auto &dict_value_value = dict_value->value();
137 if (elements_pair.size() != dict_value_value.size()) {
138 MS_LOG(INTERNAL_EXCEPTION) << "For node: " << dict_node->DebugString()
139 << ", the abstract elements size is: " << elements_pair.size()
140 << " and the value elements size is: " << dict_value_value.size()
141 << ". Size not matched.";
142 }
143
144 std::vector<AnfNodePtr> keys;
145 std::vector<AnfNodePtr> values;
146 for (size_t i = 0; i < dict_value_value.size(); ++i) {
147 auto item = dict_value_value[i];
148 (void)keys.emplace_back(NewValueNodeWithAbstract(item.first));
149 // Key element may contain ExtraInfoHolder, need to clone the abstract.
150 (void)values.emplace_back(NewValueNodeWithAbstract(item.second, elements_pair[i].second));
151 }
152 return std::make_pair(keys, values);
153 }
154
155 if (!IsPrimitiveCNode(dict_node, prim::kPrimMakeDict)) {
156 MS_LOG(INTERNAL_EXCEPTION) << "The PyInterpret local dict should be a dictionary, but got "
157 << dict_node->DebugString();
158 }
159 auto make_dict_node = dict_node->cast_ptr<CNode>();
160 constexpr auto kMakeDictKeysInputIndex = 1;
161 constexpr auto kMakeDictValueInputIndex = 2;
162 auto keys_input = make_dict_node->input(kMakeDictKeysInputIndex);
163 auto values_input = make_dict_node->input(kMakeDictValueInputIndex);
164
165 auto keys_list = ConvertValueTupleToList(keys_input);
166 auto values_list = ConvertValueTupleToList(values_input);
167 return std::make_pair(keys_list, values_list);
168 }
169
GetLocalKeySet(const std::vector<AnfNodePtr> & key_node_list)170 std::set<std::string> GetLocalKeySet(const std::vector<AnfNodePtr> &key_node_list) {
171 std::set<std::string> key_set;
172 std::transform(key_node_list.begin(), key_node_list.end(), std::inserter(key_set, key_set.begin()),
173 [](const AnfNodePtr &node) -> std::string {
174 auto abs = node->abstract();
175 MS_EXCEPTION_IF_NULL(abs);
176 auto value = abs->BuildValue();
177 MS_EXCEPTION_IF_NULL(value);
178 return GetValue<std::string>(value);
179 });
180 return key_set;
181 }
182
183 // Merge global dict to local dict and return merged key and value
MergeGlobalDictToLocal(const AnfNodePtr & global_dict_node,const AnfNodePtr & local_dict_node,const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,std::map<AnfNodePtr,AnfNodePtr> * has_converted_nodes)184 std::pair<AnfNodePtr, AnfNodePtr> MergeGlobalDictToLocal(const AnfNodePtr &global_dict_node,
185 const AnfNodePtr &local_dict_node,
186 const FuncGraphPtr &func_graph,
187 const FuncGraphManagerPtr &manager,
188 std::map<AnfNodePtr, AnfNodePtr> *has_converted_nodes) {
189 MS_EXCEPTION_IF_NULL(global_dict_node);
190 MS_EXCEPTION_IF_NULL(local_dict_node);
191 auto [global_keys, global_values] = UnzipGlobalDict(global_dict_node);
192 auto [local_keys, local_values] = UnzipLocalDict(local_dict_node);
193
194 auto local_dict_keys_set = GetLocalKeySet(local_keys);
195
196 std::vector<AnfNodePtr> local_keys_inputs{NewValueNode(prim::kPrimMakeTuple)};
197 std::vector<AnfNodePtr> local_value_inputs{NewValueNode(prim::kPrimMakeTuple)};
198 for (size_t index = 0; index < global_keys.size(); ++index) {
199 auto global_key = global_keys.at(index);
200 MS_EXCEPTION_IF_NULL(global_key);
201 auto key = GetValueNode<StringImmPtr>(global_key);
202 if (local_dict_keys_set.find(GetValue<std::string>(key)) != local_dict_keys_set.end()) {
203 MS_LOG(INFO) << "The global dict has the same name with local dict.:" << key->ToString();
204 continue;
205 }
206 MS_LOG(DEBUG) << "The global key " << global_key->DebugString() << ", value "
207 << global_values.at(index)->DebugString() << ". merged in local dict.";
208 (void)local_keys_inputs.emplace_back(global_key);
209 (void)local_value_inputs.emplace_back(FuncGraphToPyData(global_values.at(index)));
210 }
211 std::copy(local_keys.begin(), local_keys.end(), std::back_inserter(local_keys_inputs));
212
213 for (size_t i = 0; i < local_values.size(); ++i) {
214 auto local_value_node = local_values[i];
215 if (!IsPrimitiveCNode(local_value_node, prim::kPrimPyInterpret)) {
216 (void)local_value_inputs.emplace_back(local_value_node);
217 } else if (has_converted_nodes->find(local_value_node) != has_converted_nodes->end()) {
218 (void)local_value_inputs.emplace_back((*has_converted_nodes)[local_value_node]);
219 } else {
220 auto trans_node = Transform(local_value_node->cast<CNodePtr>(), manager, has_converted_nodes);
221 (void)manager->Replace(local_value_node, trans_node);
222 (void)local_value_inputs.emplace_back(trans_node);
223 }
224 }
225 return std::make_pair(func_graph->NewCNode(local_keys_inputs), func_graph->NewCNode(local_value_inputs));
226 }
227
Transform(const CNodePtr & cnode,const FuncGraphManagerPtr & manager,std::map<AnfNodePtr,AnfNodePtr> * has_converted_nodes)228 CNodePtr Transform(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
229 std::map<AnfNodePtr, AnfNodePtr> *has_converted_nodes) {
230 constexpr auto input_index_one = 1;
231 constexpr auto input_index_two = 2;
232 constexpr auto input_index_three = 3;
233 auto new_cnode = std::make_shared<CNode>(*cnode);
234 new_cnode->CloneUserData(cnode);
235 new_cnode->set_input(0, NewValueNode(prim::kPrimPyExecute));
236 auto &first_input = cnode->input(input_index_one);
237 if (IsValueNode<parse::Script>(first_input)) {
238 const auto &script = GetValueNode<std::shared_ptr<parse::Script>>(first_input);
239 const auto &script_str = script->script();
240 const auto &script_strimm_node = NewValueNode(std::make_shared<StringImm>(script_str));
241 new_cnode->set_input(input_index_one, script_strimm_node);
242 } else if (!IsValueNode<StringImm>(first_input)) {
243 MS_LOG(INTERNAL_EXCEPTION) << "The first input should be a Script or string, but got "
244 << cnode->input(input_index_one)->DebugString();
245 }
246 auto global_dict_node = cnode->input(input_index_two);
247 auto local_dict_node = cnode->input(input_index_three);
248
249 auto [local_dict_keys, local_dict_values] =
250 MergeGlobalDictToLocal(global_dict_node, local_dict_node, cnode->func_graph(), manager, has_converted_nodes);
251
252 new_cnode->set_input(input_index_two, local_dict_keys);
253 new_cnode->set_input(input_index_three, local_dict_values);
254
255 // Record the PyExecute node.
256 InterpretNodeRecorder::GetInstance().PushPyExecuteNode(new_cnode);
257 (void)has_converted_nodes->emplace(cnode, new_cnode);
258 return new_cnode;
259 }
260 } // namespace
261
262 // Convert PyInterpret into PyExecute:
263 // PyInterpret(script, global_dict, local_dict)
264 // -->
265 // PyExecute(script, local_dict_keys, local_dict_values),
266 // with side-effect operation:
267 // Merge global_dict to local dict.
268 // If there are arguments in global dict and local dict use local dict argument instead of global dict.
PyInterpretToExecute(const pipeline::ResourcePtr & resource)269 bool PyInterpretToExecute(const pipeline::ResourcePtr &resource) {
270 MS_EXCEPTION_IF_NULL(resource);
271 auto manager = resource->manager();
272 MS_EXCEPTION_IF_NULL(manager);
273 auto transact = manager->Transact();
274 const auto all_nodes = manager->all_nodes();
275 std::map<AnfNodePtr, AnfNodePtr> has_converted_nodes;
276 for (const auto &node : all_nodes) {
277 if (IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
278 auto trans_node = Transform(node->cast<CNodePtr>(), manager, &has_converted_nodes);
279 (void)transact.Replace(node, trans_node);
280 }
281 }
282 transact.Commit();
283 return true;
284 }
285 } // namespace opt
286 } // namespace mindspore
287