1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ 19 20 #include <memory> 21 #include <vector> 22 23 namespace mindspore { 24 namespace dataset { 25 26 // Forward declare 27 class SamplerObj; 28 29 // Abstract class to represent a sampler in the data pipeline. 30 /// \class Sampler samplers.h 31 /// \brief An abstract base class to represent a sampler in the data pipeline. 32 class Sampler : std::enable_shared_from_this<Sampler> { 33 friend class AlbumDataset; 34 friend class CelebADataset; 35 friend class Cifar10Dataset; 36 friend class Cifar100Dataset; 37 friend class CityscapesDataset; 38 friend class CLUEDataset; 39 friend class CocoDataset; 40 friend class CSVDataset; 41 friend class DIV2KDataset; 42 friend class FlickrDataset; 43 friend class ImageFolderDataset; 44 friend class ManifestDataset; 45 friend class MindDataDataset; 46 friend class MnistDataset; 47 friend class RandomDataDataset; 48 friend class SBUDataset; 49 friend class TextFileDataset; 50 friend class TFRecordDataset; 51 friend class USPSDataset; 52 friend class VOCDataset; 53 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t); 54 55 public: 56 /// \brief Constructor Sampler()57 Sampler() {} 58 59 /// \brief Destructor 60 ~Sampler() = default; 61 62 /// \brief A virtual function to add a child sampler. 63 /// \param[in] child The child sampler to be added as a children of this sampler. AddChild(std::shared_ptr<Sampler> child)64 virtual void AddChild(std::shared_ptr<Sampler> child) { children_.push_back(child); } 65 66 protected: 67 /// \brief Pure virtual function to convert a Sampler class into an IR Sampler object. 68 /// \return shared pointer to the newly created TensorOperation. 69 virtual std::shared_ptr<SamplerObj> Parse() const = 0; 70 71 std::vector<std::shared_ptr<Sampler>> children_; 72 }; 73 74 /// \brief A class to represent a Distributed Sampler in the data pipeline. 75 /// \note A Sampler that accesses a shard of the dataset. 76 class DistributedSampler final : public Sampler { 77 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t); 78 79 public: 80 /// \brief Constructor 81 /// \param[in] num_shards Number of shards to divide the dataset into. 82 /// \param[in] shard_id Shard ID of the current shard within num_shards. 83 /// \param[in] shuffle If true, the indices are shuffled (default=true). 84 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 85 /// \param[in] seed The seed in use when shuffle is true (default=1). 86 /// \param[in] offset The starting position where access to elements in the dataset begins (default=-1). 87 /// \param[in] even_dist If true, each shard would return the same number of rows (default=true). 88 /// If false the total rows returned by all the shards would not have overlap. 89 DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0, 90 uint32_t seed = 1, int64_t offset = -1, bool even_dist = true); 91 /// \brief Destructor. 92 ~DistributedSampler() = default; 93 94 protected: 95 /// \brief The function to convert a Sampler into an IR SamplerObj. 96 /// \return shared pointer to the newly created SamplerObj. 97 std::shared_ptr<SamplerObj> Parse() const override; 98 99 private: 100 int64_t num_shards_; 101 int64_t shard_id_; 102 bool shuffle_; 103 int64_t num_samples_; 104 uint32_t seed_; 105 int64_t offset_; 106 bool even_dist_; 107 }; 108 109 /// \brief A class to represent a PK Sampler in the data pipeline. 110 /// \note Samples K elements for each P class in the dataset. 111 /// This will sample all classes. 112 class PKSampler final : public Sampler { 113 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t); 114 115 public: 116 /// \brief Constructor 117 /// \param[in] num_val Number of elements to sample for each class. 118 /// \param[in] shuffle If true, the class IDs are shuffled (default=false). 119 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 120 explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0); 121 122 /// \brief Destructor. 123 ~PKSampler() = default; 124 125 protected: 126 /// \brief The function to convert a Sampler into an IR SamplerObj. 127 /// \return shared pointer to the newly created SamplerObj. 128 std::shared_ptr<SamplerObj> Parse() const override; 129 130 private: 131 int64_t num_val_; 132 bool shuffle_; 133 int64_t num_samples_; 134 }; 135 136 /// \brief A class to represent a Random Sampler in the data pipeline. 137 /// \note Samples the elements randomly. 138 class RandomSampler final : public Sampler { 139 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t); 140 141 public: 142 /// \brief Constructor 143 /// \param[in] replacement If true, put the sample ID back for the next draw (default=false). 144 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 145 explicit RandomSampler(bool replacement = false, int64_t num_samples = 0); 146 147 /// \brief Destructor. 148 ~RandomSampler() = default; 149 150 protected: 151 /// \brief The function to convert a Sampler into an IR SamplerObj. 152 /// \return shared pointer to the newly created SamplerObj. 153 std::shared_ptr<SamplerObj> Parse() const override; 154 155 private: 156 bool replacement_; 157 int64_t num_samples_; 158 }; 159 160 /// \brief A class to represent a Sequential Sampler in the data pipeline. 161 /// \note Samples the dataset elements sequentially, same as not having a sampler. 162 class SequentialSampler final : public Sampler { 163 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t); 164 165 public: 166 /// \brief Constructor 167 /// \param[in] start_index Index to start sampling at (default=0, start at first id). 168 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 169 explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0); 170 171 /// \brief Destructor. 172 ~SequentialSampler() = default; 173 174 protected: 175 /// \brief The function to convert a Sampler into an IR SamplerObj. 176 /// \return shared pointer to the newly created SamplerObj. 177 std::shared_ptr<SamplerObj> Parse() const override; 178 179 private: 180 int64_t start_index_; 181 int64_t num_samples_; 182 }; 183 184 /// \brief A class to represent a Subset Sampler in the data pipeline. 185 /// \note Samples the elements from a sequence of indices. 186 class SubsetSampler : public Sampler { 187 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t); 188 189 public: 190 /// \brief Constructor 191 /// \param[in] indices A vector sequence of indices. 192 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 193 explicit SubsetSampler(std::vector<int64_t> indices, int64_t num_samples = 0); 194 195 /// \brief Destructor. 196 ~SubsetSampler() = default; 197 198 protected: 199 /// \brief The function to convert a Sampler into an IR SamplerObj. 200 /// \return shared pointer to the newly created SamplerObj. 201 std::shared_ptr<SamplerObj> Parse() const override; 202 203 std::vector<int64_t> indices_; 204 int64_t num_samples_; 205 }; 206 207 /// \brief A class to represent a Subset Random Sampler in the data pipeline. 208 /// \note Samples the elements randomly from a sequence of indices. 209 class SubsetRandomSampler final : public SubsetSampler { 210 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t); 211 212 public: 213 /// \brief Constructor 214 /// \param[in] indices A vector sequence of indices. 215 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 216 explicit SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0); 217 218 /// \brief Destructor. 219 ~SubsetRandomSampler() = default; 220 221 protected: 222 /// \brief The function to convert a Sampler into an IR SamplerObj. 223 /// \return shared pointer to the newly created SamplerObj. 224 std::shared_ptr<SamplerObj> Parse() const override; 225 }; 226 227 /// \brief A class to represent a Weighted Random Sampler in the data pipeline. 228 /// \note Samples the elements from [0, len(weights) - 1] randomly with the given 229 /// weights (probabilities). 230 class WeightedRandomSampler final : public Sampler { 231 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t); 232 233 public: 234 /// \brief Constructor 235 /// \param[in] weights A vector sequence of weights, not necessarily summing up to 1. 236 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 237 /// \param[in] replacement If true, put the sample ID back for the next draw (default=true). 238 explicit WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true); 239 240 /// \brief Destructor. 241 ~WeightedRandomSampler() = default; 242 243 protected: 244 /// \brief The function to convert a Sampler into an IR SamplerObj. 245 /// \return shared pointer to the newly created SamplerObj. 246 std::shared_ptr<SamplerObj> Parse() const override; 247 248 private: 249 std::vector<double> weights_; 250 int64_t num_samples_; 251 bool replacement_; 252 }; 253 254 } // namespace dataset 255 } // namespace mindspore 256 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ 257