1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/common.h" 17 #include "gtest/gtest.h" 18 19 #include "minddata/dataset/include/dataset/constants.h" 20 #include "minddata/dataset/core/tensor.h" 21 22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 23 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" 24 #include "utils/log_adapter.h" 25 26 #include <vector> 27 #include <unordered_set> 28 29 using namespace mindspore::dataset; 30 using mindspore::LogStream; 31 using mindspore::ExceptionType::NoExceptionType; 32 using mindspore::MsLogLevel::INFO; 33 34 class MindDataTestDistributedSampler : public UT::Common { 35 public: 36 class DummyRandomAccessOp : public RandomAccessOp { 37 public: 38 DummyRandomAccessOp(uint64_t num_rows) { 39 // row count is in base class as protected member 40 // GetNumRowsInDataset does not need an override, the default from base class is fine. 41 num_rows_ = num_rows; 42 } 43 }; 44 }; 45 46 TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) { 47 // num samples to draw. 48 uint64_t num_samples = 7; 49 50 // create sampler with replacement = true 51 DistributedSamplerRT m_sampler(2, 0, false, num_samples, 0, -1, false); 52 DummyRandomAccessOp dummyRandomAccessOp(num_samples); 53 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 54 55 TensorRow row; 56 std::vector<uint64_t> out; 57 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 58 for (const auto &t : row) { 59 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 60 out.push_back(*it); 61 } 62 } 63 64 ASSERT_EQ(4, out.size()); 65 66 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 67 ASSERT_EQ(row.eoe(), true); 68 } 69 70 TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) { 71 // num samples to draw. 72 uint64_t num_samples = 7; 73 74 // create sampler with replacement = true 75 DistributedSamplerRT m_sampler(2, 1, false, num_samples, 0, -1, false); 76 DummyRandomAccessOp dummyRandomAccessOp(num_samples); 77 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 78 79 TensorRow row; 80 std::vector<uint64_t> out; 81 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 82 83 for (const auto &t : row) { 84 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 85 out.push_back(*it); 86 } 87 } 88 89 ASSERT_EQ(3, out.size()); 90 91 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 92 ASSERT_EQ(row.eoe(), true); 93 } 94 95 TEST_F(MindDataTestDistributedSampler, TestThreeShards) { 96 // num samples to draw. 97 uint64_t num_samples = 2; 98 99 // create sampler with replacement = true 100 DistributedSamplerRT m_sampler(3, 2, false, num_samples, 0, -1, false); 101 DummyRandomAccessOp dummyRandomAccessOp(num_samples); 102 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); 103 104 TensorRow row; 105 std::vector<uint64_t> out; 106 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 107 108 for (const auto &t : row) { 109 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) { 110 out.push_back(*it); 111 } 112 } 113 114 ASSERT_EQ(0, out.size()); 115 116 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK()); 117 ASSERT_EQ(row.eoe(), true); 118 } 119