1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "common/common.h" 18 #include "minddata/dataset/core/client.h" 19 #include "minddata/dataset/core/global_context.h" 20 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" 21 #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" 22 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 23 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" 24 #include "minddata/dataset/util/status.h" 25 #include "gtest/gtest.h" 26 #include "utils/log_adapter.h" 27 #include "securec.h" 28 29 using namespace mindspore::dataset; 30 31 Status CreateINT64Tensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements, unsigned char *data = nullptr) { 32 TensorShape shape(std::vector<int64_t>(1, num_elements)); 33 RETURN_IF_NOT_OK(Tensor::CreateFromMemory(shape, DataType(DataType::DE_INT64), data, sample_ids)); 34 35 return Status::OK(); 36 } 37 38 class MindDataTestStandAloneSampler : public UT::DatasetOpTesting { 39 protected: 40 class MockStorageOp : public RandomAccessOp { 41 public: 42 MockStorageOp(int64_t val) { 43 // row count is in base class as protected member 44 // GetNumRowsInDataset does not need an override, the default from base class is fine. 45 num_rows_ = val; 46 } 47 }; 48 }; 49 50 TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { 51 std::vector<std::shared_ptr<Tensor>> row; 52 uint64_t res[6][7] = {{0, 3, 6, 9, 12, 15, 18}, {1, 4, 7, 10, 13, 16, 19}, {2, 5, 8, 11, 14, 17, 0}, 53 {0, 17, 4, 10, 14, 8, 15}, {13, 9, 16, 3, 2, 19, 12}, {1, 11, 6, 18, 7, 5, 0}}; 54 for (int i = 0; i < 6; i++) { 55 std::shared_ptr<Tensor> t; 56 Tensor::CreateFromMemory(TensorShape({7}), DataType(DataType::DE_INT64), (unsigned char *)(res[i]), &t); 57 row.push_back(t); 58 } 59 MockStorageOp mock(20); 60 std::shared_ptr<Tensor> tensor; 61 int64_t num_samples = 0; 62 TensorRow sample_row; 63 for (int i = 0; i < 6; i++) { 64 std::shared_ptr<SamplerRT> sampler = 65 std::make_shared<DistributedSamplerRT>(3, i % 3, (i < 3 ? false : true), num_samples); 66 sampler->HandshakeRandomAccessOp(&mock); 67 sampler->GetNextSample(&sample_row); 68 tensor = sample_row[0]; 69 MS_LOG(DEBUG) << (*tensor); 70 if (i < 3) { // This is added due to std::shuffle() 71 EXPECT_TRUE((*tensor) == (*row[i])); 72 } 73 } 74 } 75 76 TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { 77 std::vector<std::shared_ptr<Tensor>> row; 78 MockStorageOp mock(5); 79 uint64_t res[5] = {0, 1, 2, 3, 4}; 80 std::shared_ptr<Tensor> label1, label2; 81 CreateINT64Tensor(&label1, 3, reinterpret_cast<unsigned char *>(res)); 82 CreateINT64Tensor(&label2, 2, reinterpret_cast<unsigned char *>(res + 3)); 83 int64_t num_samples = 0; 84 int64_t start_index = 0; 85 std::shared_ptr<SamplerRT> sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples, 3); 86 87 std::shared_ptr<Tensor> tensor; 88 TensorRow sample_row; 89 sampler->HandshakeRandomAccessOp(&mock); 90 sampler->GetNextSample(&sample_row); 91 tensor = sample_row[0]; 92 EXPECT_TRUE((*tensor) == (*label1)); 93 sampler->GetNextSample(&sample_row); 94 tensor = sample_row[0]; 95 EXPECT_TRUE((*tensor) == (*label2)); 96 sampler->ResetSampler(); 97 sampler->GetNextSample(&sample_row); 98 tensor = sample_row[0]; 99 EXPECT_TRUE((*tensor) == (*label1)); 100 sampler->GetNextSample(&sample_row); 101 tensor = sample_row[0]; 102 EXPECT_TRUE((*tensor) == (*label2)); 103 } 104