• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
19 
20 #include <memory>
21 #include <vector>
22 
23 namespace mindspore {
24 namespace dataset {
25 
26 // Forward declare
27 class SamplerObj;
28 
29 // Abstract class to represent a sampler in the data pipeline.
30 /// \class Sampler samplers.h
31 /// \brief An abstract base class to represent a sampler in the data pipeline.
32 class Sampler : std::enable_shared_from_this<Sampler> {
33   friend class AlbumDataset;
34   friend class CelebADataset;
35   friend class Cifar10Dataset;
36   friend class Cifar100Dataset;
37   friend class CityscapesDataset;
38   friend class CLUEDataset;
39   friend class CocoDataset;
40   friend class CSVDataset;
41   friend class DIV2KDataset;
42   friend class FlickrDataset;
43   friend class ImageFolderDataset;
44   friend class ManifestDataset;
45   friend class MindDataDataset;
46   friend class MnistDataset;
47   friend class RandomDataDataset;
48   friend class SBUDataset;
49   friend class TextFileDataset;
50   friend class TFRecordDataset;
51   friend class USPSDataset;
52   friend class VOCDataset;
53   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
54 
55  public:
56   /// \brief Constructor
Sampler()57   Sampler() {}
58 
59   /// \brief Destructor
60   ~Sampler() = default;
61 
62   /// \brief A virtual function to add a child sampler.
63   /// \param[in] child The child sampler to be added as a children of this sampler.
AddChild(std::shared_ptr<Sampler> child)64   virtual void AddChild(std::shared_ptr<Sampler> child) { children_.push_back(child); }
65 
66  protected:
67   /// \brief Pure virtual function to convert a Sampler class into an IR Sampler object.
68   /// \return shared pointer to the newly created TensorOperation.
69   virtual std::shared_ptr<SamplerObj> Parse() const = 0;
70 
71   std::vector<std::shared_ptr<Sampler>> children_;
72 };
73 
74 /// \brief A class to represent a Distributed Sampler in the data pipeline.
75 /// \note A Sampler that accesses a shard of the dataset.
76 class DistributedSampler final : public Sampler {
77   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
78 
79  public:
80   /// \brief Constructor
81   /// \param[in] num_shards Number of shards to divide the dataset into.
82   /// \param[in] shard_id Shard ID of the current shard within num_shards.
83   /// \param[in] shuffle If true, the indices are shuffled (default=true).
84   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
85   /// \param[in] seed The seed in use when shuffle is true (default=1).
86   /// \param[in] offset The starting position where access to elements in the dataset begins (default=-1).
87   /// \param[in] even_dist If true, each shard would return the same number of rows (default=true).
88   ///     If false the total rows returned by all the shards would not have overlap.
89   DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
90                      uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
91   /// \brief Destructor.
92   ~DistributedSampler() = default;
93 
94  protected:
95   /// \brief The function to convert a Sampler into an IR SamplerObj.
96   /// \return shared pointer to the newly created SamplerObj.
97   std::shared_ptr<SamplerObj> Parse() const override;
98 
99  private:
100   int64_t num_shards_;
101   int64_t shard_id_;
102   bool shuffle_;
103   int64_t num_samples_;
104   uint32_t seed_;
105   int64_t offset_;
106   bool even_dist_;
107 };
108 
109 /// \brief A class to represent a PK Sampler in the data pipeline.
110 /// \note Samples K elements for each P class in the dataset.
111 ///        This will sample all classes.
112 class PKSampler final : public Sampler {
113   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
114 
115  public:
116   /// \brief Constructor
117   /// \param[in] num_val Number of elements to sample for each class.
118   /// \param[in] shuffle If true, the class IDs are shuffled (default=false).
119   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
120   explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);
121 
122   /// \brief Destructor.
123   ~PKSampler() = default;
124 
125  protected:
126   /// \brief The function to convert a Sampler into an IR SamplerObj.
127   /// \return shared pointer to the newly created SamplerObj.
128   std::shared_ptr<SamplerObj> Parse() const override;
129 
130  private:
131   int64_t num_val_;
132   bool shuffle_;
133   int64_t num_samples_;
134 };
135 
136 /// \brief A class to represent a Random Sampler in the data pipeline.
137 /// \note Samples the elements randomly.
138 class RandomSampler final : public Sampler {
139   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
140 
141  public:
142   /// \brief Constructor
143   /// \param[in] replacement If true, put the sample ID back for the next draw (default=false).
144   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
145   explicit RandomSampler(bool replacement = false, int64_t num_samples = 0);
146 
147   /// \brief Destructor.
148   ~RandomSampler() = default;
149 
150  protected:
151   /// \brief The function to convert a Sampler into an IR SamplerObj.
152   /// \return shared pointer to the newly created SamplerObj.
153   std::shared_ptr<SamplerObj> Parse() const override;
154 
155  private:
156   bool replacement_;
157   int64_t num_samples_;
158 };
159 
160 /// \brief A class to represent a Sequential Sampler in the data pipeline.
161 /// \note Samples the dataset elements sequentially, same as not having a sampler.
162 class SequentialSampler final : public Sampler {
163   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
164 
165  public:
166   /// \brief Constructor
167   /// \param[in] start_index Index to start sampling at (default=0, start at first id).
168   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
169   explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);
170 
171   /// \brief Destructor.
172   ~SequentialSampler() = default;
173 
174  protected:
175   /// \brief The function to convert a Sampler into an IR SamplerObj.
176   /// \return shared pointer to the newly created SamplerObj.
177   std::shared_ptr<SamplerObj> Parse() const override;
178 
179  private:
180   int64_t start_index_;
181   int64_t num_samples_;
182 };
183 
184 /// \brief A class to represent a Subset Sampler in the data pipeline.
185 /// \note Samples the elements from a sequence of indices.
186 class SubsetSampler : public Sampler {
187   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
188 
189  public:
190   /// \brief Constructor
191   /// \param[in] indices A vector sequence of indices.
192   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
193   explicit SubsetSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
194 
195   /// \brief Destructor.
196   ~SubsetSampler() = default;
197 
198  protected:
199   /// \brief The function to convert a Sampler into an IR SamplerObj.
200   /// \return shared pointer to the newly created SamplerObj.
201   std::shared_ptr<SamplerObj> Parse() const override;
202 
203   std::vector<int64_t> indices_;
204   int64_t num_samples_;
205 };
206 
207 /// \brief A class to represent a Subset Random Sampler in the data pipeline.
208 /// \note Samples the elements randomly from a sequence of indices.
209 class SubsetRandomSampler final : public SubsetSampler {
210   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
211 
212  public:
213   /// \brief Constructor
214   /// \param[in] indices A vector sequence of indices.
215   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
216   explicit SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
217 
218   /// \brief Destructor.
219   ~SubsetRandomSampler() = default;
220 
221  protected:
222   /// \brief The function to convert a Sampler into an IR SamplerObj.
223   /// \return shared pointer to the newly created SamplerObj.
224   std::shared_ptr<SamplerObj> Parse() const override;
225 };
226 
227 /// \brief A class to represent a Weighted Random Sampler in the data pipeline.
228 /// \note Samples the elements from [0, len(weights) - 1] randomly with the given
229 ///        weights (probabilities).
230 class WeightedRandomSampler final : public Sampler {
231   friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
232 
233  public:
234   /// \brief Constructor
235   /// \param[in] weights A vector sequence of weights, not necessarily summing up to 1.
236   /// \param[in] num_samples The number of samples to draw (default=0, return all samples).
237   /// \param[in] replacement If true, put the sample ID back for the next draw (default=true).
238   explicit WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
239 
240   /// \brief Destructor.
241   ~WeightedRandomSampler() = default;
242 
243  protected:
244   /// \brief The function to convert a Sampler into an IR SamplerObj.
245   /// \return shared pointer to the newly created SamplerObj.
246   std::shared_ptr<SamplerObj> Parse() const override;
247 
248  private:
249   std::vector<double> weights_;
250   int64_t num_samples_;
251   bool replacement_;
252 };
253 
254 }  // namespace dataset
255 }  // namespace mindspore
256 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_SAMPLERS_H_
257