1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "utils/counter.h" 17 #include "common/common_test.h" 18 19 namespace mindspore { 20 class TestCounter : public UT::Common { 21 public: 22 TestCounter() { 23 std::string s1 = "abcdeedfrgbhrtfsfd"; 24 std::string s2 = "shceufhvogawrycawr"; 25 26 for (auto c : s1) { 27 std::string key(1, c); 28 counter_a[key] += 1; 29 } 30 31 for (auto c : s2) { 32 std::string key(1, c); 33 counter_b[key] += 1; 34 } 35 } 36 37 public: 38 Counter<std::string> counter_a; 39 Counter<std::string> counter_b; 40 }; 41 42 TEST_F(TestCounter, test_constructor) { 43 assert(counter_a.size() == 11); 44 assert(counter_b.size() == 13); 45 } 46 47 TEST_F(TestCounter, test_subtitle) { 48 std::string s = "d"; 49 assert(counter_a[s] == 3); 50 s = "f"; 51 assert(counter_a[s] == 3); 52 s = "h"; 53 assert(counter_b[s] = 2); 54 s = "c"; 55 assert(counter_b[s] = 2); 56 } 57 58 TEST_F(TestCounter, test_contains) { 59 std::string s = "d"; 60 assert(counter_a.contains(s) == true); 61 s = "z"; 62 assert(counter_a.contains(s) == false); 63 s = "q"; 64 assert(counter_b.contains(s) == false); 65 } 66 67 TEST_F(TestCounter, test_add) { 68 auto counter_add = counter_a + counter_b; 69 assert(counter_add.size() == 16); 70 std::string s = "f"; 71 assert(counter_add[s] == 4); 72 s = "r"; 73 assert(counter_add[s] == 4); 74 s = "y"; 75 assert(counter_add[s] == 1); 76 } 77 78 TEST_F(TestCounter, test_minus) { 79 auto counter_minus = counter_a - counter_b; 80 assert(counter_minus.size() == 5); 81 std::string s = "d"; 82 assert(counter_minus[s] == 3); 83 s = "t"; 84 assert(counter_minus[s] == 1); 85 s = "a"; 86 assert(counter_minus.contains(s) == false); 87 } 88 89 struct MyStruct { 90 int a = 0; 91 int b = 0; 92 }; 93 94 struct MyHash { 95 std::size_t operator()(const MyStruct &e) const noexcept { // 96 return (static_cast<std::size_t>(e.a) << 16) + e.b; 97 } 98 }; 99 100 struct MyEqual { 101 bool operator()(const MyStruct &lhs, const MyStruct &rhs) const noexcept { // 102 return lhs.a == rhs.a && lhs.b == rhs.b; 103 } 104 }; 105 106 TEST_F(TestCounter, test_struct) { 107 using MyCounter = Counter<MyStruct, MyHash, MyEqual>; 108 MyCounter counter; 109 counter.add(MyStruct{100, 1}); 110 counter.add(MyStruct{100, 2}); 111 counter.add(MyStruct{100, 2}); 112 counter.add(MyStruct{100, 3}); 113 counter.add(MyStruct{100, 3}); 114 counter.add(MyStruct{100, 3}); 115 ASSERT_EQ(1, (counter[MyStruct{100, 1}])); 116 ASSERT_EQ(2, (counter[MyStruct{100, 2}])); 117 ASSERT_EQ(3, (counter[MyStruct{100, 3}])); 118 119 MyCounter counter2; 120 counter2.add(MyStruct{100, 2}); 121 counter2.add(MyStruct{100, 3}); 122 counter2.add(MyStruct{100, 3}); 123 counter2.add(MyStruct{100, 3}); 124 counter2.add(MyStruct{100, 4}); 125 126 auto result = counter.subtract(counter2); 127 ASSERT_EQ(2, result.size()); 128 ASSERT_TRUE((MyEqual{}(MyStruct{100, 1}, result[0]))); 129 ASSERT_TRUE((MyEqual{}(MyStruct{100, 2}, result[1]))); 130 131 counter2 = counter; 132 ASSERT_EQ(3, counter2.size()); 133 ASSERT_EQ(1, (counter2[MyStruct{100, 1}])); 134 ASSERT_EQ(2, (counter2[MyStruct{100, 2}])); 135 ASSERT_EQ(3, (counter2[MyStruct{100, 3}])); 136 } 137 138 } // namespace mindspore 139