1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/common.h" 17 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 18 #include "minddata/dataset/engine/ir/datasetops/source/samplers/distributed_sampler_ir.h" 19 #include "minddata/dataset/engine/ir/datasetops/source/samplers/pk_sampler_ir.h" 20 #include "minddata/dataset/engine/ir/datasetops/source/samplers/prebuilt_sampler_ir.h" 21 #include "minddata/dataset/engine/ir/datasetops/source/samplers/random_sampler_ir.h" 22 #include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h" 23 #include "minddata/dataset/engine/ir/datasetops/source/samplers/sequential_sampler_ir.h" 24 #include "minddata/dataset/engine/ir/datasetops/source/samplers/subset_random_sampler_ir.h" 25 #include "minddata/dataset/engine/ir/datasetops/source/samplers/subset_sampler_ir.h" 26 #include "minddata/dataset/engine/ir/datasetops/source/samplers/weighted_random_sampler_ir.h" 27 #include "minddata/dataset/core/tensor.h" 28 29 using namespace mindspore::dataset; 30 using mindspore::dataset::Tensor; 31 32 class MindDataTestIrSampler : public UT::DatasetOpTesting { 33 protected: 34 }; 35 36 TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) { 37 int64_t num_rows = 30; // dummy variable for number of rows in the dataset 38 std::shared_ptr<SamplerObj> sampl = std::make_shared<DistributedSamplerObj>(2, 1, false, 6, 1, -1, true); 39 EXPECT_NE(sampl, nullptr); 40 std::shared_ptr<SamplerRT> sampler_rt; 41 sampl->SamplerBuild(&sampler_rt); 42 EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6); 43 44 sampl = std::make_shared<PKSamplerObj>(3, false, 0); 45 EXPECT_NE(sampl, nullptr); 46 sampl->SamplerBuild(&sampler_rt); 47 EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), -1); 48 49 sampl = std::make_shared<RandomSamplerObj>(false, 12); 50 EXPECT_NE(sampl, nullptr); 51 sampl->SamplerBuild(&sampler_rt); 52 EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); 53 54 sampl = std::make_shared<SequentialSamplerObj>(0, 10); 55 EXPECT_NE(sampl, nullptr); 56 sampl->SamplerBuild(&sampler_rt); 57 EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10); 58 59 std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; 60 sampl = std::make_shared<WeightedRandomSamplerObj>(weights, 12); 61 EXPECT_NE(sampl, nullptr); 62 sampl->SamplerBuild(&sampler_rt); 63 EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12); 64 65 std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21}; 66 sampl = std::make_shared<SubsetRandomSamplerObj>(indices, 11); 67 EXPECT_NE(sampl, nullptr); 68 sampl->SamplerBuild(&sampler_rt); 69 EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11); 70 71 // Testing chains 72 // Parent and child have num_samples 73 std::shared_ptr<SamplerObj> sampl1 = std::make_shared<WeightedRandomSamplerObj>(weights, 12); 74 EXPECT_NE(sampl1, nullptr); 75 std::shared_ptr<SamplerRT> sampler_rt1; 76 sampl1->SamplerBuild(&sampler_rt1); 77 78 std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SequentialSamplerObj>(0, 10); 79 EXPECT_NE(sampl2, nullptr); 80 std::shared_ptr<SamplerRT> sampler_rt2; 81 sampl2->SamplerBuild(&sampler_rt2); 82 sampler_rt2->AddChild(sampler_rt1); 83 EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10); 84 85 // Parent doesn't have num_samples 86 std::shared_ptr<SamplerObj> sampl3 = std::make_shared<WeightedRandomSamplerObj>(weights, 12); 87 EXPECT_NE(sampl3, nullptr); 88 std::shared_ptr<SamplerRT> sampler_rt3; 89 sampl3->SamplerBuild(&sampler_rt3); 90 91 std::shared_ptr<SamplerObj> sampl4 = std::make_shared<SubsetRandomSamplerObj>(indices, 0); 92 EXPECT_NE(sampl4, nullptr); 93 std::shared_ptr<SamplerRT> sampler_rt4; 94 sampl4->SamplerBuild(&sampler_rt4); 95 sampler_rt4->AddChild(sampler_rt3); 96 EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 11); 97 98 // Child doesn't have num_samples 99 std::shared_ptr<SamplerObj> sampl5 = std::make_shared<RandomSamplerObj>(false, 0); 100 EXPECT_NE(sampl5, nullptr); 101 std::shared_ptr<SamplerRT> sampler_rt5; 102 sampl5->SamplerBuild(&sampler_rt5); 103 104 std::shared_ptr<SamplerObj> sampl6 = std::make_shared<PKSamplerObj>(3, false, 7); 105 EXPECT_NE(sampl6, nullptr); 106 std::shared_ptr<SamplerRT> sampler_rt6; 107 sampl6->SamplerBuild(&sampler_rt6); 108 sampler_rt6->AddChild(sampler_rt5); 109 EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), -1); 110 } 111 112 TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) { 113 std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; 114 std::shared_ptr<SamplerObj> sampl1 = std::make_shared<SubsetRandomSamplerObj>(indices, 0); 115 EXPECT_FALSE(indices.empty()); 116 std::shared_ptr<SamplerRT> sampler_rt = nullptr; 117 sampl1->SamplerBuild(&sampler_rt); 118 EXPECT_NE(sampler_rt, nullptr); 119 std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), 0); 120 EXPECT_TRUE(indices.empty()); 121 std::shared_ptr<SamplerRT> sampler_rt2 = nullptr; 122 sampl2->SamplerBuild(&sampler_rt2); 123 EXPECT_NE(sampler_rt, nullptr); 124 } 125