1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <memory>
17
18 #include "abstract/abstract_function.h"
19 #include "mindspore/core/ops/sequence_ops.h"
20 #include "mindspore/core/ops/nn_optimizer_ops.h"
21 #include "mindspore/core/ops/framework_ops.h"
22 #include "common/common_test.h"
23 #include "frontend/operator/composite/composite.h"
24 #include "frontend/operator/ops.h"
25 #include "ir/anf.h"
26 #include "ir/value.h"
27 #include "pipeline/jit/ps/debug/trace.h"
28 #include "pipeline/jit/ps/static_analysis/prim.h"
29
30 namespace mindspore {
31 using Shape = abstract::Shape;
32
33 using AbstractScalar = abstract::AbstractScalar;
34 using AbstractScalarPtr = abstract::AbstractScalarPtr;
35
36 using AbstractSlice = abstract::AbstractSlice;
37 using AbstractSlicePtr = abstract::AbstractSlicePtr;
38
39 using AbstractTuple = abstract::AbstractTuple;
40 using AbstractTuplePtr = abstract::AbstractTuplePtr;
41 using AbstractList = abstract::AbstractList;
42 using AbstractListPtr = abstract::AbstractListPtr;
43
44 using AbstractTensor = abstract::AbstractTensor;
45 using AbstractTensorPtr = abstract::AbstractTensorPtr;
46
47 using AbstractNone = abstract::AbstractNone;
48 using AbstractAttribute = abstract::AbstractElementPair;
49 using AnalysisEngine = abstract::AnalysisEngine;
50 using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
51
52 class TestComposite : public UT::Common {
53 public:
54 virtual void SetUp();
55 virtual void TearDown();
56
57 AnalysisEnginePtr engine_;
58 };
59
SetUp()60 void TestComposite::SetUp() {
61 // init resource
62 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager();
63 engine_ = std::make_shared<AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), graph_manager);
64 }
65
TearDown()66 void TestComposite::TearDown() {
67 // destroy resource
68 }
69
70 class UTCompositeUtils {
71 public:
ArrayInt32Of(std::initializer_list<int64_t> shp)72 static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int64_t> shp) {
73 auto ele = std::make_shared<AbstractScalar>(kValueAny, kInt64);
74 return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
75 }
MakeFuncGraph(const MetaFuncGraphPtr & metaFuncGraphPtr,size_t nparam)76 static FuncGraphPtr MakeFuncGraph(const MetaFuncGraphPtr &metaFuncGraphPtr, size_t nparam) {
77 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
78 std::vector<AnfNodePtr> inputs;
79 inputs.push_back(NewValueNode(metaFuncGraphPtr));
80 for (size_t i = 0; i < nparam; i++) {
81 inputs.push_back(func_graph->add_parameter());
82 }
83 CNodePtr cnode_prim = func_graph->NewCNode(inputs);
84 inputs.clear();
85 inputs.push_back(NewValueNode(prim::kPrimReturn));
86 inputs.push_back(cnode_prim);
87 CNodePtr cnode_return = func_graph->NewCNode(inputs);
88 func_graph->set_return(cnode_return);
89 return func_graph;
90 }
91 };
92
TEST_F(TestComposite,test_TupleSlice_arg_two_numbers)93 TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
94 MetaFuncGraphPtr tupleSlicePtr =
95 std::make_shared<prim::SequenceSliceGetItem>("TupleSlice", "MakeTuple", "TupleGetItem");
96 FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
97
98 AbstractBasePtrList eles;
99 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
100 size_t tuple_size = 6;
101 for (size_t i = 0; i < tuple_size; i++) {
102 eles.push_back(tensor);
103 }
104 auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
105 auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
106 auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
107 AbstractBasePtrList args_spec_list = {tuple_tensor, start_index, stop_index};
108
109 try {
110 engine_->Run(tupleSliceGraphPtr, args_spec_list);
111 FAIL() << "Excepted exception :Args type is wrong";
112 } catch (std::runtime_error const &err) {
113 ASSERT_TRUE(std::string(err.what()).find("For 'TupleSlice', the number of input should be 2, but got 3") !=
114 std::string::npos);
115 } catch (...) {
116 FAIL() << "Excepted exception :Args type is wrong";
117 }
118 }
119
TEST_F(TestComposite,test_TupleSlice_arg_one_number)120 TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
121 MetaFuncGraphPtr tupleSlicePtr =
122 std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
123 FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
124
125 AbstractBasePtrList eles;
126 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
127 size_t tuple_size = 6;
128 for (size_t i = 0; i < tuple_size; i++) {
129 eles.push_back(tensor);
130 }
131 auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
132 auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
133 AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
134
135 try {
136 trace::ClearTraceStack();
137 engine_->Run(tupleSliceGraphPtr, args_spec_list);
138 FAIL() << "Excepted exception: Args type is wrong";
139 } catch (pybind11::type_error const &err) {
140 ASSERT_TRUE(true);
141 } catch (std::runtime_error const &err) {
142 if (std::strstr(err.what(), "TypeError") != nullptr) {
143 ASSERT_TRUE(true);
144 } else {
145 FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
146 }
147 } catch (...) {
148 FAIL() << "Excepted exception: Args type is wrong";
149 }
150 }
151
TEST_F(TestComposite,test_TupleSlice_arg_slice)152 TEST_F(TestComposite, test_TupleSlice_arg_slice) {
153 std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
154 MetaFuncGraphPtr tupleSlicePtr =
155 std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
156 FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
157
158 AbstractBasePtrList eles;
159 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
160 size_t tuple_size = 6;
161 for (size_t i = 0; i < tuple_size; i++) {
162 eles.push_back(tensor);
163 }
164 auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
165 auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
166 auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
167 auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
168 auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
169 AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
170
171 AbstractTuplePtr ret =
172 dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
173 if (ret == nullptr) {
174 FAIL() << "Cast ret to abstract tuple failed.";
175 }
176 size_t real = ret->size();
177 size_t expect = 3;
178 ASSERT_EQ(real, expect);
179 }
180
TEST_F(TestComposite,test_TupleSlice_arg_slice_step_none)181 TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
182 MetaFuncGraphPtr tupleSlicePtr =
183 std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
184 FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
185
186 AbstractBasePtrList eles;
187 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
188 size_t tuple_size = 6;
189 for (size_t i = 0; i < tuple_size; i++) {
190 eles.push_back(tensor);
191 }
192 auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
193 auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
194 auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
195 auto step = std::make_shared<AbstractNone>();
196 auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
197 AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
198
199 AbstractTuplePtr ret =
200 dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
201 if (ret == nullptr) {
202 FAIL() << "Cast ret to abstract tuple failed.";
203 }
204 size_t real = ret->size();
205 size_t expect = 4;
206 ASSERT_EQ(real, expect);
207 }
208
TEST_F(TestComposite,test_TupleSlice_arg_slice_step_negative)209 TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
210 MetaFuncGraphPtr tupleSlicePtr =
211 std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
212 FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
213
214 AbstractBasePtrList eles;
215 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
216 size_t tuple_size = 6;
217 for (size_t i = 0; i < tuple_size; i++) {
218 eles.push_back(tensor);
219 }
220 auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
221 auto start_index = std::make_shared<AbstractNone>();
222 auto stop_index = std::make_shared<AbstractNone>();
223 auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
224 auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
225 AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
226
227 AbstractTuplePtr ret =
228 dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
229 if (ret == nullptr) {
230 FAIL() << "Cast ret to abstract tuple failed.";
231 }
232 size_t real = ret->size();
233 size_t expect = 6;
234 ASSERT_EQ(real, expect);
235 }
236
TEST_F(TestComposite,test_TupleSlice_arg_slice_step_positive)237 TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
238 MetaFuncGraphPtr tupleSlicePtr =
239 std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
240 FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
241
242 AbstractBasePtrList eles;
243 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
244 size_t tuple_size = 6;
245 for (size_t i = 0; i < tuple_size; i++) {
246 eles.push_back(tensor);
247 }
248 auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
249 auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
250 auto stop_index = std::make_shared<AbstractNone>();
251 auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
252 auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
253 AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
254
255 AbstractTuplePtr ret =
256 dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
257 if (ret == nullptr) {
258 FAIL() << "Cast ret to abstract tuple failed.";
259 }
260 size_t real = ret->size();
261 size_t expect = 5;
262 ASSERT_EQ(real, expect);
263 }
264
265 /// Feature: Test list slice
266 /// Description: The second input is a scalar
267 /// Expectation: Throw type error
TEST_F(TestComposite,test_ListSlice_arg_one_number)268 TEST_F(TestComposite, test_ListSlice_arg_one_number) {
269 MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
270 FuncGraphPtr list_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 3);
271
272 AbstractBasePtrList eles;
273 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
274 size_t list_size = 6;
275 for (size_t i = 0; i < list_size; i++) {
276 eles.push_back(tensor);
277 }
278 auto list_tensor = std::make_shared<AbstractList>(eles);
279 auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
280 AbstractBasePtrList args_spec_list = {list_tensor, start_index};
281
282 try {
283 trace::ClearTraceStack();
284 engine_->Run(list_graph, args_spec_list);
285 FAIL() << "Excepted exception: Args type is wrong";
286 } catch (pybind11::type_error const &err) {
287 ASSERT_TRUE(true);
288 } catch (std::runtime_error const &err) {
289 if (std::strstr(err.what(), "TypeError") != nullptr) {
290 ASSERT_TRUE(true);
291 } else {
292 FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
293 }
294 } catch (...) {
295 FAIL() << "Excepted exception: Args type is wrong";
296 }
297 }
298
299 /// Feature: Test list slice
300 /// Description: Test List slice
301 /// Expectation: No Expectation
TEST_F(TestComposite,test_ListSlice_arg_slice)302 TEST_F(TestComposite, test_ListSlice_arg_slice) {
303 std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
304 MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
305 FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
306
307 AbstractBasePtrList eles;
308 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
309 size_t list_size = 6;
310 for (size_t i = 0; i < list_size; i++) {
311 eles.push_back(tensor);
312 }
313 auto list_tensor = std::make_shared<AbstractList>(eles);
314 auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
315 auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
316 auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
317 auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
318 AbstractBasePtrList args_spec_list = {list_tensor, slice};
319
320 AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
321 if (ret == nullptr) {
322 FAIL() << "Cast ret to abstract list failed.";
323 }
324 size_t real = ret->size();
325 size_t expect = 3;
326 ASSERT_EQ(real, expect);
327 }
328
329 /// Feature: Test list slice
330 /// Description: Test List slice the step is none
331 /// Expectation: No Expectation
TEST_F(TestComposite,test_ListSlice_arg_slice_step_none)332 TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
333 MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
334 FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
335
336 AbstractBasePtrList eles;
337 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
338 size_t list_size = 6;
339 for (size_t i = 0; i < list_size; i++) {
340 eles.push_back(tensor);
341 }
342 auto list_tensor = std::make_shared<AbstractList>(eles);
343 auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
344 auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
345 auto step = std::make_shared<AbstractNone>();
346 auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
347 AbstractBasePtrList args_spec_list = {list_tensor, slice};
348
349 AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
350 if (ret == nullptr) {
351 FAIL() << "Cast ret to abstract list failed.";
352 }
353 size_t real = ret->size();
354 size_t expect = 4;
355 ASSERT_EQ(real, expect);
356 }
357
358 /// Feature: Test list slice
359 /// Description: Test List slice the step is negative
360 /// Expectation: No Expectation
TEST_F(TestComposite,test_ListSlice_arg_slice_step_negative)361 TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
362 MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
363 FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
364
365 AbstractBasePtrList eles;
366 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
367 size_t list_size = 6;
368 for (size_t i = 0; i < list_size; i++) {
369 eles.push_back(tensor);
370 }
371 auto list_tensor = std::make_shared<AbstractList>(eles);
372 auto start_index = std::make_shared<AbstractNone>();
373 auto stop_index = std::make_shared<AbstractNone>();
374 auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
375 auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
376 AbstractBasePtrList args_spec_list = {list_tensor, slice};
377
378 AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
379 if (ret == nullptr) {
380 FAIL() << "Cast ret to abstract list failed.";
381 }
382 size_t real = ret->size();
383 size_t expect = 6;
384 ASSERT_EQ(real, expect);
385 }
386
387 /// Feature: Test list slice
388 /// Description: Test List slice the step is positive
389 /// Expectation: No Expectation
TEST_F(TestComposite,test_ListSlice_arg_slice_step_positive)390 TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
391 MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
392 FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
393
394 AbstractBasePtrList eles;
395 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
396 size_t list_size = 6;
397 for (size_t i = 0; i < list_size; i++) {
398 eles.push_back(tensor);
399 }
400 auto list_tensor = std::make_shared<AbstractList>(eles);
401 auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
402 auto stop_index = std::make_shared<AbstractNone>();
403 auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
404 auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
405 AbstractBasePtrList args_spec_list = {list_tensor, slice};
406
407 AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
408 if (ret == nullptr) {
409 FAIL() << "Cast ret to abstract list failed.";
410 }
411 size_t real = ret->size();
412 size_t expect = 5;
413 ASSERT_EQ(real, expect);
414 }
415
TEST_F(TestComposite,test_UnpackCall_3args)416 TEST_F(TestComposite, test_UnpackCall_3args) {
417 MetaFuncGraphPtr unpackCallPtr = std::make_shared<prim::UnpackCall>("UnpackCall");
418 FuncGraphPtr unpackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unpackCallPtr, 3);
419
420 auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
421 AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
422 AbstractBasePtrList eles;
423 for (size_t i = 0; i < 6; i++) {
424 eles.push_back(tensor);
425 }
426 AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
427 AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
428 AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
429 AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
430 auto key_x = std::make_shared<AbstractScalar>(static_cast<std::string>("x"));
431 auto key_y = std::make_shared<AbstractScalar>(static_cast<std::string>("y"));
432 auto key_z = std::make_shared<AbstractScalar>(static_cast<std::string>("z"));
433 std::vector<AbstractAttribute> tensor_map{{key_x, arr_x}, {key_y, arr_y}, {key_z, arr_z}};
434 abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
435
436 AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
437 AbstractTuplePtr ret =
438 dyn_cast<AbstractTuple>(engine_->Run(unpackCallGraphPtr, args_spec_list).eval_result->abstract());
439 if (ret == nullptr) {
440 FAIL() << "Cast ret to abstract tuple failed.";
441 }
442 size_t real = ret->size();
443 size_t expect = 9;
444 ASSERT_EQ(real, expect);
445 }
446
TEST_F(TestComposite,test_UnpackCall_5args)447 TEST_F(TestComposite, test_UnpackCall_5args) {
448 MetaFuncGraphPtr unpackCallPtr = std::make_shared<prim::UnpackCall>("UnpackCall");
449 FuncGraphPtr unpackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unpackCallPtr, 5);
450
451 auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
452 AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
453 AbstractBasePtrList eles;
454 for (size_t i = 0; i < 6; i++) {
455 eles.push_back(tensor);
456 }
457 AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
458 AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
459 AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
460 AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
461 auto key_x = std::make_shared<AbstractScalar>(static_cast<std::string>("x"));
462 auto key_y = std::make_shared<AbstractScalar>(static_cast<std::string>("y"));
463 auto key_z = std::make_shared<AbstractScalar>(static_cast<std::string>("z"));
464 std::vector<AbstractAttribute> tensor_map{{key_x, arr_x}, {key_y, arr_y}, {key_z, arr_z}};
465 abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
466
467 AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
468 AbstractTuplePtr ret =
469 dyn_cast<AbstractTuple>(engine_->Run(unpackCallGraphPtr, args_spec_list).eval_result->abstract());
470 if (ret == nullptr) {
471 FAIL() << "Cast ret to abstract tuple failed.";
472 }
473 size_t real = ret->size();
474 size_t expect = 18;
475 ASSERT_EQ(real, expect);
476 }
477
TEST_F(TestComposite,test_ZipOperation)478 TEST_F(TestComposite, test_ZipOperation) {
479 MetaFuncGraphPtr zip_op = std::make_shared<prim::ZipOperation>("zip_op");
480 FuncGraphPtr zip_op_graph = UTCompositeUtils::MakeFuncGraph(zip_op, 1);
481
482 AbstractBasePtrList eles;
483 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
484 size_t tuple_size = 3;
485 for (size_t i = 0; i < tuple_size; i++) {
486 eles.push_back(tensor);
487 }
488 auto tuple = std::make_shared<AbstractTuple>(eles);
489 AbstractBasePtrList args_spec_list = {tuple};
490
491 AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).eval_result->abstract());
492 if (ret == nullptr) {
493 FAIL() << "Cast ret to abstract tuple failed.";
494 }
495 size_t real = ret->size();
496 size_t expect = 3;
497 ASSERT_EQ(real, expect);
498 }
499
500 /// Feature: Shard operation.
501 /// Description: Test the func_graph generation of Shard op and the inference of the Shard caller.
502 /// Expectation: Generate and the infer successfully.
TEST_F(TestComposite,test_shard)503 TEST_F(TestComposite, test_shard) {
504 // Make origin func_graph which includes a relu node.
505 FuncGraphPtr origin_func_graph = std::make_shared<FuncGraph>();
506 std::vector<AnfNodePtr> inputs;
507 inputs.push_back(NewValueNode(prim::kPrimReLU));
508 inputs.push_back(origin_func_graph->add_parameter());
509 CNodePtr relu = origin_func_graph->NewCNode(inputs);
510 inputs.clear();
511 inputs.push_back(NewValueNode(prim::kPrimReturn));
512 inputs.push_back(relu);
513 CNodePtr origin_return = origin_func_graph->NewCNode(inputs);
514 origin_func_graph->set_return(origin_return);
515
516 // Make the func_graph which includes a Shard meta_func_graph.
517 FuncGraphPtr shard_func_graph = std::make_shared<FuncGraph>();
518 MetaFuncGraphPtr shard_op = std::make_shared<prim::Shard>("shard_op");
519 inputs.clear();
520 inputs.push_back(NewValueNode(shard_op));
521 inputs.push_back(NewValueNode(origin_func_graph));
522 for (size_t i = 0; i < 4; ++i) {
523 inputs.push_back(NewValueNode(MakeValue(0)));
524 }
525 CNodePtr shard = shard_func_graph->NewCNode(inputs);
526 inputs.clear();
527 inputs.push_back(shard);
528 inputs.push_back(shard_func_graph->add_parameter());
529 CNodePtr shard_user = shard_func_graph->NewCNode(inputs);
530 inputs.clear();
531 inputs.push_back(NewValueNode(prim::kPrimReturn));
532 inputs.push_back(shard_user);
533 CNodePtr shard_return = shard_func_graph->NewCNode(inputs);
534 shard_func_graph->set_return(shard_return);
535
536 auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
537 AbstractBasePtrList args_spec_list = {tensor};
538
539 auto ret = engine_->Run(shard_func_graph, args_spec_list).eval_result->abstract();
540 ASSERT_NE(ret, nullptr);
541 ASSERT_TRUE(ret->isa<abstract::AbstractTensor>());
542 auto build_shape = ret->BuildShape();
543 EXPECT_TRUE(build_shape->isa<abstract::Shape>());
544 auto shape = build_shape->cast<abstract::ShapePtr>();
545 ASSERT_EQ(shape->shape(), std::vector<int64_t>({2, 3, 4}));
546 }
547 } // namespace mindspore
548