• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/common.h"
17 #include "gtest/gtest.h"
18 
19 #include "minddata/dataset/include/dataset/constants.h"
20 #include "minddata/dataset/core/tensor.h"
21 
22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h"
24 
25 #include <vector>
26 #include <unordered_set>
27 
28 using namespace mindspore::dataset;
29 
30 class MindDataTestSubsetSampler : public UT::Common {
31  public:
32   class DummyRandomAccessOp : public RandomAccessOp {
33    public:
34     DummyRandomAccessOp(int64_t num_rows) {
35       num_rows_ = num_rows;  // base class
36     };
37   };
38 };
39 
40 TEST_F(MindDataTestSubsetSampler, TestAllAtOnce) {
41   std::vector<int64_t> in({3, 1, 4, 0, 1});
42   std::unordered_set<int64_t> in_set(in.begin(), in.end());
43   int64_t num_samples = 0;
44   SubsetSamplerRT sampler(in, num_samples);
45 
46   DummyRandomAccessOp dummyRandomAccessOp(5);
47   sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
48 
49   TensorRow row;
50   std::vector<int64_t> out;
51   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
52 
53   for (const auto &t : row) {
54     for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
55       out.push_back(*it);
56     }
57   }
58   ASSERT_EQ(in.size(), out.size());
59   for (int i = 0; i < in.size(); i++) {
60     ASSERT_EQ(in[i], out[i]);
61   }
62 
63   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
64   ASSERT_EQ(row.eoe(), true);
65 }
66 
67 TEST_F(MindDataTestSubsetSampler, TestGetNextSample) {
68   int64_t total_samples = 100000 - 5;
69   int64_t samples_per_tensor = 10;
70   int64_t num_samples = 0;
71   std::vector<int64_t> input(total_samples, 1);
72   SubsetSamplerRT sampler(input, num_samples, samples_per_tensor);
73 
74   DummyRandomAccessOp dummyRandomAccessOp(total_samples);
75   sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
76 
77   TensorRow row;
78   std::vector<int64_t> out;
79 
80   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
81   int epoch = 0;
82   while (!row.eoe()) {
83     epoch++;
84 
85     for (const auto &t : row) {
86       for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
87         out.push_back(*it);
88       }
89     }
90 
91     ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
92   }
93 
94   ASSERT_EQ(epoch, (total_samples + samples_per_tensor - 1) / samples_per_tensor);
95   ASSERT_EQ(input.size(), out.size());
96 }
97 
98 TEST_F(MindDataTestSubsetSampler, TestReset) {
99   std::vector<int64_t> in({0, 1, 2, 3, 4});
100   std::unordered_set<int64_t> in_set(in.begin(), in.end());
101   int64_t num_samples = 0;
102   SubsetSamplerRT sampler(in, num_samples);
103 
104   DummyRandomAccessOp dummyRandomAccessOp(5);
105   sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
106 
107   TensorRow row;
108   std::vector<int64_t> out;
109 
110   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
111 
112   for (const auto &t : row) {
113     for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
114       out.push_back(*it);
115     }
116   }
117   ASSERT_EQ(in.size(), out.size());
118   for (int i = 0; i < in.size(); i++) {
119     ASSERT_EQ(in[i], out[i]);
120   }
121 
122   sampler.ResetSampler();
123 
124   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
125   ASSERT_EQ(row.eoe(), false);
126 
127   out.clear();
128   for (const auto &t : row) {
129     for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) {
130       out.push_back(*it);
131     }
132   }
133   ASSERT_EQ(in.size(), out.size());
134   for (int i = 0; i < in.size(); i++) {
135     ASSERT_EQ(in[i], out[i]);
136   }
137 
138   ASSERT_EQ(sampler.GetNextSample(&row), Status::OK());
139   ASSERT_EQ(row.eoe(), true);
140 }
141