• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/common.h"
17 #include "gtest/gtest.h"
18 
19 #include "minddata/dataset/include/dataset/constants.h"
20 #include "minddata/dataset/core/tensor.h"
21 
22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
24 #include "utils/log_adapter.h"
25 
26 #include <vector>
27 #include <unordered_set>
28 
29 using namespace mindspore::dataset;
30 using mindspore::LogStream;
31 using mindspore::ExceptionType::NoExceptionType;
32 using mindspore::MsLogLevel::INFO;
33 
34 class MindDataTestDistributedSampler : public UT::Common {
35  public:
36   class DummyRandomAccessOp : public RandomAccessOp {
37    public:
38     DummyRandomAccessOp(uint64_t num_rows) {
39       // row count is in base class as protected member
40       // GetNumRowsInDataset does not need an override, the default from base class is fine.
41       num_rows_ = num_rows;
42     }
43   };
44 };
45 
46 TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) {
47   // num samples to draw.
48   uint64_t num_samples = 7;
49 
50   // create sampler with replacement = true
51   DistributedSamplerRT m_sampler(2, 0, false, num_samples, 0, -1, false);
52   DummyRandomAccessOp dummyRandomAccessOp(num_samples);
53   m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
54 
55   TensorRow row;
56   std::vector<uint64_t> out;
57   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
58   for (const auto &t : row) {
59     for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
60       out.push_back(*it);
61     }
62   }
63 
64   ASSERT_EQ(4, out.size());
65 
66   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
67   ASSERT_EQ(row.eoe(), true);
68 }
69 
70 TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) {
71   // num samples to draw.
72   uint64_t num_samples = 7;
73 
74   // create sampler with replacement = true
75   DistributedSamplerRT m_sampler(2, 1, false, num_samples, 0, -1, false);
76   DummyRandomAccessOp dummyRandomAccessOp(num_samples);
77   m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
78 
79   TensorRow row;
80   std::vector<uint64_t> out;
81   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
82 
83   for (const auto &t : row) {
84     for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
85       out.push_back(*it);
86     }
87   }
88 
89   ASSERT_EQ(3, out.size());
90 
91   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
92   ASSERT_EQ(row.eoe(), true);
93 }
94 
95 TEST_F(MindDataTestDistributedSampler, TestThreeShards) {
96   // num samples to draw.
97   uint64_t num_samples = 2;
98 
99   // create sampler with replacement = true
100   DistributedSamplerRT m_sampler(3, 2, false, num_samples, 0, -1, false);
101   DummyRandomAccessOp dummyRandomAccessOp(num_samples);
102   m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);
103 
104   TensorRow row;
105   std::vector<uint64_t> out;
106   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
107 
108   for (const auto &t : row) {
109     for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
110       out.push_back(*it);
111     }
112   }
113 
114   ASSERT_EQ(0, out.size());
115 
116   ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
117   ASSERT_EQ(row.eoe(), true);
118 }
119