1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/operator/composite/map.h"
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "ir/anf.h"
25 #include "ir/func_graph.h"
26 #include "abstract/abstract_value.h"
27 #include "abstract/abstract_function.h"
28 #include "abstract/dshape.h"
29 #include "include/common/fallback.h"
30 #include "include/common/pybind_api/api_register.h"
31 #include "pipeline/jit/ps/debug/trace.h"
32 #include "frontend/operator/ops.h"
33 #include "mindspore/core/utils/ms_context.h"
34 #include "pipeline/jit/ps/fallback.h"
35
36 namespace mindspore {
37 // namespace to support composite operators definition
38 namespace prim {
39 using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
40
FullMakeLeaf(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const AnfNodePtrList & args)41 AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) {
42 MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n";
43 MS_EXCEPTION_IF_NULL(func_graph);
44 std::vector<AnfNodePtr> inputs;
45 if (fn_arg != nullptr) {
46 inputs.emplace_back(fn_arg);
47 } else {
48 inputs.emplace_back(NewValueNode(fn_leaf_));
49 }
50 (void)inputs.insert(inputs.cend(), args.cbegin(), args.cend());
51 return func_graph->NewCNodeInOrder(inputs);
52 }
53
GenerateLeafFunc(const size_t & args_size)54 FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
55 // Generate func for leaf nodes
56 FuncGraphPtr res_fg = std::make_shared<FuncGraph>();
57 res_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
58 res_fg->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
59 res_fg->debug_info()->set_name("map");
60 AnfNodePtr fn_param = nullptr;
61 if (fn_leaf_ == nullptr) {
62 fn_param = res_fg->add_parameter();
63 }
64 AnfNodePtrList args;
65 for (size_t i = 0; i < args_size; ++i) {
66 args.emplace_back(res_fg->add_parameter());
67 }
68 res_fg->set_output(FullMakeLeaf(res_fg, fn_param, args));
69 return res_fg;
70 }
71
GetMapInputIndex(size_t num) const72 std::pair<std::string, std::string> Map::GetMapInputIndex(size_t num) const {
73 std::string error_index;
74 std::string next_index;
75 const size_t first_index = 1;
76 const size_t second_index = 2;
77 if (num == first_index) {
78 // The first element in Map is func_graph
79 error_index = "first";
80 next_index = "second";
81 } else if (num == second_index) {
82 error_index = "second";
83 next_index = "third";
84 } else {
85 error_index = std::to_string(num) + "th";
86 next_index = std::to_string(num + 1) + "th";
87 }
88 return std::pair<std::string, std::string>(error_index, next_index);
89 }
90
FullMakeList(const std::shared_ptr<List> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)91 AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
92 const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
93 MS_EXCEPTION_IF_NULL(func_graph);
94 MS_EXCEPTION_IF_NULL(type);
95
96 std::size_t size = type->elements().size();
97 size_t num = 0;
98 std::ostringstream oss;
99 bool is_not_same = false;
100 for (auto &item : arg_pairs) {
101 num++;
102 auto lhs = std::dynamic_pointer_cast<List>(item.second);
103 auto [error_index, next_index] = GetMapInputIndex(num);
104 if (lhs == nullptr) {
105 MS_LOG(EXCEPTION) << "The " << error_index << " element in Map has wrong type, expected a List, but got "
106 << item.second->ToString() << ".";
107 }
108 if (lhs->dynamic_len()) {
109 MS_LOG(EXCEPTION) << "For 'map', the dynamic length input is unsupported in graph mode";
110 }
111 if (lhs->elements().size() != size) {
112 oss << "\nThe length of the " << error_index << " element in Map is " << size << ", but the length of the "
113 << next_index << " element in Map is " << lhs->elements().size() << ".\n";
114 is_not_same = true;
115 break;
116 }
117 }
118 if (is_not_same) {
119 MS_LOG(EXCEPTION) << "For 'Map', the length of lists must be the same. " << oss.str();
120 }
121
122 constexpr size_t prim_hold_len = 1;
123 std::vector<AnfNodePtr> inputs;
124 inputs.reserve(size + prim_hold_len);
125 inputs.push_back(NewValueNode(prim::kPrimMakeList));
126
127 for (size_t i = 0; i < size; i++) {
128 MS_LOG(DEBUG) << "FullMakeList for the " << i << "th arg of the target, reverse_: " << reverse_ << ".";
129 auto res_fg = GenerateLeafFunc(arg_pairs.size());
130 auto fn = NewValueNode(res_fg);
131
132 std::vector<AnfNodePtr> inputs2;
133 inputs2.push_back(fn);
134 if (fn_arg != nullptr) {
135 inputs2.push_back(fn_arg);
136 }
137
138 size_t pos = (reverse_ ? (size - 1 - i) : i);
139 (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
140 [&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
141 return func_graph->NewCNodeInOrder(
142 {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
143 });
144
145 auto call_node = func_graph->NewCNodeInOrder(inputs2);
146 if (reverse_) {
147 (void)inputs.insert(inputs.cbegin() + 1, call_node);
148 } else {
149 inputs.emplace_back(call_node);
150 }
151 }
152 return func_graph->NewCNodeInOrder(inputs);
153 }
154
FullMakeTuple(const std::shared_ptr<Tuple> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)155 AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
156 const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
157 MS_EXCEPTION_IF_NULL(func_graph);
158 MS_EXCEPTION_IF_NULL(type);
159
160 size_t size = type->elements().size();
161 size_t num = 0;
162 std::ostringstream oss;
163 bool is_not_same = false;
164 for (auto &item : arg_pairs) {
165 num++;
166 auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
167 auto [error_index, next_index] = GetMapInputIndex(num);
168 if (lhs == nullptr) {
169 MS_LOG(EXCEPTION) << "The " << error_index << " element in Map has wrong type, expected a Tuple, but got "
170 << item.second->ToString() << ".";
171 }
172 if (lhs->dynamic_len()) {
173 MS_LOG(EXCEPTION) << "For 'map', the dynamic length input is unsupported in graph mode";
174 }
175 if (lhs->elements().size() != size) {
176 oss << "\nThe length of the " << error_index << " element in Map is " << size << ", but the length of the "
177 << next_index << " element in Map is " << lhs->elements().size() << ".\n";
178 is_not_same = true;
179 break;
180 }
181 }
182 if (is_not_same) {
183 MS_LOG(EXCEPTION) << "For 'Map', the length of tuples must be the same. " << oss.str();
184 }
185
186 constexpr size_t prim_hold_len = 1;
187 std::vector<AnfNodePtr> inputs;
188 inputs.reserve(size + prim_hold_len);
189 inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
190
191 for (size_t i = 0; i < size; i++) {
192 MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th arg of the tuple inputs, reverse_: " << reverse_ << ".";
193 auto res_fg = GenerateLeafFunc(arg_pairs.size());
194 auto fn = NewValueNode(res_fg);
195
196 std::vector<AnfNodePtr> inputs2;
197 inputs2.push_back(fn);
198 if (fn_arg != nullptr) {
199 inputs2.push_back(fn_arg);
200 }
201
202 size_t pos = (reverse_ ? (size - 1 - i) : i);
203 (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
204 [&func_graph, &pos](const std::pair<AnfNodePtr, Any> &item) {
205 return func_graph->NewCNodeInOrder(
206 {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
207 });
208
209 auto call_node = func_graph->NewCNodeInOrder(inputs2);
210 if (reverse_) {
211 (void)inputs.insert(inputs.cbegin() + 1, call_node);
212 } else {
213 inputs.emplace_back(call_node);
214 }
215 }
216 return func_graph->NewCNodeInOrder(inputs);
217 }
218
Make(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)219 AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
220 if (arg_pairs.empty()) {
221 MS_EXCEPTION(TypeError) << "The Map operator must have at least two arguments. But the size of arguments is "
222 << (arg_pairs.size() + 1) << ".";
223 }
224 bool found = false;
225 TypeId id = kObjectTypeEnd;
226 std::pair<AnfNodePtr, TypePtr> pair;
227 for (auto &arg_pair : arg_pairs) {
228 pair = arg_pair;
229 MS_LOG(DEBUG) << "Map " << pair.second->ToString();
230 id = arg_pair.second->type_id();
231 if (nonleaf_.count(id) != 0) {
232 found = true;
233 break;
234 }
235 }
236
237 if (found) {
238 // In a nonleaf situation, all arguments must have the same generic.
239 bool is_not_same =
240 std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
241 if (item.first != pair.first) {
242 return item.second->type_id() != pair.second->type_id();
243 }
244 return false;
245 });
246 if (is_not_same) {
247 std::ostringstream oss;
248 oss << "There are " << (arg_pairs.size() + 1) << " inputs of `" << name_ << "`, corresponding type info:\n"
249 << trace::GetDebugInfoStr(func_graph->debug_info()) << ".\n";
250 int64_t idx = 0;
251 std::string str_index = "first";
252 for (auto &item : arg_pairs) {
253 if (idx == 0) {
254 // The first element in HyperMap is func_graph
255 str_index = "second";
256 } else if (idx == 1) {
257 str_index = "third";
258 } else {
259 constexpr auto arg_start_idx = 2;
260 str_index = std::to_string(idx + arg_start_idx) + "th";
261 }
262 ++idx;
263 oss << "The type of the " << str_index << " argument in Map is: " << item.second->ToString() << ".\n";
264 }
265 MS_LOG(EXCEPTION) << "The types of arguments in Map must be consistent, "
266 << "but the types of arguments are inconsistent.\n"
267 << oss.str();
268 }
269 }
270
271 switch (id) {
272 case kObjectTypeList: {
273 auto type = std::static_pointer_cast<List>(pair.second);
274 return FullMakeList(type, func_graph, fn_arg, arg_pairs);
275 }
276 case kObjectTypeTuple: {
277 auto type = std::static_pointer_cast<Tuple>(pair.second);
278 return FullMakeTuple(type, func_graph, fn_arg, arg_pairs);
279 }
280 default:
281 MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple, but got " << pair.second->ToString() << ".";
282 }
283 }
284
GenerateFromTypes(const TypePtrList & args_abs_list)285 FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_abs_list) {
286 bool convert_to_interpret = std::any_of(args_abs_list.begin() + 1, args_abs_list.end(), [](const TypePtr &type) {
287 MS_EXCEPTION_IF_NULL(type);
288 return type->isa<AnyType>() || type->isa<External>();
289 });
290 if (convert_to_interpret) {
291 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
292 const std::vector<std::string> funcs_str{"map"};
293 auto ret_node = fallback::GeneratePyInterpretWithAbstract(func_graph, funcs_str, args_abs_list.size());
294 func_graph->set_output(ret_node);
295 return func_graph;
296 }
297 FuncGraphPtr res_fg = std::make_shared<FuncGraph>();
298 res_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
299 res_fg->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
300 res_fg->debug_info()->set_name("map");
301
302 AnfNodePtr fn_param = nullptr;
303 std::size_t i = 0;
304 if (fn_leaf_ == nullptr) {
305 fn_param = res_fg->add_parameter();
306 i = 1;
307 }
308 ArgsPairList arg_pairs;
309 std::size_t size = args_abs_list.size();
310 for (; i < size; ++i) {
311 MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_abs_list[i]->ToString() << ".";
312 arg_pairs.push_back(std::make_pair(res_fg->add_parameter(), args_abs_list[i]));
313 }
314
315 res_fg->set_output(Make(res_fg, fn_param, arg_pairs));
316 return res_fg;
317 }
318
NormalizeArgs(const AbstractBasePtrList & args_abs_list) const319 abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_abs_list) const {
320 if (fn_leaf_ == nullptr) {
321 if (args_abs_list.empty()) {
322 MS_LOG(EXCEPTION) << "The arguments of Map operator should not be empty.";
323 }
324 MS_EXCEPTION_IF_NULL(args_abs_list[0]);
325 // Assert that map's function param does not contain free variables
326 if (args_abs_list[0]->isa<FuncGraphAbstractClosure>()) {
327 auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_abs_list[0]);
328 auto func_graph = graph_func->func_graph();
329 if (func_graph->parent() != nullptr) {
330 MS_LOG(EXCEPTION) << "The Map operator don't support Closure with free variable yet.";
331 }
332 }
333 }
334
335 bool convert_to_interpret =
336 std::any_of(args_abs_list.begin() + 1, args_abs_list.end(), [](const AbstractBasePtr &abs) {
337 MS_EXCEPTION_IF_NULL(abs);
338 return abs->isa<abstract::AbstractAny>() || abs->BuildValue()->isa<parse::InterpretedObject>();
339 });
340 if (convert_to_interpret) {
341 // If the map operation has interpreted object or any object, the map will be converted to PyInterpret node.
342 // So, we can not broaden the args, since the broaden will convert PyInterpret node to PyExecute automatically.
343 return args_abs_list;
344 }
345
346 AbstractBasePtrList broadened;
347 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(broadened),
348 [](const AbstractBasePtr &arg) -> AbstractBasePtr {
349 MS_EXCEPTION_IF_NULL(arg);
350 return arg->Broaden();
351 });
352 return broadened;
353 }
354 } // namespace prim
355 } // namespace mindspore
356