• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/common.h"
17 #include "gtest/gtest.h"
18 
19 #include "minddata/dataset/include/dataset/constants.h"
20 #include "minddata/dataset/core/tensor.h"
21 
22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
24 #include "utils/log_adapter.h"
25 
26 #include <vector>
27 #include <unordered_set>
28 
29 using namespace mindspore::dataset;
30 using mindspore::LogStream;
31 using mindspore::ExceptionType::NoExceptionType;
32 using mindspore::MsLogLevel::INFO;
33 
34 class MindDataTestWeightedRandomSampler : public UT::Common {
35  public:
36   class DummyRandomAccessOp : public RandomAccessOp {
37    public:
DummyRandomAccessOp(uint64_t num_rows)38     DummyRandomAccessOp(uint64_t num_rows) {
39       // row count is in base class as protected member
40       // GetNumRowsInDataset does not need an override, the default from base class is fine.
41       num_rows_ = num_rows;
42     }
43   };
44 };
45 
TEST_F(MindDataTestWeightedRandomSampler,TestOneshotReplacement)46 TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) {
47   // num samples to draw.
48   uint64_t num_samples = 100;
49   uint64_t total_samples = 1000;
50   std::vector<double> weights(total_samples, std::rand() % 100);
51   std::vector<uint64_t> freq(total_samples, 0);
52 
53   // create sampler with replacement = true
54   WeightedRandomSamplerRT m_sampler(weights, num_samples, true);
55   DummyRandomAccessOp dummyRandomAccessOp(total_samples);
56   m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
57 
58   TensorRow row;
59   std::vector<uint64_t> out;
60   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
61 
62   for (const auto &t : row) {
63     for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
64       out.push_back(*it);
65       freq[*it]++;
66     }
67   }
68 
69   ASSERT_EQ(num_samples, out.size());
70 
71   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
72   ASSERT_EQ(row.eoe(), true);
73 }
74 
TEST_F(MindDataTestWeightedRandomSampler,TestOneshotNoReplacement)75 TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) {
76   // num samples to draw.
77   uint64_t num_samples = 100;
78   uint64_t total_samples = 1000;
79   std::vector<double> weights(total_samples, std::rand() % 100);
80   std::vector<uint64_t> freq(total_samples, 0);
81 
82   // create sampler with replacement = replacement
83   WeightedRandomSamplerRT m_sampler(weights, num_samples, false);
84   DummyRandomAccessOp dummyRandomAccessOp(total_samples);
85   m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
86 
87   TensorRow row;
88   std::vector<uint64_t> out;
89   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
90 
91   for (const auto &t : row) {
92     for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
93       out.push_back(*it);
94       freq[*it]++;
95     }
96   }
97   ASSERT_EQ(num_samples, out.size());
98 
99   // Without replacement, each sample only drawn once.
100   for (int i = 0; i < total_samples; i++) {
101     if (freq[i]) {
102       ASSERT_EQ(freq[i], 1);
103     }
104   }
105 
106   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
107   ASSERT_EQ(row.eoe(), true);
108 }
109 
TEST_F(MindDataTestWeightedRandomSampler,TestGetNextSampleReplacement)110 TEST_F(MindDataTestWeightedRandomSampler, TestGetNextSampleReplacement) {
111   // num samples to draw.
112   uint64_t num_samples = 100;
113   uint64_t total_samples = 1000;
114   uint64_t samples_per_tensor = 10;
115   std::vector<double> weights(total_samples, std::rand() % 100);
116 
117   // create sampler with replacement = replacement
118   WeightedRandomSamplerRT m_sampler(weights, num_samples, true, samples_per_tensor);
119   DummyRandomAccessOp dummyRandomAccessOp(total_samples);
120   m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
121 
122   TensorRow row;
123   std::vector<uint64_t> out;
124   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
125   int epoch = 0;
126   while (!row.eoe()) {
127     epoch++;
128 
129     for (const auto &t : row) {
130       for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
131         out.push_back(*it);
132       }
133     }
134 
135     ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
136   }
137 
138   ASSERT_EQ(epoch, (num_samples + samples_per_tensor - 1) / samples_per_tensor);
139   ASSERT_EQ(num_samples, out.size());
140 }
141 
TEST_F(MindDataTestWeightedRandomSampler,TestGetNextSampleNoReplacement)142 TEST_F(MindDataTestWeightedRandomSampler, TestGetNextSampleNoReplacement) {
143   // num samples to draw.
144   uint64_t num_samples = 100;
145   uint64_t total_samples = 100;
146   uint64_t samples_per_tensor = 10;
147   std::vector<double> weights(total_samples, std::rand() % 100);
148   weights[1] = 0;
149   weights[2] = 0;
150   std::vector<uint64_t> freq(total_samples, 0);
151 
152   // create sampler with replacement = replacement
153   WeightedRandomSamplerRT m_sampler(weights, num_samples, false, samples_per_tensor);
154   DummyRandomAccessOp dummyRandomAccessOp(total_samples);
155   m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
156 
157   TensorRow row;
158   std::vector<uint64_t> out;
159   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
160   int epoch = 0;
161   while (!row.eoe()) {
162     epoch++;
163 
164     for (const auto &t : row) {
165       for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
166         out.push_back(*it);
167         freq[*it]++;
168       }
169     }
170 
171     ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
172   }
173 
174   // Without replacement, each sample only drawn once.
175   for (int i = 0; i < total_samples; i++) {
176     if (freq[i]) {
177       ASSERT_EQ(freq[i], 1);
178     }
179   }
180 
181   ASSERT_EQ(epoch, (num_samples + samples_per_tensor - 1) / samples_per_tensor);
182   ASSERT_EQ(num_samples, out.size());
183 }
184 
TEST_F(MindDataTestWeightedRandomSampler,TestResetReplacement)185 TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) {
186   // num samples to draw.
187   uint64_t num_samples = 1000000;
188   uint64_t total_samples = 1000000;
189   std::vector<double> weights(total_samples, std::rand() % 100);
190   std::vector<uint64_t> freq(total_samples, 0);
191 
192   // create sampler with replacement = true
193   WeightedRandomSamplerRT m_sampler(weights, num_samples, true);
194   DummyRandomAccessOp dummyRandomAccessOp(total_samples);
195   m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
196 
197   TensorRow row;
198   std::vector<uint64_t> out;
199   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
200 
201   for (const auto &t : row) {
202     for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
203       out.push_back(*it);
204       freq[*it]++;
205     }
206   }
207   ASSERT_EQ(num_samples, out.size());
208 
209   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
210   ASSERT_EQ(row.eoe(), true);
211 
212   m_sampler.ResetSampler();
213   out.clear();
214 
215   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
216 
217   for (const auto &t : row) {
218     for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
219       out.push_back(*it);
220       freq[*it]++;
221     }
222   }
223   ASSERT_EQ(num_samples, out.size());
224 
225   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
226   ASSERT_EQ(row.eoe(), true);
227 }
228 
TEST_F(MindDataTestWeightedRandomSampler,TestResetNoReplacement)229 TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) {
230   // num samples to draw.
231   uint64_t num_samples = 1000000;
232   uint64_t total_samples = 1000000;
233   std::vector<double> weights(total_samples, std::rand() % 100);
234   std::vector<uint64_t> freq(total_samples, 0);
235 
236   // create sampler with replacement = true
237   WeightedRandomSamplerRT m_sampler(weights, num_samples, false);
238   DummyRandomAccessOp dummyRandomAccessOp(total_samples);
239   m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
240 
241   TensorRow row;
242   std::vector<uint64_t> out;
243   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
244 
245   for (const auto &t : row) {
246     for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
247       out.push_back(*it);
248       freq[*it]++;
249     }
250   }
251   ASSERT_EQ(num_samples, out.size());
252 
253   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
254   ASSERT_EQ(row.eoe(), true);
255 
256   m_sampler.ResetSampler();
257   out.clear();
258   freq.clear();
259   freq.resize(total_samples, 0);
260   MS_LOG(INFO) << "Resetting sampler";
261 
262   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
263 
264   for (const auto &t : row) {
265     for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
266       out.push_back(*it);
267       freq[*it]++;
268     }
269   }
270   ASSERT_EQ(num_samples, out.size());
271 
272   // Without replacement, each sample only drawn once.
273   for (int i = 0; i < total_samples; i++) {
274     if (freq[i]) {
275       ASSERT_EQ(freq[i], 1);
276     }
277   }
278 
279   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
280   ASSERT_EQ(row.eoe(), true);
281 }
282