1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/common.h"
17 #include "gtest/gtest.h"
18
19 #include "minddata/dataset/include/dataset/constants.h"
20 #include "minddata/dataset/core/tensor.h"
21
22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
24
25 #include <vector>
26 #include <unordered_set>
27
28 using namespace mindspore::dataset;
29
30 class MindDataTestSubsetRandomSampler : public UT::Common {
31 public:
32 class DummyRandomAccessOp : public RandomAccessOp {
33 public:
DummyRandomAccessOp(int64_t num_rows)34 DummyRandomAccessOp(int64_t num_rows) {
35 num_rows_ = num_rows; // base class
36 };
37 };
38 };
39
TEST_F(MindDataTestSubsetRandomSampler,TestAllAtOnce)40 TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
41 std::vector<int64_t> in({0, 1, 2, 3, 4});
42 std::unordered_set<int64_t> in_set(in.begin(), in.end());
43 int64_t num_samples = 0;
44 SubsetRandomSamplerRT sampler(in, num_samples);
45
46 DummyRandomAccessOp dummyRandomAccessOp(5);
47 sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
48
49 TensorRow row;
50 std::vector<int64_t> out;
51 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
52
53 for (const auto &t : row) {
54 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
55 out.push_back(*it);
56 }
57 }
58 ASSERT_EQ(in.size(), out.size());
59 for (int i = 0; i < in.size(); i++) {
60 ASSERT_NE(in_set.find(out[i]), in_set.end());
61 }
62
63 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
64 ASSERT_EQ(row.eoe(), true);
65 }
66
TEST_F(MindDataTestSubsetRandomSampler,TestGetNextSample)67 TEST_F(MindDataTestSubsetRandomSampler, TestGetNextSample) {
68 int64_t total_samples = 100000 - 5;
69 int64_t samples_per_tensor = 10;
70 int64_t num_samples = 0;
71 std::vector<int64_t> input(total_samples, 1);
72 SubsetRandomSamplerRT sampler(input, num_samples, samples_per_tensor);
73
74 DummyRandomAccessOp dummyRandomAccessOp(total_samples);
75 sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
76
77 TensorRow row;
78 std::vector<int64_t> out;
79
80 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
81 int epoch = 0;
82 while (!row.eoe()) {
83 epoch++;
84 for (const auto &t : row) {
85 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
86 out.push_back(*it);
87 }
88 }
89
90 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
91 }
92
93 ASSERT_EQ(epoch, (total_samples + samples_per_tensor - 1) / samples_per_tensor);
94 ASSERT_EQ(input.size(), out.size());
95 }
96
TEST_F(MindDataTestSubsetRandomSampler,TestReset)97 TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
98 std::vector<int64_t> in({0, 1, 2, 3, 4});
99 std::unordered_set<int64_t> in_set(in.begin(), in.end());
100 int64_t num_samples = 0;
101 SubsetRandomSamplerRT sampler(in, num_samples);
102
103 DummyRandomAccessOp dummyRandomAccessOp(5);
104 sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
105
106 TensorRow row;
107 std::vector<int64_t> out;
108
109 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
110 for (const auto &t : row) {
111 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
112 out.push_back(*it);
113 }
114 }
115 ASSERT_EQ(in.size(), out.size());
116 for (int i = 0; i < in.size(); i++) {
117 ASSERT_NE(in_set.find(out[i]), in_set.end());
118 }
119
120 sampler.ResetSampler();
121
122 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
123 ASSERT_EQ(row.eoe(), false);
124 out.clear();
125 for (const auto &t : row) {
126 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
127 out.push_back(*it);
128 }
129 }
130 ASSERT_EQ(in.size(), out.size());
131 for (int i = 0; i < in.size(); i++) {
132 ASSERT_NE(in_set.find(out[i]), in_set.end());
133 }
134
135 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
136 ASSERT_EQ(row.eoe(), true);
137 }
138