• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/common.h"
18 #include "minddata/dataset/core/client.h"
19 #include "minddata/dataset/core/global_context.h"
20 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
21 #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
23 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
24 #include "minddata/dataset/util/status.h"
25 #include "gtest/gtest.h"
26 #include "utils/log_adapter.h"
27 #include "securec.h"
28 
29 using namespace mindspore::dataset;
30 
31 Status CreateINT64Tensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements, unsigned char *data = nullptr) {
32   TensorShape shape(std::vector<int64_t>(1, num_elements));
33   RETURN_IF_NOT_OK(Tensor::CreateFromMemory(shape, DataType(DataType::DE_INT64), data, sample_ids));
34 
35   return Status::OK();
36 }
37 
38 class MindDataTestStandAloneSampler : public UT::DatasetOpTesting {
39  protected:
40   class MockStorageOp : public RandomAccessOp {
41    public:
42     MockStorageOp(int64_t val) {
43       // row count is in base class as protected member
44       // GetNumRowsInDataset does not need an override, the default from base class is fine.
45       num_rows_ = val;
46     }
47   };
48 };
49 
50 TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) {
51   std::vector<std::shared_ptr<Tensor>> row;
52   uint64_t res[6][7] = {{0, 3, 6, 9, 12, 15, 18},  {1, 4, 7, 10, 13, 16, 19}, {2, 5, 8, 11, 14, 17, 0},
53                         {0, 17, 4, 10, 14, 8, 15}, {13, 9, 16, 3, 2, 19, 12}, {1, 11, 6, 18, 7, 5, 0}};
54   for (int i = 0; i < 6; i++) {
55     std::shared_ptr<Tensor> t;
56     Tensor::CreateFromMemory(TensorShape({7}), DataType(DataType::DE_INT64), (unsigned char *)(res[i]), &t);
57     row.push_back(t);
58   }
59   MockStorageOp mock(20);
60   std::shared_ptr<Tensor> tensor;
61   int64_t num_samples = 0;
62   TensorRow sample_row;
63   for (int i = 0; i < 6; i++) {
64     std::shared_ptr<SamplerRT> sampler =
65       std::make_shared<DistributedSamplerRT>(3, i % 3, (i < 3 ? false : true), num_samples);
66     sampler->HandshakeRandomAccessOp(&mock);
67     sampler->GetNextSample(&sample_row);
68     tensor = sample_row[0];
69     MS_LOG(DEBUG) << (*tensor);
70     if (i < 3) {  // This is added due to std::shuffle()
71       EXPECT_TRUE((*tensor) == (*row[i]));
72     }
73   }
74 }
75 
76 TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) {
77   std::vector<std::shared_ptr<Tensor>> row;
78   MockStorageOp mock(5);
79   uint64_t res[5] = {0, 1, 2, 3, 4};
80   std::shared_ptr<Tensor> label1, label2;
81   CreateINT64Tensor(&label1, 3, reinterpret_cast<unsigned char *>(res));
82   CreateINT64Tensor(&label2, 2, reinterpret_cast<unsigned char *>(res + 3));
83   int64_t num_samples = 0;
84   int64_t start_index = 0;
85   std::shared_ptr<SamplerRT> sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples, 3);
86 
87   std::shared_ptr<Tensor> tensor;
88   TensorRow sample_row;
89   sampler->HandshakeRandomAccessOp(&mock);
90   sampler->GetNextSample(&sample_row);
91   tensor = sample_row[0];
92   EXPECT_TRUE((*tensor) == (*label1));
93   sampler->GetNextSample(&sample_row);
94   tensor = sample_row[0];
95   EXPECT_TRUE((*tensor) == (*label2));
96   sampler->ResetSampler();
97   sampler->GetNextSample(&sample_row);
98   tensor = sample_row[0];
99   EXPECT_TRUE((*tensor) == (*label1));
100   sampler->GetNextSample(&sample_row);
101   tensor = sample_row[0];
102   EXPECT_TRUE((*tensor) == (*label2));
103 }
104