1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/common.h" 17 #include "gtest/gtest.h" 18 19 #include "minddata/dataset/include/dataset/constants.h" 20 #include "minddata/dataset/core/tensor.h" 21 22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 23 #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" 24 25 #include <vector> 26 #include <unordered_set> 27 28 using namespace mindspore::dataset; 29 30 class MindDataTestSubsetRandomSampler : public UT::Common { 31 public: 32 class DummyRandomAccessOp : public RandomAccessOp { 33 public: 34 DummyRandomAccessOp(int64_t num_rows) { 35 num_rows_ = num_rows; // base class 36 }; 37 }; 38 }; 39 40 TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { 41 std::vector<int64_t> in({0, 1, 2, 3, 4}); 42 std::unordered_set<int64_t> in_set(in.begin(), in.end()); 43 int64_t num_samples = 0; 44 SubsetRandomSamplerRT sampler(in, num_samples); 45 46 DummyRandomAccessOp dummyRandomAccessOp(5); 47 sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 48 49 TensorRow row; 50 std::vector<int64_t> out; 51 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); 52 53 for (const auto &t : row) { 54 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { 55 out.push_back(*it); 56 } 57 } 58 ASSERT_EQ(in.size(), out.size()); 59 for (int i = 0; i < in.size(); i++) { 60 ASSERT_NE(in_set.find(out[i]), in_set.end()); 61 } 62 63 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); 64 ASSERT_EQ(row.eoe(), true); 65 } 66 67 TEST_F(MindDataTestSubsetRandomSampler, TestGetNextSample) { 68 int64_t total_samples = 100000 - 5; 69 int64_t samples_per_tensor = 10; 70 int64_t num_samples = 0; 71 std::vector<int64_t> input(total_samples, 1); 72 SubsetRandomSamplerRT sampler(input, num_samples, samples_per_tensor); 73 74 DummyRandomAccessOp dummyRandomAccessOp(total_samples); 75 sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 76 77 TensorRow row; 78 std::vector<int64_t> out; 79 80 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); 81 int epoch = 0; 82 while (!row.eoe()) { 83 epoch++; 84 for (const auto &t : row) { 85 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { 86 out.push_back(*it); 87 } 88 } 89 90 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); 91 } 92 93 ASSERT_EQ(epoch, (total_samples + samples_per_tensor - 1) / samples_per_tensor); 94 ASSERT_EQ(input.size(), out.size()); 95 } 96 97 TEST_F(MindDataTestSubsetRandomSampler, TestReset) { 98 std::vector<int64_t> in({0, 1, 2, 3, 4}); 99 std::unordered_set<int64_t> in_set(in.begin(), in.end()); 100 int64_t num_samples = 0; 101 SubsetRandomSamplerRT sampler(in, num_samples); 102 103 DummyRandomAccessOp dummyRandomAccessOp(5); 104 sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 105 106 TensorRow row; 107 std::vector<int64_t> out; 108 109 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); 110 for (const auto &t : row) { 111 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { 112 out.push_back(*it); 113 } 114 } 115 ASSERT_EQ(in.size(), out.size()); 116 for (int i = 0; i < in.size(); i++) { 117 ASSERT_NE(in_set.find(out[i]), in_set.end()); 118 } 119 120 sampler.ResetSampler(); 121 122 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); 123 ASSERT_EQ(row.eoe(), false); 124 out.clear(); 125 for (const auto &t : row) { 126 for (auto it = t->begin<int64_t>(); it != t->end<int64_t>(); it++) { 127 out.push_back(*it); 128 } 129 } 130 ASSERT_EQ(in.size(), out.size()); 131 for (int i = 0; i < in.size(); i++) { 132 ASSERT_NE(in_set.find(out[i]), in_set.end()); 133 } 134 135 ASSERT_EQ(sampler.GetNextSample(&row), Status::OK()); 136 ASSERT_EQ(row.eoe(), true); 137 } 138