• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/include/dataset/transforms.h"
18 
19 #include <algorithm>
20 
21 #include "minddata/dataset/core/type_id.h"
22 #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
23 #include "mindspore/core/ir/dtype/type_id.h"
24 
25 namespace mindspore {
26 namespace dataset {
27 // Transform operations for data.
28 namespace transforms {
29 // API CLASS FOR DATA TRANSFORM OPERATIONS
30 // (In alphabetical order)
31 
32 // Constructor to Compose.
33 struct Compose::Data {
34   std::vector<std::shared_ptr<TensorOperation>> transforms_;
35 };
36 
Compose(const std::vector<TensorTransform * > & transforms)37 Compose::Compose(const std::vector<TensorTransform *> &transforms) : data_(std::make_shared<Data>()) {
38   (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
39                        [](TensorTransform *const op) -> std::shared_ptr<TensorOperation> {
40                          return op != nullptr ? op->Parse() : nullptr;
41                        });
42 }
43 
Compose(const std::vector<std::shared_ptr<TensorTransform>> & transforms)44 Compose::Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms) : data_(std::make_shared<Data>()) {
45   (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
46                        [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
47                          return op != nullptr ? op->Parse() : nullptr;
48                        });
49 }
50 
Compose(const std::vector<std::reference_wrapper<TensorTransform>> & transforms)51 Compose::Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms)
52     : data_(std::make_shared<Data>()) {
53   (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
54                        [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
55 }
56 
Parse()57 std::shared_ptr<TensorOperation> Compose::Parse() { return std::make_shared<ComposeOperation>(data_->transforms_); }
58 
59 // Constructor to Concatenate
60 struct Concatenate::Data {
Datamindspore::dataset::transforms::Concatenate::Data61   explicit Data(int8_t axis, const MSTensor &prepend, const MSTensor &append)
62       : axis_(axis), prepend_(prepend), append_(append) {}
63   int8_t axis_;
64   MSTensor prepend_;
65   MSTensor append_;
66 };
67 
Concatenate(int8_t axis,const MSTensor & prepend,const MSTensor & append)68 Concatenate::Concatenate(int8_t axis, const MSTensor &prepend, const MSTensor &append)
69     : data_(std::make_shared<Data>(axis, prepend, append)) {}
70 
Parse()71 std::shared_ptr<TensorOperation> Concatenate::Parse() {
72 #ifndef ENABLE_ANDROID
73   std::shared_ptr<Tensor> out_prepend, out_append;
74   Status rc = Tensor::CreateFromMSTensor(data_->prepend_, &out_prepend);
75   if (rc.IsError()) {
76     MS_LOG(ERROR) << "Error creating prepend constant tensor. " << rc;
77     return nullptr;
78   }
79   rc = Tensor::CreateFromMSTensor(data_->append_, &out_append);
80   if (rc.IsError()) {
81     MS_LOG(ERROR) << "Error creating append constant tensor. " << rc;
82     return nullptr;
83   }
84   return std::make_shared<ConcatenateOperation>(data_->axis_, out_prepend, out_append);
85 #else
86   MS_LOG(ERROR) << "Concatenate op is not supported for Android.";
87   return nullptr;
88 #endif  // not ENABLE_ANDROID
89 }
90 
91 // Constructor to Duplicate
92 Duplicate::Duplicate() = default;
93 
Parse()94 std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<DuplicateOperation>(); }
95 
96 // Constructor to Fill
97 struct Fill::Data {
Datamindspore::dataset::transforms::Fill::Data98   explicit Data(const MSTensor &fill_value) : fill_value_(fill_value) {}
99   MSTensor fill_value_;
100 };
101 
Fill(const MSTensor & fill_value)102 Fill::Fill(const MSTensor &fill_value) : data_(std::make_shared<Data>(fill_value)) {}
103 
Parse()104 std::shared_ptr<TensorOperation> Fill::Parse() {
105 #ifndef ENABLE_ANDROID
106   std::shared_ptr<Tensor> out_fill_value;
107   Status rc = Tensor::CreateFromMSTensor(data_->fill_value_, &out_fill_value);
108   if (rc.IsError()) {
109     MS_LOG(ERROR) << "Error creating fill value tensor. " << rc;
110     return nullptr;
111   }
112   return std::make_shared<FillOperation>(out_fill_value);
113 #else
114   MS_LOG(ERROR) << "Fill op is not supported for Android.";
115   return nullptr;
116 #endif  // not ENABLE_ANDROID
117 }
118 
119 // Constructor to Mask
120 struct Mask::Data {
Datamindspore::dataset::transforms::Mask::Data121   explicit Data(RelationalOp op, const MSTensor &constant, mindspore::DataType ms_type)
122       : op_(op), constant_(constant), ms_type_(ms_type) {}
123   RelationalOp op_;
124   MSTensor constant_;
125   mindspore::DataType ms_type_;
126 };
127 
Mask(RelationalOp op,const MSTensor & constant,mindspore::DataType ms_type)128 Mask::Mask(RelationalOp op, const MSTensor &constant, mindspore::DataType ms_type)
129     : data_(std::make_shared<Data>(op, constant, ms_type)) {}
130 
Parse()131 std::shared_ptr<TensorOperation> Mask::Parse() {
132 #ifndef ENABLE_ANDROID
133   std::shared_ptr<Tensor> out_constant;
134   Status rc = Tensor::CreateFromMSTensor(data_->constant_, &out_constant);
135   if (rc.IsError()) {
136     MS_LOG(ERROR) << "Error creating constant tensor. " << rc;
137     return nullptr;
138   }
139 
140   DataType de_type = dataset::MSTypeToDEType(static_cast<TypeId>(data_->ms_type_));
141   return std::make_shared<MaskOperation>(data_->op_, out_constant, de_type);
142 #else
143   MS_LOG(ERROR) << "Mask op is not supported for Android.";
144   return nullptr;
145 #endif  // not ENABLE_ANDROID
146 }
147 
148 // Constructor to OneHot
149 struct OneHot::Data {
Datamindspore::dataset::transforms::OneHot::Data150   explicit Data(int32_t num_classes, double smoothing_rate)
151       : num_classes_(num_classes), smoothing_rate_(smoothing_rate) {}
152   int32_t num_classes_;
153   double smoothing_rate_;
154 };
155 
OneHot(int32_t num_classes,double smoothing_rate)156 OneHot::OneHot(int32_t num_classes, double smoothing_rate)
157     : data_(std::make_shared<Data>(num_classes, smoothing_rate)) {}
158 
Parse()159 std::shared_ptr<TensorOperation> OneHot::Parse() {
160   return std::make_shared<OneHotOperation>(data_->num_classes_, data_->smoothing_rate_);
161 }
162 
163 // Constructor to PadEnd
164 struct PadEnd::Data {
Datamindspore::dataset::transforms::PadEnd::Data165   explicit Data(const std::vector<dsize_t> &pad_shape, const MSTensor &pad_value)
166       : pad_shape_(pad_shape), pad_value_(pad_value) {}
167   std::vector<dsize_t> pad_shape_;
168   MSTensor pad_value_;
169 };
170 
PadEnd(const std::vector<dsize_t> & pad_shape,const MSTensor & pad_value)171 PadEnd::PadEnd(const std::vector<dsize_t> &pad_shape, const MSTensor &pad_value)
172     : data_(std::make_shared<Data>(pad_shape, pad_value)) {}
173 
Parse()174 std::shared_ptr<TensorOperation> PadEnd::Parse() {
175 #ifndef ENABLE_ANDROID
176   std::shared_ptr<Tensor> pad_value;
177   Status rc = Tensor::CreateFromMSTensor(data_->pad_value_, &pad_value);
178   if (rc.IsError()) {
179     MS_LOG(ERROR) << "Error creating pad_value constant tensor. " << rc;
180     return nullptr;
181   }
182   return std::make_shared<PadEndOperation>(TensorShape(data_->pad_shape_), pad_value);
183 #else
184   MS_LOG(ERROR) << "PadEnd op is not supported for Android.";
185   return nullptr;
186 #endif  // not ENABLE_ANDROID
187 }
188 
189 // Constructor to RandomApply.
190 struct RandomApply::Data {
191   std::vector<std::shared_ptr<TensorOperation>> transforms_;
192   double prob_;
193 };
194 
RandomApply(const std::vector<TensorTransform * > & transforms,double prob)195 RandomApply::RandomApply(const std::vector<TensorTransform *> &transforms, double prob)
196     : data_(std::make_shared<Data>()) {
197   (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
198                        [](TensorTransform *const op) -> std::shared_ptr<TensorOperation> {
199                          return op != nullptr ? op->Parse() : nullptr;
200                        });
201   data_->prob_ = prob;
202 }
203 
RandomApply(const std::vector<std::shared_ptr<TensorTransform>> & transforms,double prob)204 RandomApply::RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob)
205     : data_(std::make_shared<Data>()) {
206   (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
207                        [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
208                          return op != nullptr ? op->Parse() : nullptr;
209                        });
210   data_->prob_ = prob;
211 }
212 
RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> & transforms,double prob)213 RandomApply::RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob)
214     : data_(std::make_shared<Data>()) {
215   (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
216                        [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
217   data_->prob_ = prob;
218 }
219 
Parse()220 std::shared_ptr<TensorOperation> RandomApply::Parse() {
221   return std::make_shared<RandomApplyOperation>(data_->transforms_, data_->prob_);
222 }
223 
224 // Constructor to RandomChoice.
225 struct RandomChoice::Data {
226   std::vector<std::shared_ptr<TensorOperation>> transforms_;
227 };
228 
RandomChoice(const std::vector<TensorTransform * > & transforms)229 RandomChoice::RandomChoice(const std::vector<TensorTransform *> &transforms) : data_(std::make_shared<Data>()) {
230   (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
231                        [](TensorTransform *const op) -> std::shared_ptr<TensorOperation> {
232                          return op != nullptr ? op->Parse() : nullptr;
233                        });
234 }
235 
RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> & transforms)236 RandomChoice::RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms)
237     : data_(std::make_shared<Data>()) {
238   (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
239                        [](const std::shared_ptr<TensorTransform> &op) -> std::shared_ptr<TensorOperation> {
240                          return op != nullptr ? op->Parse() : nullptr;
241                        });
242 }
243 
RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> & transforms)244 RandomChoice::RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms)
245     : data_(std::make_shared<Data>()) {
246   (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
247                        [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
248 }
249 
Parse()250 std::shared_ptr<TensorOperation> RandomChoice::Parse() {
251   return std::make_shared<RandomChoiceOperation>(data_->transforms_);
252 }
253 
254 // Constructor to Slice
255 struct Slice::Data {
Datamindspore::dataset::transforms::Slice::Data256   explicit Data(const std::vector<SliceOption> &slice_input) : slice_input_(slice_input) {}
257   std::vector<SliceOption> slice_input_;
258 };
259 
Slice(const std::vector<SliceOption> & slice_input)260 Slice::Slice(const std::vector<SliceOption> &slice_input) : data_(std::make_shared<Data>(slice_input)) {}
261 
Parse()262 std::shared_ptr<TensorOperation> Slice::Parse() {
263 #ifndef ENABLE_ANDROID
264   return std::make_shared<SliceOperation>(data_->slice_input_);
265 #else
266   MS_LOG(ERROR) << "Slice op is not supported for Android.";
267   return nullptr;
268 #endif  // not ENABLE_ANDROID
269 }
270 
271 // Constructor to TypeCast
272 struct TypeCast::Data {
273   dataset::DataType data_type_;
274 };
275 
TypeCast(mindspore::DataType data_type)276 TypeCast::TypeCast(mindspore::DataType data_type) : data_(std::make_shared<Data>()) {
277   data_->data_type_ = dataset::MSTypeToDEType(static_cast<TypeId>(data_type));
278 }
279 
Parse()280 std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); }
281 
282 // Constructor to Unique
283 Unique::Unique() = default;
284 
285 #ifndef ENABLE_ANDROID
Parse()286 std::shared_ptr<TensorOperation> Unique::Parse() { return std::make_shared<UniqueOperation>(); }
287 #else
Parse()288 std::shared_ptr<TensorOperation> Unique::Parse() {
289   MS_LOG(ERROR) << "Unique op is not supported for Android.";
290   return nullptr;
291 }
292 #endif
293 }  // namespace transforms
294 }  // namespace dataset
295 }  // namespace mindspore
296