1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/common.h" 17 #include "gtest/gtest.h" 18 19 #include "minddata/dataset/include/dataset/constants.h" 20 #include "minddata/dataset/core/tensor.h" 21 22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 23 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" 24 #include "utils/log_adapter.h" 25 26 #include <vector> 27 #include <unordered_set> 28 29 using namespace mindspore::dataset; 30 using mindspore::LogStream; 31 using mindspore::ExceptionType::NoExceptionType; 32 using mindspore::MsLogLevel::INFO; 33 34 class MindDataTestWeightedRandomSampler : public UT::Common { 35 public: 36 class DummyRandomAccessOp : public RandomAccessOp { 37 public: 38 DummyRandomAccessOp(uint64_t num_rows) { 39 // row count is in base class as protected member 40 // GetNumRowsInDataset does not need an override, the default from base class is fine. 41 num_rows_ = num_rows; 42 } 43 }; 44 }; 45 46 TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { 47 // num samples to draw. 48 uint64_t num_samples = 100; 49 uint64_t total_samples = 1000; 50 std::vector<double> weights(total_samples, std::rand() % 100); 51 std::vector<uint64_t> freq(total_samples, 0); 52 53 // create sampler with replacement = true 54 WeightedRandomSamplerRT m_sampler(weights, num_samples, true); 55 DummyRandomAccessOp dummyRandomAccessOp(total_samples); 56 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 57 58 TensorRow row; 59 std::vector<uint64_t> out; 60 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 61 62 for (const auto &t : row) { 63 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 64 out.push_back(*it); 65 freq[*it]++; 66 } 67 } 68 69 ASSERT_EQ(num_samples, out.size()); 70 71 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 72 ASSERT_EQ(row.eoe(), true); 73 } 74 75 TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { 76 // num samples to draw. 77 uint64_t num_samples = 100; 78 uint64_t total_samples = 1000; 79 std::vector<double> weights(total_samples, std::rand() % 100); 80 std::vector<uint64_t> freq(total_samples, 0); 81 82 // create sampler with replacement = replacement 83 WeightedRandomSamplerRT m_sampler(weights, num_samples, false); 84 DummyRandomAccessOp dummyRandomAccessOp(total_samples); 85 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 86 87 TensorRow row; 88 std::vector<uint64_t> out; 89 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 90 91 for (const auto &t : row) { 92 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 93 out.push_back(*it); 94 freq[*it]++; 95 } 96 } 97 ASSERT_EQ(num_samples, out.size()); 98 99 // Without replacement, each sample only drawn once. 100 for (int i = 0; i < total_samples; i++) { 101 if (freq[i]) { 102 ASSERT_EQ(freq[i], 1); 103 } 104 } 105 106 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 107 ASSERT_EQ(row.eoe(), true); 108 } 109 110 TEST_F(MindDataTestWeightedRandomSampler, TestGetNextSampleReplacement) { 111 // num samples to draw. 112 uint64_t num_samples = 100; 113 uint64_t total_samples = 1000; 114 uint64_t samples_per_tensor = 10; 115 std::vector<double> weights(total_samples, std::rand() % 100); 116 117 // create sampler with replacement = replacement 118 WeightedRandomSamplerRT m_sampler(weights, num_samples, true, samples_per_tensor); 119 DummyRandomAccessOp dummyRandomAccessOp(total_samples); 120 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 121 122 TensorRow row; 123 std::vector<uint64_t> out; 124 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 125 int epoch = 0; 126 while (!row.eoe()) { 127 epoch++; 128 129 for (const auto &t : row) { 130 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 131 out.push_back(*it); 132 } 133 } 134 135 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 136 } 137 138 ASSERT_EQ(epoch, (num_samples + samples_per_tensor - 1) / samples_per_tensor); 139 ASSERT_EQ(num_samples, out.size()); 140 } 141 142 TEST_F(MindDataTestWeightedRandomSampler, TestGetNextSampleNoReplacement) { 143 // num samples to draw. 144 uint64_t num_samples = 100; 145 uint64_t total_samples = 100; 146 uint64_t samples_per_tensor = 10; 147 std::vector<double> weights(total_samples, std::rand() % 100); 148 weights[1] = 0; 149 weights[2] = 0; 150 std::vector<uint64_t> freq(total_samples, 0); 151 152 // create sampler with replacement = replacement 153 WeightedRandomSamplerRT m_sampler(weights, num_samples, false, samples_per_tensor); 154 DummyRandomAccessOp dummyRandomAccessOp(total_samples); 155 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 156 157 TensorRow row; 158 std::vector<uint64_t> out; 159 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 160 int epoch = 0; 161 while (!row.eoe()) { 162 epoch++; 163 164 for (const auto &t : row) { 165 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 166 out.push_back(*it); 167 freq[*it]++; 168 } 169 } 170 171 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 172 } 173 174 // Without replacement, each sample only drawn once. 175 for (int i = 0; i < total_samples; i++) { 176 if (freq[i]) { 177 ASSERT_EQ(freq[i], 1); 178 } 179 } 180 181 ASSERT_EQ(epoch, (num_samples + samples_per_tensor - 1) / samples_per_tensor); 182 ASSERT_EQ(num_samples, out.size()); 183 } 184 185 TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { 186 // num samples to draw. 187 uint64_t num_samples = 1000000; 188 uint64_t total_samples = 1000000; 189 std::vector<double> weights(total_samples, std::rand() % 100); 190 std::vector<uint64_t> freq(total_samples, 0); 191 192 // create sampler with replacement = true 193 WeightedRandomSamplerRT m_sampler(weights, num_samples, true); 194 DummyRandomAccessOp dummyRandomAccessOp(total_samples); 195 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 196 197 TensorRow row; 198 std::vector<uint64_t> out; 199 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 200 201 for (const auto &t : row) { 202 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 203 out.push_back(*it); 204 freq[*it]++; 205 } 206 } 207 ASSERT_EQ(num_samples, out.size()); 208 209 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 210 ASSERT_EQ(row.eoe(), true); 211 212 m_sampler.ResetSampler(); 213 out.clear(); 214 215 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 216 217 for (const auto &t : row) { 218 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 219 out.push_back(*it); 220 freq[*it]++; 221 } 222 } 223 ASSERT_EQ(num_samples, out.size()); 224 225 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 226 ASSERT_EQ(row.eoe(), true); 227 } 228 229 TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { 230 // num samples to draw. 231 uint64_t num_samples = 1000000; 232 uint64_t total_samples = 1000000; 233 std::vector<double> weights(total_samples, std::rand() % 100); 234 std::vector<uint64_t> freq(total_samples, 0); 235 236 // create sampler with replacement = true 237 WeightedRandomSamplerRT m_sampler(weights, num_samples, false); 238 DummyRandomAccessOp dummyRandomAccessOp(total_samples); 239 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 240 241 TensorRow row; 242 std::vector<uint64_t> out; 243 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 244 245 for (const auto &t : row) { 246 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 247 out.push_back(*it); 248 freq[*it]++; 249 } 250 } 251 ASSERT_EQ(num_samples, out.size()); 252 253 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 254 ASSERT_EQ(row.eoe(), true); 255 256 m_sampler.ResetSampler(); 257 out.clear(); 258 freq.clear(); 259 freq.resize(total_samples, 0); 260 MS_LOG(INFO) << "Resetting sampler"; 261 262 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 263 264 for (const auto &t : row) { 265 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 266 out.push_back(*it); 267 freq[*it]++; 268 } 269 } 270 ASSERT_EQ(num_samples, out.size()); 271 272 // Without replacement, each sample only drawn once. 273 for (int i = 0; i < total_samples; i++) { 274 if (freq[i]) { 275 ASSERT_EQ(freq[i], 1); 276 } 277 } 278 279 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 280 ASSERT_EQ(row.eoe(), true); 281 } 282