1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ 19 20 #include <memory> 21 #include <vector> 22 23 #include "include/api/types.h" 24 #include "include/api/status.h" 25 26 namespace mindspore { 27 namespace dataset { 28 // Forward declare 29 class SamplerObj; 30 31 // Abstract class to represent a sampler in the data pipeline. 32 /// \class Sampler samplers.h 33 /// \brief An abstract base class to represent a sampler in the data pipeline. 34 class DATASET_API Sampler : std::enable_shared_from_this<Sampler> { 35 friend class AlbumDataset; 36 friend class Caltech256Dataset; 37 friend class CelebADataset; 38 friend class Cifar10Dataset; 39 friend class Cifar100Dataset; 40 friend class CityscapesDataset; 41 friend class CLUEDataset; 42 friend class CMUArcticDataset; 43 friend class CocoDataset; 44 friend class CSVDataset; 45 friend class DIV2KDataset; 46 friend class EMnistDataset; 47 friend class FakeImageDataset; 48 friend class FashionMnistDataset; 49 friend class FlickrDataset; 50 friend class Food101Dataset; 51 friend class GTZANDataset; 52 friend class ImageFolderDataset; 53 friend class IMDBDataset; 54 friend class KITTIDataset; 55 friend class KMnistDataset; 56 friend class LFWDataset; 57 friend class LibriTTSDataset; 58 friend class LJSpeechDataset; 59 friend class LSUNDataset; 60 friend class ManifestDataset; 61 friend class MindDataDataset; 62 friend class MnistDataset; 63 friend class OmniglotDataset; 64 friend class PhotoTourDataset; 65 friend class Places365Dataset; 66 friend class QMnistDataset; 67 friend class RandomDataDataset; 68 friend class RenderedSST2Dataset; 69 friend class SBUDataset; 70 friend class SemeionDataset; 71 friend class SpeechCommandsDataset; 72 friend class SST2Dataset; 73 friend class STL10Dataset; 74 friend class SUN397Dataset; 75 friend class TedliumDataset; 76 friend class TextFileDataset; 77 friend class TFRecordDataset; 78 friend class USPSDataset; 79 friend class VOCDataset; 80 friend class WIDERFaceDataset; 81 friend class YesNoDataset; 82 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, 83 int32_t shard_id); 84 85 public: 86 /// \brief Constructor 87 Sampler() = default; 88 89 /// \brief Destructor 90 virtual ~Sampler() = default; 91 92 /// \brief A virtual function to add a child sampler. 93 /// \param[in] child The child sampler to be added as a children of this sampler. AddChild(const std::shared_ptr<Sampler> & child)94 virtual void AddChild(const std::shared_ptr<Sampler> &child) { children_.push_back(child); } 95 96 protected: 97 /// \brief Pure virtual function to convert a Sampler class into an IR Sampler object. 98 /// \return shared pointer to the newly created TensorOperation. 99 virtual std::shared_ptr<SamplerObj> Parse() const = 0; 100 101 /// \brief A function that calls Parse() on the children of this sampler 102 /// \param[in] sampler The samplerIR object built from this sampler 103 /// \return the Status code returned 104 Status BuildChildren(std::shared_ptr<SamplerObj> *const sampler) const; 105 106 std::vector<std::shared_ptr<Sampler>> children_; 107 }; 108 109 /// \brief A class to represent a Distributed Sampler in the data pipeline. 110 /// \note A Sampler that accesses a shard of the dataset. 111 class DATASET_API DistributedSampler final : public Sampler { 112 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, 113 int32_t shard_id); 114 115 public: 116 /// \brief Constructor 117 /// \param[in] num_shards Number of shards to divide the dataset into. 118 /// \param[in] shard_id Shard ID of the current shard within num_shards. 119 /// \param[in] shuffle If true, the indices are shuffled (default=true). 120 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 121 /// \param[in] seed The seed in use when shuffle is true (default=1). 122 /// \param[in] offset The starting position where access to elements in the dataset begins (default=-1). 123 /// \param[in] even_dist If true, each shard would return the same number of rows (default=true). 124 /// If false the total rows returned by all the shards would not have overlap. 125 /// \par Example 126 /// \code 127 /// /* creates a distributed sampler with 2 shards in total. This shard is shard 0 */ 128 /// std::string file_path = "/path/to/test.mindrecord"; 129 /// std::shared_ptr<Dataset> ds = MindData(file_path, {}, std::make_shared<DistributedSampler>(2, 0, false)); 130 /// \endcode 131 DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0, 132 uint32_t seed = 1, int64_t offset = -1, bool even_dist = true); 133 /// \brief Destructor. 134 ~DistributedSampler() override = default; 135 136 protected: 137 /// \brief The function to convert a Sampler into an IR SamplerObj. 138 /// \return shared pointer to the newly created SamplerObj. 139 std::shared_ptr<SamplerObj> Parse() const override; 140 141 private: 142 int64_t num_shards_; 143 int64_t shard_id_; 144 bool shuffle_; 145 int64_t num_samples_; 146 uint32_t seed_; 147 int64_t offset_; 148 bool even_dist_; 149 }; 150 151 /// \brief A class to represent a PK Sampler in the data pipeline. 152 /// \note Samples K elements for each P class in the dataset. 153 /// This will sample all classes. 154 class DATASET_API PKSampler final : public Sampler { 155 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, 156 int32_t shard_id); 157 158 public: 159 /// \brief Constructor 160 /// \param[in] num_val Number of elements to sample for each class. 161 /// \param[in] shuffle If true, the class IDs are shuffled (default=false). 162 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 163 /// \par Example 164 /// \code 165 /// /* creates a PKSampler that will get 3 samples from every class. */ 166 /// std::string folder_path = "/path/to/image/folder"; 167 /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<PKSampler>(3)); 168 /// \endcode 169 explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0); 170 171 /// \brief Destructor. 172 ~PKSampler() override = default; 173 174 protected: 175 /// \brief The function to convert a Sampler into an IR SamplerObj. 176 /// \return shared pointer to the newly created SamplerObj. 177 std::shared_ptr<SamplerObj> Parse() const override; 178 179 private: 180 int64_t num_val_; 181 bool shuffle_; 182 int64_t num_samples_; 183 }; 184 185 /// \brief A class to represent a Random Sampler in the data pipeline. 186 /// \note Samples the elements randomly. 187 class DATASET_API RandomSampler final : public Sampler { 188 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, 189 int32_t shard_id); 190 191 public: 192 /// \brief Constructor 193 /// \param[in] replacement If true, put the sample ID back for the next draw (default=false). 194 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 195 /// \par Example 196 /// \code 197 /// /* creates a RandomSampler that will get 10 samples randomly */ 198 /// std::string folder_path = "/path/to/image/folder"; 199 /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10)); 200 /// \endcode 201 explicit RandomSampler(bool replacement = false, int64_t num_samples = 0); 202 203 /// \brief Destructor. 204 ~RandomSampler() override = default; 205 206 protected: 207 /// \brief The function to convert a Sampler into an IR SamplerObj. 208 /// \return shared pointer to the newly created SamplerObj. 209 std::shared_ptr<SamplerObj> Parse() const override; 210 211 private: 212 bool replacement_; 213 int64_t num_samples_; 214 }; 215 216 /// \brief A class to represent a Sequential Sampler in the data pipeline. 217 /// \note Samples the dataset elements sequentially, same as not having a sampler. 218 class DATASET_API SequentialSampler final : public Sampler { 219 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, 220 int32_t shard_id); 221 222 public: 223 /// \brief Constructor 224 /// \param[in] start_index Index to start sampling at (default=0, start at first id). 225 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 226 /// \par Example 227 /// \code 228 /// /* creates a SequentialSampler that will get 2 samples sequentially in original dataset */ 229 /// std::string folder_path = "/path/to/image/folder"; 230 /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(0, 2)); 231 /// \endcode 232 explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0); 233 234 /// \brief Destructor. 235 ~SequentialSampler() override = default; 236 237 protected: 238 /// \brief The function to convert a Sampler into an IR SamplerObj. 239 /// \return shared pointer to the newly created SamplerObj. 240 std::shared_ptr<SamplerObj> Parse() const override; 241 242 private: 243 int64_t start_index_; 244 int64_t num_samples_; 245 }; 246 247 /// \brief A class to represent a Subset Sampler in the data pipeline. 248 /// \note Samples the elements from a sequence of indices. 249 class DATASET_API SubsetSampler : public Sampler { 250 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, 251 int32_t shard_id); 252 253 public: 254 /// \brief Constructor 255 /// \param[in] indices A vector sequence of indices. 256 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 257 /// \par Example 258 /// \code 259 /// /* creates a SubsetSampler, will sample from the provided indices */ 260 /// std::string folder_path = "/path/to/image/folder"; 261 /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetSampler>({0, 2, 5})); 262 /// \endcode 263 explicit SubsetSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0); 264 265 /// \brief Destructor. 266 ~SubsetSampler() override = default; 267 268 protected: 269 /// \brief The function to convert a Sampler into an IR SamplerObj. 270 /// \return shared pointer to the newly created SamplerObj. 271 std::shared_ptr<SamplerObj> Parse() const override; 272 273 std::vector<int64_t> indices_; 274 int64_t num_samples_; 275 }; 276 277 /// \brief A class to represent a Subset Random Sampler in the data pipeline. 278 /// \note Samples the elements randomly from a sequence of indices. 279 class DATASET_API SubsetRandomSampler final : public SubsetSampler { 280 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, 281 int32_t shard_id); 282 283 public: 284 /// \brief Constructor 285 /// \param[in] indices A vector sequence of indices. 286 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 287 /// \par Example 288 /// \code 289 /// /* create a SubsetRandomSampler, will sample from the provided indices */ 290 /// std::string folder_path = "/path/to/image/folder"; 291 /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>({2, 7})); 292 /// \endcode 293 explicit SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0); 294 295 /// \brief Destructor. 296 ~SubsetRandomSampler() override = default; 297 298 protected: 299 /// \brief The function to convert a Sampler into an IR SamplerObj. 300 /// \return shared pointer to the newly created SamplerObj. 301 std::shared_ptr<SamplerObj> Parse() const override; 302 }; 303 304 /// \brief A class to represent a Weighted Random Sampler in the data pipeline. 305 /// \note Samples the elements from [0, len(weights) - 1] randomly with the given 306 /// weights (probabilities). 307 class DATASET_API WeightedRandomSampler final : public Sampler { 308 friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, 309 int32_t shard_id); 310 311 public: 312 /// \brief Constructor 313 /// \param[in] weights A vector sequence of weights, not necessarily summing up to 1. 314 /// \param[in] num_samples The number of samples to draw (default=0, return all samples). 315 /// \param[in] replacement If true, put the sample ID back for the next draw (default=true). 316 /// \par Example 317 /// \code 318 /// /* creates a WeightedRandomSampler that will sample 4 elements without replacement */ 319 /// std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; 320 /// sampler = std::make_shared<WeightedRandomSampler>(weights, 4); 321 /// std::string folder_path = "/path/to/image/folder"; 322 /// std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler); 323 /// \endcode 324 explicit WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples = 0, bool replacement = true); 325 326 /// \brief Destructor. 327 ~WeightedRandomSampler() override = default; 328 329 protected: 330 /// \brief The function to convert a Sampler into an IR SamplerObj. 331 /// \return shared pointer to the newly created SamplerObj. 332 std::shared_ptr<SamplerObj> Parse() const override; 333 334 private: 335 std::vector<double> weights_; 336 int64_t num_samples_; 337 bool replacement_; 338 }; 339 } // namespace dataset 340 } // namespace mindspore 341 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_ 342