1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/operator/composite/map.h"
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "ir/anf.h"
24 #include "ir/func_graph.h"
25 #include "abstract/abstract_value.h"
26 #include "abstract/abstract_function.h"
27 #include "abstract/dshape.h"
28 #include "pybind_api/api_register.h"
29 #include "debug/trace.h"
30 #include "frontend/operator/ops.h"
31
32 namespace mindspore {
33 // namespace to support composite operators definition
34 namespace prim {
35 using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
36
FullMakeLeaf(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const AnfNodePtrList & args)37 AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) {
38 MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n";
39 MS_EXCEPTION_IF_NULL(func_graph);
40 std::vector<AnfNodePtr> inputs;
41 if (fn_arg != nullptr) {
42 inputs.emplace_back(fn_arg);
43 } else {
44 inputs.emplace_back(NewValueNode(fn_leaf_));
45 }
46 inputs.insert(inputs.end(), args.begin(), args.end());
47 return func_graph->NewCNodeInOrder(inputs);
48 }
49
GenerateLeafFunc(const size_t & args_size)50 FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
51 // Generate func for leaf nodes
52 FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
53 ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
54 ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
55 ptrGraph->debug_info()->set_name("map");
56 AnfNodePtr ptrFnArg = nullptr;
57 if (fn_leaf_ == nullptr) {
58 ptrFnArg = ptrGraph->add_parameter();
59 }
60 AnfNodePtrList args;
61 for (size_t i = 0; i < args_size; ++i) {
62 args.emplace_back(ptrGraph->add_parameter());
63 }
64 ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args));
65 return ptrGraph;
66 }
67
FullMakeList(const std::shared_ptr<List> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)68 AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
69 const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
70 MS_EXCEPTION_IF_NULL(func_graph);
71 MS_EXCEPTION_IF_NULL(type);
72
73 std::size_t size = type->elements().size();
74 size_t num = 0;
75 bool is_not_same =
76 std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
77 num++;
78 auto lhs = std::dynamic_pointer_cast<List>(item.second);
79 if (lhs == nullptr) {
80 MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a List, but got "
81 << item.second->ToString();
82 }
83 if (lhs->elements().size() != size) {
84 MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
85 << lhs->elements().size();
86 return true;
87 }
88 return false;
89 });
90 if (is_not_same) {
91 MS_LOG(EXCEPTION) << "List in Map should have same length";
92 }
93
94 constexpr size_t kPrimHoldLen = 1;
95 std::vector<AnfNodePtr> inputs;
96 inputs.reserve(size + kPrimHoldLen);
97 inputs.push_back(NewValueNode(prim::kPrimMakeList));
98
99 for (size_t i = 0; i < size; i++) {
100 MS_LOG(DEBUG) << "FullMakeList for the " << i << "th arg of the target, reverse_: " << reverse_;
101 auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
102 auto fn = NewValueNode(ptrGraph);
103
104 std::vector<AnfNodePtr> inputs2;
105 inputs2.push_back(fn);
106 if (fn_arg != nullptr) {
107 inputs2.push_back(fn_arg);
108 }
109
110 size_t pos = (reverse_ ? (size - 1 - i) : i);
111 (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
112 [&func_graph, pos](const std::pair<AnfNodePtr, Any> &item) {
113 return func_graph->NewCNodeInOrder(
114 {NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(SizeToLong(pos))});
115 });
116
117 auto call_node = func_graph->NewCNodeInOrder(inputs2);
118 if (reverse_) {
119 (void)inputs.insert(inputs.begin() + 1, call_node);
120 } else {
121 inputs.emplace_back(call_node);
122 }
123 }
124 return func_graph->NewCNodeInOrder(inputs);
125 }
126
FullMakeTuple(const std::shared_ptr<Tuple> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)127 AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
128 const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
129 MS_EXCEPTION_IF_NULL(func_graph);
130 MS_EXCEPTION_IF_NULL(type);
131
132 size_t size = type->elements().size();
133 size_t num = 0;
134 bool is_not_same =
135 std::any_of(arg_pairs.begin(), arg_pairs.end(), [&num, size](const std::pair<AnfNodePtr, TypePtr> &item) {
136 num++;
137 auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
138 if (lhs == nullptr) {
139 MS_LOG(EXCEPTION) << "The elements[" << (num - 1) << "] has wrong type, expected a Tuple, but got "
140 << item.second->ToString();
141 }
142 if (lhs->elements().size() != size) {
143 MS_LOG(ERROR) << "The elements[" << (num - 1) << "] has different length, expected " << size << ", but got "
144 << lhs->elements().size();
145 return true;
146 }
147 return false;
148 });
149 if (is_not_same) {
150 MS_LOG(EXCEPTION) << "tuple in Map should have same length";
151 }
152
153 constexpr size_t kPrimHoldLen = 1;
154 std::vector<AnfNodePtr> inputs;
155 inputs.reserve(size + kPrimHoldLen);
156 inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
157
158 for (size_t i = 0; i < size; i++) {
159 MS_LOG(DEBUG) << "FullMakeTuple for the " << i << "th arg of the tuple inputs, reverse_: " << reverse_;
160 auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
161 auto fn = NewValueNode(ptrGraph);
162
163 std::vector<AnfNodePtr> inputs2;
164 inputs2.push_back(fn);
165 if (fn_arg != nullptr) {
166 inputs2.push_back(fn_arg);
167 }
168
169 size_t pos = (reverse_ ? (size - 1 - i) : i);
170 (void)std::transform(arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
171 [&func_graph, &pos](const std::pair<AnfNodePtr, Any> &item) {
172 return func_graph->NewCNodeInOrder(
173 {NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(SizeToLong(pos))});
174 });
175
176 auto call_node = func_graph->NewCNodeInOrder(inputs2);
177 if (reverse_) {
178 (void)inputs.insert(inputs.begin() + 1, call_node);
179 } else {
180 inputs.emplace_back(call_node);
181 }
182 }
183 return func_graph->NewCNodeInOrder(inputs);
184 }
185
FullMakeClass(const std::shared_ptr<Class> & type,const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)186 AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
187 const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
188 MS_EXCEPTION_IF_NULL(type);
189 MS_EXCEPTION_IF_NULL(func_graph);
190
191 size_t attrSize = type->GetAttributes().size();
192 constexpr size_t kPrimAndTypeLen = 2;
193 std::vector<AnfNodePtr> inputs;
194 inputs.reserve(attrSize + kPrimAndTypeLen);
195 inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
196 inputs.push_back(NewValueNode(type));
197
198 for (size_t i = 0; i < attrSize; i++) {
199 MS_LOG(DEBUG) << "FullMakeClass for the " << i << "th element of the inputs, reverse_: " << reverse_;
200 auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
201 auto fn = NewValueNode(ptrGraph);
202
203 std::vector<AnfNodePtr> inputs2;
204 inputs2.push_back(fn);
205 if (fn_arg != nullptr) {
206 inputs2.push_back(fn_arg);
207 }
208
209 size_t size = arg_pairs.size();
210 for (size_t j = 0; j < size; j++) {
211 size_t pos = (reverse_ ? (size - 1 - j) : j);
212 auto &item = arg_pairs[pos];
213 inputs2.push_back(
214 func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(SizeToLong(pos))}));
215 }
216
217 auto call_node = func_graph->NewCNodeInOrder(inputs2);
218 if (reverse_) {
219 constexpr auto kCallNodePosition = 2;
220 (void)inputs.insert(inputs.begin() + kCallNodePosition, call_node);
221 } else {
222 inputs.emplace_back(call_node);
223 }
224 }
225 return func_graph->NewCNodeInOrder(inputs);
226 }
227
Make(const FuncGraphPtr & func_graph,const AnfNodePtr & fn_arg,const ArgsPairList & arg_pairs)228 AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
229 if (arg_pairs.empty()) {
230 MS_EXCEPTION(TypeError) << "map() must have at least two arguments";
231 }
232 bool found = false;
233 TypeId id = kObjectTypeEnd;
234 std::pair<AnfNodePtr, TypePtr> pair;
235 for (auto &arg_pair : arg_pairs) {
236 pair = arg_pair;
237 MS_LOG(DEBUG) << "Map " << pair.second->ToString();
238 id = arg_pair.second->type_id();
239 if (nonleaf_.count(id)) {
240 found = true;
241 break;
242 }
243 }
244
245 if (found) {
246 // In a nonleaf situation, all arguments must have the same generic.
247 bool is_not_same =
248 std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
249 if (item.first != pair.first) {
250 return item.second->type_id() != pair.second->type_id();
251 }
252 return false;
253 });
254 if (is_not_same) {
255 std::ostringstream oss;
256 oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
257 << trace::GetDebugInfo(func_graph->debug_info()) << "\n";
258 int64_t idx = 0;
259 for (auto &item : arg_pairs) {
260 oss << ++idx << ": " << item.second->ToString() << "\n";
261 }
262 MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n"
263 << oss.str() << pair.second->ToString() << "\n";
264 }
265 }
266
267 switch (id) {
268 case kObjectTypeList: {
269 auto type = std::static_pointer_cast<List>(pair.second);
270 return FullMakeList(type, func_graph, fn_arg, arg_pairs);
271 }
272 case kObjectTypeTuple: {
273 auto type = std::static_pointer_cast<Tuple>(pair.second);
274 return FullMakeTuple(type, func_graph, fn_arg, arg_pairs);
275 }
276 case kObjectTypeClass: {
277 auto type = std::static_pointer_cast<Class>(pair.second);
278 return FullMakeClass(type, func_graph, fn_arg, arg_pairs);
279 }
280 default:
281 MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class "
282 << ", but got " << pair.second->ToString();
283 }
284 }
285
GenerateFromTypes(const TypePtrList & args_spec_list)286 FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) {
287 FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
288 ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
289 ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
290 ptrGraph->debug_info()->set_name("map");
291
292 AnfNodePtr ptrFnArg = nullptr;
293 std::size_t i = 0;
294 if (fn_leaf_ == nullptr) {
295 ptrFnArg = ptrGraph->add_parameter();
296 i = 1;
297 }
298 ArgsPairList arg_pairs;
299 std::size_t size = args_spec_list.size();
300 for (; i < size; ++i) {
301 MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString();
302 arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
303 }
304
305 ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs));
306 return ptrGraph;
307 }
308
NormalizeArgs(const AbstractBasePtrList & args_spec_list) const309 abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
310 if (fn_leaf_ == nullptr) {
311 if (args_spec_list.empty()) {
312 MS_LOG(EXCEPTION) << "The args spec list should not be empty.";
313 }
314 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
315 // Assert that map's function param does not contain free variables
316 if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
317 auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
318 auto func_graph = graph_func->func_graph();
319 if (func_graph->parent() != nullptr) {
320 MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet.";
321 }
322 }
323 }
324
325 AbstractBasePtrList broadened;
326 (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
327 [](const AbstractBasePtr &arg) -> AbstractBasePtr {
328 MS_EXCEPTION_IF_NULL(arg);
329 return arg->Broaden();
330 });
331 return broadened;
332 }
333
__anond2cf421d0702(const py::module *m) 334 REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) {
335 (void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
336 .def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"),
337 py::arg("ops"))
338 .def(py::init<bool>(), py::arg("reverse"));
339 }));
340 } // namespace prim
341 } // namespace mindspore
342