1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/common.h"
17 #include "gtest/gtest.h"
18
19 #include "minddata/dataset/include/dataset/constants.h"
20 #include "minddata/dataset/core/tensor.h"
21
22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
24 #include "utils/log_adapter.h"
25
26 #include <vector>
27 #include <unordered_set>
28
29 using namespace mindspore::dataset;
30 using mindspore::LogStream;
31 using mindspore::ExceptionType::NoExceptionType;
32 using mindspore::MsLogLevel::INFO;
33
34 class MindDataTestWeightedRandomSampler : public UT::Common {
35 public:
36 class DummyRandomAccessOp : public RandomAccessOp {
37 public:
DummyRandomAccessOp(uint64_t num_rows)38 DummyRandomAccessOp(uint64_t num_rows) {
39 // row count is in base class as protected member
40 // GetNumRowsInDataset does not need an override, the default from base class is fine.
41 num_rows_ = num_rows;
42 }
43 };
44 };
45
TEST_F(MindDataTestWeightedRandomSampler,TestOneshotReplacement)46 TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
47 // num samples to draw.
48 uint64_t num_samples = 100;
49 uint64_t total_samples = 1000;
50 std::vector<double> weights(total_samples, std::rand() % 100);
51 std::vector<uint64_t> freq(total_samples, 0);
52
53 // create sampler with replacement = true
54 WeightedRandomSamplerRT m_sampler(weights, num_samples, true);
55 DummyRandomAccessOp dummyRandomAccessOp(total_samples);
56 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
57
58 TensorRow row;
59 std::vector<uint64_t> out;
60 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
61
62 for (const auto &t : row) {
63 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
64 out.push_back(*it);
65 freq[*it]++;
66 }
67 }
68
69 ASSERT_EQ(num_samples, out.size());
70
71 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
72 ASSERT_EQ(row.eoe(), true);
73 }
74
TEST_F(MindDataTestWeightedRandomSampler,TestOneshotNoReplacement)75 TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
76 // num samples to draw.
77 uint64_t num_samples = 100;
78 uint64_t total_samples = 1000;
79 std::vector<double> weights(total_samples, std::rand() % 100);
80 std::vector<uint64_t> freq(total_samples, 0);
81
82 // create sampler with replacement = replacement
83 WeightedRandomSamplerRT m_sampler(weights, num_samples, false);
84 DummyRandomAccessOp dummyRandomAccessOp(total_samples);
85 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
86
87 TensorRow row;
88 std::vector<uint64_t> out;
89 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
90
91 for (const auto &t : row) {
92 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
93 out.push_back(*it);
94 freq[*it]++;
95 }
96 }
97 ASSERT_EQ(num_samples, out.size());
98
99 // Without replacement, each sample only drawn once.
100 for (int i = 0; i < total_samples; i++) {
101 if (freq[i]) {
102 ASSERT_EQ(freq[i], 1);
103 }
104 }
105
106 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
107 ASSERT_EQ(row.eoe(), true);
108 }
109
TEST_F(MindDataTestWeightedRandomSampler,TestGetNextSampleReplacement)110 TEST_F(MindDataTestWeightedRandomSampler, TestGetNextSampleReplacement) {
111 // num samples to draw.
112 uint64_t num_samples = 100;
113 uint64_t total_samples = 1000;
114 uint64_t samples_per_tensor = 10;
115 std::vector<double> weights(total_samples, std::rand() % 100);
116
117 // create sampler with replacement = replacement
118 WeightedRandomSamplerRT m_sampler(weights, num_samples, true, samples_per_tensor);
119 DummyRandomAccessOp dummyRandomAccessOp(total_samples);
120 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
121
122 TensorRow row;
123 std::vector<uint64_t> out;
124 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
125 int epoch = 0;
126 while (!row.eoe()) {
127 epoch++;
128
129 for (const auto &t : row) {
130 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
131 out.push_back(*it);
132 }
133 }
134
135 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
136 }
137
138 ASSERT_EQ(epoch, (num_samples + samples_per_tensor - 1) / samples_per_tensor);
139 ASSERT_EQ(num_samples, out.size());
140 }
141
TEST_F(MindDataTestWeightedRandomSampler,TestGetNextSampleNoReplacement)142 TEST_F(MindDataTestWeightedRandomSampler, TestGetNextSampleNoReplacement) {
143 // num samples to draw.
144 uint64_t num_samples = 100;
145 uint64_t total_samples = 100;
146 uint64_t samples_per_tensor = 10;
147 std::vector<double> weights(total_samples, std::rand() % 100);
148 weights[1] = 0;
149 weights[2] = 0;
150 std::vector<uint64_t> freq(total_samples, 0);
151
152 // create sampler with replacement = replacement
153 WeightedRandomSamplerRT m_sampler(weights, num_samples, false, samples_per_tensor);
154 DummyRandomAccessOp dummyRandomAccessOp(total_samples);
155 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
156
157 TensorRow row;
158 std::vector<uint64_t> out;
159 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
160 int epoch = 0;
161 while (!row.eoe()) {
162 epoch++;
163
164 for (const auto &t : row) {
165 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
166 out.push_back(*it);
167 freq[*it]++;
168 }
169 }
170
171 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
172 }
173
174 // Without replacement, each sample only drawn once.
175 for (int i = 0; i < total_samples; i++) {
176 if (freq[i]) {
177 ASSERT_EQ(freq[i], 1);
178 }
179 }
180
181 ASSERT_EQ(epoch, (num_samples + samples_per_tensor - 1) / samples_per_tensor);
182 ASSERT_EQ(num_samples, out.size());
183 }
184
TEST_F(MindDataTestWeightedRandomSampler,TestResetReplacement)185 TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
186 // num samples to draw.
187 uint64_t num_samples = 1000000;
188 uint64_t total_samples = 1000000;
189 std::vector<double> weights(total_samples, std::rand() % 100);
190 std::vector<uint64_t> freq(total_samples, 0);
191
192 // create sampler with replacement = true
193 WeightedRandomSamplerRT m_sampler(weights, num_samples, true);
194 DummyRandomAccessOp dummyRandomAccessOp(total_samples);
195 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
196
197 TensorRow row;
198 std::vector<uint64_t> out;
199 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
200
201 for (const auto &t : row) {
202 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
203 out.push_back(*it);
204 freq[*it]++;
205 }
206 }
207 ASSERT_EQ(num_samples, out.size());
208
209 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
210 ASSERT_EQ(row.eoe(), true);
211
212 m_sampler.ResetSampler();
213 out.clear();
214
215 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
216
217 for (const auto &t : row) {
218 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
219 out.push_back(*it);
220 freq[*it]++;
221 }
222 }
223 ASSERT_EQ(num_samples, out.size());
224
225 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
226 ASSERT_EQ(row.eoe(), true);
227 }
228
TEST_F(MindDataTestWeightedRandomSampler,TestResetNoReplacement)229 TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
230 // num samples to draw.
231 uint64_t num_samples = 1000000;
232 uint64_t total_samples = 1000000;
233 std::vector<double> weights(total_samples, std::rand() % 100);
234 std::vector<uint64_t> freq(total_samples, 0);
235
236 // create sampler with replacement = true
237 WeightedRandomSamplerRT m_sampler(weights, num_samples, false);
238 DummyRandomAccessOp dummyRandomAccessOp(total_samples);
239 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
240
241 TensorRow row;
242 std::vector<uint64_t> out;
243 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
244
245 for (const auto &t : row) {
246 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
247 out.push_back(*it);
248 freq[*it]++;
249 }
250 }
251 ASSERT_EQ(num_samples, out.size());
252
253 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
254 ASSERT_EQ(row.eoe(), true);
255
256 m_sampler.ResetSampler();
257 out.clear();
258 freq.clear();
259 freq.resize(total_samples, 0);
260 MS_LOG(INFO) << "Resetting sampler";
261
262 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
263
264 for (const auto &t : row) {
265 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
266 out.push_back(*it);
267 freq[*it]++;
268 }
269 }
270 ASSERT_EQ(num_samples, out.size());
271
272 // Without replacement, each sample only drawn once.
273 for (int i = 0; i < total_samples; i++) {
274 if (freq[i]) {
275 ASSERT_EQ(freq[i], 1);
276 }
277 }
278
279 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
280 ASSERT_EQ(row.eoe(), true);
281 }
282