• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <memory>
18 
19 #include "pybind11/pybind11.h"
20 
21 #include "common/common_test.h"
22 #include "common/py_func_graph_fetcher.h"
23 #include "ir/manager.h"
24 #include "pipeline/jit/static_analysis/prim.h"
25 #include "pipeline/static_analysis/helper.h"
26 #include "frontend/operator/ops.h"
27 #include "debug/draw.h"
28 #include "ir/tensor.h"
29 #include "utils/symbolic.h"
30 #include "base/core_ops.h"
31 
32 namespace mindspore {
33 namespace abstract {
34 namespace py = pybind11;
35 namespace python_adapter = mindspore::parse::python_adapter;
36 class UTPrimUtils {
37  public:
38   using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
39   using AbstractTuplePtr = std::shared_ptr<AbstractTuple>;
40 
41   static const std::shared_ptr<Float> kF32;
42   static const std::shared_ptr<Float> kF64;
43   static const std::shared_ptr<Int> kI16;
44   static const std::shared_ptr<Int> kI64;
45   static const std::shared_ptr<UInt> kU64;
46 
TypeToAbstract(TypePtr t)47   static std::shared_ptr<AbstractType> TypeToAbstract(TypePtr t) { return std::make_shared<AbstractType>(t); }
48 
ArrayFloat64Of(std::initializer_list<int64_t> shp)49   static AbstractTensorPtr ArrayFloat64Of(std::initializer_list<int64_t> shp) {
50     auto ele = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
51     return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
52   }
53 
ArrayFloat32Of(std::initializer_list<int64_t> shp)54   static AbstractTensorPtr ArrayFloat32Of(std::initializer_list<int64_t> shp) {
55     auto ele = std::make_shared<AbstractScalar>(kAnyValue, kFloat32);
56     return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
57   }
58 
ArrayInt32Of(std::initializer_list<int64_t> shp)59   static AbstractTensorPtr ArrayInt32Of(std::initializer_list<int64_t> shp) {
60     auto ele = std::make_shared<AbstractScalar>(kAnyValue, kInt64);
61     return std::make_shared<AbstractTensor>(ele, std::make_shared<Shape>(shp));
62   }
63 
ShapeOf(std::initializer_list<int64_t> vals)64   static AbstractTuplePtr ShapeOf(std::initializer_list<int64_t> vals) {
65     AbstractBasePtrList te;
66     for (auto v : vals) {
67       te.push_back(std::make_shared<AbstractScalar>(v));
68     }
69     return std::make_shared<AbstractTuple>(te);
70   }
71 
ListShapeOf(std::initializer_list<int64_t> vals)72   static AbstractListPtr ListShapeOf(std::initializer_list<int64_t> vals) {
73     AbstractBasePtrList te;
74     for (auto v : vals) {
75       te.push_back(std::make_shared<AbstractScalar>(v));
76     }
77     return std::make_shared<AbstractList>(te);
78   }
79 };
80 const std::shared_ptr<Float> UTPrimUtils::kF64 = std::make_shared<Float>(64);
81 const std::shared_ptr<Float> UTPrimUtils::kF32 = std::make_shared<Float>(32);
82 const std::shared_ptr<Int> UTPrimUtils::kI16 = std::make_shared<Int>(16);
83 const std::shared_ptr<Int> UTPrimUtils::kI64 = std::make_shared<Int>(64);
84 const std::shared_ptr<UInt> UTPrimUtils::kU64 = std::make_shared<UInt>(64);
85 namespace {
86 /* skip ut test cases temporarily
87 AbstractBasePtr ArrayOfTensor(const TypePtr &t, std::initializer_list<int64_t> shp) {
88   auto shape = std::vector<int64_t>(shp);
89   auto tensor = std::make_shared<tensor::Tensor>(t->type_id(), shape);
90   return ToAbstract(tensor);
91 }
92 */
93 }  // namespace
94 
95 class TestPrim : public UT::Common {
96  public:
TestPrim()97   TestPrim() : getPyFun("gtest_input.pipeline.infer", true) {}
98   void SetUp();
99   void TearDown();
100   AnalysisEnginePtr engine_;
101   UT::PyFuncGraphFetcher getPyFun;
102 };
103 
SetUp()104 void TestPrim::SetUp() { engine_ = SetupAnalysisEngine(); }
105 
TearDown()106 void TestPrim::TearDown() {
107   // destroy resource
108 }
109 
MakeFuncGraph(const PrimitivePtr prim,uint64_t nparam)110 static FuncGraphPtr MakeFuncGraph(const PrimitivePtr prim, uint64_t nparam) {
111   // build the func_graph manually, eg:
112   // MakeFuncGraph(std::make_shared<Primitive>("scalar_add"), 2) means:
113   /* python source code:
114    * @mindspore
115    * def f(x, y):
116    *     return x + y
117    */
118   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
119   std::vector<AnfNodePtr> inputs;
120   inputs.push_back(NewValueNode(prim));
121   for (uint64_t i = 0; i < nparam; i++) {
122     inputs.push_back(func_graph->add_parameter());
123   }
124   CNodePtr cnode_prim = func_graph->NewCNode(inputs);
125   inputs.clear();
126   inputs.push_back(NewValueNode(prim::kPrimReturn));
127   inputs.push_back(cnode_prim);
128   CNodePtr cnode_return = func_graph->NewCNode(inputs);
129   func_graph->set_return(cnode_return);
130   return func_graph;
131 }
132 
TEST_F(TestPrim,test_typeof)133 TEST_F(TestPrim, test_typeof) {
134   AbstractBasePtrList args_spec_list;
135   int64_t v1 = 1;
136 
137   AbstractBasePtr abstract_v1 = FromValue(v1, false);
138   args_spec_list.push_back(abstract_v1);
139 
140   auto prim_typeof = std::make_shared<Primitive>("typeof");
141   FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1);
142   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
143   res->dump();
144   TypePtr res_value = res->GetValueTrack()->cast<TypePtr>();
145   res_value->dump();
146   ASSERT_TRUE(*res_value == Int(64));
147 }
148 
TEST_F(TestPrim,test_list_map)149 TEST_F(TestPrim, test_list_map) {
150   AbstractBasePtrList args_spec_list;
151 
152   AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
153   AbstractBasePtr abstract_u1 = FromValue(static_cast<int64_t>(1), false);
154   auto abstract_list1 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_u1}));
155   AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(2), false);
156   AbstractBasePtr abstract_u2 = FromValue(static_cast<int64_t>(2), false);
157   auto abstract_list2 = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v2, abstract_u2}));
158   auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
159   AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
160 
161   args_spec_list.push_back(abstract_func);
162   args_spec_list.push_back(abstract_list1);
163   args_spec_list.push_back(abstract_list2);
164 
165   auto prim_list_map = std::make_shared<Primitive>("list_map");
166   FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3);
167   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
168   auto expected = std::make_shared<AbstractList>(
169     AbstractBasePtrList({FromValue(static_cast<int64_t>(3), false), FromValue(static_cast<int64_t>(3), false)}));
170   res->dump();
171   MS_LOG(INFO) << "result res: " << res->ToString();
172   MS_LOG(INFO) << "result expected: " << expected->ToString();
173   ASSERT_TRUE(*res == *expected);
174 }
175 
TEST_F(TestPrim,test_list_reduce)176 TEST_F(TestPrim, test_list_reduce) {
177   AbstractBasePtrList args_spec_list;
178   int64_t v1 = 1;
179 
180   AbstractBasePtr abstract_v1 = FromValue(v1, false);
181   AbstractBasePtr abstract_v2 = FromValue(v1, false);
182   auto abstract_list = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
183   auto prim_scalar_add = std::make_shared<Primitive>(prim::kScalarAdd);
184   AbstractBasePtr abstract_func = ToAbstract(prim_scalar_add);
185 
186   args_spec_list.push_back(abstract_func);
187   args_spec_list.push_back(abstract_list);
188   args_spec_list.push_back(abstract_v1);
189 
190   auto prim_list_reduce = std::make_shared<Primitive>("list_reduce");
191   FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3);
192   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
193   res->dump();
194   TypePtr res_type = res->GetTypeTrack();
195   res_type->dump();
196   ASSERT_TRUE(*res_type == Int(64));
197 }
198 
TEST_F(TestPrim,test_scalar_to_array)199 TEST_F(TestPrim, test_scalar_to_array) {
200   AbstractBasePtrList args_spec_list;
201   int64_t v1 = 1;
202 
203   AbstractBasePtr abstract_v1 = FromValue(v1, false);
204 
205   args_spec_list.push_back(abstract_v1);
206 
207   auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array");
208   FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1);
209   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
210   res->dump();
211   TypePtr res_type = res->BuildType();
212   res_type->dump();
213   ASSERT_TRUE(*res_type == TensorType(std::make_shared<Int>(64)));
214 }
215 
TEST_F(TestPrim,test_array_to_scalar)216 TEST_F(TestPrim, test_array_to_scalar) {
217   AbstractBasePtrList args_spec_list;
218   int64_t v1 = 1;
219 
220   AbstractBasePtr abstract_v1 = FromValue(v1, false);
221   auto abstract_a1 = std::make_shared<AbstractTensor>(abstract_v1, std::make_shared<Shape>());
222 
223   args_spec_list.push_back(abstract_a1);
224 
225   auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar");
226   FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1);
227   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
228   res->dump();
229   TypePtr res_type = res->BuildType();
230   res_type->dump();
231   ASSERT_TRUE(*res_type == Int(64));
232 }
233 
TEST_F(TestPrim,test_J_1)234 TEST_F(TestPrim, test_J_1) {
235   AbstractBasePtrList args_spec_list;
236   int64_t v1 = 1;
237 
238   AbstractBasePtr abstract_v1 = FromValue(v1, false);
239   args_spec_list.push_back(abstract_v1);
240 
241   auto prim_J = std::make_shared<Primitive>("J");
242   FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1);
243   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
244   AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res);
245   ASSERT_TRUE(res_J != nullptr);
246   ASSERT_TRUE(*(res_J->element()) == *abstract_v1);
247 }
248 
TEST_F(TestPrim,test_J_2)249 TEST_F(TestPrim, test_J_2) {
250   // def add(x):
251   //   return x + x
252   // def f(x):
253   //   return J(add)(x)
254   std::vector<AnfNodePtr> inputs;
255   FuncGraphPtr func_graph1 = std::make_shared<FuncGraph>();
256   inputs.push_back(NewValueNode(prim::kPrimScalarAdd));
257   auto x = func_graph1->add_parameter();
258   inputs.push_back(x);
259   inputs.push_back(x);
260   CNodePtr cnode1 = func_graph1->NewCNode(inputs);
261   func_graph1->set_return(cnode1);
262 
263   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
264   inputs.clear();
265   auto x1 = func_graph->add_parameter();
266   inputs.clear();
267   inputs.push_back(NewValueNode(prim::kPrimJ));
268   inputs.push_back(NewValueNode(func_graph1));
269   CNodePtr jf = func_graph->NewCNode(inputs);
270   inputs.clear();
271   inputs.push_back(jf);
272   inputs.push_back(x1);
273   CNodePtr jf_jx = func_graph->NewCNode(inputs);
274   inputs.clear();
275   inputs.push_back(NewValueNode(prim::kPrimReturn));
276   inputs.push_back(jf_jx);
277   CNodePtr cnode_return = func_graph->NewCNode(inputs);
278   func_graph->set_return(cnode_return);
279 
280   int64_t v1 = 1;
281   AbstractBasePtr abstract_v1 = FromValue(v1, false);
282   AbstractBasePtrList args_spec_list = {abstract_v1};
283   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
284   res->dump();
285   AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res);
286   ASSERT_TRUE(res_J != nullptr);
287   auto res_J_0 = res_J->elements()[0];
288   ASSERT_TRUE(res_J_0 != nullptr);
289   ASSERT_TRUE(*res_J_0 == *(FromValue(static_cast<int64_t>(2), false)));
290   AbstractFunctionPtr res_J_1 = dyn_cast<AbstractFunction>(res_J->elements()[1]);
291   ASSERT_TRUE(res_J_1 != nullptr);
292 }
293 
294 // tail half
TEST_F(TestPrim,test_switch1)295 TEST_F(TestPrim, test_switch1) {
296   PrimitivePtr switch_ = std::make_shared<Primitive>("Switch");
297   FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
298 
299   AbstractBasePtr arg0 = FromValue(true, false);
300   AbstractBasePtr arg1 = FromValue(static_cast<int64_t>(1), false);
301   AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false);
302   AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
303 
304   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
305   ASSERT_TRUE(*res == *arg1);
306 }
307 
TEST_F(TestPrim,test_switch2)308 TEST_F(TestPrim, test_switch2) {
309   PrimitivePtr switch_ = std::make_shared<Primitive>("Switch");
310   FuncGraphPtr func_graph = MakeFuncGraph(switch_, 3);
311 
312   AbstractBasePtr arg0 = FromValue(false, false);
313   AbstractBasePtr arg1 = FromValue(static_cast<int64_t>(1), false);
314   AbstractBasePtr arg2 = FromValue(static_cast<int64_t>(2), false);
315   AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
316 
317   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
318   MS_LOG(INFO) << "make result res: " << res->ToString();
319   MS_LOG(INFO) << "make result arg2: " << arg2->ToString();
320   ASSERT_TRUE(*res == *arg2);
321 }
322 
TEST_F(TestPrim,test_identity)323 TEST_F(TestPrim, test_identity) {
324   PrimitivePtr identity = std::make_shared<Primitive>("identity");
325   FuncGraphPtr func_graph = MakeFuncGraph(identity, 1);
326 
327   AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
328   AbstractBasePtrList args_spec_list = {abstract_v1};
329 
330   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
331   ASSERT_TRUE(*res == *abstract_v1);
332 }
333 
TEST_F(TestPrim,test_broadcast_shape)334 TEST_F(TestPrim, test_broadcast_shape) {
335   PrimitivePtr broadcast_shape = std::make_shared<Primitive>("broadcast_shape");
336   FuncGraphPtr func_graph = MakeFuncGraph(broadcast_shape, 2);
337 
338   auto a = UTPrimUtils::ShapeOf({Shape::SHP_ANY, Shape::SHP_ANY});
339   auto b = UTPrimUtils::ShapeOf({Shape::SHP_ANY});
340   std::vector<Any> expected{Shape::SHP_ANY, Shape::SHP_ANY};
341 
342   AbstractBasePtrList args_spec_list = {a, b};
343 
344   AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
345 
346   auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
347   std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)};
348   ASSERT_TRUE(ret.size() == element_list.size());
349   for (int64_t i = 0; i < element_list.size(); i++) {
350     ASSERT_TRUE(*ret[i] == *element_list[i]);
351   }
352 }
353 
TEST_F(TestPrim,test_partial)354 TEST_F(TestPrim, test_partial) {
355   PrimitivePtr prim = prim::kPrimPartial;
356   FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
357 
358   PrimitivePtr add = prim::kPrimScalarAdd;
359   AbstractBasePtr abstract_add = ToAbstract(add);
360   AbstractBasePtr abstract_v1 = FromValue(static_cast<int64_t>(1), false);
361   AbstractBasePtr abstract_v2 = FromValue(static_cast<int64_t>(1), false);
362   AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2};
363 
364   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
365   AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2};
366   auto expected = std::make_shared<PartialAbstractClosure>(
367     std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list);
368   MS_LOG(INFO) << "result: " << res->ToString();
369   MS_LOG(INFO) << "expected: " << expected->ToString();
370   ASSERT_TRUE(res->ToString() == expected->ToString());
371 }
372 
373 // def test_env(x, y):
374 //     return env_setitem(newenv, embed(x), y)
TEST_F(TestPrim,test_env_setitem)375 TEST_F(TestPrim, test_env_setitem) {
376   FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
377   AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
378   AbstractBasePtrList args_spec_list = {abstract_x};
379   AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
380 
381   FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
382 
383   AbstractBasePtr abstract_env = ToAbstract(newenv);
384   AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
385   args_spec_list = {abstract_env, embed_x, abstract_y};
386 
387   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
388   AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
389   ASSERT_TRUE(*res == *exp);
390 }
391 
392 // def test_env(x, y, z):
393 //     e = env_setitem(newenv, embed(x), y)
394 //     return env_getitem(e, embed(x), z)
TEST_F(TestPrim,test_env_getitem)395 TEST_F(TestPrim, test_env_getitem) {
396   FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
397   AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
398   AbstractBasePtrList args_spec_list = {abstract_x};
399   AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
400 
401   FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
402 
403   AbstractBasePtr abstract_env = ToAbstract(newenv);
404   AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
405   args_spec_list = {abstract_env, embed_x, abstract_y};
406 
407   AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
408   AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
409   ASSERT_TRUE(*res == *exp);
410 
411   FuncGraphPtr graph_getitem = MakeFuncGraph(prim::kPrimEnvGetItem, 3);
412 
413   AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
414   args_spec_list = {res, embed_x, abstract_z};
415 
416   res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract();
417 
418   ASSERT_TRUE(*res == *abstract_x);
419 }
420 
421 // def test_env(x, y, z):
422 //     e1 = env_setitem(newenv, embed(x), y)
423 //     e2 = env_setitem(newenv, embed(x), z)
424 //     return env_add(e1, e2)
TEST_F(TestPrim,test_env_add)425 TEST_F(TestPrim, test_env_add) {
426   FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
427   AbstractBasePtr abstract_x = FromValue(static_cast<int64_t>(1), false);
428   AbstractBasePtrList args_spec_list = {abstract_x};
429   AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
430 
431   FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
432 
433   AbstractBasePtr abstract_env = ToAbstract(newenv);
434   AbstractBasePtr abstract_y = FromValue(static_cast<int64_t>(2), false);
435   args_spec_list = {abstract_env, embed_x, abstract_y};
436 
437   AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
438   AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
439   ASSERT_TRUE(*abstract_e1 == *exp);
440 
441   AbstractBasePtr abstract_z = FromValue(static_cast<int64_t>(3), false);
442   args_spec_list = {abstract_env, embed_x, abstract_z};
443 
444   AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
445   ASSERT_TRUE(*abstract_e2 == *exp);
446 
447   FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2);
448   args_spec_list = {abstract_e1, abstract_e2};
449   AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract();
450 
451   ASSERT_TRUE(*res == *exp);
452 }
453 
TEST_F(TestPrim,test_relu)454 TEST_F(TestPrim, test_relu) {
455   PrimitivePtr relu = prim::kPrimRelu;
456   relu->AddAttr("T", MakeValue(static_cast<int64_t>(kNumberTypeFloat64)));
457   FuncGraphPtr func_graph = MakeFuncGraph(relu, 1);
458 
459   AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3});  // NCHW
460   AbstractBasePtrList args_spec_list = {expected};
461 
462   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
463   ASSERT_TRUE(*res == *expected);
464 }
465 
466 /*
467 TEST_F(TestPrim, test_relu2) {
468   FuncGraphPtr func_graph = getPyFun("get_relu");
469   ASSERT_TRUE(func_graph != nullptr);
470 
471   auto arr = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
472   auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
473 
474   AbstractBasePtrList args_spec_list = {arr};
475   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
476   auto res = dyn_cast<AbstractTensor>(ret);
477   ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
478 }
479 
480 TEST_F(TestPrim, test_conv2d1) {
481   std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
482   py::tuple kernel_size(2);
483   kernel_size[0] = 5;
484   kernel_size[1] = 5;
485   std::shared_ptr<FuncGraph> func_graph = getPyFun.CallAndParseRet("test_conv2d", 64, kernel_size, 0, 2, 1);
486 
487   // NCHW
488   std::vector<int64_t> inputs_dims = {2, 20, 32, 32};
489   std::vector<int64_t> weight_dims = {64, 20, 5, 5};
490 
491   tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
492   inputs->set_data_type(kNumberTypeInt32);
493   inputs->set_shape(inputs_dims);
494   // Cout, Cin, kernel_size
495   tensor::TensorPtr weight = std::make_shared<tensor::Tensor>();
496   weight->set_data_type(kNumberTypeInt32);
497   weight->set_shape(weight_dims);
498 
499   AbstractBasePtr abstract_inputs = FromValue(inputs, true);
500   AbstractBasePtr abstract_weight = FromValue(weight, true);
501   AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_weight};
502 
503   AbstractBasePtr expected = abstract_inputs->Clone();
504   // NCHW
505   std::vector<int64_t> shape = {2, 64, 14, 14};
506   expected->set_shape(std::make_shared<Shape>(shape));
507 
508   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
509   MS_LOG(INFO) << "result: " << res->ToString();
510   MS_LOG(INFO) << "expected: " << expected->ToString();
511 
512   auto res_ptr = dyn_cast<AbstractTensor>(res);
513   auto expected_ptr = dyn_cast<AbstractTensor>(expected);
514   ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
515   ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
516 }
517 
518 TEST_F(TestPrim, test_conv2d) {
519   FuncGraphPtr func_graph = getPyFun("get_conv2d");
520   ASSERT_TRUE(func_graph != nullptr);
521 
522   auto input = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
523   auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3});
524 
525   AbstractBasePtrList args_spec_list = {input, weight};
526   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
527   auto res = dyn_cast<AbstractTensor>(ret);
528   auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16});
529   MS_LOG(INFO) << "result: " << res->ToString();
530   MS_LOG(INFO) << "expected: " << expected->ToString();
531   ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
532 }
533 
534 TEST_F(TestPrim, test_conv2d_native) {
535   FuncGraphPtr func_graph = getPyFun("get_conv2d_native");
536   ASSERT_TRUE(func_graph != nullptr);
537 
538   auto input = ArrayOfTensor(UTPrimUtils::kF64, {10, 32, 32, 32});
539   auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3});
540 
541   AbstractBasePtrList args_spec_list = {input, weight};
542   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
543   auto res = dyn_cast<AbstractTensor>(ret);
544   auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16});
545   MS_LOG(INFO) << "result: " << res->ToString();
546   MS_LOG(INFO) << "expected: " << expected->ToString();
547   ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
548 }
549 
550 TEST_F(TestPrim, test_biasAdd) {
551   FuncGraphPtr func_graph = getPyFun("get_bias_add");
552   ASSERT_TRUE(func_graph != nullptr);
553 
554   auto value = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
555   auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32});
556 
557   AbstractBasePtrList args_spec_list = {value, bias};
558   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
559   auto res = dyn_cast<AbstractTensor>(ret);
560   auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
561   MS_LOG(INFO) << "result: " << res->ToString();
562   MS_LOG(INFO) << "expected: " << expected->ToString();
563   ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
564 }
565 
566 TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) {
567   FuncGraphPtr func_graph = getPyFun("get_softmax_cross_entropy_with_logits");
568   ASSERT_TRUE(func_graph != nullptr);
569 
570   auto logits = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
571   auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
572 
573   AbstractBasePtrList args_spec_list = {logits, labels};
574   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
575   ASSERT_NE(ret, nullptr);
576   auto res = dyn_cast<AbstractTuple>(ret);
577   auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64});
578   auto dLogits = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
579   AbstractBasePtrList expected_list = {loss, dLogits};
580   auto expected = std::make_shared<AbstractTuple>(expected_list);
581   MS_LOG(INFO) << "result: " << res->ToString();
582   MS_LOG(INFO) << "expected: " << expected->ToString();
583 
584   auto res_ptr0 = dyn_cast<AbstractTuple>(res);
585   auto expected_ptr0 = dyn_cast<AbstractTuple>(expected);
586 
587   ASSERT_GT((*res_ptr0).size(), 1);
588   auto res_ptr = dyn_cast<AbstractTensor>((*res_ptr0)[1]);
589   ASSERT_GT((*expected_ptr0).size(), 1);
590   auto expected_ptr = dyn_cast<AbstractTensor>((*expected_ptr0)[1]);
591   ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
592   ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
593 }
594 
595 TEST_F(TestPrim, test_tensor_to_scalar_prim) {
596   FuncGraphPtr func_graph = getPyFun("get_tensor_to_scalar");
597   ASSERT_TRUE(func_graph != nullptr);
598 
599   auto logits = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
600   auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
601 
602   AbstractBasePtrList args_spec_list = {logits, labels};
603   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
604   auto res = dyn_cast<AbstractScalar>(ret);
605   AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
606   expected->set_type(UTPrimUtils::kF64);
607   MS_LOG(INFO) << "result: " << res->ToString();
608   MS_LOG(INFO) << "expected: " << expected->ToString();
609   ASSERT_TRUE(*res == *expected);
610 }
611 
612 TEST_F(TestPrim, test_pooling) {
613   PrimitivePtr pooling = prim::kPrimPooling;
614   pooling->AddAttr("mode", MakeValue(std::string("avg")));
615   pooling->AddAttr("pad_mode", MakeValue(std::string("valid")));
616   pooling->AddAttr("nan_opt", MakeValue(0));
617   pooling->AddAttr("window", MakeValue(2));
618   pooling->AddAttr("pad", MakeValue(1));
619   pooling->AddAttr("stride", MakeValue(1));
620   pooling->AddAttr("data_mode", MakeValue(1));
621   pooling->AddAttr("ceil_mode", MakeValue(0));
622   FuncGraphPtr func_graph = MakeFuncGraph(pooling, 1);
623 
624   std::vector<int64_t> inputs_dims = {8, 64, 3, 3};
625   auto inputs = std::make_shared<tensor::Tensor>();
626   inputs->set_data_type(kNumberTypeFloat32);
627   inputs->set_shape(inputs_dims);
628   AbstractBasePtr abstract_input = FromValue(inputs, false);
629   AbstractBasePtrList args_spec_list = {abstract_input};
630   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
631 
632   AbstractBasePtr expected = abstract_input->Clone()->Broaden();
633   std::vector<int64_t> expected_dims = {8, 64, 2, 2};
634   expected->set_shape(std::make_shared<Shape>(expected_dims));
635   MS_LOG(INFO) << "result: " << res->ToString();
636   MS_LOG(INFO) << "expected: " << expected->ToString();
637   ASSERT_TRUE(*res == *expected);
638 }
639 
640 TEST_F(TestPrim, test_hastype) {
641   AbstractBasePtrList args_spec_list;
642   int64_t v1 = 1;
643   TypePtr v2 = std::make_shared<Number>();
644 
645   AbstractBasePtr abstract_v1 = FromValue(v1, false);
646   AbstractTypePtr abstract_v2 = UTPrimUtils::TypeToAbstract(v2);
647   AbstractBasePtr expected = FromValue(true, false);
648 
649   args_spec_list.push_back(abstract_v1);
650   args_spec_list.push_back(abstract_v2);
651 
652   auto prim = std::make_shared<Primitive>("hastype");
653   FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
654 
655   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
656   ASSERT_TRUE(*res == *expected);
657 }
658 
659 TEST_F(TestPrim, test_array_len) {
660   AbstractBasePtrList args_spec_list;
661   auto v1 = UTPrimUtils::ArrayFloat64Of({3, 4, 0, 2});
662   auto expected = std::make_shared<AbstractScalar>(kAnyValue, kInt32);
663 
664   args_spec_list.push_back(v1);
665 
666   auto prim = std::make_shared<Primitive>("array_len");
667   FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
668 
669   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
670   ASSERT_TRUE(*res == *expected);
671 }
672 
673 TEST_F(TestPrim, test_list_len) {
674   AbstractBasePtrList args_spec_list;
675   auto v1 = UTPrimUtils::ListShapeOf({3, 4, 0, 2});
676   auto expected = std::make_shared<AbstractScalar>(4);
677 
678   args_spec_list.push_back(v1);
679 
680   auto prim = std::make_shared<Primitive>("list_len");
681   FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
682 
683   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
684   ASSERT_TRUE(*res == *expected);
685 }
686 
687 TEST_F(TestPrim, test_tuple_len) {
688   AbstractBasePtrList args_spec_list;
689   auto v1 = UTPrimUtils::ShapeOf({3, 4, 0, 2});
690   auto expected = std::make_shared<AbstractScalar>(4);
691 
692   args_spec_list.push_back(v1);
693 
694   auto prim = std::make_shared<Primitive>("tuple_len");
695   FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
696 
697   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
698   ASSERT_TRUE(*res == *expected);
699 }
700 
701 TEST_F(TestPrim, test_tuple_reversed) {
702   AbstractBasePtrList args_spec_list;
703   auto v1 = UTPrimUtils::ShapeOf({0, 1, 2, 3});
704   auto expected = UTPrimUtils::ShapeOf({3, 2, 1, 0});
705 
706   args_spec_list.push_back(v1);
707 
708   auto prim = std::make_shared<Primitive>("tuple_reversed");
709   FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
710 
711   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
712   MS_LOG(INFO) << "expect=" << expected->ToString();
713   ASSERT_TRUE(*res == *expected);
714 }
715 
716 TEST_F(TestPrim, test_list_getitem) {
717   AbstractBasePtrList args_spec_list;
718   int64_t v1 = 2;
719   int64_t v2 = 1;
720 
721   AbstractBasePtr elem = FromValue(v1, false);
722   AbstractBasePtr elem2 = FromValue(v2, false);
723   AbstractBasePtrList elems = {elem, elem};
724   auto abstract_v1 = std::make_shared<AbstractList>(elems);
725   AbstractBasePtr abstract_v2 = FromValue(v2, false);
726 
727   args_spec_list.push_back(abstract_v1);
728   args_spec_list.push_back(abstract_v2);
729 
730   auto prim = std::make_shared<Primitive>("list_getitem");
731   FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
732 
733   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
734   ASSERT_TRUE(*res == *elem);
735 }
736 
737 TEST_F(TestPrim, test_list_setitem) {
738   int64_t v1 = 1;
739   int64_t v2 = 2;
740 
741   AbstractBasePtr elem1 = FromValue(v1, false);
742   AbstractBasePtr elem2 = FromValue(v2, false);
743   AbstractBasePtrList elems = {elem1, elem1};
744   auto abstract_tuple = std::make_shared<AbstractList>(elems);
745   AbstractBasePtr abstract_v2 = FromValue(v1, false);
746   AbstractBasePtr abstract_v3 = FromValue(v2, false);
747   AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2, abstract_v3};
748 
749   auto prim = std::make_shared<Primitive>("list_setitem");
750   FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
751 
752   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
753   MS_LOG(INFO) << "result: " << res->ToString();
754   AbstractBasePtrList elems_exp = {elem1, elem2};
755   auto expected = std::make_shared<AbstractList>(elems_exp);
756   MS_LOG(INFO) << "expected: " << expected->ToString();
757 
758   auto res_list = dyn_cast<AbstractList>(res);
759   ASSERT_TRUE(*expected == *res_list);
760 }
761 
762 TEST_F(TestPrim, test_list_append) {
763   int64_t v1 = 1;
764 
765   AbstractBasePtr elem1 = FromValue(v1, false);
766   AbstractBasePtr elem2 = FromValue(v1, false);
767   auto abstract_tuple = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
768   AbstractBasePtr abstract_v2 = FromValue(v1, false);
769   AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2};
770 
771   auto prim = std::make_shared<Primitive>("list_append");
772   FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
773 
774   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
775   MS_LOG(INFO) << "result: " << res->ToString();
776   auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
777   MS_LOG(INFO) << "expected: " << expected->ToString();
778 
779   auto res_list = dyn_cast<AbstractList>(res);
780   ASSERT_TRUE(*res_list == *expected);
781 }
782 
783 TEST_F(TestPrim, test_tuple_setitem) {
784   int64_t v1 = 1;
785   int64_t v2 = 2;
786 
787   AbstractBasePtr elem1 = FromValue(v1, false);
788   AbstractBasePtr elem2 = FromValue(v2, false);
789   AbstractBasePtrList elems = {elem1, elem1};
790   auto abstract_tuple = std::make_shared<AbstractTuple>(elems);
791   AbstractBasePtr abstract_v2 = FromValue(v1, false);
792   AbstractBasePtr abstract_v3 = FromValue(v2, false);
793   AbstractBasePtrList args_spec_list = {abstract_tuple, abstract_v2, abstract_v3};
794 
795   auto prim = std::make_shared<Primitive>("tuple_setitem");
796   FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
797 
798   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
799   MS_LOG(INFO) << "result: " << res->ToString();
800   AbstractBasePtrList elems_exp = {elem1, elem2};
801   auto expected = std::make_shared<AbstractTuple>(elems_exp);
802   MS_LOG(INFO) << "expected: " << expected->ToString();
803 
804   auto res_tuple = dyn_cast<AbstractTuple>(res);
805   ASSERT_TRUE(*res == *expected);
806 }
807 
808 TEST_F(TestPrim, test_make_list) {
809   AbstractBasePtrList args_spec_list;
810   int64_t v1 = 2;
811   int64_t v2 = 2;
812 
813   AbstractBasePtr abstract_v1 = FromValue(v1, false);
814   AbstractBasePtr abstract_v2 = FromValue(v2, false);
815 
816   auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({abstract_v1, abstract_v2}));
817 
818   args_spec_list.push_back(abstract_v1);
819   args_spec_list.push_back(abstract_v2);
820 
821   auto prim = std::make_shared<Primitive>("make_list");
822   FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
823 
824   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
825   ASSERT_TRUE(*res == *expected);
826 }
827 
828 TEST_F(TestPrim, test_make_range) {
829   AbstractBasePtrList args_spec_list;
830   int64_t v1 = 1;
831   int64_t v2 = 4;
832 
833   AbstractBasePtr abstract_v1 = FromValue(v1);
834   AbstractBasePtr abstract_v2 = FromValue(v2);
835   args_spec_list.push_back(abstract_v1);
836   args_spec_list.push_back(abstract_v2);
837 
838   auto prim = std::make_shared<Primitive>("make_range");
839   std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(prim, 2);
840 
841   AbstractBasePtr ele1 = FromValue(1);
842   AbstractBasePtr ele2 = FromValue(2);
843   AbstractBasePtr ele3 = FromValue(3);
844   AbstractBasePtrList elem_list({ele1, ele2, ele3});
845   AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list);
846 
847   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
848   MS_LOG(INFO) << "res=" << res->ToString();
849   MS_LOG(INFO) << "expected=" << expected->ToString();
850   ASSERT_TRUE(*res == *expected);
851 }
852 
853 TEST_F(TestPrim, test_layernorm) {
854   PrimitivePtr layerNorm = prim::kPrimLayerNorm;
855   layerNorm->AddAttr("begin_norm_axis", MakeValue(1));
856   layerNorm->AddAttr("begin_params_axis", MakeValue(1));
857 
858   std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(layerNorm, 3);
859 
860   std::vector<int64_t> inputs_dims = {128, 64, 32, 64};
861   std::vector<int64_t> mean_var_dims = {128, 64, 32, 1};
862   std::vector<int64_t> params_dims = {64, 32, 64};
863 
864   tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
865   inputs->set_data_type(kNumberTypeFloat32);
866   inputs->set_shape(inputs_dims);
867 
868   tensor::TensorPtr mean_var = std::make_shared<tensor::Tensor>();
869   mean_var->set_data_type(kNumberTypeFloat32);
870   mean_var->set_shape(mean_var_dims);
871 
872   tensor::TensorPtr gamma = std::make_shared<tensor::Tensor>();
873   gamma->set_data_type(kNumberTypeFloat32);
874   gamma->set_shape(params_dims);
875 
876   tensor::TensorPtr beta = std::make_shared<tensor::Tensor>();
877   beta->set_data_type(kNumberTypeFloat32);
878   beta->set_shape(params_dims);
879 
880   AbstractBasePtr abstract_inputs = FromValue(inputs, true);
881   AbstractBasePtr abstract_mean_var = FromValue(mean_var, true);
882   AbstractBasePtr abstract_gamma = FromValue(gamma, true);
883   AbstractBasePtr abstract_beta = FromValue(beta, true);
884   AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_gamma, abstract_beta};
885 
886   AbstractBasePtr expected0 = abstract_inputs->Clone();
887   AbstractBasePtr expected1 = abstract_mean_var->Clone();
888   AbstractBasePtr expected2 = abstract_mean_var->Clone();
889 
890   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
891   MS_LOG(INFO) << "result: " << res->ToString();
892   MS_LOG(INFO) << "expected0: " << expected0->ToString();
893   MS_LOG(INFO) << "expected1: " << expected1->ToString();
894   MS_LOG(INFO) << "expected2: " << expected2->ToString();
895 
896   std::shared_ptr<AbstractTuple> abs_tuple = dyn_cast<AbstractTuple>(res);
897   ASSERT_TRUE(abs_tuple != nullptr);
898 
899   auto res_ptr0 = dyn_cast<AbstractTensor>(abs_tuple->elements()[0]);
900   auto expected_ptr0 = dyn_cast<AbstractTensor>(expected0);
901   ASSERT_TRUE(*res_ptr0->shape() == *expected_ptr0->shape());
902   ASSERT_TRUE(*res_ptr0->element() == *expected_ptr0->element());
903 
904   auto res_ptr1 = dyn_cast<AbstractTensor>(abs_tuple->elements()[1]);
905   auto expected_ptr1 = dyn_cast<AbstractTensor>(expected1);
906   ASSERT_TRUE(*res_ptr1->shape() == *expected_ptr1->shape());
907   ASSERT_TRUE(*res_ptr1->element() == *expected_ptr1->element());
908 
909   auto res_ptr2 = dyn_cast<AbstractTensor>(abs_tuple->elements()[2]);
910   auto expected_ptr2 = dyn_cast<AbstractTensor>(expected2);
911   ASSERT_TRUE(*res_ptr2->shape() == *expected_ptr2->shape());
912   ASSERT_TRUE(*res_ptr2->element() == *expected_ptr2->element());
913 }
914 
915 TEST_F(TestPrim, test_DropoutGenMask) {
916   AbstractBasePtrList args_spec_list;
917 
918   auto arg0 = UTPrimUtils::ShapeOf({5, 5, 5, 5});
919 
920   std::vector<int64_t> keep_prob_shape = {};
921   tensor::TensorPtr keep_prob = std::make_shared<tensor::Tensor>(0.5f);
922   keep_prob->set_data_type(kNumberTypeFloat32);
923   keep_prob->set_shape(keep_prob_shape);
924   AbstractBasePtr abstract_keep_prob = FromValue(keep_prob);
925 
926   auto prim = std::make_shared<Primitive>("DropoutGenMask");
927   std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(prim, 2);
928 
929   args_spec_list.push_back(arg0);
930   args_spec_list.push_back(abstract_keep_prob);
931 
932   // should return a tensor with on dimension of 79 elements
933   AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
934                                                               std::make_shared<Shape>(std::vector<int64_t>{79}));
935 
936   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
937   MS_LOG(INFO) << "res=" << res->ToString();
938   MS_LOG(INFO) << "expected=" << expected->ToString();
939   ASSERT_TRUE(*res == *expected);
940 }
941 
942 TEST_F(TestPrim, test_dropout) {
943   std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
944   std::shared_ptr<FuncGraph> func_graph = getPyFun.CallAndParseRet("test_dropout");
945 
946   std::vector<int64_t> inputs_dims = {2, 20, 32, 32};
947 
948   tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
949   inputs->set_data_type(kNumberTypeFloat32);
950   inputs->set_shape(inputs_dims);
951 
952   AbstractBasePtr abstract_inputs = FromValue(inputs, true);
953   std::vector<int64_t> keep_prob_shape = {};
954   tensor::TensorPtr keep_prob = std::make_shared<tensor::Tensor>(0.5f);
955   keep_prob->set_data_type(kNumberTypeFloat32);
956   keep_prob->set_shape(keep_prob_shape);
957   AbstractBasePtr abstract_keep_prob = FromValue(keep_prob);
958 
959   AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_keep_prob};
960   AbstractBasePtr expected = abstract_inputs->Clone();
961 
962   // NCHW
963   std::vector<int64_t> shape = {2, 20, 32, 32};
964   expected->set_shape(std::make_shared<Shape>(shape));
965 
966   AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
967   MS_LOG(INFO) << "result: " << res->ToString();
968   MS_LOG(INFO) << "expected: " << expected->ToString();
969 
970   auto res_ptr = dyn_cast<AbstractTensor>(res);
971   auto expected_ptr = dyn_cast<AbstractTensor>(expected);
972   ASSERT_TRUE(*res_ptr->shape() == *expected_ptr->shape());
973   ASSERT_TRUE(*res_ptr->element() == *expected_ptr->element());
974 }
975 
976 TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) {
977   PrimitivePtr broadcatGradientArgs = prim::kPrimBroadcastGradientArgs;
978   std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(broadcatGradientArgs, 2);
979 
980   // broadcast shape: x: 8,5,3, y:3
981   // output: ((),(0, 1))
982   AbstractBasePtrList x_arg_list({abstract::FromValue(8), abstract::FromValue(5), abstract::FromValue(3)});
983   AbstractBasePtrList y_arg_list({abstract::FromValue(3)});
984   auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
985   auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
986   AbstractBasePtrList args_spec_list = {x_input, y_input};
987   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
988   auto res = dyn_cast<AbstractTuple>(ret);
989   AbstractBasePtrList x_idx_list;
990   auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
991   AbstractBasePtrList y_idx_list({abstract::FromValue(0), abstract::FromValue(1)});
992   auto r_y = std::make_shared<AbstractTuple>(y_idx_list);
993   AbstractBasePtrList elem_list({r_x, r_y});
994   auto expected = std::make_shared<AbstractTuple>(elem_list);
995   MS_LOG(INFO) << "result: " << res->ToString();
996   MS_LOG(INFO) << "expected: " << expected->ToString();
997   ASSERT_TRUE(*res == *expected);
998 }
999 
1000 TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) {
1001   PrimitivePtr broadcatGradientArgs = prim::kPrimBroadcastGradientArgs;
1002   std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(broadcatGradientArgs, 2);
1003 
1004   // broadcast shape: x: 8,1,3, y:8 5 3
1005   // output: ((1),())
1006   AbstractBasePtrList x_arg_list({abstract::FromValue(8), abstract::FromValue(1), abstract::FromValue(3)});
1007   AbstractBasePtrList y_arg_list({abstract::FromValue(8), abstract::FromValue(5), abstract::FromValue(3)});
1008   auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
1009   auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
1010   AbstractBasePtrList args_spec_list = {x_input, y_input};
1011   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
1012   auto res = dyn_cast<AbstractTuple>(ret);
1013   AbstractBasePtrList x_idx_list({abstract::FromValue(1)});
1014   auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
1015   AbstractBasePtrList y_idx_list;
1016   auto r_y = std::make_shared<AbstractTuple>(y_idx_list);
1017   AbstractBasePtrList elem_list({r_x, r_y});
1018   auto expected = std::make_shared<AbstractTuple>(elem_list);
1019   MS_LOG(INFO) << "result: " << res->ToString();
1020   MS_LOG(INFO) << "expected: " << expected->ToString();
1021   ASSERT_TRUE(*res == *expected);
1022 }
1023 
1024 TEST_F(TestPrim, test_DictGetItem) {
1025   PrimitivePtr dictGetItem = prim::kPrimDictGetItem;
1026   std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(dictGetItem, 2);
1027 
1028   std::vector<std::pair<std::string, ValuePtr>> tensor_map = {
1029     {"x", std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>{2, 3, 4})},
1030     {"y", std::make_shared<tensor::Tensor>(kNumberTypeInt32, std::vector<int64_t>{2, 1, 4})}};
1031   ValueDictionary value_dict(tensor_map);
1032   AbstractBasePtr array_dict = value_dict.ToAbstract();
1033   AbstractBasePtr key = abstract::FromValue("x");
1034   AbstractBasePtrList args_spec_list = {array_dict, key};
1035 
1036   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
1037   AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
1038   AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second));
1039 
1040   ASSERT_TRUE(*tensor_ret == *expect);
1041 }
1042 
1043 TEST_F(TestPrim, test_DictGetItem2) {
1044   PrimitivePtr dictGetItem = prim::kPrimDictGetItem;
1045   std::shared_ptr<FuncGraph> func_graph = MakeFuncGraph(dictGetItem, 2);
1046 
1047   AbstractBasePtr arr_x = ArrayOfTensor(UTPrimUtils::kF64, {3, 4, 5});
1048   AbstractBasePtr arr_y = ArrayOfTensor(UTPrimUtils::kF64, {1, 4, 5});
1049   AbstractBasePtr arr_z = ArrayOfTensor(UTPrimUtils::kF64, {3, 1, 5});
1050   std::vector<AbstractAttribute> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
1051   AbstractDictionaryPtr array_dict = std::make_shared<AbstractDictionary>(array_map);
1052   AbstractBasePtr key = abstract::FromValue("x");
1053   AbstractBasePtrList args_spec_list = {array_dict, key};
1054 
1055   AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
1056   AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
1057   AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x);
1058 
1059   ASSERT_TRUE(*tensor_ret == *expect);
1060 }
1061 */
1062 
1063 }  // namespace abstract
1064 }  // namespace mindspore
1065