• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/operator/composite/map.h"
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "ir/anf.h"
24 #include "ir/func_graph.h"
25 #include "abstract/abstract_value.h"
26 #include "abstract/abstract_function.h"
27 #include "abstract/dshape.h"
28 #include "pybind_api/api_register.h"
29 #include "debug/trace.h"
30 #include "frontend/operator/ops.h"
31 
32 namespace mindspore {
33 // namespace to support composite operators definition
34 namespace prim {
35 using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
36 
FullMakeLeaf(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const AnfNodePtrList & args)37 AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) {
38   MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n";
39   MS_EXCEPTION_IF_NULL(func_graph);
40   std::vector<AnfNodePtr> inputs;
41   if (fn_arg != nullptr) {
42     inputs.emplace_back(fn_arg);
43   } else {
44     inputs.emplace_back(NewValueNode(fn_leaf_));
45   }
46   inputs.insert(inputs.end(), args.begin(), args.end());
47   return func_graph->NewCNodeInOrder(inputs);
48 }
49 
GenerateLeafFunc(const size_t & args_size)50 FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
51   // Generate func for leaf nodes
52   FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
53   ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
54   ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
55   ptrGraph->debug_info()->set_name("map");
56   AnfNodePtr ptrFnArg = nullptr;
57   if (fn_leaf_ == nullptr) {
58     ptrFnArg = ptrGraph->add_parameter();
59   }
60   AnfNodePtrList args;
61   for (size_t i = 0; i < args_size; ++i) {
62     args.emplace_back(ptrGraph->add_parameter());
63   }
64   ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args));
65   return ptrGraph;
66 }
67 
FullMakeList(const std::shared_ptr<List> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)68 AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
69                              const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
70   MS_EXCEPTION_IF_NULL(func_graph);
71   MS_EXCEPTION_IF_NULL(type);
72 
73   std::size_t size = type->elements().size();
74   size_t num = 0;
75   bool is_not_same =
76     std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
77       num++;
78       auto lhs = std::dynamic_pointer_cast<List>(item.second);
79       if (lhs == nullptr) {
80         MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a List, but got "
81                           << item.second->ToString();
82       }
83       if (lhs->elements().size() != size) {
84         MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
85                       << lhs->elements().size();
86         return true;
87       }
88       return false;
89     });
90   if (is_not_same) {
91     MS_LOG(EXCEPTION) << "List in Map should have same length";
92   }
93 
94   constexpr size_t kPrimHoldLen = 1;
95   std::vector<AnfNodePtr> inputs;
96   inputs.reserve(size + kPrimHoldLen);
97   inputs.push_back(NewValueNode(prim::kPrimMakeList));
98 
99   for (size_t i = 0; i < size; i++) {
100     MS_LOG(DEBUG) << "FullMakeList for the " << i << "th arg of the target, reverse_: " << reverse_;
101     auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
102     auto fn = NewValueNode(ptrGraph);
103 
104     std::vector<AnfNodePtr> inputs2;
105     inputs2.push_back(fn);
106     if (fn_arg != nullptr) {
107       inputs2.push_back(fn_arg);
108     }
109 
110     size_t pos = (reverse_ ? (size - 1 - i) : i);
111     (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
112                          [&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
113                            return func_graph->NewCNodeInOrder(
114                              {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
115                          });
116 
117     auto call_node = func_graph->NewCNodeInOrder(inputs2);
118     if (reverse_) {
119       (void)inputs.insert(inputs.begin() + 1, call_node);
120     } else {
121       inputs.emplace_back(call_node);
122     }
123   }
124   return func_graph->NewCNodeInOrder(inputs);
125 }
126 
FullMakeTuple(const std::shared_ptr<Tuple> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)127 AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
128                               const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
129   MS_EXCEPTION_IF_NULL(func_graph);
130   MS_EXCEPTION_IF_NULL(type);
131 
132   size_t size = type->elements().size();
133   size_t num = 0;
134   bool is_not_same =
135     std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
136       num++;
137       auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
138       if (lhs == nullptr) {
139         MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a Tuple, but got "
140                           << item.second->ToString();
141       }
142       if (lhs->elements().size() != size) {
143         MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
144                       << lhs->elements().size();
145         return true;
146       }
147       return false;
148     });
149   if (is_not_same) {
150     MS_LOG(EXCEPTION) << "tuple in Map should have same length";
151   }
152 
153   constexpr size_t kPrimHoldLen = 1;
154   std::vector<AnfNodePtr> inputs;
155   inputs.reserve(size + kPrimHoldLen);
156   inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
157 
158   for (size_t i = 0; i < size; i++) {
159     MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th arg of the tuple inputs, reverse_: " << reverse_;
160     auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
161     auto fn = NewValueNode(ptrGraph);
162 
163     std::vector<AnfNodePtr> inputs2;
164     inputs2.push_back(fn);
165     if (fn_arg != nullptr) {
166       inputs2.push_back(fn_arg);
167     }
168 
169     size_t pos = (reverse_ ? (size - 1 - i) : i);
170     (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
171                          [&func_graph, &pos](const std::pair<AnfNodePtr, Any> &item) {
172                            return func_graph->NewCNodeInOrder(
173                              {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
174                          });
175 
176     auto call_node = func_graph->NewCNodeInOrder(inputs2);
177     if (reverse_) {
178       (void)inputs.insert(inputs.begin() + 1, call_node);
179     } else {
180       inputs.emplace_back(call_node);
181     }
182   }
183   return func_graph->NewCNodeInOrder(inputs);
184 }
185 
FullMakeClass(const std::shared_ptr<Class> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)186 AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
187                               const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
188   MS_EXCEPTION_IF_NULL(type);
189   MS_EXCEPTION_IF_NULL(func_graph);
190 
191   size_t attrSize = type->GetAttributes().size();
192   constexpr size_t kPrimAndTypeLen = 2;
193   std::vector<AnfNodePtr> inputs;
194   inputs.reserve(attrSize + kPrimAndTypeLen);
195   inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
196   inputs.push_back(NewValueNode(type));
197 
198   for (size_t i = 0; i < attrSize; i++) {
199     MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the inputs, reverse_: " << reverse_;
200     auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
201     auto fn = NewValueNode(ptrGraph);
202 
203     std::vector<AnfNodePtr> inputs2;
204     inputs2.push_back(fn);
205     if (fn_arg != nullptr) {
206       inputs2.push_back(fn_arg);
207     }
208 
209     size_t size = arg_pairs.size();
210     for (size_t j = 0; j < size; j++) {
211       size_t pos = (reverse_ ? (size - 1 - j) : j);
212       auto &item = arg_pairs[pos];
213       inputs2.push_back(
214         func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
215     }
216 
217     auto call_node = func_graph->NewCNodeInOrder(inputs2);
218     if (reverse_) {
219       constexpr auto kCallNodePosition = 2;
220       (void)inputs.insert(inputs.begin() + kCallNodePosition, call_node);
221     } else {
222       inputs.emplace_back(call_node);
223     }
224   }
225   return func_graph->NewCNodeInOrder(inputs);
226 }
227 
Make(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)228 AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
229   if (arg_pairs.empty()) {
230     MS_EXCEPTION(TypeError) << "map() must have at least two arguments";
231   }
232   bool found = false;
233   TypeId id = kObjectTypeEnd;
234   std::pair<AnfNodePtr, TypePtr> pair;
235   for (auto &arg_pair : arg_pairs) {
236     pair = arg_pair;
237     MS_LOG(DEBUG) << "Map " << pair.second->ToString();
238     id = arg_pair.second->type_id();
239     if (nonleaf_.count(id)) {
240       found = true;
241       break;
242     }
243   }
244 
245   if (found) {
246     // In a nonleaf situation, all arguments must have the same generic.
247     bool is_not_same =
248       std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
249         if (item.first != pair.first) {
250           return item.second->type_id() != pair.second->type_id();
251         }
252         return false;
253       });
254     if (is_not_same) {
255       std::ostringstream oss;
256       oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
257           << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
258       int64_t idx = 0;
259       for (auto &item : arg_pairs) {
260         oss << ++idx << ": " << item.second->ToString() << "\n";
261       }
262       MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n"
263                         << oss.str() << pair.second->ToString() << "\n";
264     }
265   }
266 
267   switch (id) {
268     case kObjectTypeList: {
269       auto type = std::static_pointer_cast<List>(pair.second);
270       return FullMakeList(type, func_graph, fn_arg, arg_pairs);
271     }
272     case kObjectTypeTuple: {
273       auto type = std::static_pointer_cast<Tuple>(pair.second);
274       return FullMakeTuple(type, func_graph, fn_arg, arg_pairs);
275     }
276     case kObjectTypeClass: {
277       auto type = std::static_pointer_cast<Class>(pair.second);
278       return FullMakeClass(type, func_graph, fn_arg, arg_pairs);
279     }
280     default:
281       MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class "
282                         << ", but got " << pair.second->ToString();
283   }
284 }
285 
GenerateFromTypes(const TypePtrList & args_spec_list)286 FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) {
287   FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
288   ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
289   ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
290   ptrGraph->debug_info()->set_name("map");
291 
292   AnfNodePtr ptrFnArg = nullptr;
293   std::size_t i = 0;
294   if (fn_leaf_ == nullptr) {
295     ptrFnArg = ptrGraph->add_parameter();
296     i = 1;
297   }
298   ArgsPairList arg_pairs;
299   std::size_t size = args_spec_list.size();
300   for (; i < size; ++i) {
301     MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString();
302     arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
303   }
304 
305   ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs));
306   return ptrGraph;
307 }
308 
NormalizeArgs(const AbstractBasePtrList & args_spec_list) const309 abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
310   if (fn_leaf_ == nullptr) {
311     if (args_spec_list.empty()) {
312       MS_LOG(EXCEPTION) << "The args spec list should not be empty.";
313     }
314     MS_EXCEPTION_IF_NULL(args_spec_list[0]);
315     // Assert that map's function param does not contain free variables
316     if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
317       auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
318       auto func_graph = graph_func->func_graph();
319       if (func_graph->parent() != nullptr) {
320         MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet.";
321       }
322     }
323   }
324 
325   AbstractBasePtrList broadened;
326   (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
327                        [](const AbstractBasePtr &arg) -> AbstractBasePtr {
328                          MS_EXCEPTION_IF_NULL(arg);
329                          return arg->Broaden();
330                        });
331   return broadened;
332 }
333 
__anond2cf421d0702(const py::module *m) 334 REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) {
335                          (void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
336                            .def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"),
337                                 py::arg("ops"))
338                            .def(py::init<bool>(), py::arg("reverse"));
339                        }));
340 }  // namespace prim
341 }  // namespace mindspore
342