• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
19 
20 #include <memory>
21 #include <vector>
22 
23 #include "include/api/types.h"
24 #include "include/api/status.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 // Forward declare
29 class SamplerObj;
30 
31 // Abstract class to represent a sampler in the data pipeline.
32 /// \class Sampler samplers.h
33 /// \brief An abstract base class to represent a sampler in the data pipeline.
34 class DATASET_API Sampler : std::enable_shared_from_this<Sampler> {
35   friend class AlbumDataset;
36   friend class Caltech256Dataset;
37   friend class CelebADataset;
38   friend class Cifar10Dataset;
39   friend class Cifar100Dataset;
40   friend class CityscapesDataset;
41   friend class CLUEDataset;
42   friend class CMUArcticDataset;
43   friend class CocoDataset;
44   friend class CSVDataset;
45   friend class DIV2KDataset;
46   friend class EMnistDataset;
47   friend class FakeImageDataset;
48   friend class FashionMnistDataset;
49   friend class FlickrDataset;
50   friend class Food101Dataset;
51   friend class GTZANDataset;
52   friend class ImageFolderDataset;
53   friend class IMDBDataset;
54   friend class KITTIDataset;
55   friend class KMnistDataset;
56   friend class LFWDataset;
57   friend class LibriTTSDataset;
58   friend class LJSpeechDataset;
59   friend class LSUNDataset;
60   friend class ManifestDataset;
61   friend class MindDataDataset;
62   friend class MnistDataset;
63   friend class OmniglotDataset;
64   friend class PhotoTourDataset;
65   friend class Places365Dataset;
66   friend class QMnistDataset;
67   friend class RandomDataDataset;
68   friend class RenderedSST2Dataset;
69   friend class SBUDataset;
70   friend class SemeionDataset;
71   friend class SpeechCommandsDataset;
72   friend class SST2Dataset;
73   friend class STL10Dataset;
74   friend class SUN397Dataset;
75   friend class TedliumDataset;
76   friend class TextFileDataset;
77   friend class TFRecordDataset;
78   friend class USPSDataset;
79   friend class VOCDataset;
80   friend class WIDERFaceDataset;
81   friend class YesNoDataset;
82   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
83                                                    int32_t shard_id);
84 
85  public:
86   /// \brief Constructor
87   Sampler() = default;
88 
89   /// \brief Destructor
90   virtual ~Sampler() = default;
91 
92   /// \brief A virtual function to add a child sampler.
93   /// \param[in] child The child sampler to be added as a children of this sampler.
AddChild(const std::shared_ptr<Sampler> & child)94   virtual void AddChild(const std::shared_ptr<Sampler> &child) { children_.push_back(child); }
95 
96  protected:
97   /// \brief Pure virtual function to convert a Sampler class into an IR Sampler object.
98   /// \return shared pointer to the newly created TensorOperation.
99   virtual std::shared_ptr<SamplerObj> Parse() const = 0;
100 
101   /// \brief A function that calls Parse() on the children of this sampler
102   /// \param[in] sampler The samplerIR object built from this sampler
103   /// \return the Status code returned
104   Status BuildChildren(std::shared_ptr<SamplerObj> *const sampler) const;
105 
106   std::vector<std::shared_ptr<Sampler>> children_;
107 };
108 
109 /// \brief A class to represent a Distributed Sampler in the data pipeline.
110 /// \note A Sampler that accesses a shard of the dataset.
111 class DATASET_API DistributedSampler final : public Sampler {
112   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
113                                                    int32_t shard_id);
114 
115  public:
116   /// \brief Constructor
117   /// \param[in] num_shards Number of shards to divide the dataset into.
118   /// \param[in] shard_id Shard ID of the current shard within num_shards.
119   /// \param[in] shuffle If true, the indices are shuffled (default=true).
120   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
121   /// \param[in] seed The seed in use when shuffle is true (default=1).
122   /// \param[in] offset The starting position where access to elements in the dataset begins (default=-1).
123   /// \param[in] even_dist If true, each shard would return the same number of rows (default=true).
124   ///     If false the total rows returned by all the shards would not have overlap.
125   /// \par Example
126   /// \code
127   ///      /* creates a distributed sampler with 2 shards in total. This shard is shard 0 */
128   ///      std::string file_path = "/path/to/test.mindrecord";
129   ///      std::shared_ptr<Dataset> ds = MindData(file_path, {}, std::make_shared<DistributedSampler>(2, 0, false));
130   /// \endcode
131   DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
132                      uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
133   /// \brief Destructor.
134   ~DistributedSampler() override = default;
135 
136  protected:
137   /// \brief The function to convert a Sampler into an IR SamplerObj.
138   /// \return shared pointer to the newly created SamplerObj.
139   std::shared_ptr<SamplerObj> Parse() const override;
140 
141  private:
142   int64_t num_shards_;
143   int64_t shard_id_;
144   bool shuffle_;
145   int64_t num_samples_;
146   uint32_t seed_;
147   int64_t offset_;
148   bool even_dist_;
149 };
150 
151 /// \brief A class to represent a PK Sampler in the data pipeline.
152 /// \note Samples K elements for each P class in the dataset.
153 ///        This will sample all classes.
154 class DATASET_API PKSampler final : public Sampler {
155   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
156                                                    int32_t shard_id);
157 
158  public:
159   /// \brief Constructor
160   /// \param[in] num_val Number of elements to sample for each class.
161   /// \param[in] shuffle If true, the class IDs are shuffled (default=false).
162   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
163   /// \par Example
164   /// \code
165   ///      /* creates a PKSampler that will get 3 samples from every class. */
166   ///      std::string folder_path = "/path/to/image/folder";
167   ///      std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<PKSampler>(3));
168   /// \endcode
169   explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);
170 
171   /// \brief Destructor.
172   ~PKSampler() override = default;
173 
174  protected:
175   /// \brief The function to convert a Sampler into an IR SamplerObj.
176   /// \return shared pointer to the newly created SamplerObj.
177   std::shared_ptr<SamplerObj> Parse() const override;
178 
179  private:
180   int64_t num_val_;
181   bool shuffle_;
182   int64_t num_samples_;
183 };
184 
185 /// \brief A class to represent a Random Sampler in the data pipeline.
186 /// \note Samples the elements randomly.
187 class DATASET_API RandomSampler final : public Sampler {
188   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
189                                                    int32_t shard_id);
190 
191  public:
192   /// \brief Constructor
193   /// \param[in] replacement If true, put the sample ID back for the next draw (default=false).
194   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
195   /// \par Example
196   /// \code
197   ///      /* creates a RandomSampler that will get 10 samples randomly */
198   ///      std::string folder_path = "/path/to/image/folder";
199   ///      std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
200   /// \endcode
201   explicit RandomSampler(bool replacement = false, int64_t num_samples = 0);
202 
203   /// \brief Destructor.
204   ~RandomSampler() override = default;
205 
206  protected:
207   /// \brief The function to convert a Sampler into an IR SamplerObj.
208   /// \return shared pointer to the newly created SamplerObj.
209   std::shared_ptr<SamplerObj> Parse() const override;
210 
211  private:
212   bool replacement_;
213   int64_t num_samples_;
214 };
215 
216 /// \brief A class to represent a Sequential Sampler in the data pipeline.
217 /// \note Samples the dataset elements sequentially, same as not having a sampler.
218 class DATASET_API SequentialSampler final : public Sampler {
219   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
220                                                    int32_t shard_id);
221 
222  public:
223   /// \brief Constructor
224   /// \param[in] start_index Index to start sampling at (default=0, start at first id).
225   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
226   /// \par Example
227   /// \code
228   ///      /* creates a SequentialSampler that will get 2 samples sequentially in original dataset */
229   ///      std::string folder_path = "/path/to/image/folder";
230   ///      std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SequentialSampler>(0, 2));
231   /// \endcode
232   explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);
233 
234   /// \brief Destructor.
235   ~SequentialSampler() override = default;
236 
237  protected:
238   /// \brief The function to convert a Sampler into an IR SamplerObj.
239   /// \return shared pointer to the newly created SamplerObj.
240   std::shared_ptr<SamplerObj> Parse() const override;
241 
242  private:
243   int64_t start_index_;
244   int64_t num_samples_;
245 };
246 
247 /// \brief A class to represent a Subset Sampler in the data pipeline.
248 /// \note Samples the elements from a sequence of indices.
249 class DATASET_API SubsetSampler : public Sampler {
250   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
251                                                    int32_t shard_id);
252 
253  public:
254   /// \brief Constructor
255   /// \param[in] indices A vector sequence of indices.
256   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
257   /// \par Example
258   /// \code
259   ///      /* creates a SubsetSampler, will sample from the provided indices */
260   ///      std::string folder_path = "/path/to/image/folder";
261   ///      std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetSampler>({0, 2, 5}));
262   /// \endcode
263   explicit SubsetSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);
264 
265   /// \brief Destructor.
266   ~SubsetSampler() override = default;
267 
268  protected:
269   /// \brief The function to convert a Sampler into an IR SamplerObj.
270   /// \return shared pointer to the newly created SamplerObj.
271   std::shared_ptr<SamplerObj> Parse() const override;
272 
273   std::vector<int64_t> indices_;
274   int64_t num_samples_;
275 };
276 
277 /// \brief A class to represent a Subset Random Sampler in the data pipeline.
278 /// \note Samples the elements randomly from a sequence of indices.
279 class DATASET_API SubsetRandomSampler final : public SubsetSampler {
280   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
281                                                    int32_t shard_id);
282 
283  public:
284   /// \brief Constructor
285   /// \param[in] indices A vector sequence of indices.
286   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
287   /// \par Example
288   /// \code
289   ///      /* create a SubsetRandomSampler, will sample from the provided indices */
290   ///      std::string folder_path = "/path/to/image/folder";
291   ///      std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, std::make_shared<SubsetRandomSampler>({2, 7}));
292   /// \endcode
293   explicit SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples = 0);
294 
295   /// \brief Destructor.
296   ~SubsetRandomSampler() override = default;
297 
298  protected:
299   /// \brief The function to convert a Sampler into an IR SamplerObj.
300   /// \return shared pointer to the newly created SamplerObj.
301   std::shared_ptr<SamplerObj> Parse() const override;
302 };
303 
304 /// \brief A class to represent a Weighted Random Sampler in the data pipeline.
305 /// \note Samples the elements from [0, len(weights) - 1] randomly with the given
306 ///        weights (probabilities).
307 class DATASET_API WeightedRandomSampler final : public Sampler {
308   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards,
309                                                    int32_t shard_id);
310 
311  public:
312   /// \brief Constructor
313   /// \param[in] weights A vector sequence of weights, not necessarily summing up to 1.
314   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
315   /// \param[in] replacement If true, put the sample ID back for the next draw (default=true).
316   /// \par Example
317   /// \code
318   ///      /* creates a WeightedRandomSampler that will sample 4 elements without replacement */
319   ///      std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1};
320   ///      sampler = std::make_shared<WeightedRandomSampler>(weights, 4);
321   ///      std::string folder_path = "/path/to/image/folder";
322   ///      std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, sampler);
323   /// \endcode
324   explicit WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples = 0, bool replacement = true);
325 
326   /// \brief Destructor.
327   ~WeightedRandomSampler() override = default;
328 
329  protected:
330   /// \brief The function to convert a Sampler into an IR SamplerObj.
331   /// \return shared pointer to the newly created SamplerObj.
332   std::shared_ptr<SamplerObj> Parse() const override;
333 
334  private:
335   std::vector<double> weights_;
336   int64_t num_samples_;
337   bool replacement_;
338 };
339 }  // namespace dataset
340 }  // namespace mindspore
341 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
342