1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <iostream>
17 #include <unordered_map>
18
19 #include "pybind11/pybind11.h"
20
21 #include "transform/transform_base_test.h"
22 #include "common/py_func_graph_fetcher.h"
23 #include "pipeline/jit/parse/parse.h"
24 #include "debug/draw.h"
25 #include "debug/anf_ir_dump.h"
26 #include "pipeline/jit/static_analysis/prim.h"
27 #include "frontend/operator/ops.h"
28 #include "common/common_test.h"
29
30 #define private public
31 #include "transform/graph_ir/types.h"
32 #include "transform/graph_ir/convert.h"
33 #include "securec/include/securec.h"
34 #include "utils/utils.h"
35 using std::cout;
36 using std::endl;
37 using std::string;
38 using std::unordered_map;
39
40 namespace mindspore {
41 namespace transform {
42 using AbstractScalar = abstract::AbstractScalar;
43 using mindspore::parse::ResolveAll;
44
45 class TestConvert : public UT::Common {
46 public:
TestConvert()47 TestConvert() {}
48 virtual void SetUp();
49 virtual void TearDown();
50 static const std::shared_ptr<Float> kF32;
51 };
52
SetUp()53 void TestConvert::SetUp() { UT::InitPythonPath(); }
TearDown()54 void TestConvert::TearDown() {}
55
56 const std::shared_ptr<Float> TestConvert::kF32 = std::make_shared<Float>(32);
57
createAnfGraph()58 AnfGraphPtr createAnfGraph() { return std::make_shared<AnfGraph>(); }
59
TEST_F(TestConvert,TestConstruct)60 TEST_F(TestConvert, TestConstruct) {
61 AnfGraphPtr func_graph = std::make_shared<AnfGraph>();
62 DfGraphConvertor converter(func_graph);
63 converter.ConvertAllNode().GetComputeGraph();
64 ASSERT_NE(converter.ErrCode(), SUCCESS);
65 }
66
67 #if (!defined ENABLE_GE)
68
69 namespace {
70
MakeDfGraph(PrimitivePtr prim,unsigned int nparam)71 bool MakeDfGraph(PrimitivePtr prim, unsigned int nparam) {
72 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, nparam);
73 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
74
75 DfGraphConvertor converter(anf_graph);
76 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
77 if (converter.ErrCode() != 0) {
78 MS_LOG(ERROR) << "DfGraphConvertor convert " << prim->name() << " error, error code is: " << converter.ErrCode();
79 return false;
80 }
81 if (df_graph == nullptr) {
82 MS_LOG(ERROR) << "DfGraphConvertor get " << prim->name() << " compute func_graph failed";
83 return false;
84 }
85 return true;
86 }
87
88 } // namespace
89
TEST_F(TestConvert,TestConvertConv2d)90 TEST_F(TestConvert, TestConvertConv2d) {
91 PrimitivePtr conv2d = prim::kPrimConv2D;
92 conv2d->AddAttr("stride", MakeValue(static_cast<int64_t>(2)));
93 conv2d->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
94 conv2d->AddAttr("dilation", MakeValue(static_cast<int64_t>(0)));
95
96 FuncGraphPtr anf_graph = MakeFuncGraph(conv2d, 2);
97 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
98 DfGraphConvertor converter(anf_graph);
99 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
100 ASSERT_EQ(converter.ErrCode(), 0);
101 ASSERT_NE(df_graph, nullptr);
102 }
103
TEST_F(TestConvert,TestConvertMaxpooling)104 TEST_F(TestConvert, TestConvertMaxpooling) {
105 auto prim = std::make_shared<Primitive>("MaxPool");
106 FuncGraphPtr anf_graph = MakeFuncGraph(prim, 5); // ary, ksize, stride, padding, data_format
107
108 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
109 DfGraphConvertor converter(anf_graph);
110 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
111 ASSERT_EQ(converter.ErrCode(), 0);
112 ASSERT_NE(df_graph, nullptr);
113 }
114
TEST_F(TestConvert,TestReluOps)115 TEST_F(TestConvert, TestReluOps) {
116 auto prim = prim::kPrimRelu;
117 prim->AddAttr("T", MakeValue(static_cast<int64_t>(0)));
118
119 auto func_graph = MakeFuncGraph(prim, 1);
120 ASSERT_TRUE(nullptr != func_graph);
121
122 // save the func_graph to manager
123 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
124
125 // call resolve
126 bool ret_ = ResolveAll(manager);
127 ASSERT_TRUE(ret_);
128
129 // draw graph
130 auto anfGraph = *(manager->func_graphs().begin());
131 DfGraphConvertor converter(anfGraph);
132 converter.ConvertAllNode().BuildGraph().GetComputeGraph();
133 ASSERT_EQ(converter.ErrCode(), 0);
134 }
135
TEST_F(TestConvert,TestConvertBatchNorm)136 TEST_F(TestConvert, TestConvertBatchNorm) {
137 PrimitivePtr batch_norm = prim::kPrimBatchNorm;
138 batch_norm->AddAttr("epsilon", MakeValue(0.001f));
139 batch_norm->AddAttr("momentum", MakeValue(0.1f));
140
141 FuncGraphPtr anf_graph = std::make_shared<FuncGraph>();
142 std::vector<AnfNodePtr> inputs;
143 inputs.push_back(NewValueNode(batch_norm));
144 for (unsigned int i = 0; i < 5; i++) {
145 inputs.push_back(anf_graph->add_parameter());
146 }
147 CNodePtr cnode_prim = anf_graph->NewCNode(inputs);
148 inputs.clear();
149
150 inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
151 inputs.push_back(cnode_prim);
152 inputs.push_back(NewValueNode(static_cast<int64_t>(2)));
153 CNodePtr cnode_getitem = anf_graph->NewCNode(inputs);
154 inputs.clear();
155
156 inputs.push_back(NewValueNode(prim::kPrimRelu));
157 inputs.push_back(cnode_getitem);
158 CNodePtr cnode_relu = anf_graph->NewCNode(inputs);
159 inputs.clear();
160
161 inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
162 inputs.push_back(cnode_relu);
163 CNodePtr cnode_return = anf_graph->NewCNode(inputs);
164 anf_graph->set_return(cnode_return);
165
166 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
167 DfGraphConvertor converter(anf_graph);
168 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
169 ASSERT_EQ(converter.ErrCode(), 0);
170 ASSERT_NE(df_graph, nullptr);
171 }
172
TEST_F(TestConvert,TestConvertConvBackpropInput)173 TEST_F(TestConvert, TestConvertConvBackpropInput) {
174 auto prim = prim::kPrimConv2DBackpropInput;
175 const std::vector<int64_t> list{1,1};
176 prim->AddAttr("stride", MakeValue(list));
177 prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
178 prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
179 prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
180 prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
181 prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
182 prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
183
184 auto func_graph = MakeFuncGraph(prim, 3);
185 ASSERT_NE(func_graph, nullptr);
186 // save the func_graph to manager
187 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
188
189 // call resolve
190 bool ret_ = ResolveAll(manager);
191 ASSERT_TRUE(ret_);
192
193 // draw graph
194 auto anf_graph = *(manager->func_graphs().begin());
195 DfGraphConvertor converter(anf_graph);
196 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
197 ASSERT_EQ(converter.ErrCode(), 0);
198 ASSERT_NE(df_graph, nullptr);
199 }
200
TEST_F(TestConvert,TestConvertConvBackpropFilter)201 TEST_F(TestConvert, TestConvertConvBackpropFilter) {
202 auto prim = prim::kPrimConv2DBackpropFilter;
203 const std::vector<int64_t> list{1,1};
204 prim->AddAttr("stride", MakeValue(list));
205 prim->AddAttr("pad", MakeValue(static_cast<int64_t>(0)));
206 prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
207 prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
208 prim->AddAttr("group", MakeValue(static_cast<int64_t>(1)));
209 prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
210 prim->AddAttr("dilation", MakeValue(static_cast<int64_t>(1)));
211
212 auto func_graph = MakeFuncGraph(prim, 3);
213 ASSERT_NE(func_graph, nullptr);
214 // save the func_graph to manager
215 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
216
217 // call resolve
218 bool ret_ = ResolveAll(manager);
219 ASSERT_TRUE(ret_);
220
221 // draw graph
222 auto anf_graph = *(manager->func_graphs().begin());
223 DfGraphConvertor converter(anf_graph);
224 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
225 ASSERT_EQ(converter.ErrCode(), 0);
226 ASSERT_NE(df_graph, nullptr);
227 }
228
TEST_F(TestConvert,TestConvertReluGrad)229 TEST_F(TestConvert, TestConvertReluGrad) {
230 auto prim = prim::kPrimReluGrad;
231 prim->AddAttr("alpha", MakeValue(0.1f));
232 prim->AddAttr("beta", MakeValue(0.1f));
233 prim->AddAttr("mode", MakeValue(static_cast<int64_t>(1)));
234
235 auto func_graph = MakeFuncGraph(prim, 2);
236 ASSERT_NE(func_graph, nullptr);
237 // save the func_graph to manager
238 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
239
240 // call resolve
241 bool ret_ = ResolveAll(manager);
242 ASSERT_TRUE(ret_);
243
244 // draw graph
245 auto anf_graph = *(manager->func_graphs().begin());
246 DfGraphConvertor converter(anf_graph);
247 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
248 ASSERT_EQ(converter.ErrCode(), 0);
249 ASSERT_NE(df_graph, nullptr);
250 }
251
TEST_F(TestConvert,TestConvertBiasAdd)252 TEST_F(TestConvert, TestConvertBiasAdd) {
253 auto prim = std::make_shared<Primitive>("BiasAdd");
254 prim->AddAttr("alpha", MakeValue(0.0f));
255 prim->AddAttr("beta", MakeValue(1.0f));
256
257 auto func_graph = MakeFuncGraph(prim, 2);
258 ASSERT_NE(func_graph, nullptr);
259 // save the func_graph to manager
260 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
261
262 // call resolve
263 bool ret_ = ResolveAll(manager);
264 ASSERT_TRUE(ret_);
265
266 // draw graph
267 auto anf_graph = *(manager->func_graphs().begin());
268 DfGraphConvertor converter(anf_graph);
269 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
270 ASSERT_EQ(converter.ErrCode(), 0);
271 ASSERT_NE(df_graph, nullptr);
272 }
273
TEST_F(TestConvert,TestConvertBiasAddGrad)274 TEST_F(TestConvert, TestConvertBiasAddGrad) {
275 auto prim = prim::kPrimBiasAddGrad;
276 prim->AddAttr("alpha", MakeValue(0.0f));
277 prim->AddAttr("beta", MakeValue(1.0f));
278
279 auto func_graph = MakeFuncGraph(prim, 2);
280 ASSERT_NE(func_graph, nullptr);
281 // save the func_graph to manager
282 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
283
284 // call resolve
285 bool ret_ = ResolveAll(manager);
286 ASSERT_TRUE(ret_);
287
288 // draw graph
289 auto anf_graph = *(manager->func_graphs().begin());
290 DfGraphConvertor converter(anf_graph);
291 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
292 ASSERT_EQ(converter.ErrCode(), 0);
293 ASSERT_NE(df_graph, nullptr);
294 }
295
TEST_F(TestConvert,TestConvertMaxPoolGradWithArgmax)296 TEST_F(TestConvert, TestConvertMaxPoolGradWithArgmax) {
297 auto prim = std::make_shared<Primitive>("MaxPoolGradWithArgmax");
298 prim->AddAttr("alpha", MakeValue(0.0f));
299 prim->AddAttr("beta", MakeValue(1.0f));
300 prim->AddAttr("window", MakeValue(static_cast<int64_t>(2)));
301 prim->AddAttr("stride", MakeValue(static_cast<int64_t>(1)));
302 prim->AddAttr("ceil_mode", MakeValue(static_cast<int64_t>(0)));
303 prim->AddAttr("data_mode", MakeValue(static_cast<int64_t>(0)));
304 prim->AddAttr("alpha", MakeValue(0.1f));
305 prim->AddAttr("beta", MakeValue(1.0f));
306
307 auto func_graph = MakeFuncGraph(prim, 2);
308 ASSERT_NE(func_graph, nullptr);
309 // save the func_graph to manager
310 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
311
312 // call resolve
313 bool ret_ = ResolveAll(manager);
314 ASSERT_TRUE(ret_);
315
316 // draw graph
317 auto anf_graph = *(manager->func_graphs().begin());
318 DfGraphConvertor converter(anf_graph);
319 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
320 ASSERT_EQ(converter.ErrCode(), 0);
321 ASSERT_NE(df_graph, nullptr);
322 }
323
TEST_F(TestConvert,TestConcat)324 TEST_F(TestConvert, TestConcat) {
325 auto prim = prim::kPrimConcat;
326
327 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
328 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
329 DfGraphConvertor converter(anf_graph);
330 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
331 ASSERT_EQ(converter.ErrCode(), 0);
332 ASSERT_NE(df_graph, nullptr);
333 }
334
TEST_F(TestConvert,TestGatherV2)335 TEST_F(TestConvert, TestGatherV2) {
336 auto prim = prim::kPrimGather;
337
338 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 3);
339 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
340 DfGraphConvertor converter(anf_graph);
341 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
342 ASSERT_EQ(converter.ErrCode(), 0);
343 ASSERT_NE(df_graph, nullptr);
344 }
345
TEST_F(TestConvert,TestCast)346 TEST_F(TestConvert, TestCast) {
347 auto prim = prim::kPrimCast;
348
349 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
350 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
351 DfGraphConvertor converter(anf_graph);
352 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
353 ASSERT_EQ(converter.ErrCode(), 0);
354 ASSERT_NE(df_graph, nullptr);
355 }
356
TEST_F(TestConvert,TestExp)357 TEST_F(TestConvert, TestExp) {
358 auto prim = std::make_shared<Primitive>("Exp");
359
360 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
361 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
362 DfGraphConvertor converter(anf_graph);
363 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
364 ASSERT_EQ(converter.ErrCode(), 0);
365 ASSERT_NE(df_graph, nullptr);
366 }
367
TEST_F(TestConvert,TestFloor)368 TEST_F(TestConvert, TestFloor) {
369 auto prim = std::make_shared<Primitive>("Floor");
370
371 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
372 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
373 DfGraphConvertor converter(anf_graph);
374 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
375 ASSERT_EQ(converter.ErrCode(), 0);
376 ASSERT_NE(df_graph, nullptr);
377 }
378
TEST_F(TestConvert,TestGreaterEqual)379 TEST_F(TestConvert, TestGreaterEqual) {
380 auto prim = std::make_shared<Primitive>("GreaterEqual");
381
382 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
383 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
384
385 DfGraphConvertor converter(anf_graph);
386 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
387 ASSERT_EQ(converter.ErrCode(), 0);
388 ASSERT_NE(df_graph, nullptr);
389 }
390
TEST_F(TestConvert,TestLess)391 TEST_F(TestConvert, TestLess) {
392 auto prim = std::make_shared<Primitive>("Less");
393 prim->AddAttr("T", MakeValue(kFloat32));
394
395 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
396 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
397
398 DfGraphConvertor converter(anf_graph);
399 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
400 ASSERT_EQ(converter.ErrCode(), 0);
401 ASSERT_NE(df_graph, nullptr);
402 }
403
TEST_F(TestConvert,TestLessEqual)404 TEST_F(TestConvert, TestLessEqual) {
405 auto prim = std::make_shared<Primitive>("LessEqual");
406
407 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
408 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
409
410 DfGraphConvertor converter(anf_graph);
411 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
412 ASSERT_EQ(converter.ErrCode(), 0);
413 ASSERT_NE(df_graph, nullptr);
414 }
415
TEST_F(TestConvert,TestLogicalNot)416 TEST_F(TestConvert, TestLogicalNot) {
417 auto prim = std::make_shared<Primitive>("LogicalNot");
418
419 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
420 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
421
422 DfGraphConvertor converter(anf_graph);
423 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
424 ASSERT_EQ(converter.ErrCode(), 0);
425 ASSERT_NE(df_graph, nullptr);
426 }
427
TEST_F(TestConvert,TestAssignAdd)428 TEST_F(TestConvert, TestAssignAdd) {
429 auto prim = prim::kPrimAssignAdd;
430 prim->AddAttr("use_locking", MakeValue(true));
431
432 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
433 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
434
435 DfGraphConvertor converter(anf_graph);
436 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
437 ASSERT_EQ(converter.ErrCode(), 0);
438 ASSERT_NE(df_graph, nullptr);
439 }
440
TEST_F(TestConvert,LogSoftmax)441 TEST_F(TestConvert, LogSoftmax) {
442 auto prim = prim::kPrimLogSoftmax;
443 prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
444
445 std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
446 std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
447
448 DfGraphConvertor converter(anf_graph);
449 auto df_graph = converter.ConvertAllNode().BuildGraph().GetComputeGraph();
450 ASSERT_EQ(converter.ErrCode(), 0);
451 ASSERT_NE(df_graph, nullptr);
452 }
453
TEST_F(TestConvert,TestMaximumOps)454 TEST_F(TestConvert, TestMaximumOps) {
455 auto prim = prim::kPrimMaximum;
456 bool ret = MakeDfGraph(prim, 2);
457 ASSERT_TRUE(ret);
458 }
459
TEST_F(TestConvert,TestReduceMeanOps)460 TEST_F(TestConvert, TestReduceMeanOps) {
461 auto prim = prim::kPrimReduceMean;
462 prim->AddAttr("keepdims", MakeValue(true));
463 bool ret = MakeDfGraph(prim, 2);
464 ASSERT_TRUE(ret);
465 }
466
TEST_F(TestConvert,TestMinimumOps)467 TEST_F(TestConvert, TestMinimumOps) {
468 auto prim = prim::kPrimMinimum;
469 bool ret = MakeDfGraph(prim, 2);
470 ASSERT_TRUE(ret);
471 }
472
TEST_F(TestConvert,TestFusedMinOrMaxGradOps)473 TEST_F(TestConvert, TestFusedMinOrMaxGradOps) {
474 // Add infer step to this test case
475 ASSERT_TRUE(true);
476 }
477
TEST_F(TestConvert,TestSqueezeOps)478 TEST_F(TestConvert, TestSqueezeOps) {
479 auto prim = prim::kPrimSqueeze;
480 bool ret = MakeDfGraph(prim, 2);
481 ASSERT_TRUE(ret);
482 }
483
TEST_F(TestConvert,TestMulOps)484 TEST_F(TestConvert, TestMulOps) {
485 auto prim = prim::kPrimMul;
486 bool ret = MakeDfGraph(prim, 2);
487 ASSERT_TRUE(ret);
488 }
489
TEST_F(TestConvert,TestNegOps)490 TEST_F(TestConvert, TestNegOps) {
491 auto prim = prim::kPrimNeg;
492 bool ret = MakeDfGraph(prim, 1);
493 ASSERT_TRUE(ret);
494 }
495
TEST_F(TestConvert,TestOneHotOps)496 TEST_F(TestConvert, TestOneHotOps) {
497 auto prim = prim::kPrimOneHot;
498 prim->AddAttr("axis", MakeValue(static_cast<int64_t>(0)));
499 bool ret = MakeDfGraph(prim, 4);
500 ASSERT_TRUE(ret);
501 }
502
TEST_F(TestConvert,TestPowOps)503 TEST_F(TestConvert, TestPowOps) {
504 auto prim = std::make_shared<Primitive>("Pow");
505 bool ret = MakeDfGraph(prim, 2);
506 ASSERT_TRUE(ret);
507 }
508
TEST_F(TestConvert,TestReciprocalOps)509 TEST_F(TestConvert, TestReciprocalOps) {
510 auto prim = std::make_shared<Primitive>("Reciprocal");
511 bool ret = MakeDfGraph(prim, 1);
512 ASSERT_TRUE(ret);
513 }
514
TEST_F(TestConvert,TestSelectOps)515 TEST_F(TestConvert, TestSelectOps) {
516 auto prim = prim::kPrimSelect;
517 bool ret = MakeDfGraph(prim, 3);
518 ASSERT_TRUE(ret);
519 }
520
TEST_F(TestConvert,TestSqrtOps)521 TEST_F(TestConvert, TestSqrtOps) {
522 auto prim = std::make_shared<Primitive>("Sqrt");
523 bool ret = MakeDfGraph(prim, 1);
524 ASSERT_TRUE(ret);
525 }
526
TEST_F(TestConvert,TestSquareOps)527 TEST_F(TestConvert, TestSquareOps) {
528 auto prim = std::make_shared<Primitive>("Square");
529 bool ret = MakeDfGraph(prim, 1);
530 ASSERT_TRUE(ret);
531 }
532
533 #ifndef ENABLE_SECURITY
TEST_F(TestConvert,TestScalarSummaryOps)534 TEST_F(TestConvert, TestScalarSummaryOps) {
535 auto prim = prim::kPrimScalarSummary;
536 // should have only 1 input.
537 bool ret = MakeDfGraph(prim, 2);
538 ASSERT_TRUE(ret);
539 }
540
TEST_F(TestConvert,TestTensorSummaryOps)541 TEST_F(TestConvert, TestTensorSummaryOps) {
542 auto prim = prim::kPrimTensorSummary;
543 bool ret = MakeDfGraph(prim, 2);
544 ASSERT_TRUE(ret);
545 }
546
TEST_F(TestConvert,TestHistogramSummaryOps)547 TEST_F(TestConvert, TestHistogramSummaryOps) {
548 auto prim = prim::kPrimHistogramSummary;
549 bool ret = MakeDfGraph(prim, 2);
550 ASSERT_TRUE(ret);
551 }
552 #endif
553
TEST_F(TestConvert,TestGreaterOps)554 TEST_F(TestConvert, TestGreaterOps) {
555 auto prim = std::make_shared<Primitive>("Greater");
556 bool ret = MakeDfGraph(prim, 2);
557 ASSERT_TRUE(ret);
558 }
559
TEST_F(TestConvert,TestEqualOps)560 TEST_F(TestConvert, TestEqualOps) {
561 auto prim = std::make_shared<Primitive>("Equal");
562 bool ret = MakeDfGraph(prim, 2);
563 ASSERT_TRUE(ret);
564 }
565
TEST_F(TestConvert,TestArgMaxiOps)566 TEST_F(TestConvert, TestArgMaxiOps) {
567 auto prim = std::make_shared<Primitive>("Argmax");
568 bool ret = MakeDfGraph(prim, 2);
569 ASSERT_TRUE(ret);
570 }
571
TEST_F(TestConvert,TestResizeNearestNeighborOps)572 TEST_F(TestConvert, TestResizeNearestNeighborOps) {
573 auto prim = std::make_shared<Primitive>("ResizeNearestNeighbor");
574 bool ret = MakeDfGraph(prim, 1);
575 ASSERT_TRUE(ret);
576 }
577
TEST_F(TestConvert,TestApplyMomentumOps)578 TEST_F(TestConvert, TestApplyMomentumOps) {
579 auto prim = std::make_shared<Primitive>("ApplyMomentum");
580 bool ret = MakeDfGraph(prim, 5);
581 ASSERT_TRUE(ret);
582 }
583
TEST_F(TestConvert,TestNPUGetFloatStatusOps)584 TEST_F(TestConvert, TestNPUGetFloatStatusOps) {
585 auto prim = std::make_shared<Primitive>("NPUGetFloatStatus");
586 bool ret = MakeDfGraph(prim, 1);
587 ASSERT_TRUE(ret);
588 }
589
TEST_F(TestConvert,TestNPUAllocFloatStatusOps)590 TEST_F(TestConvert, TestNPUAllocFloatStatusOps) {
591 auto prim = std::make_shared<Primitive>("NPUAllocFloatStatus");
592 bool ret = MakeDfGraph(prim, 0);
593 ASSERT_TRUE(ret);
594 }
595
TEST_F(TestConvert,TestNPUClearFloatStatusOps)596 TEST_F(TestConvert, TestNPUClearFloatStatusOps) {
597 auto prim = std::make_shared<Primitive>("NPUClearFloatStatus");
598 bool ret = MakeDfGraph(prim, 1);
599 ASSERT_TRUE(ret);
600 }
601
602 #endif
603
TEST_F(TestConvert,TestAddOps)604 TEST_F(TestConvert, TestAddOps) {
605 auto prim = std::make_shared<Primitive>("Add");
606 auto func_graph = MakeFuncGraph(prim, 2);
607 ASSERT_TRUE(nullptr != func_graph);
608
609 // save the func_graph to manager
610 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
611
612 // call resolve
613 bool ret_ = ResolveAll(manager);
614 ASSERT_TRUE(ret_);
615
616 // draw graph
617 auto anfGraph = *(manager->func_graphs().begin());
618 DfGraphConvertor converter(anfGraph);
619 converter.ConvertAllNode().BuildGraph().GetComputeGraph();
620 ASSERT_EQ(converter.ErrCode(), 0);
621 }
622
TEST_F(TestConvert,TestConvertTensor)623 TEST_F(TestConvert, TestConvertTensor) {
624 float data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
625 // Create a tensor with wanted data type and shape
626 std::vector<int64_t> dims{2, 2, 3};
627 std::vector<int64_t> ge_dims{2, 2, 3};
628 auto type_id = kNumberTypeFloat32;
629 MeTensor me_tensor(type_id, dims);
630 // Get the writable data pointer of the tensor and cast it to its data type
631 uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c());
632 // Copy or use the writable data pointer of the ME tensor
633 memcpy_s(me_data_ptr, me_tensor.data().nbytes(), data, 12 * sizeof(float));
634 auto me_tensor_ptr = std::make_shared<MeTensor>(me_tensor);
635 auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW);
636 ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetFormat(), GeFormat::FORMAT_NCHW);
637 ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetDataType(), GeDataType::DT_FLOAT);
638 // ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().array().GetDims(), ge_dims);
639 int i = 0;
640 for (i = 0; i < ge_dims.size(); i++) {
641 ASSERT_EQ(ge_dims[i], ge_tensor_ptr->GetTensorDesc().GetShape().GetDims()[i]);
642 }
643 for (i = 0; i < ge_tensor_ptr->GetTensorDesc().GetShape().GetShapeSize(); i++) {
644 ASSERT_EQ(data[i], (reinterpret_cast<float*>(ge_tensor_ptr->GetData()))[i]);
645 }
646 }
647
TEST_F(TestConvert,TestConvertTensor0Dims)648 TEST_F(TestConvert, TestConvertTensor0Dims) {
649 // shape with 0 dims is also valid
650 std::vector<int64_t> dims{};
651 auto type_id = kNumberTypeFloat32;
652 auto me_tensor_ptr = std::make_shared<MeTensor>(type_id, dims);
653 ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW), nullptr);
654 }
655
TEST_F(TestConvert,TestConvertTensorError)656 TEST_F(TestConvert, TestConvertTensorError) {
657 std::vector<int64_t> dims2{2, 3, 4};
658 auto type_id_2 = kNumberTypeFloat32;
659 auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2);
660 ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
661 }
662
TEST_F(TestConvert,TestUtilsConvertDataType)663 TEST_F(TestConvert, TestUtilsConvertDataType) {
664 ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat16), GeDataType::DT_FLOAT16);
665 ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat32), GeDataType::DT_FLOAT);
666 ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat64), GeDataType::DT_DOUBLE);
667 ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt8), GeDataType::DT_INT8);
668 ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt16), GeDataType::DT_INT16);
669 ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt32), GeDataType::DT_INT32);
670 ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt64), GeDataType::DT_INT64);
671 ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeUInt32), GeDataType::DT_UINT32);
672 ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeBool), GeDataType::DT_BOOL);
673 }
674
TEST_F(TestConvert,TestUtilsConvertFormat)675 TEST_F(TestConvert, TestUtilsConvertFormat) {
676 ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NCHW), GeFormat::FORMAT_NCHW);
677 ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NC1HWC0), GeFormat::FORMAT_NC1HWC0);
678 ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NHWC), GeFormat::FORMAT_NHWC);
679 ASSERT_EQ(TransformUtil::ConvertFormat("xyz"), GeFormat::FORMAT_ND);
680 }
681
TEST_F(TestConvert,TestUtilsDataSize)682 TEST_F(TestConvert, TestUtilsDataSize) {
683 ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat32), 4);
684 ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat16), 2);
685 ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat64), 8);
686 ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt8), 1);
687 ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt16), 2);
688 ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt32), 4);
689 ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt64), 8);
690 ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeUInt32), 4);
691 ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeBool), 1);
692 }
693
TEST_F(TestConvert,TestConvertGeTensor)694 TEST_F(TestConvert, TestConvertGeTensor) {
695 #define DTYPE float
696 ge::DataType dt = ge::DataType::DT_FLOAT;
697
698 std::vector<float> data1 = {1.1, 2.2, 3.3, 4.4, 6.6, 7.7, 8.8, 9.9};
699 std::vector<DTYPE> data2 = {1, 2, 3, 4, 6, 7, 8, 9};
700 auto data = data1;
701 ge::Shape shape({2, 2, 2});
702 ge::Format format = ge::Format::FORMAT_NCHW;
703 ge::TensorDesc desc(shape, format, dt);
704 GeTensorPtr ge_tensor_ptr =
705 std::make_shared<GeTensor>(desc, reinterpret_cast<uint8_t*>(data.data()), data.size() * sizeof(DTYPE));
706 GeTensor& ge_tensor = *ge_tensor_ptr;
707 const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensor.GetData());
708
709 // make sure GetData()'s return is a reference
710 assert(ge_data == reinterpret_cast<DTYPE*>(ge_tensor.GetData()));
711
712 cout << "ge data size is: " << std::dec << ge_tensor.GetSize() << " bytes" << endl;
713 for (int i = 0; i < ge_tensor.GetSize() / sizeof(DTYPE); i++) {
714 cout << "ge data is: " << static_cast<DTYPE>(*(ge_data + i)) << endl;
715 }
716
717 MeTensorPtr me_tensor_ptr = TransformUtil::ConvertGeTensor(ge_tensor_ptr);
718 MeTensor& me_tensor = *me_tensor_ptr;
719 cout << "after convert ge tensor to me tensor" << endl;
720 DTYPE* me_data = reinterpret_cast<DTYPE*>(me_tensor.data_c());
721 PrintMeTensor(&me_tensor);
722
723 assert(ge_tensor.GetSize() == me_tensor.data().nbytes());
724 assert(memcmp(ge_data, me_data, ge_tensor.GetSize()) == 0);
725 }
726
TEST_F(TestConvert,TestConvertMakeTuple)727 TEST_F(TestConvert, TestConvertMakeTuple) {
728 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
729 std::vector<AnfNodePtr> inputs;
730 inputs.push_back(NewValueNode(std::make_shared<Primitive>("MakeTuple")));
731 for (int i = 0; i < 3; i++) {
732 auto input = func_graph->add_parameter();
733 input->set_name("x" + std::to_string(i));
734 inputs.push_back(input);
735 }
736 CNodePtr cnode_prim = func_graph->NewCNode(inputs);
737 inputs.clear();
738 inputs.push_back(NewValueNode(std::make_shared<Primitive>("Return")));
739 inputs.push_back(cnode_prim);
740 CNodePtr cnode_return = func_graph->NewCNode(inputs);
741 func_graph->set_return(cnode_return);
742
743 // save the func_graph to manager
744 std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
745
746 // call resolve
747 bool ret_ = ResolveAll(manager);
748 ASSERT_TRUE(ret_);
749
750 // draw graph
751 auto anfGraph = *(manager->func_graphs().begin());
752 DfGraphConvertor converter(anfGraph);
753 converter.ConvertAllNode().BuildGraph().GetComputeGraph();
754 ASSERT_EQ(converter.ErrCode(), 0);
755 }
756
TEST_F(TestConvert,TestConvertInputTensors)757 TEST_F(TestConvert, TestConvertInputTensors) {
758 #define DTYPE float
759 std::initializer_list<int64_t> list0 = {1, 1, 4, 4};
760 std::initializer_list<int64_t> list1 = {2, 3, 4, 5};
761 std::initializer_list<int64_t> list2 = {9, 9, 1, 1};
762 MeTensorPtr input_ptr1 = MakeTensor(kF32, list0);
763 MeTensorPtr input_ptr2 = MakeTensor(kF32, list1);
764 MeTensorPtr input_ptr3 = MakeTensor(kF32, list2);
765 std::vector<MeTensorPtr> me_inputs;
766 me_inputs.emplace_back(input_ptr1);
767 me_inputs.emplace_back(input_ptr2);
768 me_inputs.emplace_back(input_ptr3);
769
770 std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(me_inputs, kOpFormat_NCHW);
771
772 for (int i = 0; i < ge_tensors.size(); i++) {
773 DTYPE* me_data = reinterpret_cast<DTYPE*>(me_inputs[i]->data_c());
774 const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
775 ASSERT_TRUE(ge_tensors[i]->GetSize() == me_inputs[i]->data().nbytes());
776 ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
777 ASSERT_TRUE(ge_tensors[i]->GetTensorDesc().GetShape().GetDims() ==
778 TransformUtil::ConvertMeShape(me_inputs[i]->shape_c()).GetDims());
779 }
780 }
781
TEST_F(TestConvert,TestConvertGeTensors)782 TEST_F(TestConvert, TestConvertGeTensors) {
783 #define DTYPE float
784 ge::DataType dt = ge::DataType::DT_FLOAT;
785
786 std::vector<float> data1(16);
787 std::vector<float> data2(120);
788 std::vector<float> data3(81);
789 ge::Shape shape1({1, 1, 4, 4});
790 ge::Shape shape2({2, 3, 4, 5});
791 ge::Shape shape3({9, 9, 1, 1});
792 ge::Format format = ge::Format::FORMAT_NCHW;
793 ge::TensorDesc desc1(shape1, format, dt);
794 ge::TensorDesc desc2(shape2, format, dt);
795 ge::TensorDesc desc3(shape3, format, dt);
796 GeTensorPtr ge_tensor_ptr1 =
797 std::make_shared<GeTensor>(desc1, reinterpret_cast<uint8_t*>(data1.data()), data1.size() * sizeof(DTYPE));
798 GeTensorPtr ge_tensor_ptr2 =
799 std::make_shared<GeTensor>(desc2, reinterpret_cast<uint8_t*>(data2.data()), data2.size() * sizeof(DTYPE));
800 GeTensorPtr ge_tensor_ptr3 =
801 std::make_shared<GeTensor>(desc3, reinterpret_cast<uint8_t*>(data3.data()), data3.size() * sizeof(DTYPE));
802
803 std::vector<GeTensorPtr> ge_tensors;
804 ge_tensors.emplace_back(ge_tensor_ptr1);
805 ge_tensors.emplace_back(ge_tensor_ptr2);
806 ge_tensors.emplace_back(ge_tensor_ptr3);
807
808 std::vector<std::vector<int64_t>> request_dims;
809 std::vector<int64_t> dims1 = {1, 1, 4, 4};
810 std::vector<int64_t> dims2 = {2, 3, 4, 5};
811 std::vector<int64_t> dims3 = {9, 9, 1, 1};
812 request_dims.emplace_back(dims1);
813 request_dims.emplace_back(dims2);
814 request_dims.emplace_back(dims3);
815
816 std::vector<MeTensorPtr> me_outputs = TransformUtil::ConvertGeTensors(ge_tensors, request_dims);
817
818 for (int i = 0; i < ge_tensors.size(); i++) {
819 DTYPE* me_data = reinterpret_cast<DTYPE*>(me_outputs[i]->data_c());
820 const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
821 ASSERT_TRUE(ge_tensors[i]->GetSize() == me_outputs[i]->data().nbytes());
822 ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
823 ASSERT_TRUE(request_dims[i] == me_outputs[i]->shape_c());
824 }
825 }
826
TEST_F(TestConvert,TestConvertGeShape1)827 TEST_F(TestConvert, TestConvertGeShape1) {
828 GeShape ge_shape({10, 1, 1, 1});
829 std::vector<int64_t> request_dims{10};
830 ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
831 }
832
TEST_F(TestConvert,TestConvertGeShape2)833 TEST_F(TestConvert, TestConvertGeShape2) {
834 GeShape ge_shape({10, 15, 1, 1});
835 std::vector<int64_t> request_dims{10, 15};
836 ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
837 }
838
TEST_F(TestConvert,TestConvertGeShape3)839 TEST_F(TestConvert, TestConvertGeShape3) {
840 GeShape ge_shape({10, 13, 18, 1});
841 std::vector<int64_t> request_dims{10, 13, 18};
842 ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
843 }
844
TEST_F(TestConvert,TestConvertGeShape4)845 TEST_F(TestConvert, TestConvertGeShape4) {
846 GeShape ge_shape({1, 10, 1, 1});
847 std::vector<int64_t> request_dims{10};
848 ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
849 }
850
TEST_F(TestConvert,TestConvertGeShape5)851 TEST_F(TestConvert, TestConvertGeShape5) {
852 GeShape ge_shape({10, 1, 1, 2});
853 std::vector<int64_t> request_dims{10};
854 ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
855 }
856
TEST_F(TestConvert,TestConvertGeShape6)857 TEST_F(TestConvert, TestConvertGeShape6) {
858 GeShape ge_shape({5, 2, 1, 1});
859 std::vector<int64_t> request_dims{10};
860 ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
861 }
862
TEST_F(TestConvert,TestConvertGeShape7)863 TEST_F(TestConvert, TestConvertGeShape7) {
864 GeShape ge_shape({10});
865 std::vector<int64_t> request_dims{10, 1};
866 ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
867 }
868 } // namespace transform
869 } // namespace mindspore
870