1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/common.h"
17 #include "gtest/gtest.h"
18
19 #include "minddata/dataset/include/dataset/constants.h"
20 #include "minddata/dataset/core/tensor.h"
21
22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
24
25 #include <vector>
26 #include <unordered_set>
27
28 using namespace mindspore::dataset;
29
30 class MindDataTestSubsetSampler : public UT::Common {
31 public:
32 class DummyRandomAccessOp : public RandomAccessOp {
33 public:
DummyRandomAccessOp(int64_t num_rows)34 DummyRandomAccessOp(int64_t num_rows) {
35 num_rows_ = num_rows; // base class
36 };
37 };
38 };
39
TEST_F(MindDataTestSubsetSampler,TestAllAtOnce)40 TEST_F(MindDataTestSubsetSampler, TestAllAtOnce) {
41 std::vector<int64_t> in({3, 1, 4, 0, 1});
42 std::unordered_set<int64_t> in_set(in.begin(), in.end());
43 int64_t num_samples = 0;
44 SubsetSamplerRT sampler(in, num_samples);
45
46 DummyRandomAccessOp dummyRandomAccessOp(5);
47 sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
48
49 TensorRow row;
50 std::vector<int64_t> out;
51 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
52
53 for (const auto &t : row) {
54 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
55 out.push_back(*it);
56 }
57 }
58 ASSERT_EQ(in.size(), out.size());
59 for (int i = 0; i < in.size(); i++) {
60 ASSERT_EQ(in[i], out[i]);
61 }
62
63 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
64 ASSERT_EQ(row.eoe(), true);
65 }
66
TEST_F(MindDataTestSubsetSampler,TestGetNextSample)67 TEST_F(MindDataTestSubsetSampler, TestGetNextSample) {
68 int64_t total_samples = 100000 - 5;
69 int64_t samples_per_tensor = 10;
70 int64_t num_samples = 0;
71 std::vector<int64_t> input(total_samples, 1);
72 SubsetSamplerRT sampler(input, num_samples, samples_per_tensor);
73
74 DummyRandomAccessOp dummyRandomAccessOp(total_samples);
75 sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
76
77 TensorRow row;
78 std::vector<int64_t> out;
79
80 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
81 int epoch = 0;
82 while (!row.eoe()) {
83 epoch++;
84
85 for (const auto &t : row) {
86 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
87 out.push_back(*it);
88 }
89 }
90
91 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
92 }
93
94 ASSERT_EQ(epoch, (total_samples + samples_per_tensor - 1) / samples_per_tensor);
95 ASSERT_EQ(input.size(), out.size());
96 }
97
TEST_F(MindDataTestSubsetSampler,TestReset)98 TEST_F(MindDataTestSubsetSampler, TestReset) {
99 std::vector<int64_t> in({0, 1, 2, 3, 4});
100 std::unordered_set<int64_t> in_set(in.begin(), in.end());
101 int64_t num_samples = 0;
102 SubsetSamplerRT sampler(in, num_samples);
103
104 DummyRandomAccessOp dummyRandomAccessOp(5);
105 sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
106
107 TensorRow row;
108 std::vector<int64_t> out;
109
110 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
111
112 for (const auto &t : row) {
113 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
114 out.push_back(*it);
115 }
116 }
117 ASSERT_EQ(in.size(), out.size());
118 for (int i = 0; i < in.size(); i++) {
119 ASSERT_EQ(in[i], out[i]);
120 }
121
122 sampler.ResetSampler();
123
124 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
125 ASSERT_EQ(row.eoe(), false);
126
127 out.clear();
128 for (const auto &t : row) {
129 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
130 out.push_back(*it);
131 }
132 }
133 ASSERT_EQ(in.size(), out.size());
134 for (int i = 0; i < in.size(); i++) {
135 ASSERT_EQ(in[i], out[i]);
136 }
137
138 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
139 ASSERT_EQ(row.eoe(), true);
140 }
141