• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <memory>
17 
18 #include "abstract/abstract_function.h"
19 #include "mindspore/core/ops/sequence_ops.h"
20 #include "mindspore/core/ops/nn_optimizer_ops.h"
21 #include "mindspore/core/ops/framework_ops.h"
22 #include "common/common_test.h"
23 #include "frontend/operator/composite/composite.h"
24 #include "frontend/operator/ops.h"
25 #include "ir/anf.h"
26 #include "ir/value.h"
27 #include "pipeline/jit/ps/debug/trace.h"
28 #include "pipeline/jit/ps/static_analysis/prim.h"
29 
30 namespace mindspore {
31 using Shape = abstract::Shape;
32 
33 using AbstractScalar = abstract::AbstractScalar;
34 using AbstractScalarPtr = abstract::AbstractScalarPtr;
35 
36 using AbstractSlice = abstract::AbstractSlice;
37 using AbstractSlicePtr = abstract::AbstractSlicePtr;
38 
39 using AbstractTuple = abstract::AbstractTuple;
40 using AbstractTuplePtr = abstract::AbstractTuplePtr;
41 using AbstractList = abstract::AbstractList;
42 using AbstractListPtr = abstract::AbstractListPtr;
43 
44 using AbstractTensor = abstract::AbstractTensor;
45 using AbstractTensorPtr = abstract::AbstractTensorPtr;
46 
47 using AbstractNone = abstract::AbstractNone;
48 using AbstractAttribute = abstract::AbstractElementPair;
49 using AnalysisEngine = abstract::AnalysisEngine;
50 using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
51 
52 class TestComposite : public UT::Common {
53  public:
54   virtual void SetUp();
55   virtual void TearDown();
56 
57   AnalysisEnginePtr engine_;
58 };
59 
SetUp()60 void TestComposite::SetUp() {
61   // init resource
62   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager();
63   engine_ = std::make_shared<AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), graph_manager);
64 }
65 
TearDown()66 void TestComposite::TearDown() {
67   // destroy resource
68 }
69 
70 class UTCompositeUtils {
71  public:
ArrayInt32Of(std::initializer_list<int64_t> shp)72   static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int64_t> shp) {
73     auto ele = std::make_shared<AbstractScalar>(kValueAny, kInt64);
74     return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
75   }
MakeFuncGraph(const MetaFuncGraphPtr & metaFuncGraphPtr,size_t nparam)76   static FuncGraphPtr MakeFuncGraph(const MetaFuncGraphPtr &metaFuncGraphPtr, size_t nparam) {
77     FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
78     std::vector<AnfNodePtr> inputs;
79     inputs.push_back(NewValueNode(metaFuncGraphPtr));
80     for (size_t i = 0; i < nparam; i++) {
81       inputs.push_back(func_graph->add_parameter());
82     }
83     CNodePtr cnode_prim = func_graph->NewCNode(inputs);
84     inputs.clear();
85     inputs.push_back(NewValueNode(prim::kPrimReturn));
86     inputs.push_back(cnode_prim);
87     CNodePtr cnode_return = func_graph->NewCNode(inputs);
88     func_graph->set_return(cnode_return);
89     return func_graph;
90   }
91 };
92 
TEST_F(TestComposite,test_TupleSlice_arg_two_numbers)93 TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
94   MetaFuncGraphPtr tupleSlicePtr =
95     std::make_shared<prim::SequenceSliceGetItem>("TupleSlice", "MakeTuple", "TupleGetItem");
96   FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
97 
98   AbstractBasePtrList eles;
99   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
100   size_t tuple_size = 6;
101   for (size_t i = 0; i < tuple_size; i++) {
102     eles.push_back(tensor);
103   }
104   auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
105   auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
106   auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
107   AbstractBasePtrList args_spec_list = {tuple_tensor, start_index, stop_index};
108 
109   try {
110     engine_->Run(tupleSliceGraphPtr, args_spec_list);
111     FAIL() << "Excepted exception :Args type is wrong";
112   } catch (std::runtime_error const &err) {
113     ASSERT_TRUE(std::string(err.what()).find("For 'TupleSlice', the number of input should be 2, but got 3") !=
114                 std::string::npos);
115   } catch (...) {
116     FAIL() << "Excepted exception :Args type is wrong";
117   }
118 }
119 
TEST_F(TestComposite,test_TupleSlice_arg_one_number)120 TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
121   MetaFuncGraphPtr tupleSlicePtr =
122     std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
123   FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
124 
125   AbstractBasePtrList eles;
126   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
127   size_t tuple_size = 6;
128   for (size_t i = 0; i < tuple_size; i++) {
129     eles.push_back(tensor);
130   }
131   auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
132   auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
133   AbstractBasePtrList args_spec_list = {tuple_tensor, start_index};
134 
135   try {
136     trace::ClearTraceStack();
137     engine_->Run(tupleSliceGraphPtr, args_spec_list);
138     FAIL() << "Excepted exception: Args type is wrong";
139   } catch (pybind11::type_error const &err) {
140     ASSERT_TRUE(true);
141   } catch (std::runtime_error const &err) {
142     if (std::strstr(err.what(), "TypeError") != nullptr) {
143       ASSERT_TRUE(true);
144     } else {
145       FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
146     }
147   } catch (...) {
148     FAIL() << "Excepted exception: Args type is wrong";
149   }
150 }
151 
TEST_F(TestComposite,test_TupleSlice_arg_slice)152 TEST_F(TestComposite, test_TupleSlice_arg_slice) {
153   std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
154   MetaFuncGraphPtr tupleSlicePtr =
155     std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
156   FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
157 
158   AbstractBasePtrList eles;
159   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
160   size_t tuple_size = 6;
161   for (size_t i = 0; i < tuple_size; i++) {
162     eles.push_back(tensor);
163   }
164   auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
165   auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
166   auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
167   auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
168   auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
169   AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
170 
171   AbstractTuplePtr ret =
172     dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
173   if (ret == nullptr) {
174     FAIL() << "Cast ret to abstract tuple failed.";
175   }
176   size_t real = ret->size();
177   size_t expect = 3;
178   ASSERT_EQ(real, expect);
179 }
180 
TEST_F(TestComposite,test_TupleSlice_arg_slice_step_none)181 TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
182   MetaFuncGraphPtr tupleSlicePtr =
183     std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
184   FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
185 
186   AbstractBasePtrList eles;
187   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
188   size_t tuple_size = 6;
189   for (size_t i = 0; i < tuple_size; i++) {
190     eles.push_back(tensor);
191   }
192   auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
193   auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
194   auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
195   auto step = std::make_shared<AbstractNone>();
196   auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
197   AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
198 
199   AbstractTuplePtr ret =
200     dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
201   if (ret == nullptr) {
202     FAIL() << "Cast ret to abstract tuple failed.";
203   }
204   size_t real = ret->size();
205   size_t expect = 4;
206   ASSERT_EQ(real, expect);
207 }
208 
TEST_F(TestComposite,test_TupleSlice_arg_slice_step_negative)209 TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
210   MetaFuncGraphPtr tupleSlicePtr =
211     std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
212   FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
213 
214   AbstractBasePtrList eles;
215   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
216   size_t tuple_size = 6;
217   for (size_t i = 0; i < tuple_size; i++) {
218     eles.push_back(tensor);
219   }
220   auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
221   auto start_index = std::make_shared<AbstractNone>();
222   auto stop_index = std::make_shared<AbstractNone>();
223   auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
224   auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
225   AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
226 
227   AbstractTuplePtr ret =
228     dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
229   if (ret == nullptr) {
230     FAIL() << "Cast ret to abstract tuple failed.";
231   }
232   size_t real = ret->size();
233   size_t expect = 6;
234   ASSERT_EQ(real, expect);
235 }
236 
TEST_F(TestComposite,test_TupleSlice_arg_slice_step_positive)237 TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
238   MetaFuncGraphPtr tupleSlicePtr =
239     std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
240   FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
241 
242   AbstractBasePtrList eles;
243   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
244   size_t tuple_size = 6;
245   for (size_t i = 0; i < tuple_size; i++) {
246     eles.push_back(tensor);
247   }
248   auto tuple_tensor = std::make_shared<AbstractTuple>(eles);
249   auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
250   auto stop_index = std::make_shared<AbstractNone>();
251   auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
252   auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
253   AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
254 
255   AbstractTuplePtr ret =
256     dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).eval_result->abstract());
257   if (ret == nullptr) {
258     FAIL() << "Cast ret to abstract tuple failed.";
259   }
260   size_t real = ret->size();
261   size_t expect = 5;
262   ASSERT_EQ(real, expect);
263 }
264 
265 /// Feature: Test list slice
266 /// Description: The second input is a scalar
267 /// Expectation: Throw type error
TEST_F(TestComposite,test_ListSlice_arg_one_number)268 TEST_F(TestComposite, test_ListSlice_arg_one_number) {
269   MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
270   FuncGraphPtr list_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 3);
271 
272   AbstractBasePtrList eles;
273   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
274   size_t list_size = 6;
275   for (size_t i = 0; i < list_size; i++) {
276     eles.push_back(tensor);
277   }
278   auto list_tensor = std::make_shared<AbstractList>(eles);
279   auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
280   AbstractBasePtrList args_spec_list = {list_tensor, start_index};
281 
282   try {
283     trace::ClearTraceStack();
284     engine_->Run(list_graph, args_spec_list);
285     FAIL() << "Excepted exception: Args type is wrong";
286   } catch (pybind11::type_error const &err) {
287     ASSERT_TRUE(true);
288   } catch (std::runtime_error const &err) {
289     if (std::strstr(err.what(), "TypeError") != nullptr) {
290       ASSERT_TRUE(true);
291     } else {
292       FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
293     }
294   } catch (...) {
295     FAIL() << "Excepted exception: Args type is wrong";
296   }
297 }
298 
299 /// Feature: Test list slice
300 /// Description: Test List slice
301 /// Expectation: No Expectation
TEST_F(TestComposite,test_ListSlice_arg_slice)302 TEST_F(TestComposite, test_ListSlice_arg_slice) {
303   std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
304   MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
305   FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
306 
307   AbstractBasePtrList eles;
308   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
309   size_t list_size = 6;
310   for (size_t i = 0; i < list_size; i++) {
311     eles.push_back(tensor);
312   }
313   auto list_tensor = std::make_shared<AbstractList>(eles);
314   auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
315   auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
316   auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
317   auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
318   AbstractBasePtrList args_spec_list = {list_tensor, slice};
319 
320   AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
321   if (ret == nullptr) {
322     FAIL() << "Cast ret to abstract list failed.";
323   }
324   size_t real = ret->size();
325   size_t expect = 3;
326   ASSERT_EQ(real, expect);
327 }
328 
329 /// Feature: Test list slice
330 /// Description: Test List slice the step is none
331 /// Expectation: No Expectation
TEST_F(TestComposite,test_ListSlice_arg_slice_step_none)332 TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
333   MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
334   FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
335 
336   AbstractBasePtrList eles;
337   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
338   size_t list_size = 6;
339   for (size_t i = 0; i < list_size; i++) {
340     eles.push_back(tensor);
341   }
342   auto list_tensor = std::make_shared<AbstractList>(eles);
343   auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
344   auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
345   auto step = std::make_shared<AbstractNone>();
346   auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
347   AbstractBasePtrList args_spec_list = {list_tensor, slice};
348 
349   AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
350   if (ret == nullptr) {
351     FAIL() << "Cast ret to abstract list failed.";
352   }
353   size_t real = ret->size();
354   size_t expect = 4;
355   ASSERT_EQ(real, expect);
356 }
357 
358 /// Feature: Test list slice
359 /// Description: Test List slice the step is negative
360 /// Expectation: No Expectation
TEST_F(TestComposite,test_ListSlice_arg_slice_step_negative)361 TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
362   MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
363   FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
364 
365   AbstractBasePtrList eles;
366   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
367   size_t list_size = 6;
368   for (size_t i = 0; i < list_size; i++) {
369     eles.push_back(tensor);
370   }
371   auto list_tensor = std::make_shared<AbstractList>(eles);
372   auto start_index = std::make_shared<AbstractNone>();
373   auto stop_index = std::make_shared<AbstractNone>();
374   auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
375   auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
376   AbstractBasePtrList args_spec_list = {list_tensor, slice};
377 
378   AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
379   if (ret == nullptr) {
380     FAIL() << "Cast ret to abstract list failed.";
381   }
382   size_t real = ret->size();
383   size_t expect = 6;
384   ASSERT_EQ(real, expect);
385 }
386 
387 /// Feature: Test list slice
388 /// Description: Test List slice the step is positive
389 /// Expectation: No Expectation
TEST_F(TestComposite,test_ListSlice_arg_slice_step_positive)390 TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
391   MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
392   FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
393 
394   AbstractBasePtrList eles;
395   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
396   size_t list_size = 6;
397   for (size_t i = 0; i < list_size; i++) {
398     eles.push_back(tensor);
399   }
400   auto list_tensor = std::make_shared<AbstractList>(eles);
401   auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
402   auto stop_index = std::make_shared<AbstractNone>();
403   auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
404   auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
405   AbstractBasePtrList args_spec_list = {list_tensor, slice};
406 
407   AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
408   if (ret == nullptr) {
409     FAIL() << "Cast ret to abstract list failed.";
410   }
411   size_t real = ret->size();
412   size_t expect = 5;
413   ASSERT_EQ(real, expect);
414 }
415 
TEST_F(TestComposite,test_UnpackCall_3args)416 TEST_F(TestComposite, test_UnpackCall_3args) {
417   MetaFuncGraphPtr unpackCallPtr = std::make_shared<prim::UnpackCall>("UnpackCall");
418   FuncGraphPtr unpackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unpackCallPtr, 3);
419 
420   auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
421   AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
422   AbstractBasePtrList eles;
423   for (size_t i = 0; i < 6; i++) {
424     eles.push_back(tensor);
425   }
426   AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
427   AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
428   AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
429   AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
430   auto key_x = std::make_shared<AbstractScalar>(static_cast<std::string>("x"));
431   auto key_y = std::make_shared<AbstractScalar>(static_cast<std::string>("y"));
432   auto key_z = std::make_shared<AbstractScalar>(static_cast<std::string>("z"));
433   std::vector<AbstractAttribute> tensor_map{{key_x, arr_x}, {key_y, arr_y}, {key_z, arr_z}};
434   abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
435 
436   AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
437   AbstractTuplePtr ret =
438     dyn_cast<AbstractTuple>(engine_->Run(unpackCallGraphPtr, args_spec_list).eval_result->abstract());
439   if (ret == nullptr) {
440     FAIL() << "Cast ret to abstract tuple failed.";
441   }
442   size_t real = ret->size();
443   size_t expect = 9;
444   ASSERT_EQ(real, expect);
445 }
446 
TEST_F(TestComposite,test_UnpackCall_5args)447 TEST_F(TestComposite, test_UnpackCall_5args) {
448   MetaFuncGraphPtr unpackCallPtr = std::make_shared<prim::UnpackCall>("UnpackCall");
449   FuncGraphPtr unpackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unpackCallPtr, 5);
450 
451   auto fn_arg = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimMakeTuple);
452   AbstractTensorPtr tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
453   AbstractBasePtrList eles;
454   for (size_t i = 0; i < 6; i++) {
455     eles.push_back(tensor);
456   }
457   AbstractTuplePtr tensor_tuple = std::make_shared<AbstractTuple>(eles);
458   AbstractTensorPtr arr_x = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
459   AbstractTensorPtr arr_y = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
460   AbstractTensorPtr arr_z = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
461   auto key_x = std::make_shared<AbstractScalar>(static_cast<std::string>("x"));
462   auto key_y = std::make_shared<AbstractScalar>(static_cast<std::string>("y"));
463   auto key_z = std::make_shared<AbstractScalar>(static_cast<std::string>("z"));
464   std::vector<AbstractAttribute> tensor_map{{key_x, arr_x}, {key_y, arr_y}, {key_z, arr_z}};
465   abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
466 
467   AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
468   AbstractTuplePtr ret =
469     dyn_cast<AbstractTuple>(engine_->Run(unpackCallGraphPtr, args_spec_list).eval_result->abstract());
470   if (ret == nullptr) {
471     FAIL() << "Cast ret to abstract tuple failed.";
472   }
473   size_t real = ret->size();
474   size_t expect = 18;
475   ASSERT_EQ(real, expect);
476 }
477 
TEST_F(TestComposite,test_ZipOperation)478 TEST_F(TestComposite, test_ZipOperation) {
479   MetaFuncGraphPtr zip_op = std::make_shared<prim::ZipOperation>("zip_op");
480   FuncGraphPtr zip_op_graph = UTCompositeUtils::MakeFuncGraph(zip_op, 1);
481 
482   AbstractBasePtrList eles;
483   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
484   size_t tuple_size = 3;
485   for (size_t i = 0; i < tuple_size; i++) {
486     eles.push_back(tensor);
487   }
488   auto tuple = std::make_shared<AbstractTuple>(eles);
489   AbstractBasePtrList args_spec_list = {tuple};
490 
491   AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).eval_result->abstract());
492   if (ret == nullptr) {
493     FAIL() << "Cast ret to abstract tuple failed.";
494   }
495   size_t real = ret->size();
496   size_t expect = 3;
497   ASSERT_EQ(real, expect);
498 }
499 
500 /// Feature: Shard operation.
501 /// Description: Test the func_graph generation of Shard op and the inference of the Shard caller.
502 /// Expectation: Generate and the infer successfully.
TEST_F(TestComposite,test_shard)503 TEST_F(TestComposite, test_shard) {
504   // Make origin func_graph which includes a relu node.
505   FuncGraphPtr origin_func_graph = std::make_shared<FuncGraph>();
506   std::vector<AnfNodePtr> inputs;
507   inputs.push_back(NewValueNode(prim::kPrimReLU));
508   inputs.push_back(origin_func_graph->add_parameter());
509   CNodePtr relu = origin_func_graph->NewCNode(inputs);
510   inputs.clear();
511   inputs.push_back(NewValueNode(prim::kPrimReturn));
512   inputs.push_back(relu);
513   CNodePtr origin_return = origin_func_graph->NewCNode(inputs);
514   origin_func_graph->set_return(origin_return);
515 
516   // Make the func_graph which includes a Shard meta_func_graph.
517   FuncGraphPtr shard_func_graph = std::make_shared<FuncGraph>();
518   MetaFuncGraphPtr shard_op = std::make_shared<prim::Shard>("shard_op");
519   inputs.clear();
520   inputs.push_back(NewValueNode(shard_op));
521   inputs.push_back(NewValueNode(origin_func_graph));
522   for (size_t i = 0; i < 4; ++i) {
523     inputs.push_back(NewValueNode(MakeValue(0)));
524   }
525   CNodePtr shard = shard_func_graph->NewCNode(inputs);
526   inputs.clear();
527   inputs.push_back(shard);
528   inputs.push_back(shard_func_graph->add_parameter());
529   CNodePtr shard_user = shard_func_graph->NewCNode(inputs);
530   inputs.clear();
531   inputs.push_back(NewValueNode(prim::kPrimReturn));
532   inputs.push_back(shard_user);
533   CNodePtr shard_return = shard_func_graph->NewCNode(inputs);
534   shard_func_graph->set_return(shard_return);
535 
536   auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
537   AbstractBasePtrList args_spec_list = {tensor};
538 
539   auto ret = engine_->Run(shard_func_graph, args_spec_list).eval_result->abstract();
540   ASSERT_NE(ret, nullptr);
541   ASSERT_TRUE(ret->isa<abstract::AbstractTensor>());
542   auto build_shape = ret->BuildShape();
543   EXPECT_TRUE(build_shape->isa<abstract::Shape>());
544   auto shape = build_shape->cast<abstract::ShapePtr>();
545   ASSERT_EQ(shape->shape(), std::vector<int64_t>({2, 3, 4}));
546 }
547 }  // namespace mindspore
548