• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <iostream>
17 #include <memory>
18 #include <vector>
19 
20 #include "common/common_test.h"
21 #include "frontend/operator/cc_implementations.h"
22 
23 namespace mindspore {
24 namespace prim {
25 
26 class TestImplementations : public UT::Common {
27  public:
28   TestImplementations() {}
29   virtual void SetUp() {}
30 };
31 
32 TEST_F(TestImplementations, ScalarAddTest) {
33   ValuePtrList list;
34   list.push_back(MakeValue(static_cast<int64_t>(1)));
35   list.push_back(MakeValue(static_cast<int64_t>(2)));
36   ASSERT_EQ(ScalarAdd(list)->cast<Int64ImmPtr>()->value(), 3);
37   list.clear();
38 
39   list.push_back(MakeValue(1.0f));
40   list.push_back(MakeValue(1.5f));
41   ASSERT_EQ(ScalarAdd(list)->cast<FP32ImmPtr>()->value(), 2.5f);
42   list.clear();
43 
44   list.push_back(MakeValue(3.0));
45   list.push_back(MakeValue(0.5));
46   ASSERT_EQ(ScalarAdd(list)->cast<FP64ImmPtr>()->value(), 3.5);
47   list.clear();
48 
49   list.push_back(MakeValue(INT64_MAX));
50   list.push_back(MakeValue(static_cast<int64_t>(2)));
51   try {
52     ScalarAdd(list);
53     FAIL();
54   } catch (std::runtime_error const &err) {
55     ASSERT_TRUE(std::string(err.what()).find("Overflow of the sum of two signed number") != std::string::npos);
56   }
57   list.clear();
58 
59   list.push_back(MakeValue(INT64_MIN));
60   list.push_back(MakeValue(static_cast<int64_t>(-1)));
61   try {
62     ScalarAdd(list);
63     FAIL();
64   } catch (std::runtime_error const &err) {
65     ASSERT_TRUE(std::string(err.what()).find("Overflow of the sum of two signed number") != std::string::npos);
66   }
67   list.clear();
68 }
69 
70 TEST_F(TestImplementations, ScalarSubTest) {
71   ValuePtrList list;
72   list.push_back(MakeValue(static_cast<int64_t>(1)));
73   list.push_back(MakeValue(static_cast<int64_t>(3)));
74   ASSERT_EQ(ScalarSub(list)->cast<Int64ImmPtr>()->value(), -2);
75   list.clear();
76 
77   list.push_back(MakeValue(1.0f));
78   list.push_back(MakeValue(1.5f));
79   ASSERT_EQ(ScalarSub(list)->cast<FP32ImmPtr>()->value(), -0.5f);
80   list.clear();
81 
82   list.push_back(MakeValue(3.0));
83   list.push_back(MakeValue(0.5));
84   ASSERT_EQ(ScalarSub(list)->cast<FP64ImmPtr>()->value(), 2.5);
85   list.clear();
86 
87   list.push_back(MakeValue(INT64_MAX));
88   list.push_back(MakeValue(static_cast<int64_t>(-1)));
89   try {
90     ScalarSub(list);
91     FAIL();
92   } catch (std::runtime_error const &err) {
93     ASSERT_TRUE(std::string(err.what()).find("Overflow of the sub of two signed number") != std::string::npos);
94   }
95   list.clear();
96 
97   list.push_back(MakeValue(INT64_MIN));
98   list.push_back(MakeValue(static_cast<int64_t>(1)));
99   try {
100     ScalarSub(list);
101     FAIL();
102   } catch (std::runtime_error const &err) {
103     ASSERT_TRUE(std::string(err.what()).find("Overflow of the sub of two signed number") != std::string::npos);
104   }
105   list.clear();
106 }
107 
108 TEST_F(TestImplementations, ScalarMulTest) {
109   ValuePtrList list;
110   list.push_back(MakeValue(static_cast<int64_t>(2)));
111   list.push_back(MakeValue(static_cast<int64_t>(3)));
112   ASSERT_EQ(ScalarMul(list)->cast<Int64ImmPtr>()->value(), 6);
113   list.clear();
114 
115   list.push_back(MakeValue(2.0f));
116   list.push_back(MakeValue(1.5f));
117   ASSERT_EQ(ScalarMul(list)->cast<FP32ImmPtr>()->value(), 3.0f);
118   list.clear();
119 
120   list.push_back(MakeValue(-2.0));
121   list.push_back(MakeValue(-4.0));
122   ASSERT_EQ(ScalarMul(list)->cast<FP64ImmPtr>()->value(), 8.0);
123   list.clear();
124 
125   list.push_back(MakeValue(static_cast<int64_t>(10)));
126   list.push_back(MakeValue(INT64_MAX));
127   try {
128     ScalarMul(list);
129     FAIL();
130   } catch (std::runtime_error const &err) {
131     ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos);
132   }
133   list.clear();
134 
135   list.push_back(MakeValue(INT64_MIN));
136   list.push_back(MakeValue(static_cast<int64_t>(-1)));
137   try {
138     ScalarMul(list);
139     FAIL();
140   } catch (std::runtime_error const &err) {
141     ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos);
142   }
143   list.clear();
144 
145   list.push_back(MakeValue(static_cast<int64_t>(-2)));
146   list.push_back(MakeValue(INT64_MAX));
147   try {
148     ScalarMul(list);
149     FAIL();
150   } catch (std::runtime_error const &err) {
151     ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos);
152   }
153   list.clear();
154 
155   list.push_back(MakeValue(static_cast<int64_t>(2)));
156   list.push_back(MakeValue(INT64_MIN));
157   try {
158     ScalarMul(list);
159     FAIL();
160   } catch (std::runtime_error const &err) {
161     ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos);
162   }
163   list.clear();
164 
165   list.push_back(MakeValue(static_cast<int64_t>(0)));
166   list.push_back(MakeValue(INT64_MIN));
167   ASSERT_EQ(ScalarDiv(list)->cast<Int64ImmPtr>()->value(), 0);
168   list.clear();
169 }
170 
171 TEST_F(TestImplementations, ScalarDivTest) {
172   ValuePtrList list;
173   list.push_back(MakeValue(static_cast<int64_t>(6)));
174   list.push_back(MakeValue(static_cast<int64_t>(3)));
175   ASSERT_EQ(ScalarDiv(list)->cast<Int64ImmPtr>()->value(), 2);
176   list.clear();
177 
178   list.push_back(MakeValue(3.0f));
179   list.push_back(MakeValue(1.5f));
180   ASSERT_EQ(ScalarDiv(list)->cast<FP32ImmPtr>()->value(), 2.0f);
181   list.clear();
182 
183   list.push_back(MakeValue(-4.0));
184   list.push_back(MakeValue(2.0));
185   ASSERT_EQ(ScalarDiv(list)->cast<FP64ImmPtr>()->value(), -2.0);
186   list.clear();
187 
188   list.push_back(MakeValue(INT64_MAX));
189   list.push_back(MakeValue(static_cast<int64_t>(0)));
190   try {
191     ScalarDiv(list);
192     FAIL();
193   } catch (std::runtime_error const &err) {
194     ASSERT_TRUE(std::string(err.what()).find("Divisor could not be zero") != std::string::npos);
195   }
196   list.clear();
197 
198   list.push_back(MakeValue(INT64_MIN));
199   list.push_back(MakeValue(static_cast<int64_t>(-1)));
200   try {
201     ScalarDiv(list);
202     FAIL();
203   } catch (std::runtime_error const &err) {
204     ASSERT_TRUE(std::string(err.what()).find("Overflow of the div of two signed number") != std::string::npos);
205   }
206   list.clear();
207 
208   list.push_back(MakeValue(static_cast<int64_t>(-1)));
209   list.push_back(MakeValue(INT64_MIN));
210   ASSERT_EQ(ScalarDiv(list)->cast<Int64ImmPtr>()->value(), 0);
211   list.clear();
212 }
213 
214 TEST_F(TestImplementations, ScalarModTest) {
215   ValuePtrList list;
216   list.push_back(MakeValue(static_cast<int64_t>(7)));
217   list.push_back(MakeValue(static_cast<int64_t>(3)));
218   ASSERT_EQ(ScalarMod(list)->cast<Int64ImmPtr>()->value(), 1);
219   list.clear();
220 
221   list.push_back(MakeValue(static_cast<int64_t>(-8)));
222   list.push_back(MakeValue(static_cast<int64_t>(3)));
223   ASSERT_EQ(ScalarMod(list)->cast<Int64ImmPtr>()->value(), -2);
224   list.clear();
225 
226   list.push_back(MakeValue(static_cast<int64_t>(-9)));
227   list.push_back(MakeValue(static_cast<int64_t>(2)));
228   ASSERT_EQ(ScalarMod(list)->cast<Int64ImmPtr>()->value(), -1);
229   list.clear();
230 
231   list.push_back(MakeValue(INT64_MIN));
232   list.push_back(MakeValue(static_cast<int64_t>(0)));
233   try {
234     ScalarMod(list);
235     FAIL();
236   } catch (std::runtime_error const &err) {
237     ASSERT_TRUE(std::string(err.what()).find("Could not mod to zero") != std::string::npos);
238   }
239   list.clear();
240 
241   list.push_back(MakeValue(INT64_MIN));
242   list.push_back(MakeValue(static_cast<int64_t>(-1)));
243   try {
244     ScalarMod(list);
245     FAIL();
246   } catch (std::runtime_error const &err) {
247     ASSERT_TRUE(std::string(err.what()).find("Overflow of the mod of two signed number") != std::string::npos);
248   }
249   list.clear();
250 }
251 
252 TEST_F(TestImplementations, ScalarUAddTest) {
253   ValuePtrList list;
254   list.push_back(MakeValue((uint64_t)1));
255   ASSERT_EQ(ScalarUAdd(list)->cast<UInt64ImmPtr>()->value(), 1);
256   list.clear();
257 }
258 
259 TEST_F(TestImplementations, ScalarLogTest) {
260   ValuePtrList list;
261   list.push_back(MakeValue(static_cast<double>(7.3890560989306495)));
262   ASSERT_EQ(ScalarLog(list)->cast<FP64ImmPtr>()->value(), 2.0);
263   list.clear();
264 }
265 
266 TEST_F(TestImplementations, ScalarUSubTest) {
267   ValuePtrList list;
268   list.push_back(MakeValue(static_cast<int64_t>(1)));
269   ASSERT_EQ(ScalarUSub(list)->cast<Int64ImmPtr>()->value(), -1);
270   list.clear();
271 }
272 
273 TEST_F(TestImplementations, ScalarEqTest) {
274   ValuePtrList list;
275   list.push_back(MakeValue(1.0f));
276   list.push_back(MakeValue(1.0f));
277   ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), true);
278   list.clear();
279 
280   list.push_back(MakeValue(1.0f));
281   list.push_back(MakeValue(-1.0f));
282   ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), false);
283   list.clear();
284 
285   list.push_back(MakeValue(1.0f));
286   list.push_back(MakeValue(1.0));
287   ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), true);
288   list.clear();
289 
290   list.push_back(MakeValue(1.0));
291   list.push_back(MakeValue(1.0));
292   ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), true);
293   list.clear();
294 }
295 
296 TEST_F(TestImplementations, ScalarLtTest) {
297   ValuePtrList list;
298   list.push_back(MakeValue(1.0f));
299   list.push_back(MakeValue(1.0f));
300   ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), false);
301   list.clear();
302 
303   list.push_back(MakeValue(1.0f));
304   list.push_back(MakeValue(-1.0f));
305   ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), false);
306   list.clear();
307 
308   list.push_back(MakeValue(1.0f));
309   list.push_back(MakeValue(2.5));
310   ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), true);
311   list.clear();
312 
313   list.push_back(MakeValue(2.5));
314   list.push_back(MakeValue(3.0));
315   ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), true);
316   list.clear();
317 }
318 
319 TEST_F(TestImplementations, ScalarGtTest) {
320   ValuePtrList list;
321   list.push_back(MakeValue(1.0f));
322   list.push_back(MakeValue(2.0f));
323   ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), false);
324   list.clear();
325 
326   list.push_back(MakeValue(2.0f));
327   list.push_back(MakeValue(-1.0f));
328   ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), true);
329   list.clear();
330 
331   list.push_back(MakeValue(2.0f));
332   list.push_back(MakeValue(2.0));
333   ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), false);
334   list.clear();
335 
336   list.push_back(MakeValue(2.5));
337   list.push_back(MakeValue(2.0));
338   ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), true);
339   list.clear();
340 }
341 
342 TEST_F(TestImplementations, ScalarNeTest) {
343   ValuePtrList list;
344   list.push_back(MakeValue(1.0f));
345   list.push_back(MakeValue(1.0f));
346   ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), false);
347   list.clear();
348 
349   list.push_back(MakeValue(1.0f));
350   list.push_back(MakeValue(-1.0f));
351   ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), true);
352   list.clear();
353 
354   list.push_back(MakeValue(1.0f));
355   list.push_back(MakeValue(2.0));
356   ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), true);
357   list.clear();
358 
359   list.push_back(MakeValue(2.0));
360   list.push_back(MakeValue(2.0));
361   ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), false);
362   list.clear();
363 }
364 
365 TEST_F(TestImplementations, ScalarLeTest) {
366   ValuePtrList list;
367   list.push_back(MakeValue(1.0f));
368   list.push_back(MakeValue(1.0f));
369   ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), true);
370   list.clear();
371 
372   list.push_back(MakeValue(1.0f));
373   list.push_back(MakeValue(-1.0f));
374   ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), false);
375   list.clear();
376 
377   list.push_back(MakeValue(1.0f));
378   list.push_back(MakeValue(2.0));
379   ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), true);
380   list.clear();
381 
382   list.push_back(MakeValue(6.0));
383   list.push_back(MakeValue(-1.0f));
384   ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), false);
385   list.clear();
386 }
387 
388 TEST_F(TestImplementations, ScalarGeTest) {
389   ValuePtrList list;
390   list.push_back(MakeValue(1.0f));
391   list.push_back(MakeValue(1.0f));
392   ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), true);
393   list.clear();
394 
395   list.push_back(MakeValue(1.0f));
396   list.push_back(MakeValue(-1.0f));
397   ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), true);
398   list.clear();
399 
400   list.push_back(MakeValue(1.0f));
401   list.push_back(MakeValue(2.0));
402   ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), false);
403   list.clear();
404 
405   list.push_back(MakeValue(6.0));
406   list.push_back(MakeValue(-1.0f));
407   ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), true);
408   list.clear();
409 }
410 
411 TEST_F(TestImplementations, BoolNotTest) {
412   ValuePtrList list;
413   list.push_back(MakeValue(true));
414   ASSERT_EQ(BoolNot(list)->cast<BoolImmPtr>()->value(), false);
415   list.clear();
416 
417   list.push_back(MakeValue(false));
418   ASSERT_EQ(BoolNot(list)->cast<BoolImmPtr>()->value(), true);
419   list.clear();
420 }
421 
422 TEST_F(TestImplementations, BoolAndTest) {
423   ValuePtrList list;
424   list.push_back(MakeValue(true));
425   list.push_back(MakeValue(false));
426   ASSERT_EQ(BoolAnd(list)->cast<BoolImmPtr>()->value(), false);
427   list.clear();
428 
429   list.push_back(MakeValue(true));
430   list.push_back(MakeValue(true));
431   ASSERT_EQ(BoolAnd(list)->cast<BoolImmPtr>()->value(), true);
432   list.clear();
433 
434   list.push_back(MakeValue(false));
435   list.push_back(MakeValue(false));
436   ASSERT_EQ(BoolAnd(list)->cast<BoolImmPtr>()->value(), false);
437   list.clear();
438 }
439 
440 TEST_F(TestImplementations, BoolOrTest) {
441   ValuePtrList list;
442   list.push_back(MakeValue(true));
443   list.push_back(MakeValue(false));
444   ASSERT_EQ(BoolOr(list)->cast<BoolImmPtr>()->value(), true);
445   list.clear();
446 
447   list.push_back(MakeValue(true));
448   list.push_back(MakeValue(true));
449   ASSERT_EQ(BoolOr(list)->cast<BoolImmPtr>()->value(), true);
450   list.clear();
451 
452   list.push_back(MakeValue(false));
453   list.push_back(MakeValue(false));
454   ASSERT_EQ(BoolOr(list)->cast<BoolImmPtr>()->value(), false);
455   list.clear();
456 }
457 
458 TEST_F(TestImplementations, BoolEqTest) {
459   ValuePtrList list;
460   list.push_back(MakeValue(true));
461   list.push_back(MakeValue(false));
462   ASSERT_EQ(BoolEq(list)->cast<BoolImmPtr>()->value(), false);
463   list.clear();
464 
465   list.push_back(MakeValue(true));
466   list.push_back(MakeValue(true));
467   ASSERT_EQ(BoolEq(list)->cast<BoolImmPtr>()->value(), true);
468   list.clear();
469 
470   list.push_back(MakeValue(false));
471   list.push_back(MakeValue(false));
472   ASSERT_EQ(BoolEq(list)->cast<BoolImmPtr>()->value(), true);
473   list.clear();
474 }
475 
476 }  // namespace prim
477 }  // namespace mindspore
478