• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/common.h"
17 #include "gtest/gtest.h"
18 
19 #include "minddata/dataset/include/dataset/constants.h"
20 #include "minddata/dataset/core/tensor.h"
21 
22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
24 
25 #include <vector>
26 #include <unordered_set>
27 
28 using namespace mindspore::dataset;
29 
30 class MindDataTestSubsetRandomSampler : public UT::Common {
31  public:
32   class DummyRandomAccessOp : public RandomAccessOp {
33    public:
DummyRandomAccessOp(int64_t num_rows)34     DummyRandomAccessOp(int64_t num_rows) {
35       num_rows_ = num_rows;  // base class
36     };
37   };
38 };
39 
TEST_F(MindDataTestSubsetRandomSampler,TestAllAtOnce)40 TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) {
41   std::vector<int64_t> in({0, 1, 2, 3, 4});
42   std::unordered_set<int64_t> in_set(in.begin(), in.end());
43   int64_t num_samples = 0;
44   SubsetRandomSamplerRT sampler(in, num_samples);
45 
46   DummyRandomAccessOp dummyRandomAccessOp(5);
47   sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
48 
49   TensorRow row;
50   std::vector<int64_t> out;
51   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
52 
53   for (const auto &t : row) {
54     for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
55       out.push_back(*it);
56     }
57   }
58   ASSERT_EQ(in.size(), out.size());
59   for (int i = 0; i < in.size(); i++) {
60     ASSERT_NE(in_set.find(out[i]), in_set.end());
61   }
62 
63   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
64   ASSERT_EQ(row.eoe(), true);
65 }
66 
TEST_F(MindDataTestSubsetRandomSampler,TestGetNextSample)67 TEST_F(MindDataTestSubsetRandomSampler, TestGetNextSample) {
68   int64_t total_samples = 100000 - 5;
69   int64_t samples_per_tensor = 10;
70   int64_t num_samples = 0;
71   std::vector<int64_t> input(total_samples, 1);
72   SubsetRandomSamplerRT sampler(input, num_samples, samples_per_tensor);
73 
74   DummyRandomAccessOp dummyRandomAccessOp(total_samples);
75   sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
76 
77   TensorRow row;
78   std::vector<int64_t> out;
79 
80   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
81   int epoch = 0;
82   while (!row.eoe()) {
83     epoch++;
84     for (const auto &t : row) {
85       for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
86         out.push_back(*it);
87       }
88     }
89 
90     ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
91   }
92 
93   ASSERT_EQ(epoch, (total_samples + samples_per_tensor - 1) / samples_per_tensor);
94   ASSERT_EQ(input.size(), out.size());
95 }
96 
TEST_F(MindDataTestSubsetRandomSampler,TestReset)97 TEST_F(MindDataTestSubsetRandomSampler, TestReset) {
98   std::vector<int64_t> in({0, 1, 2, 3, 4});
99   std::unordered_set<int64_t> in_set(in.begin(), in.end());
100   int64_t num_samples = 0;
101   SubsetRandomSamplerRT sampler(in, num_samples);
102 
103   DummyRandomAccessOp dummyRandomAccessOp(5);
104   sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
105 
106   TensorRow row;
107   std::vector<int64_t> out;
108 
109   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
110   for (const auto &t : row) {
111     for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
112       out.push_back(*it);
113     }
114   }
115   ASSERT_EQ(in.size(), out.size());
116   for (int i = 0; i < in.size(); i++) {
117     ASSERT_NE(in_set.find(out[i]), in_set.end());
118   }
119 
120   sampler.ResetSampler();
121 
122   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
123   ASSERT_EQ(row.eoe(), false);
124   out.clear();
125   for (const auto &t : row) {
126     for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
127       out.push_back(*it);
128     }
129   }
130   ASSERT_EQ(in.size(), out.size());
131   for (int i = 0; i < in.size(); i++) {
132     ASSERT_NE(in_set.find(out[i]), in_set.end());
133   }
134 
135   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
136   ASSERT_EQ(row.eoe(), true);
137 }
138