• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/operator/composite/map.h"
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "ir/anf.h"
25 #include "ir/func_graph.h"
26 #include "abstract/abstract_value.h"
27 #include "abstract/abstract_function.h"
28 #include "abstract/dshape.h"
29 #include "include/common/fallback.h"
30 #include "include/common/pybind_api/api_register.h"
31 #include "pipeline/jit/ps/debug/trace.h"
32 #include "frontend/operator/ops.h"
33 #include "mindspore/core/utils/ms_context.h"
34 #include "pipeline/jit/ps/fallback.h"
35 
36 namespace mindspore {
37 // namespace to support composite operators definition
38 namespace prim {
39 using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
40 
FullMakeLeaf(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const AnfNodePtrList & args)41 AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) {
42   MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n";
43   MS_EXCEPTION_IF_NULL(func_graph);
44   std::vector<AnfNodePtr> inputs;
45   if (fn_arg != nullptr) {
46     inputs.emplace_back(fn_arg);
47   } else {
48     inputs.emplace_back(NewValueNode(fn_leaf_));
49   }
50   (void)inputs.insert(inputs.cend(), args.cbegin(), args.cend());
51   return func_graph->NewCNodeInOrder(inputs);
52 }
53 
GenerateLeafFunc(const size_t & args_size)54 FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
55   // Generate func for leaf nodes
56   FuncGraphPtr res_fg = std::make_shared<FuncGraph>();
57   res_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
58   res_fg->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
59   res_fg->debug_info()->set_name("map");
60   AnfNodePtr fn_param = nullptr;
61   if (fn_leaf_ == nullptr) {
62     fn_param = res_fg->add_parameter();
63   }
64   AnfNodePtrList args;
65   for (size_t i = 0; i < args_size; ++i) {
66     args.emplace_back(res_fg->add_parameter());
67   }
68   res_fg->set_output(FullMakeLeaf(res_fg, fn_param, args));
69   return res_fg;
70 }
71 
GetMapInputIndex(size_t num) const72 std::pair<std::string, std::string> Map::GetMapInputIndex(size_t num) const {
73   std::string error_index;
74   std::string next_index;
75   const size_t first_index = 1;
76   const size_t second_index = 2;
77   if (num == first_index) {
78     // The first element in Map is func_graph
79     error_index = "first";
80     next_index = "second";
81   } else if (num == second_index) {
82     error_index = "second";
83     next_index = "third";
84   } else {
85     error_index = std::to_string(num) + "th";
86     next_index = std::to_string(num + 1) + "th";
87   }
88   return std::pair<std::string, std::string>(error_index, next_index);
89 }
90 
FullMakeList(const std::shared_ptr<List> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)91 AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
92                              const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
93   MS_EXCEPTION_IF_NULL(func_graph);
94   MS_EXCEPTION_IF_NULL(type);
95 
96   std::size_t size = type->elements().size();
97   size_t num = 0;
98   std::ostringstream oss;
99   bool is_not_same = false;
100   for (auto &item : arg_pairs) {
101     num++;
102     auto lhs = std::dynamic_pointer_cast<List>(item.second);
103     auto [error_index, next_index] = GetMapInputIndex(num);
104     if (lhs == nullptr) {
105       MS_LOG(EXCEPTION) << "The " << error_index << " element in Map has wrong type, expected a List, but got "
106                         << item.second->ToString() << ".";
107     }
108     if (lhs->dynamic_len()) {
109       MS_LOG(EXCEPTION) << "For 'map', the dynamic length input is unsupported in graph mode";
110     }
111     if (lhs->elements().size() != size) {
112       oss << "\nThe length of the " << error_index << " element in Map is " << size << ", but the length of the "
113           << next_index << " element in Map is " << lhs->elements().size() << ".\n";
114       is_not_same = true;
115       break;
116     }
117   }
118   if (is_not_same) {
119     MS_LOG(EXCEPTION) << "For 'Map', the length of lists must be the same. " << oss.str();
120   }
121 
122   constexpr size_t prim_hold_len = 1;
123   std::vector<AnfNodePtr> inputs;
124   inputs.reserve(size + prim_hold_len);
125   inputs.push_back(NewValueNode(prim::kPrimMakeList));
126 
127   for (size_t i = 0; i < size; i++) {
128     MS_LOG(DEBUG) << "FullMakeList for the " << i << "th arg of the target, reverse_: " << reverse_ << ".";
129     auto res_fg = GenerateLeafFunc(arg_pairs.size());
130     auto fn = NewValueNode(res_fg);
131 
132     std::vector<AnfNodePtr> inputs2;
133     inputs2.push_back(fn);
134     if (fn_arg != nullptr) {
135       inputs2.push_back(fn_arg);
136     }
137 
138     size_t pos = (reverse_ ? (size - 1 - i) : i);
139     (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
140                          [&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
141                            return func_graph->NewCNodeInOrder(
142                              {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
143                          });
144 
145     auto call_node = func_graph->NewCNodeInOrder(inputs2);
146     if (reverse_) {
147       (void)inputs.insert(inputs.cbegin() + 1, call_node);
148     } else {
149       inputs.emplace_back(call_node);
150     }
151   }
152   return func_graph->NewCNodeInOrder(inputs);
153 }
154 
FullMakeTuple(const std::shared_ptr<Tuple> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)155 AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
156                               const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
157   MS_EXCEPTION_IF_NULL(func_graph);
158   MS_EXCEPTION_IF_NULL(type);
159 
160   size_t size = type->elements().size();
161   size_t num = 0;
162   std::ostringstream oss;
163   bool is_not_same = false;
164   for (auto &item : arg_pairs) {
165     num++;
166     auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
167     auto [error_index, next_index] = GetMapInputIndex(num);
168     if (lhs == nullptr) {
169       MS_LOG(EXCEPTION) << "The " << error_index << " element in Map has wrong type, expected a Tuple, but got "
170                         << item.second->ToString() << ".";
171     }
172     if (lhs->dynamic_len()) {
173       MS_LOG(EXCEPTION) << "For 'map', the dynamic length input is unsupported in graph mode";
174     }
175     if (lhs->elements().size() != size) {
176       oss << "\nThe length of the " << error_index << " element in Map is " << size << ", but the length of the "
177           << next_index << " element in Map is " << lhs->elements().size() << ".\n";
178       is_not_same = true;
179       break;
180     }
181   }
182   if (is_not_same) {
183     MS_LOG(EXCEPTION) << "For 'Map', the length of tuples must be the same. " << oss.str();
184   }
185 
186   constexpr size_t prim_hold_len = 1;
187   std::vector<AnfNodePtr> inputs;
188   inputs.reserve(size + prim_hold_len);
189   inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
190 
191   for (size_t i = 0; i < size; i++) {
192     MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th arg of the tuple inputs, reverse_: " << reverse_ << ".";
193     auto res_fg = GenerateLeafFunc(arg_pairs.size());
194     auto fn = NewValueNode(res_fg);
195 
196     std::vector<AnfNodePtr> inputs2;
197     inputs2.push_back(fn);
198     if (fn_arg != nullptr) {
199       inputs2.push_back(fn_arg);
200     }
201 
202     size_t pos = (reverse_ ? (size - 1 - i) : i);
203     (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
204                          [&func_graph, &pos](const std::pair<AnfNodePtr, Any> &item) {
205                            return func_graph->NewCNodeInOrder(
206                              {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
207                          });
208 
209     auto call_node = func_graph->NewCNodeInOrder(inputs2);
210     if (reverse_) {
211       (void)inputs.insert(inputs.cbegin() + 1, call_node);
212     } else {
213       inputs.emplace_back(call_node);
214     }
215   }
216   return func_graph->NewCNodeInOrder(inputs);
217 }
218 
Make(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)219 AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
220   if (arg_pairs.empty()) {
221     MS_EXCEPTION(TypeError) << "The Map operator must have at least two arguments. But the size of arguments is "
222                             << (arg_pairs.size() + 1) << ".";
223   }
224   bool found = false;
225   TypeId id = kObjectTypeEnd;
226   std::pair<AnfNodePtr, TypePtr> pair;
227   for (auto &arg_pair : arg_pairs) {
228     pair = arg_pair;
229     MS_LOG(DEBUG) << "Map " << pair.second->ToString();
230     id = arg_pair.second->type_id();
231     if (nonleaf_.count(id) != 0) {
232       found = true;
233       break;
234     }
235   }
236 
237   if (found) {
238     // In a nonleaf situation, all arguments must have the same generic.
239     bool is_not_same =
240       std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
241         if (item.first != pair.first) {
242           return item.second->type_id() != pair.second->type_id();
243         }
244         return false;
245       });
246     if (is_not_same) {
247       std::ostringstream oss;
248       oss << "There are " << (arg_pairs.size() + 1) << " inputs of `" << name_ << "`, corresponding type info:\n"
249           << trace::GetDebugInfoStr(func_graph->debug_info()) << ".\n";
250       int64_t idx = 0;
251       std::string str_index = "first";
252       for (auto &item : arg_pairs) {
253         if (idx == 0) {
254           // The first element in HyperMap is func_graph
255           str_index = "second";
256         } else if (idx == 1) {
257           str_index = "third";
258         } else {
259           constexpr auto arg_start_idx = 2;
260           str_index = std::to_string(idx + arg_start_idx) + "th";
261         }
262         ++idx;
263         oss << "The type of the " << str_index << " argument in Map is: " << item.second->ToString() << ".\n";
264       }
265       MS_LOG(EXCEPTION) << "The types of arguments in Map must be consistent, "
266                         << "but the types of arguments are inconsistent.\n"
267                         << oss.str();
268     }
269   }
270 
271   switch (id) {
272     case kObjectTypeList: {
273       auto type = std::static_pointer_cast<List>(pair.second);
274       return FullMakeList(type, func_graph, fn_arg, arg_pairs);
275     }
276     case kObjectTypeTuple: {
277       auto type = std::static_pointer_cast<Tuple>(pair.second);
278       return FullMakeTuple(type, func_graph, fn_arg, arg_pairs);
279     }
280     default:
281       MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple, but got " << pair.second->ToString() << ".";
282   }
283 }
284 
GenerateFromTypes(const TypePtrList & args_abs_list)285 FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_abs_list) {
286   bool convert_to_interpret = std::any_of(args_abs_list.begin() + 1, args_abs_list.end(), [](const TypePtr &type) {
287     MS_EXCEPTION_IF_NULL(type);
288     return type->isa<AnyType>() || type->isa<External>();
289   });
290   if (convert_to_interpret) {
291     FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
292     const std::vector<std::string> funcs_str{"map"};
293     auto ret_node = fallback::GeneratePyInterpretWithAbstract(func_graph, funcs_str, args_abs_list.size());
294     func_graph->set_output(ret_node);
295     return func_graph;
296   }
297   FuncGraphPtr res_fg = std::make_shared<FuncGraph>();
298   res_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
299   res_fg->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
300   res_fg->debug_info()->set_name("map");
301 
302   AnfNodePtr fn_param = nullptr;
303   std::size_t i = 0;
304   if (fn_leaf_ == nullptr) {
305     fn_param = res_fg->add_parameter();
306     i = 1;
307   }
308   ArgsPairList arg_pairs;
309   std::size_t size = args_abs_list.size();
310   for (; i < size; ++i) {
311     MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_abs_list[i]->ToString() << ".";
312     arg_pairs.push_back(std::make_pair(res_fg->add_parameter(), args_abs_list[i]));
313   }
314 
315   res_fg->set_output(Make(res_fg, fn_param, arg_pairs));
316   return res_fg;
317 }
318 
NormalizeArgs(const AbstractBasePtrList & args_abs_list) const319 abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_abs_list) const {
320   if (fn_leaf_ == nullptr) {
321     if (args_abs_list.empty()) {
322       MS_LOG(EXCEPTION) << "The arguments of Map operator should not be empty.";
323     }
324     MS_EXCEPTION_IF_NULL(args_abs_list[0]);
325     // Assert that map's function param does not contain free variables
326     if (args_abs_list[0]->isa<FuncGraphAbstractClosure>()) {
327       auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_abs_list[0]);
328       auto func_graph = graph_func->func_graph();
329       if (func_graph->parent() != nullptr) {
330         MS_LOG(EXCEPTION) << "The Map operator don't support Closure with free variable yet.";
331       }
332     }
333   }
334 
335   bool convert_to_interpret =
336     std::any_of(args_abs_list.begin() + 1, args_abs_list.end(), [](const AbstractBasePtr &abs) {
337       MS_EXCEPTION_IF_NULL(abs);
338       return abs->isa<abstract::AbstractAny>() || abs->BuildValue()->isa<parse::InterpretedObject>();
339     });
340   if (convert_to_interpret) {
341     // If the map operation has interpreted object or any object, the map will be converted to PyInterpret node.
342     // So, we can not broaden the args, since the broaden will convert PyInterpret node to PyExecute automatically.
343     return args_abs_list;
344   }
345 
346   AbstractBasePtrList broadened;
347   (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(broadened),
348                        [](const AbstractBasePtr &arg) -> AbstractBasePtr {
349                          MS_EXCEPTION_IF_NULL(arg);
350                          return arg->Broaden();
351                        });
352   return broadened;
353 }
354 }  // namespace prim
355 }  // namespace mindspore
356