• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TRANSFORMS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TRANSFORMS_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "include/api/dual_abi_helper.h"
26 #include "include/api/status.h"
27 #include "include/api/types.h"
28 #include "include/dataset/constants.h"
29 
30 namespace mindspore {
31 namespace dataset {
32 
33 class TensorOperation;
34 
35 // We need the following two groups of forward declaration to friend the class in class TensorTransform.
36 namespace transforms {
37 class Compose;
38 class RandomApply;
39 class RandomChoice;
40 }  // namespace transforms
41 
42 namespace vision {
43 class BoundingBoxAugment;
44 class RandomSelectSubpolicy;
45 class UniformAugment;
46 }  // namespace vision
47 
48 // Abstract class to represent a tensor transform operation in the data pipeline.
49 /// \class TensorTransform transforms.h
50 /// \brief A base class to represent a tensor transform operation in the data pipeline.
51 class TensorTransform : public std::enable_shared_from_this<TensorTransform> {
52   friend class Dataset;
53   friend class Execute;
54   friend class transforms::Compose;
55   friend class transforms::RandomApply;
56   friend class transforms::RandomChoice;
57   friend class vision::BoundingBoxAugment;
58   friend class vision::RandomSelectSubpolicy;
59   friend class vision::UniformAugment;
60 
61  public:
62   /// \brief Constructor
TensorTransform()63   TensorTransform() {}
64 
65   /// \brief Destructor
66   ~TensorTransform() = default;
67 
68  protected:
69   /// \brief Pure virtual function to convert a TensorTransform class into a IR TensorOperation object.
70   /// \return shared pointer to the newly created TensorOperation.
71   virtual std::shared_ptr<TensorOperation> Parse() = 0;
72 
73   /// \brief Virtual function to convert a TensorTransform class into a IR TensorOperation object.
74   /// \param[in] env A string to determine the running environment
75   /// \return shared pointer to the newly created TensorOperation.
Parse(const MapTargetDevice & env)76   virtual std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) { return nullptr; }
77 };
78 
79 /// \brief Slice object used in SliceOption.
80 class Slice {
81  public:
82   /// \brief Constructor, with start, stop and step default to 0.
Slice()83   Slice() : start_(0), stop_(0), step_(0) {}
84   /// \brief Constructor.
85   /// \param[in] start Starting integer specifying where to start the slicing.
86   /// \param[in] stop Ending integer specifying where to stop the slicing.
87   /// \param[in] step An integer specifying the step of the slicing.
Slice(dsize_t start,dsize_t stop,dsize_t step)88   Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {}
89   /// \brief Constructor, with step=1
90   /// \param[in] start Starting integer specifying where to start the slicing.
91   /// \param[in] stop Ending integer specifying where to stop the slicing.
Slice(dsize_t start,dsize_t stop)92   Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {}
93   /// \brief Constructor, with start=0 and step=1
94   /// \param[in] stop Ending integer specifying where to stop the slicing.
Slice(dsize_t stop)95   explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {}
96   Slice(Slice const &slice) = default;
97 
98   ~Slice() = default;
99 
valid()100   bool valid() const { return step_ != 0; }
101   dsize_t start_;
102   dsize_t stop_;
103   dsize_t step_;
104 };
105 
106 /// \brief SliceOption used in Slice TensorTransform.
107 class SliceOption {
108  public:
109   /// \param[in] all Slice the whole dimension
SliceOption(bool all)110   explicit SliceOption(bool all) : all_(all) {}
111   /// \param[in] indices Slice these indices along the dimension. Negative indices are supported.
SliceOption(std::vector<dsize_t> indices)112   explicit SliceOption(std::vector<dsize_t> indices) : indices_(indices) {}
113   /// \param[in] slice Slice the generated indices from the slice object along the dimension.
SliceOption(Slice slice)114   explicit SliceOption(Slice slice) : slice_(slice) {}
115   SliceOption(SliceOption const &slice) = default;
116 
117   ~SliceOption() = default;
118 
119   // only one of the following will be valid
120   // given indices to slice the Tensor.
121   std::vector<dsize_t> indices_ = {};
122   // Slice object. All start, stop and step are 0 if invalid.
123   Slice slice_;
124   bool all_ = false;
125 };
126 
127 // Transform operations for performing data transformation.
128 namespace transforms {
129 
130 /// \brief Compose a list of transforms into a single transform.
131 class Compose final : public TensorTransform {
132  public:
133   /// \brief Constructor.
134   /// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
135   explicit Compose(const std::vector<TensorTransform *> &transforms);
136   /// \brief Constructor.
137   /// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
138   explicit Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
139   /// \brief Constructor.
140   /// \param[in] transforms A vector of TensorTransform objects to be applied.
141   explicit Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
142 
143   /// \brief Destructor
144   ~Compose() = default;
145 
146  protected:
147   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
148   /// \return Shared pointer to TensorOperation object.
149   std::shared_ptr<TensorOperation> Parse() override;
150 
151  private:
152   struct Data;
153   std::shared_ptr<Data> data_;
154 };
155 
156 /// \brief Concatenate all tensors into a single tensor.
157 class Concatenate final : public TensorTransform {
158  public:
159   /// \brief Constructor.
160   /// \param[in] axis Concatenate the tensors along given axis, only support 0 or -1 so far (default=0).
161   /// \param[in] prepend MSTensor to be prepended to the concatenated tensors (default={}).
162   /// \param[in] append MSTensor to be appended to the concatenated tensors (default={}).
163   explicit Concatenate(int8_t axis = 0, const MSTensor &prepend = {}, const MSTensor &append = {});
164 
165   /// \brief Destructor
166   ~Concatenate() = default;
167 
168  protected:
169   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
170   /// \return Shared pointer to TensorOperation object.
171   std::shared_ptr<TensorOperation> Parse() override;
172 
173  private:
174   struct Data;
175   std::shared_ptr<Data> data_;
176 };
177 
178 /// \brief Duplicate the input tensor to a new output tensor.
179 ///     The input tensor is carried over to the output list.
180 class Duplicate final : public TensorTransform {
181  public:
182   /// \brief Constructor.
183   Duplicate();
184 
185   /// \brief Destructor
186   ~Duplicate() = default;
187 
188  protected:
189   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
190   /// \return Shared pointer to TensorOperation object.
191   std::shared_ptr<TensorOperation> Parse() override;
192 };
193 
194 /// \brief Fill all elements in the tensor with the specified value.
195 ///    The output tensor will have the same shape and type as the input tensor.
196 class Fill final : public TensorTransform {
197  public:
198   /// \brief Constructor.
199   /// \param[in] fill_value Scalar value to fill the tensor with.
200   ///               It can only be MSTensor of the following types from mindspore::DataType:
201   ///               String, Bool, Int8/16/32/64, UInt8/16/32/64, Float16/32/64.
202   explicit Fill(const MSTensor &fill_value);
203 
204   /// \brief Destructor
205   ~Fill() = default;
206 
207  protected:
208   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
209   /// \return Shared pointer to TensorOperation object.
210   std::shared_ptr<TensorOperation> Parse() override;
211 
212  private:
213   struct Data;
214   std::shared_ptr<Data> data_;
215 };
216 
217 /// \brief Mask content of the input tensor with the given predicate.
218 ///     Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.
219 class Mask final : public TensorTransform {
220  public:
221   /// \brief Constructor.
222   /// \param[in] op One of the relational operators: EQ, NE LT, GT, LE or GE.
223   /// \param[in] constant Constant to be compared to. It can only be MSTensor of the following types
224   ///                from mindspore::DataType: String, Int, Float, Bool.
225   /// \param[in] de_type Type of the generated mask. It can only be numeric or boolean datatype.
226   ///               (default=mindspore::DataType::kNumberTypeBool)
227   explicit Mask(RelationalOp op, const MSTensor &constant,
228                 mindspore::DataType ms_type = mindspore::DataType(mindspore::DataType::kNumberTypeBool));
229 
230   /// \brief Destructor
231   ~Mask() = default;
232 
233  protected:
234   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
235   /// \return Shared pointer to TensorOperation object.
236   std::shared_ptr<TensorOperation> Parse() override;
237 
238  private:
239   struct Data;
240   std::shared_ptr<Data> data_;
241 };
242 
243 /// \brief Convert the labels into OneHot format.
244 class OneHot final : public TensorTransform {
245  public:
246   /// \brief Constructor.
247   /// \param[in] num_classes number of classes.
248   explicit OneHot(int32_t num_classes);
249 
250   /// \brief Destructor
251   ~OneHot() = default;
252 
253  protected:
254   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
255   /// \return Shared pointer to TensorOperation object.
256   std::shared_ptr<TensorOperation> Parse() override;
257 
258  private:
259   struct Data;
260   std::shared_ptr<Data> data_;
261 };
262 
263 /// \brief Pad input tensor according to pad_shape
264 class PadEnd final : public TensorTransform {
265  public:
266   /// \brief Constructor.
267   /// \param[in] pad_shape List of integers representing the shape needed, need to have same rank with input tensor.
268   ///               Dimensions that set to `-1` will not be padded (i.e., original dim will be used).
269   ///               Shorter dimensions will truncate the values.
270   /// \param[in] pad_value Value used to pad (default={}).
271   explicit PadEnd(const std::vector<dsize_t> &pad_shape, const MSTensor &pad_value = {});
272 
273   /// \brief Destructor
274   ~PadEnd() = default;
275 
276  protected:
277   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
278   /// \return Shared pointer to TensorOperation object.
279   std::shared_ptr<TensorOperation> Parse() override;
280 
281  private:
282   struct Data;
283   std::shared_ptr<Data> data_;
284 };
285 
286 /// \brief Randomly perform a series of transforms with a given probability.
287 class RandomApply final : public TensorTransform {
288  public:
289   /// \brief Constructor.
290   /// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
291   /// \param[in] prob The probability to apply the transformation list (default=0.5).
292   explicit RandomApply(const std::vector<TensorTransform *> &transforms, double prob = 0.5);
293   /// \brief Constructor.
294   /// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
295   /// \param[in] prob The probability to apply the transformation list (default=0.5).
296   explicit RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob = 0.5);
297   /// \brief Constructor.
298   /// \param[in] transforms A vector of TensorTransform objects to be applied.
299   /// \param[in] prob The probability to apply the transformation list (default=0.5).
300   explicit RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob = 0.5);
301 
302   /// \brief Destructor
303   ~RandomApply() = default;
304 
305  protected:
306   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
307   /// \return Shared pointer to TensorOperation object.
308   std::shared_ptr<TensorOperation> Parse() override;
309 
310  private:
311   struct Data;
312   std::shared_ptr<Data> data_;
313 };
314 
315 /// \brief Randomly select one transform from a list of transforms to perform on the input tensor.
316 class RandomChoice final : public TensorTransform {
317  public:
318   /// \brief Constructor.
319   /// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
320   explicit RandomChoice(const std::vector<TensorTransform *> &transforms);
321   /// \brief Constructor.
322   /// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
323   explicit RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
324   /// \brief Constructor.
325   /// \param[in] transforms A vector of TensorTransform objects to be applied.
326   explicit RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
327 
328   /// \brief Destructor
329   ~RandomChoice() = default;
330 
331  protected:
332   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
333   /// \return Shared pointer to TensorOperation object.
334   std::shared_ptr<TensorOperation> Parse() override;
335 
336  private:
337   struct Data;
338   std::shared_ptr<Data> data_;
339 };
340 
341 /// \brief Extract a tensor out using the given n slices.
342 ///     The functionality of Slice is similar to the feature of indexing of NumPy.
343 ///     (Currently only rank-1 tensors are supported).
344 class Slice final : public TensorTransform {
345  public:
346   /// \brief Constructor.
347   /// \param[in] slice_input Vector of SliceOption
348   explicit Slice(const std::vector<SliceOption> &slice_input);
349 
350   /// \brief Destructor
351   ~Slice() = default;
352 
353  protected:
354   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
355   /// \return Shared pointer to TensorOperation object.
356   std::shared_ptr<TensorOperation> Parse() override;
357 
358  private:
359   struct Data;
360   std::shared_ptr<Data> data_;
361 };
362 
363 /// \brief Cast the MindSpore data type of a tensor to another.
364 class TypeCast final : public TensorTransform {
365  public:
366   /// \brief Constructor.
367   /// \param[in] data_type mindspore::DataType to be cast to.
368   explicit TypeCast(mindspore::DataType data_type);
369 
370   /// \brief Destructor
371   ~TypeCast() = default;
372 
373  protected:
374   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
375   /// \return Shared pointer to TensorOperation object.
376   std::shared_ptr<TensorOperation> Parse() override;
377 
378  private:
379   struct Data;
380   std::shared_ptr<Data> data_;
381 };
382 
383 /// \brief Return an output tensor that contains all the unique elements of the input tensor in
384 ///     the same order as they appear in the input tensor.
385 class Unique final : public TensorTransform {
386  public:
387   /// \brief Constructor.
388   Unique();
389 
390   /// \brief Destructor
391   ~Unique() = default;
392 
393  protected:
394   /// \brief The function to convert a TensorTransform object into a TensorOperation object.
395   /// \return Shared pointer to TensorOperation object.
396   std::shared_ptr<TensorOperation> Parse() override;
397 };
398 }  // namespace transforms
399 }  // namespace dataset
400 }  // namespace mindspore
401 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TRANSFORMS_H_
402