1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <memory> 18 #include <vector> 19 20 #include "common/common_test.h" 21 #include "frontend/operator/cc_implementations.h" 22 23 namespace mindspore { 24 namespace prim { 25 26 class TestImplementations : public UT::Common { 27 public: 28 TestImplementations() {} 29 virtual void SetUp() {} 30 }; 31 32 TEST_F(TestImplementations, ScalarAddTest) { 33 ValuePtrList list; 34 list.push_back(MakeValue(static_cast<int64_t>(1))); 35 list.push_back(MakeValue(static_cast<int64_t>(2))); 36 ASSERT_EQ(ScalarAdd(list)->cast<Int64ImmPtr>()->value(), 3); 37 list.clear(); 38 39 list.push_back(MakeValue(1.0f)); 40 list.push_back(MakeValue(1.5f)); 41 ASSERT_EQ(ScalarAdd(list)->cast<FP32ImmPtr>()->value(), 2.5f); 42 list.clear(); 43 44 list.push_back(MakeValue(3.0)); 45 list.push_back(MakeValue(0.5)); 46 ASSERT_EQ(ScalarAdd(list)->cast<FP64ImmPtr>()->value(), 3.5); 47 list.clear(); 48 49 list.push_back(MakeValue(INT64_MAX)); 50 list.push_back(MakeValue(static_cast<int64_t>(2))); 51 try { 52 ScalarAdd(list); 53 FAIL(); 54 } catch (std::runtime_error const &err) { 55 ASSERT_TRUE(std::string(err.what()).find("Overflow of the sum of two signed number") != std::string::npos); 56 } 57 list.clear(); 58 59 list.push_back(MakeValue(INT64_MIN)); 60 list.push_back(MakeValue(static_cast<int64_t>(-1))); 61 try { 62 ScalarAdd(list); 63 FAIL(); 64 } catch (std::runtime_error const &err) { 65 ASSERT_TRUE(std::string(err.what()).find("Overflow of the sum of two signed number") != std::string::npos); 66 } 67 list.clear(); 68 } 69 70 TEST_F(TestImplementations, ScalarSubTest) { 71 ValuePtrList list; 72 list.push_back(MakeValue(static_cast<int64_t>(1))); 73 list.push_back(MakeValue(static_cast<int64_t>(3))); 74 ASSERT_EQ(ScalarSub(list)->cast<Int64ImmPtr>()->value(), -2); 75 list.clear(); 76 77 list.push_back(MakeValue(1.0f)); 78 list.push_back(MakeValue(1.5f)); 79 ASSERT_EQ(ScalarSub(list)->cast<FP32ImmPtr>()->value(), -0.5f); 80 list.clear(); 81 82 list.push_back(MakeValue(3.0)); 83 list.push_back(MakeValue(0.5)); 84 ASSERT_EQ(ScalarSub(list)->cast<FP64ImmPtr>()->value(), 2.5); 85 list.clear(); 86 87 list.push_back(MakeValue(INT64_MAX)); 88 list.push_back(MakeValue(static_cast<int64_t>(-1))); 89 try { 90 ScalarSub(list); 91 FAIL(); 92 } catch (std::runtime_error const &err) { 93 ASSERT_TRUE(std::string(err.what()).find("Overflow of the sub of two signed number") != std::string::npos); 94 } 95 list.clear(); 96 97 list.push_back(MakeValue(INT64_MIN)); 98 list.push_back(MakeValue(static_cast<int64_t>(1))); 99 try { 100 ScalarSub(list); 101 FAIL(); 102 } catch (std::runtime_error const &err) { 103 ASSERT_TRUE(std::string(err.what()).find("Overflow of the sub of two signed number") != std::string::npos); 104 } 105 list.clear(); 106 } 107 108 TEST_F(TestImplementations, ScalarMulTest) { 109 ValuePtrList list; 110 list.push_back(MakeValue(static_cast<int64_t>(2))); 111 list.push_back(MakeValue(static_cast<int64_t>(3))); 112 ASSERT_EQ(ScalarMul(list)->cast<Int64ImmPtr>()->value(), 6); 113 list.clear(); 114 115 list.push_back(MakeValue(2.0f)); 116 list.push_back(MakeValue(1.5f)); 117 ASSERT_EQ(ScalarMul(list)->cast<FP32ImmPtr>()->value(), 3.0f); 118 list.clear(); 119 120 list.push_back(MakeValue(-2.0)); 121 list.push_back(MakeValue(-4.0)); 122 ASSERT_EQ(ScalarMul(list)->cast<FP64ImmPtr>()->value(), 8.0); 123 list.clear(); 124 125 list.push_back(MakeValue(static_cast<int64_t>(10))); 126 list.push_back(MakeValue(INT64_MAX)); 127 try { 128 ScalarMul(list); 129 FAIL(); 130 } catch (std::runtime_error const &err) { 131 ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos); 132 } 133 list.clear(); 134 135 list.push_back(MakeValue(INT64_MIN)); 136 list.push_back(MakeValue(static_cast<int64_t>(-1))); 137 try { 138 ScalarMul(list); 139 FAIL(); 140 } catch (std::runtime_error const &err) { 141 ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos); 142 } 143 list.clear(); 144 145 list.push_back(MakeValue(static_cast<int64_t>(-2))); 146 list.push_back(MakeValue(INT64_MAX)); 147 try { 148 ScalarMul(list); 149 FAIL(); 150 } catch (std::runtime_error const &err) { 151 ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos); 152 } 153 list.clear(); 154 155 list.push_back(MakeValue(static_cast<int64_t>(2))); 156 list.push_back(MakeValue(INT64_MIN)); 157 try { 158 ScalarMul(list); 159 FAIL(); 160 } catch (std::runtime_error const &err) { 161 ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos); 162 } 163 list.clear(); 164 165 list.push_back(MakeValue(static_cast<int64_t>(0))); 166 list.push_back(MakeValue(INT64_MIN)); 167 ASSERT_EQ(ScalarDiv(list)->cast<Int64ImmPtr>()->value(), 0); 168 list.clear(); 169 } 170 171 TEST_F(TestImplementations, ScalarDivTest) { 172 ValuePtrList list; 173 list.push_back(MakeValue(static_cast<int64_t>(6))); 174 list.push_back(MakeValue(static_cast<int64_t>(3))); 175 ASSERT_EQ(ScalarDiv(list)->cast<Int64ImmPtr>()->value(), 2); 176 list.clear(); 177 178 list.push_back(MakeValue(3.0f)); 179 list.push_back(MakeValue(1.5f)); 180 ASSERT_EQ(ScalarDiv(list)->cast<FP32ImmPtr>()->value(), 2.0f); 181 list.clear(); 182 183 list.push_back(MakeValue(-4.0)); 184 list.push_back(MakeValue(2.0)); 185 ASSERT_EQ(ScalarDiv(list)->cast<FP64ImmPtr>()->value(), -2.0); 186 list.clear(); 187 188 list.push_back(MakeValue(INT64_MAX)); 189 list.push_back(MakeValue(static_cast<int64_t>(0))); 190 try { 191 ScalarDiv(list); 192 FAIL(); 193 } catch (std::runtime_error const &err) { 194 ASSERT_TRUE(std::string(err.what()).find("Divisor could not be zero") != std::string::npos); 195 } 196 list.clear(); 197 198 list.push_back(MakeValue(INT64_MIN)); 199 list.push_back(MakeValue(static_cast<int64_t>(-1))); 200 try { 201 ScalarDiv(list); 202 FAIL(); 203 } catch (std::runtime_error const &err) { 204 ASSERT_TRUE(std::string(err.what()).find("Overflow of the div of two signed number") != std::string::npos); 205 } 206 list.clear(); 207 208 list.push_back(MakeValue(static_cast<int64_t>(-1))); 209 list.push_back(MakeValue(INT64_MIN)); 210 ASSERT_EQ(ScalarDiv(list)->cast<Int64ImmPtr>()->value(), 0); 211 list.clear(); 212 } 213 214 TEST_F(TestImplementations, ScalarModTest) { 215 ValuePtrList list; 216 list.push_back(MakeValue(static_cast<int64_t>(7))); 217 list.push_back(MakeValue(static_cast<int64_t>(3))); 218 ASSERT_EQ(ScalarMod(list)->cast<Int64ImmPtr>()->value(), 1); 219 list.clear(); 220 221 list.push_back(MakeValue(static_cast<int64_t>(-8))); 222 list.push_back(MakeValue(static_cast<int64_t>(3))); 223 ASSERT_EQ(ScalarMod(list)->cast<Int64ImmPtr>()->value(), -2); 224 list.clear(); 225 226 list.push_back(MakeValue(static_cast<int64_t>(-9))); 227 list.push_back(MakeValue(static_cast<int64_t>(2))); 228 ASSERT_EQ(ScalarMod(list)->cast<Int64ImmPtr>()->value(), -1); 229 list.clear(); 230 231 list.push_back(MakeValue(INT64_MIN)); 232 list.push_back(MakeValue(static_cast<int64_t>(0))); 233 try { 234 ScalarMod(list); 235 FAIL(); 236 } catch (std::runtime_error const &err) { 237 ASSERT_TRUE(std::string(err.what()).find("Could not mod to zero") != std::string::npos); 238 } 239 list.clear(); 240 241 list.push_back(MakeValue(INT64_MIN)); 242 list.push_back(MakeValue(static_cast<int64_t>(-1))); 243 try { 244 ScalarMod(list); 245 FAIL(); 246 } catch (std::runtime_error const &err) { 247 ASSERT_TRUE(std::string(err.what()).find("Overflow of the mod of two signed number") != std::string::npos); 248 } 249 list.clear(); 250 } 251 252 TEST_F(TestImplementations, ScalarUAddTest) { 253 ValuePtrList list; 254 list.push_back(MakeValue((uint64_t)1)); 255 ASSERT_EQ(ScalarUAdd(list)->cast<UInt64ImmPtr>()->value(), 1); 256 list.clear(); 257 } 258 259 TEST_F(TestImplementations, ScalarLogTest) { 260 ValuePtrList list; 261 list.push_back(MakeValue(static_cast<double>(7.3890560989306495))); 262 ASSERT_EQ(ScalarLog(list)->cast<FP64ImmPtr>()->value(), 2.0); 263 list.clear(); 264 } 265 266 TEST_F(TestImplementations, ScalarUSubTest) { 267 ValuePtrList list; 268 list.push_back(MakeValue(static_cast<int64_t>(1))); 269 ASSERT_EQ(ScalarUSub(list)->cast<Int64ImmPtr>()->value(), -1); 270 list.clear(); 271 } 272 273 TEST_F(TestImplementations, ScalarEqTest) { 274 ValuePtrList list; 275 list.push_back(MakeValue(1.0f)); 276 list.push_back(MakeValue(1.0f)); 277 ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), true); 278 list.clear(); 279 280 list.push_back(MakeValue(1.0f)); 281 list.push_back(MakeValue(-1.0f)); 282 ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), false); 283 list.clear(); 284 285 list.push_back(MakeValue(1.0f)); 286 list.push_back(MakeValue(1.0)); 287 ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), true); 288 list.clear(); 289 290 list.push_back(MakeValue(1.0)); 291 list.push_back(MakeValue(1.0)); 292 ASSERT_EQ(ScalarEq(list)->cast<BoolImmPtr>()->value(), true); 293 list.clear(); 294 } 295 296 TEST_F(TestImplementations, ScalarLtTest) { 297 ValuePtrList list; 298 list.push_back(MakeValue(1.0f)); 299 list.push_back(MakeValue(1.0f)); 300 ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), false); 301 list.clear(); 302 303 list.push_back(MakeValue(1.0f)); 304 list.push_back(MakeValue(-1.0f)); 305 ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), false); 306 list.clear(); 307 308 list.push_back(MakeValue(1.0f)); 309 list.push_back(MakeValue(2.5)); 310 ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), true); 311 list.clear(); 312 313 list.push_back(MakeValue(2.5)); 314 list.push_back(MakeValue(3.0)); 315 ASSERT_EQ(ScalarLt(list)->cast<BoolImmPtr>()->value(), true); 316 list.clear(); 317 } 318 319 TEST_F(TestImplementations, ScalarGtTest) { 320 ValuePtrList list; 321 list.push_back(MakeValue(1.0f)); 322 list.push_back(MakeValue(2.0f)); 323 ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), false); 324 list.clear(); 325 326 list.push_back(MakeValue(2.0f)); 327 list.push_back(MakeValue(-1.0f)); 328 ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), true); 329 list.clear(); 330 331 list.push_back(MakeValue(2.0f)); 332 list.push_back(MakeValue(2.0)); 333 ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), false); 334 list.clear(); 335 336 list.push_back(MakeValue(2.5)); 337 list.push_back(MakeValue(2.0)); 338 ASSERT_EQ(ScalarGt(list)->cast<BoolImmPtr>()->value(), true); 339 list.clear(); 340 } 341 342 TEST_F(TestImplementations, ScalarNeTest) { 343 ValuePtrList list; 344 list.push_back(MakeValue(1.0f)); 345 list.push_back(MakeValue(1.0f)); 346 ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), false); 347 list.clear(); 348 349 list.push_back(MakeValue(1.0f)); 350 list.push_back(MakeValue(-1.0f)); 351 ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), true); 352 list.clear(); 353 354 list.push_back(MakeValue(1.0f)); 355 list.push_back(MakeValue(2.0)); 356 ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), true); 357 list.clear(); 358 359 list.push_back(MakeValue(2.0)); 360 list.push_back(MakeValue(2.0)); 361 ASSERT_EQ(ScalarNe(list)->cast<BoolImmPtr>()->value(), false); 362 list.clear(); 363 } 364 365 TEST_F(TestImplementations, ScalarLeTest) { 366 ValuePtrList list; 367 list.push_back(MakeValue(1.0f)); 368 list.push_back(MakeValue(1.0f)); 369 ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), true); 370 list.clear(); 371 372 list.push_back(MakeValue(1.0f)); 373 list.push_back(MakeValue(-1.0f)); 374 ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), false); 375 list.clear(); 376 377 list.push_back(MakeValue(1.0f)); 378 list.push_back(MakeValue(2.0)); 379 ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), true); 380 list.clear(); 381 382 list.push_back(MakeValue(6.0)); 383 list.push_back(MakeValue(-1.0f)); 384 ASSERT_EQ(ScalarLe(list)->cast<BoolImmPtr>()->value(), false); 385 list.clear(); 386 } 387 388 TEST_F(TestImplementations, ScalarGeTest) { 389 ValuePtrList list; 390 list.push_back(MakeValue(1.0f)); 391 list.push_back(MakeValue(1.0f)); 392 ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), true); 393 list.clear(); 394 395 list.push_back(MakeValue(1.0f)); 396 list.push_back(MakeValue(-1.0f)); 397 ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), true); 398 list.clear(); 399 400 list.push_back(MakeValue(1.0f)); 401 list.push_back(MakeValue(2.0)); 402 ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), false); 403 list.clear(); 404 405 list.push_back(MakeValue(6.0)); 406 list.push_back(MakeValue(-1.0f)); 407 ASSERT_EQ(ScalarGe(list)->cast<BoolImmPtr>()->value(), true); 408 list.clear(); 409 } 410 411 TEST_F(TestImplementations, BoolNotTest) { 412 ValuePtrList list; 413 list.push_back(MakeValue(true)); 414 ASSERT_EQ(BoolNot(list)->cast<BoolImmPtr>()->value(), false); 415 list.clear(); 416 417 list.push_back(MakeValue(false)); 418 ASSERT_EQ(BoolNot(list)->cast<BoolImmPtr>()->value(), true); 419 list.clear(); 420 } 421 422 TEST_F(TestImplementations, BoolAndTest) { 423 ValuePtrList list; 424 list.push_back(MakeValue(true)); 425 list.push_back(MakeValue(false)); 426 ASSERT_EQ(BoolAnd(list)->cast<BoolImmPtr>()->value(), false); 427 list.clear(); 428 429 list.push_back(MakeValue(true)); 430 list.push_back(MakeValue(true)); 431 ASSERT_EQ(BoolAnd(list)->cast<BoolImmPtr>()->value(), true); 432 list.clear(); 433 434 list.push_back(MakeValue(false)); 435 list.push_back(MakeValue(false)); 436 ASSERT_EQ(BoolAnd(list)->cast<BoolImmPtr>()->value(), false); 437 list.clear(); 438 } 439 440 TEST_F(TestImplementations, BoolOrTest) { 441 ValuePtrList list; 442 list.push_back(MakeValue(true)); 443 list.push_back(MakeValue(false)); 444 ASSERT_EQ(BoolOr(list)->cast<BoolImmPtr>()->value(), true); 445 list.clear(); 446 447 list.push_back(MakeValue(true)); 448 list.push_back(MakeValue(true)); 449 ASSERT_EQ(BoolOr(list)->cast<BoolImmPtr>()->value(), true); 450 list.clear(); 451 452 list.push_back(MakeValue(false)); 453 list.push_back(MakeValue(false)); 454 ASSERT_EQ(BoolOr(list)->cast<BoolImmPtr>()->value(), false); 455 list.clear(); 456 } 457 458 TEST_F(TestImplementations, BoolEqTest) { 459 ValuePtrList list; 460 list.push_back(MakeValue(true)); 461 list.push_back(MakeValue(false)); 462 ASSERT_EQ(BoolEq(list)->cast<BoolImmPtr>()->value(), false); 463 list.clear(); 464 465 list.push_back(MakeValue(true)); 466 list.push_back(MakeValue(true)); 467 ASSERT_EQ(BoolEq(list)->cast<BoolImmPtr>()->value(), true); 468 list.clear(); 469 470 list.push_back(MakeValue(false)); 471 list.push_back(MakeValue(false)); 472 ASSERT_EQ(BoolEq(list)->cast<BoolImmPtr>()->value(), true); 473 list.clear(); 474 } 475 476 } // namespace prim 477 } // namespace mindspore 478