• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <unordered_map>
18 
19 #include "pybind11/pybind11.h"
20 
21 #include "transform/transform_base_test.h"
22 #include "common/py_func_graph_fetcher.h"
23 #include "pipeline/jit/parse/parse.h"
24 #include "debug/draw.h"
25 #include "debug/anf_ir_dump.h"
26 #include "pipeline/jit/static_analysis/prim.h"
27 #include "frontend/operator/ops.h"
28 #include "common/common_test.h"
29 
30 #define private public
31 #include "transform/graph_ir/types.h"
32 #include "transform/graph_ir/convert.h"
33 #include "securec/include/securec.h"
34 #include "utils/utils.h"
35 using std::cout;
36 using std::endl;
37 using std::string;
38 using std::unordered_map;
39 
40 namespace mindspore {
41 namespace transform {
42 using AbstractScalar = abstract::AbstractScalar;
43 using mindspore::parse::ResolveAll;
44 
45 class TestConvert : public UT::Common {
46  public:
47   TestConvert() {}
48   virtual void SetUp();
49   virtual void TearDown();
50   static const std::shared_ptr<Float> kF32;
51 };
52 
53 void TestConvert::SetUp() { UT::InitPythonPath(); }
54 void TestConvert::TearDown() {}
55 
56 const std::shared_ptr<Float> TestConvert::kF32 = std::make_shared<Float>(32);
57 
58 AnfGraphPtr createAnfGraph() { return std::make_shared<AnfGraph>(); }
59 
60 TEST_F(TestConvert, TestConstruct) {
61   AnfGraphPtr func_graph = std::make_shared<AnfGraph>();
62   DfGraphConvertor converter(func_graph);
63   converter.ConvertAllNode().GetComputeGraph();
64   ASSERT_NE(converter.ErrCode(), SUCCESS);
65 }
66 
67 #if (!defined ENABLE_GE)
68 
69 namespace {
70 
71 bool MakeDfGraph(PrimitivePtr prim, unsigned int nparam) {
72   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, nparam);
73   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
74 
75   DfGraphConvertor converter(anf_graph);
76   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
77   if (converter.ErrCode() != 0) {
78     MS_LOG(ERROR) << "DfGraphConvertor convert " << prim->name() << " error, error code is: " << converter.ErrCode();
79     return false;
80   }
81   if (df_graph == nullptr) {
82     MS_LOG(ERROR) << "DfGraphConvertor get " << prim->name() << " compute func_graph failed";
83     return false;
84   }
85   return true;
86 }
87 
88 }  // namespace
89 
90 TEST_F(TestConvert, TestConvertConv2d) {
91   PrimitivePtr conv2d = prim::kPrimConv2D;
92   conv2d->AddAttr("stride", MakeValue(static_cast<int64_t>(2)));
93   conv2d->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
94   conv2d->AddAttr("dilation", MakeValue(static_cast<int64_t>(0)));
95 
96   FuncGraphPtr anf_graph = MakeFuncGraph(conv2d, 2);
97   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
98   DfGraphConvertor converter(anf_graph);
99   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
100   ASSERT_EQ(converter.ErrCode(), 0);
101   ASSERT_NE(df_graph, nullptr);
102 }
103 
104 TEST_F(TestConvert, TestConvertMaxpooling) {
105   auto prim = std::make_shared<Primitive>("MaxPool");
106   FuncGraphPtr anf_graph = MakeFuncGraph(prim, 5);  // ary, ksize, stride, padding, data_format
107 
108   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
109   DfGraphConvertor converter(anf_graph);
110   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
111   ASSERT_EQ(converter.ErrCode(), 0);
112   ASSERT_NE(df_graph, nullptr);
113 }
114 
115 TEST_F(TestConvert, TestReluOps) {
116   auto prim = prim::kPrimRelu;
117   prim->AddAttr("T", MakeValue(static_cast<int64_t>(0)));
118 
119   auto func_graph = MakeFuncGraph(prim, 1);
120   ASSERT_TRUE(nullptr != func_graph);
121 
122   // save the func_graph to manager
123   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
124 
125   // call resolve
126   bool ret_ = ResolveAll(manager);
127   ASSERT_TRUE(ret_);
128 
129   // draw graph
130   auto anfGraph = *(manager->func_graphs().begin());
131   DfGraphConvertor converter(anfGraph);
132   converter.ConvertAllNode().BuildGraph().GetComputeGraph();
133   ASSERT_EQ(converter.ErrCode(), 0);
134 }
135 
136 TEST_F(TestConvert, TestConvertBatchNorm) {
137   PrimitivePtr batch_norm = prim::kPrimBatchNorm;
138   batch_norm->AddAttr("epsilon", MakeValue(0.001f));
139   batch_norm->AddAttr("momentum", MakeValue(0.1f));
140 
141   FuncGraphPtr anf_graph = std::make_shared<FuncGraph>();
142   std::vector<AnfNodePtr> inputs;
143   inputs.push_back(NewValueNode(batch_norm));
144   for (unsigned int i = 0; i < 5; i++) {
145     inputs.push_back(anf_graph->add_parameter());
146   }
147   CNodePtr cnode_prim = anf_graph->NewCNode(inputs);
148   inputs.clear();
149 
150   inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
151   inputs.push_back(cnode_prim);
152   inputs.push_back(NewValueNode(static_cast<int64_t>(2)));
153   CNodePtr cnode_getitem = anf_graph->NewCNode(inputs);
154   inputs.clear();
155 
156   inputs.push_back(NewValueNode(prim::kPrimRelu));
157   inputs.push_back(cnode_getitem);
158   CNodePtr cnode_relu = anf_graph->NewCNode(inputs);
159   inputs.clear();
160 
161   inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
162   inputs.push_back(cnode_relu);
163   CNodePtr cnode_return = anf_graph->NewCNode(inputs);
164   anf_graph->set_return(cnode_return);
165 
166   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
167   DfGraphConvertor converter(anf_graph);
168   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
169   ASSERT_EQ(converter.ErrCode(), 0);
170   ASSERT_NE(df_graph, nullptr);
171 }
172 
173 TEST_F(TestConvert, TestConvertConvBackpropInput) {
174   auto prim = prim::kPrimConv2DBackpropInput;
175   const std::vector<int64_t> list{1,1};
176   prim->AddAttr("stride", MakeValue(list));
177   prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
178   prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
179   prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
180   prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
181   prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
182   prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
183 
184   auto func_graph = MakeFuncGraph(prim, 3);
185   ASSERT_NE(func_graph, nullptr);
186   // save the func_graph to manager
187   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
188 
189   // call resolve
190   bool ret_ = ResolveAll(manager);
191   ASSERT_TRUE(ret_);
192 
193   // draw graph
194   auto anf_graph = *(manager->func_graphs().begin());
195   DfGraphConvertor converter(anf_graph);
196   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
197   ASSERT_EQ(converter.ErrCode(), 0);
198   ASSERT_NE(df_graph, nullptr);
199 }
200 
201 TEST_F(TestConvert, TestConvertConvBackpropFilter) {
202   auto prim = prim::kPrimConv2DBackpropFilter;
203   const std::vector<int64_t> list{1,1};
204   prim->AddAttr("stride", MakeValue(list));
205   prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
206   prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
207   prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
208   prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
209   prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
210   prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
211 
212   auto func_graph = MakeFuncGraph(prim, 3);
213   ASSERT_NE(func_graph, nullptr);
214   // save the func_graph to manager
215   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
216 
217   // call resolve
218   bool ret_ = ResolveAll(manager);
219   ASSERT_TRUE(ret_);
220 
221   // draw graph
222   auto anf_graph = *(manager->func_graphs().begin());
223   DfGraphConvertor converter(anf_graph);
224   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
225   ASSERT_EQ(converter.ErrCode(), 0);
226   ASSERT_NE(df_graph, nullptr);
227 }
228 
229 TEST_F(TestConvert, TestConvertReluGrad) {
230   auto prim = prim::kPrimReluGrad;
231   prim->AddAttr("alpha", MakeValue(0.1f));
232   prim->AddAttr("beta", MakeValue(0.1f));
233   prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
234 
235   auto func_graph = MakeFuncGraph(prim, 2);
236   ASSERT_NE(func_graph, nullptr);
237   // save the func_graph to manager
238   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
239 
240   // call resolve
241   bool ret_ = ResolveAll(manager);
242   ASSERT_TRUE(ret_);
243 
244   // draw graph
245   auto anf_graph = *(manager->func_graphs().begin());
246   DfGraphConvertor converter(anf_graph);
247   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
248   ASSERT_EQ(converter.ErrCode(), 0);
249   ASSERT_NE(df_graph, nullptr);
250 }
251 
252 TEST_F(TestConvert, TestConvertBiasAdd) {
253   auto prim = std::make_shared<Primitive>("BiasAdd");
254   prim->AddAttr("alpha", MakeValue(0.0f));
255   prim->AddAttr("beta", MakeValue(1.0f));
256 
257   auto func_graph = MakeFuncGraph(prim, 2);
258   ASSERT_NE(func_graph, nullptr);
259   // save the func_graph to manager
260   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
261 
262   // call resolve
263   bool ret_ = ResolveAll(manager);
264   ASSERT_TRUE(ret_);
265 
266   // draw graph
267   auto anf_graph = *(manager->func_graphs().begin());
268   DfGraphConvertor converter(anf_graph);
269   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
270   ASSERT_EQ(converter.ErrCode(), 0);
271   ASSERT_NE(df_graph, nullptr);
272 }
273 
274 TEST_F(TestConvert, TestConvertBiasAddGrad) {
275   auto prim = prim::kPrimBiasAddGrad;
276   prim->AddAttr("alpha", MakeValue(0.0f));
277   prim->AddAttr("beta", MakeValue(1.0f));
278 
279   auto func_graph = MakeFuncGraph(prim, 2);
280   ASSERT_NE(func_graph, nullptr);
281   // save the func_graph to manager
282   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
283 
284   // call resolve
285   bool ret_ = ResolveAll(manager);
286   ASSERT_TRUE(ret_);
287 
288   // draw graph
289   auto anf_graph = *(manager->func_graphs().begin());
290   DfGraphConvertor converter(anf_graph);
291   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
292   ASSERT_EQ(converter.ErrCode(), 0);
293   ASSERT_NE(df_graph, nullptr);
294 }
295 
296 TEST_F(TestConvert, TestConvertMaxPoolGradWithArgmax) {
297   auto prim = std::make_shared<Primitive>("MaxPoolGradWithArgmax");
298   prim->AddAttr("alpha", MakeValue(0.0f));
299   prim->AddAttr("beta", MakeValue(1.0f));
300   prim->AddAttr("window", MakeValue(static_cast<int64_t>(2)));
301   prim->AddAttr("stride", MakeValue(static_cast<int64_t>(1)));
302   prim->AddAttr("ceil_mode", MakeValue(static_cast<int64_t>(0)));
303   prim->AddAttr("data_mode", MakeValue(static_cast<int64_t>(0)));
304   prim->AddAttr("alpha", MakeValue(0.1f));
305   prim->AddAttr("beta", MakeValue(1.0f));
306 
307   auto func_graph = MakeFuncGraph(prim, 2);
308   ASSERT_NE(func_graph, nullptr);
309   // save the func_graph to manager
310   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
311 
312   // call resolve
313   bool ret_ = ResolveAll(manager);
314   ASSERT_TRUE(ret_);
315 
316   // draw graph
317   auto anf_graph = *(manager->func_graphs().begin());
318   DfGraphConvertor converter(anf_graph);
319   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
320   ASSERT_EQ(converter.ErrCode(), 0);
321   ASSERT_NE(df_graph, nullptr);
322 }
323 
324 TEST_F(TestConvert, TestConcat) {
325   auto prim = prim::kPrimConcat;
326 
327   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
328   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
329   DfGraphConvertor converter(anf_graph);
330   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
331   ASSERT_EQ(converter.ErrCode(), 0);
332   ASSERT_NE(df_graph, nullptr);
333 }
334 
335 TEST_F(TestConvert, TestGatherV2) {
336   auto prim = prim::kPrimGather;
337 
338   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 3);
339   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
340   DfGraphConvertor converter(anf_graph);
341   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
342   ASSERT_EQ(converter.ErrCode(), 0);
343   ASSERT_NE(df_graph, nullptr);
344 }
345 
346 TEST_F(TestConvert, TestCast) {
347   auto prim = prim::kPrimCast;
348 
349   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
350   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
351   DfGraphConvertor converter(anf_graph);
352   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
353   ASSERT_EQ(converter.ErrCode(), 0);
354   ASSERT_NE(df_graph, nullptr);
355 }
356 
357 TEST_F(TestConvert, TestExp) {
358   auto prim = std::make_shared<Primitive>("Exp");
359 
360   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
361   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
362   DfGraphConvertor converter(anf_graph);
363   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
364   ASSERT_EQ(converter.ErrCode(), 0);
365   ASSERT_NE(df_graph, nullptr);
366 }
367 
368 TEST_F(TestConvert, TestFloor) {
369   auto prim = std::make_shared<Primitive>("Floor");
370 
371   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
372   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
373   DfGraphConvertor converter(anf_graph);
374   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
375   ASSERT_EQ(converter.ErrCode(), 0);
376   ASSERT_NE(df_graph, nullptr);
377 }
378 
379 TEST_F(TestConvert, TestGreaterEqual) {
380   auto prim = std::make_shared<Primitive>("GreaterEqual");
381 
382   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
383   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
384 
385   DfGraphConvertor converter(anf_graph);
386   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
387   ASSERT_EQ(converter.ErrCode(), 0);
388   ASSERT_NE(df_graph, nullptr);
389 }
390 
391 TEST_F(TestConvert, TestLess) {
392   auto prim = std::make_shared<Primitive>("Less");
393   prim->AddAttr("T", MakeValue(kFloat32));
394 
395   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
396   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
397 
398   DfGraphConvertor converter(anf_graph);
399   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
400   ASSERT_EQ(converter.ErrCode(), 0);
401   ASSERT_NE(df_graph, nullptr);
402 }
403 
404 TEST_F(TestConvert, TestLessEqual) {
405   auto prim = std::make_shared<Primitive>("LessEqual");
406 
407   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
408   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
409 
410   DfGraphConvertor converter(anf_graph);
411   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
412   ASSERT_EQ(converter.ErrCode(), 0);
413   ASSERT_NE(df_graph, nullptr);
414 }
415 
416 TEST_F(TestConvert, TestLogicalNot) {
417   auto prim = std::make_shared<Primitive>("LogicalNot");
418 
419   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
420   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
421 
422   DfGraphConvertor converter(anf_graph);
423   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
424   ASSERT_EQ(converter.ErrCode(), 0);
425   ASSERT_NE(df_graph, nullptr);
426 }
427 
428 TEST_F(TestConvert, TestAssignAdd) {
429   auto prim = prim::kPrimAssignAdd;
430   prim->AddAttr("use_locking", MakeValue(true));
431 
432   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
433   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
434 
435   DfGraphConvertor converter(anf_graph);
436   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
437   ASSERT_EQ(converter.ErrCode(), 0);
438   ASSERT_NE(df_graph, nullptr);
439 }
440 
441 TEST_F(TestConvert, LogSoftmax) {
442   auto prim = prim::kPrimLogSoftmax;
443   prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
444 
445   std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
446   std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
447 
448   DfGraphConvertor converter(anf_graph);
449   auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
450   ASSERT_EQ(converter.ErrCode(), 0);
451   ASSERT_NE(df_graph, nullptr);
452 }
453 
454 TEST_F(TestConvert, TestMaximumOps) {
455   auto prim = prim::kPrimMaximum;
456   bool ret = MakeDfGraph(prim, 2);
457   ASSERT_TRUE(ret);
458 }
459 
460 TEST_F(TestConvert, TestReduceMeanOps) {
461   auto prim = prim::kPrimReduceMean;
462   prim->AddAttr("keepdims", MakeValue(true));
463   bool ret = MakeDfGraph(prim, 2);
464   ASSERT_TRUE(ret);
465 }
466 
467 TEST_F(TestConvert, TestMinimumOps) {
468   auto prim = prim::kPrimMinimum;
469   bool ret = MakeDfGraph(prim, 2);
470   ASSERT_TRUE(ret);
471 }
472 
473 TEST_F(TestConvert, TestFusedMinOrMaxGradOps) {
474   // Add infer step to this test case
475   ASSERT_TRUE(true);
476 }
477 
478 TEST_F(TestConvert, TestSqueezeOps) {
479   auto prim = prim::kPrimSqueeze;
480   bool ret = MakeDfGraph(prim, 2);
481   ASSERT_TRUE(ret);
482 }
483 
484 TEST_F(TestConvert, TestMulOps) {
485   auto prim = prim::kPrimMul;
486   bool ret = MakeDfGraph(prim, 2);
487   ASSERT_TRUE(ret);
488 }
489 
490 TEST_F(TestConvert, TestNegOps) {
491   auto prim = prim::kPrimNeg;
492   bool ret = MakeDfGraph(prim, 1);
493   ASSERT_TRUE(ret);
494 }
495 
496 TEST_F(TestConvert, TestOneHotOps) {
497   auto prim = prim::kPrimOneHot;
498   prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
499   bool ret = MakeDfGraph(prim, 4);
500   ASSERT_TRUE(ret);
501 }
502 
503 TEST_F(TestConvert, TestPowOps) {
504   auto prim = std::make_shared<Primitive>("Pow");
505   bool ret = MakeDfGraph(prim, 2);
506   ASSERT_TRUE(ret);
507 }
508 
509 TEST_F(TestConvert, TestReciprocalOps) {
510   auto prim = std::make_shared<Primitive>("Reciprocal");
511   bool ret = MakeDfGraph(prim, 1);
512   ASSERT_TRUE(ret);
513 }
514 
515 TEST_F(TestConvert, TestSelectOps) {
516   auto prim = prim::kPrimSelect;
517   bool ret = MakeDfGraph(prim, 3);
518   ASSERT_TRUE(ret);
519 }
520 
521 TEST_F(TestConvert, TestSqrtOps) {
522   auto prim = std::make_shared<Primitive>("Sqrt");
523   bool ret = MakeDfGraph(prim, 1);
524   ASSERT_TRUE(ret);
525 }
526 
527 TEST_F(TestConvert, TestSquareOps) {
528   auto prim = std::make_shared<Primitive>("Square");
529   bool ret = MakeDfGraph(prim, 1);
530   ASSERT_TRUE(ret);
531 }
532 
533 #ifndef ENABLE_SECURITY
534 TEST_F(TestConvert, TestScalarSummaryOps) {
535   auto prim = prim::kPrimScalarSummary;
536   // should have only 1 input.
537   bool ret = MakeDfGraph(prim, 2);
538   ASSERT_TRUE(ret);
539 }
540 
541 TEST_F(TestConvert, TestTensorSummaryOps) {
542   auto prim = prim::kPrimTensorSummary;
543   bool ret = MakeDfGraph(prim, 2);
544   ASSERT_TRUE(ret);
545 }
546 
547 TEST_F(TestConvert, TestHistogramSummaryOps) {
548   auto prim = prim::kPrimHistogramSummary;
549   bool ret = MakeDfGraph(prim, 2);
550   ASSERT_TRUE(ret);
551 }
552 #endif
553 
554 TEST_F(TestConvert, TestGreaterOps) {
555   auto prim = std::make_shared<Primitive>("Greater");
556   bool ret = MakeDfGraph(prim, 2);
557   ASSERT_TRUE(ret);
558 }
559 
560 TEST_F(TestConvert, TestEqualOps) {
561   auto prim = std::make_shared<Primitive>("Equal");
562   bool ret = MakeDfGraph(prim, 2);
563   ASSERT_TRUE(ret);
564 }
565 
566 TEST_F(TestConvert, TestArgMaxiOps) {
567   auto prim = std::make_shared<Primitive>("Argmax");
568   bool ret = MakeDfGraph(prim, 2);
569   ASSERT_TRUE(ret);
570 }
571 
572 TEST_F(TestConvert, TestResizeNearestNeighborOps) {
573   auto prim = std::make_shared<Primitive>("ResizeNearestNeighbor");
574   bool ret = MakeDfGraph(prim, 1);
575   ASSERT_TRUE(ret);
576 }
577 
578 TEST_F(TestConvert, TestApplyMomentumOps) {
579   auto prim = std::make_shared<Primitive>("ApplyMomentum");
580   bool ret = MakeDfGraph(prim, 5);
581   ASSERT_TRUE(ret);
582 }
583 
584 TEST_F(TestConvert, TestNPUGetFloatStatusOps) {
585   auto prim = std::make_shared<Primitive>("NPUGetFloatStatus");
586   bool ret = MakeDfGraph(prim, 1);
587   ASSERT_TRUE(ret);
588 }
589 
590 TEST_F(TestConvert, TestNPUAllocFloatStatusOps) {
591   auto prim = std::make_shared<Primitive>("NPUAllocFloatStatus");
592   bool ret = MakeDfGraph(prim, 0);
593   ASSERT_TRUE(ret);
594 }
595 
596 TEST_F(TestConvert, TestNPUClearFloatStatusOps) {
597   auto prim = std::make_shared<Primitive>("NPUClearFloatStatus");
598   bool ret = MakeDfGraph(prim, 1);
599   ASSERT_TRUE(ret);
600 }
601 
602 #endif
603 
604 TEST_F(TestConvert, TestAddOps) {
605   auto prim = std::make_shared<Primitive>("Add");
606   auto func_graph = MakeFuncGraph(prim, 2);
607   ASSERT_TRUE(nullptr != func_graph);
608 
609   // save the func_graph to manager
610   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
611 
612   // call resolve
613   bool ret_ = ResolveAll(manager);
614   ASSERT_TRUE(ret_);
615 
616   // draw graph
617   auto anfGraph = *(manager->func_graphs().begin());
618   DfGraphConvertor converter(anfGraph);
619   converter.ConvertAllNode().BuildGraph().GetComputeGraph();
620   ASSERT_EQ(converter.ErrCode(), 0);
621 }
622 
623 TEST_F(TestConvert, TestConvertTensor) {
624   float data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
625   // Create a tensor with wanted data type and shape
626   std::vector<int64_t> dims{2, 2, 3};
627   std::vector<int64_t> ge_dims{2, 2, 3};
628   auto type_id = kNumberTypeFloat32;
629   MeTensor me_tensor(type_id, dims);
630   // Get the writable data pointer of the tensor and cast it to its data type
631   uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c());
632   // Copy or use the writable data pointer of the ME tensor
633   memcpy_s(me_data_ptr, me_tensor.data().nbytes(), data, 12 * sizeof(float));
634   auto me_tensor_ptr = std::make_shared<MeTensor>(me_tensor);
635   auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW);
636   ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetFormat(), GeFormat::FORMAT_NCHW);
637   ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetDataType(), GeDataType::DT_FLOAT);
638   // ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().array().GetDims(), ge_dims);
639   int i = 0;
640   for (i = 0; i < ge_dims.size(); i++) {
641     ASSERT_EQ(ge_dims[i], ge_tensor_ptr->GetTensorDesc().GetShape().GetDims()[i]);
642   }
643   for (i = 0; i < ge_tensor_ptr->GetTensorDesc().GetShape().GetShapeSize(); i++) {
644     ASSERT_EQ(data[i], (reinterpret_cast<float*>(ge_tensor_ptr->GetData()))[i]);
645   }
646 }
647 
648 TEST_F(TestConvert, TestConvertTensor0Dims) {
649   // shape with 0 dims is also valid
650   std::vector<int64_t> dims{};
651   auto type_id = kNumberTypeFloat32;
652   auto me_tensor_ptr = std::make_shared<MeTensor>(type_id, dims);
653   ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW), nullptr);
654 }
655 
656 TEST_F(TestConvert, TestConvertTensorError) {
657   std::vector<int64_t> dims2{2, 3, 4};
658   auto type_id_2 = kNumberTypeFloat32;
659   auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2);
660   ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
661 }
662 
663 TEST_F(TestConvert, TestUtilsConvertDataType) {
664   ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat16), GeDataType::DT_FLOAT16);
665   ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat32), GeDataType::DT_FLOAT);
666   ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat64), GeDataType::DT_DOUBLE);
667   ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt8), GeDataType::DT_INT8);
668   ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt16), GeDataType::DT_INT16);
669   ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt32), GeDataType::DT_INT32);
670   ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt64), GeDataType::DT_INT64);
671   ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeUInt32), GeDataType::DT_UINT32);
672   ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeBool), GeDataType::DT_BOOL);
673 }
674 
675 TEST_F(TestConvert, TestUtilsConvertFormat) {
676   ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NCHW), GeFormat::FORMAT_NCHW);
677   ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NC1HWC0), GeFormat::FORMAT_NC1HWC0);
678   ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NHWC), GeFormat::FORMAT_NHWC);
679   ASSERT_EQ(TransformUtil::ConvertFormat("xyz"), GeFormat::FORMAT_ND);
680 }
681 
682 TEST_F(TestConvert, TestUtilsDataSize) {
683   ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat32), 4);
684   ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat16), 2);
685   ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat64), 8);
686   ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt8), 1);
687   ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt16), 2);
688   ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt32), 4);
689   ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt64), 8);
690   ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeUInt32), 4);
691   ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeBool), 1);
692 }
693 
694 TEST_F(TestConvert, TestConvertGeTensor) {
695 #define DTYPE float
696   ge::DataType dt = ge::DataType::DT_FLOAT;
697 
698   std::vector<float> data1 = {1.1, 2.2, 3.3, 4.4, 6.6, 7.7, 8.8, 9.9};
699   std::vector<DTYPE> data2 = {1, 2, 3, 4, 6, 7, 8, 9};
700   auto data = data1;
701   ge::Shape shape({2, 2, 2});
702   ge::Format format = ge::Format::FORMAT_NCHW;
703   ge::TensorDesc desc(shape, format, dt);
704   GeTensorPtr ge_tensor_ptr =
705     std::make_shared<GeTensor>(desc, reinterpret_cast<uint8_t*>(data.data()), data.size() * sizeof(DTYPE));
706   GeTensor& ge_tensor = *ge_tensor_ptr;
707   const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensor.GetData());
708 
709   // make sure GetData()'s return is a reference
710   assert(ge_data == reinterpret_cast<DTYPE*>(ge_tensor.GetData()));
711 
712   cout << "ge data size is: " << std::dec << ge_tensor.GetSize() << " bytes" << endl;
713   for (int i = 0; i < ge_tensor.GetSize() / sizeof(DTYPE); i++) {
714     cout << "ge data is: " << static_cast<DTYPE>(*(ge_data + i)) << endl;
715   }
716 
717   MeTensorPtr me_tensor_ptr = TransformUtil::ConvertGeTensor(ge_tensor_ptr);
718   MeTensor& me_tensor = *me_tensor_ptr;
719   cout << "after convert ge tensor to me tensor" << endl;
720   DTYPE* me_data = reinterpret_cast<DTYPE*>(me_tensor.data_c());
721   PrintMeTensor(&me_tensor);
722 
723   assert(ge_tensor.GetSize() == me_tensor.data().nbytes());
724   assert(memcmp(ge_data, me_data, ge_tensor.GetSize()) == 0);
725 }
726 
727 TEST_F(TestConvert, TestConvertMakeTuple) {
728   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
729   std::vector<AnfNodePtr> inputs;
730   inputs.push_back(NewValueNode(std::make_shared<Primitive>("MakeTuple")));
731   for (int i = 0; i < 3; i++) {
732     auto input = func_graph->add_parameter();
733     input->set_name("x" + std::to_string(i));
734     inputs.push_back(input);
735   }
736   CNodePtr cnode_prim = func_graph->NewCNode(inputs);
737   inputs.clear();
738   inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
739   inputs.push_back(cnode_prim);
740   CNodePtr cnode_return = func_graph->NewCNode(inputs);
741   func_graph->set_return(cnode_return);
742 
743   // save the func_graph to manager
744   std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
745 
746   // call resolve
747   bool ret_ = ResolveAll(manager);
748   ASSERT_TRUE(ret_);
749 
750   // draw graph
751   auto anfGraph = *(manager->func_graphs().begin());
752   DfGraphConvertor converter(anfGraph);
753   converter.ConvertAllNode().BuildGraph().GetComputeGraph();
754   ASSERT_EQ(converter.ErrCode(), 0);
755 }
756 
757 TEST_F(TestConvert, TestConvertInputTensors) {
758 #define DTYPE float
759   std::initializer_list<int64_t> list0 = {1, 1, 4, 4};
760   std::initializer_list<int64_t> list1 = {2, 3, 4, 5};
761   std::initializer_list<int64_t> list2 = {9, 9, 1, 1};
762   MeTensorPtr input_ptr1 = MakeTensor(kF32, list0);
763   MeTensorPtr input_ptr2 = MakeTensor(kF32, list1);
764   MeTensorPtr input_ptr3 = MakeTensor(kF32, list2);
765   std::vector<MeTensorPtr> me_inputs;
766   me_inputs.emplace_back(input_ptr1);
767   me_inputs.emplace_back(input_ptr2);
768   me_inputs.emplace_back(input_ptr3);
769 
770   std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(me_inputs, kOpFormat_NCHW);
771 
772   for (int i = 0; i < ge_tensors.size(); i++) {
773     DTYPE* me_data = reinterpret_cast<DTYPE*>(me_inputs[i]->data_c());
774     const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
775     ASSERT_TRUE(ge_tensors[i]->GetSize() == me_inputs[i]->data().nbytes());
776     ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
777     ASSERT_TRUE(ge_tensors[i]->GetTensorDesc().GetShape().GetDims() ==
778                 TransformUtil::ConvertMeShape(me_inputs[i]->shape_c()).GetDims());
779   }
780 }
781 
782 TEST_F(TestConvert, TestConvertGeTensors) {
783 #define DTYPE float
784   ge::DataType dt = ge::DataType::DT_FLOAT;
785 
786   std::vector<float> data1(16);
787   std::vector<float> data2(120);
788   std::vector<float> data3(81);
789   ge::Shape shape1({1, 1, 4, 4});
790   ge::Shape shape2({2, 3, 4, 5});
791   ge::Shape shape3({9, 9, 1, 1});
792   ge::Format format = ge::Format::FORMAT_NCHW;
793   ge::TensorDesc desc1(shape1, format, dt);
794   ge::TensorDesc desc2(shape2, format, dt);
795   ge::TensorDesc desc3(shape3, format, dt);
796   GeTensorPtr ge_tensor_ptr1 =
797     std::make_shared<GeTensor>(desc1, reinterpret_cast<uint8_t*>(data1.data()), data1.size() * sizeof(DTYPE));
798   GeTensorPtr ge_tensor_ptr2 =
799     std::make_shared<GeTensor>(desc2, reinterpret_cast<uint8_t*>(data2.data()), data2.size() * sizeof(DTYPE));
800   GeTensorPtr ge_tensor_ptr3 =
801     std::make_shared<GeTensor>(desc3, reinterpret_cast<uint8_t*>(data3.data()), data3.size() * sizeof(DTYPE));
802 
803   std::vector<GeTensorPtr> ge_tensors;
804   ge_tensors.emplace_back(ge_tensor_ptr1);
805   ge_tensors.emplace_back(ge_tensor_ptr2);
806   ge_tensors.emplace_back(ge_tensor_ptr3);
807 
808   std::vector<std::vector<int64_t>> request_dims;
809   std::vector<int64_t> dims1 = {1, 1, 4, 4};
810   std::vector<int64_t> dims2 = {2, 3, 4, 5};
811   std::vector<int64_t> dims3 = {9, 9, 1, 1};
812   request_dims.emplace_back(dims1);
813   request_dims.emplace_back(dims2);
814   request_dims.emplace_back(dims3);
815 
816   std::vector<MeTensorPtr> me_outputs = TransformUtil::ConvertGeTensors(ge_tensors, request_dims);
817 
818   for (int i = 0; i < ge_tensors.size(); i++) {
819     DTYPE* me_data = reinterpret_cast<DTYPE*>(me_outputs[i]->data_c());
820     const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
821     ASSERT_TRUE(ge_tensors[i]->GetSize() == me_outputs[i]->data().nbytes());
822     ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
823     ASSERT_TRUE(request_dims[i] == me_outputs[i]->shape_c());
824   }
825 }
826 
827 TEST_F(TestConvert, TestConvertGeShape1) {
828   GeShape ge_shape({10, 1, 1, 1});
829   std::vector<int64_t> request_dims{10};
830   ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
831 }
832 
833 TEST_F(TestConvert, TestConvertGeShape2) {
834   GeShape ge_shape({10, 15, 1, 1});
835   std::vector<int64_t> request_dims{10, 15};
836   ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
837 }
838 
839 TEST_F(TestConvert, TestConvertGeShape3) {
840   GeShape ge_shape({10, 13, 18, 1});
841   std::vector<int64_t> request_dims{10, 13, 18};
842   ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
843 }
844 
845 TEST_F(TestConvert, TestConvertGeShape4) {
846   GeShape ge_shape({1, 10, 1, 1});
847   std::vector<int64_t> request_dims{10};
848   ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
849 }
850 
851 TEST_F(TestConvert, TestConvertGeShape5) {
852   GeShape ge_shape({10, 1, 1, 2});
853   std::vector<int64_t> request_dims{10};
854   ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
855 }
856 
857 TEST_F(TestConvert, TestConvertGeShape6) {
858   GeShape ge_shape({5, 2, 1, 1});
859   std::vector<int64_t> request_dims{10};
860   ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
861 }
862 
863 TEST_F(TestConvert, TestConvertGeShape7) {
864   GeShape ge_shape({10});
865   std::vector<int64_t> request_dims{10, 1};
866   ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
867 }
868 }  // namespace transform
869 }  // namespace mindspore
870