1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <fstream> 17 #include <iostream> 18 #include <memory> 19 #include <string> 20 21 #include "utils/ms_utils.h" 22 #include "common/common.h" 23 #include "minddata/dataset/core/client.h" 24 #include "minddata/dataset/core/global_context.h" 25 #include "minddata/dataset/engine/datasetops/source/mnist_op.h" 26 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" 27 #include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" 28 #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" 29 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 30 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" 31 #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" 32 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" 33 #include "minddata/dataset/include/dataset/datasets.h" 34 #include "minddata/dataset/util/path.h" 35 #include "minddata/dataset/util/status.h" 36 #include "gtest/gtest.h" 37 #include "utils/log_adapter.h" 38 #include "securec.h" 39 40 namespace common = mindspore::common; 41 using namespace mindspore::dataset; 42 using mindspore::LogStream; 43 using mindspore::ExceptionType::NoExceptionType; 44 using mindspore::MsLogLevel::ERROR; 45 46 std::shared_ptr<RepeatOp> Repeat(int repeat_cnt); 47 48 std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); 49 50 class MindDataTestMnistSampler : public UT::DatasetOpTesting { 51 protected: 52 }; 53 54 TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) { 55 std::string folder_path = datasets_root_path_ + "/testMnistData/"; 56 int64_t num_samples = 10; 57 int64_t start_index = 0; 58 std::shared_ptr<Dataset> ds = 59 Mnist(folder_path, "all", std::make_shared<SequentialSampler>(start_index, num_samples)); 60 EXPECT_NE(ds, nullptr); 61 ds = ds->Repeat(2); 62 EXPECT_NE(ds, nullptr); 63 std::shared_ptr<Iterator> iter = ds->CreateIterator(); 64 EXPECT_NE(iter, nullptr); 65 uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; 66 std::unordered_map<std::string, mindspore::MSTensor> row; 67 ASSERT_OK(iter->GetNextRow(&row)); 68 uint32_t label_idx; 69 uint64_t i = 0; 70 while (row.size() != 0) { 71 auto image = row["image"]; 72 auto label = row["label"]; 73 MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); 74 std::shared_ptr<Tensor> de_label; 75 ASSERT_OK(Tensor::CreateFromMSTensor(label, &de_label)); 76 ASSERT_OK(de_label->GetItemAt<uint32_t>(&label_idx, {})); 77 MS_LOG(INFO) << "Tensor label value: " << label_idx; 78 EXPECT_EQ(label_idx, res[i % 10]); 79 ASSERT_OK(iter->GetNextRow(&row)); 80 i++; 81 } 82 83 EXPECT_EQ(i, 20); 84 iter->Stop(); 85 } 86 87 TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) { 88 std::string folder_path = datasets_root_path_ + "/testMnistData/"; 89 int64_t num_samples = 10; 90 int64_t start_index = 0; 91 std::shared_ptr<Dataset> ds = 92 Mnist(folder_path, "all", std::make_shared<SequentialSampler>(start_index, num_samples)); 93 EXPECT_NE(ds, nullptr); 94 ds = ds->Repeat(2); 95 EXPECT_NE(ds, nullptr); 96 ds = ds->Batch(5); 97 EXPECT_NE(ds, nullptr); 98 std::shared_ptr<Iterator> iter = ds->CreateIterator(); 99 EXPECT_NE(iter, nullptr); 100 std::vector<std::vector<uint32_t>> expected = {{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}}; 101 std::unordered_map<std::string, mindspore::MSTensor> row; 102 ASSERT_OK(iter->GetNextRow(&row)); 103 uint64_t i = 0; 104 while (row.size() != 0) { 105 auto image = row["image"]; 106 auto label = row["label"]; 107 MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); 108 TEST_MS_LOG_MSTENSOR(INFO, "Tensor label: ", label); 109 std::shared_ptr<Tensor> de_expected_label; 110 ASSERT_OK(Tensor::CreateFromVector(expected[i % 4], &de_expected_label)); 111 mindspore::MSTensor expected_label = 112 mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_label)); 113 EXPECT_MSTENSOR_EQ(label, expected_label); 114 ASSERT_OK(iter->GetNextRow(&row)); 115 i++; 116 } 117 EXPECT_EQ(i, 4); 118 iter->Stop(); 119 } 120