1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/common.h"
17 #include "gtest/gtest.h"
18
19 #include "minddata/dataset/include/dataset/constants.h"
20 #include "minddata/dataset/core/tensor.h"
21
22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
24 #include "utils/log_adapter.h"
25
26 #include <vector>
27 #include <unordered_set>
28
29 using namespace mindspore::dataset;
30 using mindspore::LogStream;
31 using mindspore::ExceptionType::NoExceptionType;
32 using mindspore::MsLogLevel::INFO;
33
34 class MindDataTestDistributedSampler : public UT::Common {
35 public:
36 class DummyRandomAccessOp : public RandomAccessOp {
37 public:
DummyRandomAccessOp(uint64_t num_rows)38 DummyRandomAccessOp(uint64_t num_rows) {
39 // row count is in base class as protected member
40 // GetNumRowsInDataset does not need an override, the default from base class is fine.
41 num_rows_ = num_rows;
42 }
43 };
44 };
45
TEST_F(MindDataTestDistributedSampler,TestTwoShardsOne)46 TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) {
47 // num samples to draw.
48 uint64_t num_samples = 7;
49
50 // create sampler with replacement = true
51 DistributedSamplerRT m_sampler(2, 0, false, num_samples, 0, -1, false);
52 DummyRandomAccessOp dummyRandomAccessOp(num_samples);
53 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
54
55 TensorRow row;
56 std::vector<uint64_t> out;
57 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
58 for (const auto &t : row) {
59 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
60 out.push_back(*it);
61 }
62 }
63
64 ASSERT_EQ(4, out.size());
65
66 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
67 ASSERT_EQ(row.eoe(), true);
68 }
69
TEST_F(MindDataTestDistributedSampler,TestTwoShardsTwo)70 TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) {
71 // num samples to draw.
72 uint64_t num_samples = 7;
73
74 // create sampler with replacement = true
75 DistributedSamplerRT m_sampler(2, 1, false, num_samples, 0, -1, false);
76 DummyRandomAccessOp dummyRandomAccessOp(num_samples);
77 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
78
79 TensorRow row;
80 std::vector<uint64_t> out;
81 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
82
83 for (const auto &t : row) {
84 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
85 out.push_back(*it);
86 }
87 }
88
89 ASSERT_EQ(3, out.size());
90
91 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
92 ASSERT_EQ(row.eoe(), true);
93 }
94
TEST_F(MindDataTestDistributedSampler,TestThreeShards)95 TEST_F(MindDataTestDistributedSampler, TestThreeShards) {
96 // num samples to draw.
97 uint64_t num_samples = 2;
98
99 // create sampler with replacement = true
100 DistributedSamplerRT m_sampler(3, 2, false, num_samples, 0, -1, false);
101 DummyRandomAccessOp dummyRandomAccessOp(num_samples);
102 m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
103
104 TensorRow row;
105 std::vector<uint64_t> out;
106 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
107
108 for (const auto &t : row) {
109 for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
110 out.push_back(*it);
111 }
112 }
113
114 ASSERT_EQ(0, out.size());
115
116 ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
117 ASSERT_EQ(row.eoe(), true);
118 }
119