1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/include/dataset/transforms.h"
18
19 #include <algorithm>
20
21 #include "mindspore/ccsrc/minddata/dataset/core/type_id.h"
22 #include "mindspore/core/ir/dtype/type_id.h"
23 #include "minddata/dataset/core/type_id.h"
24 #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
25
26 namespace mindspore {
27 namespace dataset {
28
29 // Transform operations for data.
30 namespace transforms {
31
32 // API CLASS FOR DATA TRANSFORM OPERATIONS
33 // (In alphabetical order)
34
35 // Constructor to Compose.
36 struct Compose::Data {
37 std::vector<std::shared_ptr<TensorOperation>> transforms_;
38 };
39
Compose(const std::vector<TensorTransform * > & transforms)40 Compose::Compose(const std::vector<TensorTransform *> &transforms) : data_(std::make_shared<Data>()) {
41 (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
42 [](TensorTransform *const op) -> std::shared_ptr<TensorOperation> {
43 return op != nullptr ? op->Parse() : nullptr;
44 });
45 }
46
Compose(const std::vector<std::shared_ptr<TensorTransform>> & transforms)47 Compose::Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms) : data_(std::make_shared<Data>()) {
48 (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
49 [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
50 return op != nullptr ? op->Parse() : nullptr;
51 });
52 }
53
Compose(const std::vector<std::reference_wrapper<TensorTransform>> & transforms)54 Compose::Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms)
55 : data_(std::make_shared<Data>()) {
56 (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
57 [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
58 }
59
Parse()60 std::shared_ptr<TensorOperation> Compose::Parse() { return std::make_shared<ComposeOperation>(data_->transforms_); }
61
62 // Constructor to Concatenate
63 struct Concatenate::Data {
Datamindspore::dataset::transforms::Concatenate::Data64 explicit Data(int8_t axis, const MSTensor &prepend, const MSTensor &append)
65 : axis_(axis), prepend_(prepend), append_(append) {}
66 int8_t axis_;
67 MSTensor prepend_;
68 MSTensor append_;
69 };
70
Concatenate(int8_t axis,const MSTensor & prepend,const MSTensor & append)71 Concatenate::Concatenate(int8_t axis, const MSTensor &prepend, const MSTensor &append)
72 : data_(std::make_shared<Data>(axis, prepend, append)) {}
73
Parse()74 std::shared_ptr<TensorOperation> Concatenate::Parse() {
75 #ifndef ENABLE_ANDROID
76 std::shared_ptr<Tensor> out_prepend, out_append;
77 Status rc = Tensor::CreateFromMSTensor(data_->prepend_, &out_prepend);
78 if (rc.IsError()) {
79 MS_LOG(ERROR) << "Error creating prepend constant tensor. " << rc;
80 return nullptr;
81 }
82 rc = Tensor::CreateFromMSTensor(data_->append_, &out_append);
83 if (rc.IsError()) {
84 MS_LOG(ERROR) << "Error creating append constant tensor. " << rc;
85 return nullptr;
86 }
87 return std::make_shared<ConcatenateOperation>(data_->axis_, out_prepend, out_append);
88 #else
89 MS_LOG(ERROR) << "Concatenate op is not supported for Android.";
90 return nullptr;
91 #endif // not ENABLE_ANDROID
92 }
93
94 // Constructor to Duplicate
Duplicate()95 Duplicate::Duplicate() {}
96
Parse()97 std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<DuplicateOperation>(); }
98
99 // Constructor to Fill
100 struct Fill::Data {
Datamindspore::dataset::transforms::Fill::Data101 explicit Data(const MSTensor &fill_value) : fill_value_(fill_value) {}
102 MSTensor fill_value_;
103 };
104
Fill(const MSTensor & fill_value)105 Fill::Fill(const MSTensor &fill_value) : data_(std::make_shared<Data>(fill_value)) {}
106
Parse()107 std::shared_ptr<TensorOperation> Fill::Parse() {
108 #ifndef ENABLE_ANDROID
109 std::shared_ptr<Tensor> out_fill_value;
110 Status rc = Tensor::CreateFromMSTensor(data_->fill_value_, &out_fill_value);
111 if (rc.IsError()) {
112 MS_LOG(ERROR) << "Error creating fill value tensor. " << rc;
113 return nullptr;
114 }
115 return std::make_shared<FillOperation>(out_fill_value);
116 #else
117 MS_LOG(ERROR) << "Fill op is not supported for Android.";
118 return nullptr;
119 #endif // not ENABLE_ANDROID
120 }
121
122 // Constructor to Mask
123 struct Mask::Data {
Datamindspore::dataset::transforms::Mask::Data124 explicit Data(RelationalOp op, const MSTensor &constant, mindspore::DataType ms_type)
125 : op_(op), constant_(constant), ms_type_(ms_type) {}
126 RelationalOp op_;
127 MSTensor constant_;
128 mindspore::DataType ms_type_;
129 };
130
Mask(RelationalOp op,const MSTensor & constant,mindspore::DataType ms_type)131 Mask::Mask(RelationalOp op, const MSTensor &constant, mindspore::DataType ms_type)
132 : data_(std::make_shared<Data>(op, constant, ms_type)) {}
133
Parse()134 std::shared_ptr<TensorOperation> Mask::Parse() {
135 #ifndef ENABLE_ANDROID
136 std::shared_ptr<Tensor> out_constant;
137 Status rc = Tensor::CreateFromMSTensor(data_->constant_, &out_constant);
138 if (rc.IsError()) {
139 MS_LOG(ERROR) << "Error creating constant tensor. " << rc;
140 return nullptr;
141 }
142
143 DataType de_type = dataset::MSTypeToDEType(static_cast<TypeId>(data_->ms_type_));
144 return std::make_shared<MaskOperation>(data_->op_, out_constant, de_type);
145 #else
146 MS_LOG(ERROR) << "Mask op is not supported for Android.";
147 return nullptr;
148 #endif // not ENABLE_ANDROID
149 }
150
151 // Constructor to OneHot
152 struct OneHot::Data {
Datamindspore::dataset::transforms::OneHot::Data153 explicit Data(int32_t num_classes) : num_classes_(num_classes) {}
154 int32_t num_classes_;
155 };
156
OneHot(int32_t num_classes)157 OneHot::OneHot(int32_t num_classes) : data_(std::make_shared<Data>(num_classes)) {}
158
Parse()159 std::shared_ptr<TensorOperation> OneHot::Parse() { return std::make_shared<OneHotOperation>(data_->num_classes_); }
160
161 // Constructor to PadEnd
162 struct PadEnd::Data {
Datamindspore::dataset::transforms::PadEnd::Data163 explicit Data(const std::vector<dsize_t> &pad_shape, const MSTensor &pad_value)
164 : pad_shape_(pad_shape), pad_value_(pad_value) {}
165 std::vector<dsize_t> pad_shape_;
166 MSTensor pad_value_;
167 };
168
PadEnd(const std::vector<dsize_t> & pad_shape,const MSTensor & pad_value)169 PadEnd::PadEnd(const std::vector<dsize_t> &pad_shape, const MSTensor &pad_value)
170 : data_(std::make_shared<Data>(pad_shape, pad_value)) {}
171
Parse()172 std::shared_ptr<TensorOperation> PadEnd::Parse() {
173 #ifndef ENABLE_ANDROID
174 std::shared_ptr<Tensor> pad_value;
175 Status rc = Tensor::CreateFromMSTensor(data_->pad_value_, &pad_value);
176 if (rc.IsError()) {
177 MS_LOG(ERROR) << "Error creating value constant tensor. " << rc;
178 return nullptr;
179 }
180 return std::make_shared<PadEndOperation>(TensorShape(data_->pad_shape_), pad_value);
181 #else
182 MS_LOG(ERROR) << "PadEnd op is not supported for Android.";
183 return nullptr;
184 #endif // not ENABLE_ANDROID
185 }
186
187 // Constructor to RandomApply.
188 struct RandomApply::Data {
189 std::vector<std::shared_ptr<TensorOperation>> transforms_;
190 double prob_;
191 };
192
RandomApply(const std::vector<TensorTransform * > & transforms,double prob)193 RandomApply::RandomApply(const std::vector<TensorTransform *> &transforms, double prob)
194 : data_(std::make_shared<Data>()) {
195 (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
196 [](TensorTransform *const op) -> std::shared_ptr<TensorOperation> {
197 return op != nullptr ? op->Parse() : nullptr;
198 });
199 data_->prob_ = prob;
200 }
201
RandomApply(const std::vector<std::shared_ptr<TensorTransform>> & transforms,double prob)202 RandomApply::RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob)
203 : data_(std::make_shared<Data>()) {
204 (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
205 [](std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
206 return op != nullptr ? op->Parse() : nullptr;
207 });
208 data_->prob_ = prob;
209 }
210
RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> & transforms,double prob)211 RandomApply::RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob)
212 : data_(std::make_shared<Data>()) {
213 (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
214 [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
215 data_->prob_ = prob;
216 }
217
Parse()218 std::shared_ptr<TensorOperation> RandomApply::Parse() {
219 return std::make_shared<RandomApplyOperation>(data_->transforms_, data_->prob_);
220 }
221
222 // Constructor to RandomChoice.
223 struct RandomChoice::Data {
224 std::vector<std::shared_ptr<TensorOperation>> transforms_;
225 };
226
RandomChoice(const std::vector<TensorTransform * > & transforms)227 RandomChoice::RandomChoice(const std::vector<TensorTransform *> &transforms) : data_(std::make_shared<Data>()) {
228 (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
229 [](TensorTransform *const op) -> std::shared_ptr<TensorOperation> {
230 return op != nullptr ? op->Parse() : nullptr;
231 });
232 }
233
RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> & transforms)234 RandomChoice::RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms)
235 : data_(std::make_shared<Data>()) {
236 (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
237 [](const std::shared_ptr<TensorTransform> op) -> std::shared_ptr<TensorOperation> {
238 return op != nullptr ? op->Parse() : nullptr;
239 });
240 }
241
RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> & transforms)242 RandomChoice::RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms)
243 : data_(std::make_shared<Data>()) {
244 (void)std::transform(transforms.begin(), transforms.end(), std::back_inserter(data_->transforms_),
245 [](TensorTransform &op) -> std::shared_ptr<TensorOperation> { return op.Parse(); });
246 }
247
Parse()248 std::shared_ptr<TensorOperation> RandomChoice::Parse() {
249 return std::make_shared<RandomChoiceOperation>(data_->transforms_);
250 }
251
252 // Constructor to Slice
253 struct Slice::Data {
Datamindspore::dataset::transforms::Slice::Data254 explicit Data(const std::vector<SliceOption> &slice_input) : slice_input_(slice_input) {}
255 std::vector<SliceOption> slice_input_;
256 };
257
Slice(const std::vector<SliceOption> & slice_input)258 Slice::Slice(const std::vector<SliceOption> &slice_input) : data_(std::make_shared<Data>(slice_input)) {}
259
Parse()260 std::shared_ptr<TensorOperation> Slice::Parse() {
261 #ifndef ENABLE_ANDROID
262 return std::make_shared<SliceOperation>(data_->slice_input_);
263 #else
264 MS_LOG(ERROR) << "Slice op is not supported for Android.";
265 return nullptr;
266 #endif // not ENABLE_ANDROID
267 }
268
269 // Constructor to TypeCast
270 struct TypeCast::Data {
271 dataset::DataType data_type_;
272 };
273
TypeCast(mindspore::DataType data_type)274 TypeCast::TypeCast(mindspore::DataType data_type) : data_(std::make_shared<Data>()) {
275 data_->data_type_ = dataset::MSTypeToDEType(static_cast<TypeId>(data_type));
276 }
277
Parse()278 std::shared_ptr<TensorOperation> TypeCast::Parse() { return std::make_shared<TypeCastOperation>(data_->data_type_); }
279
280 // Constructor to Unique
Unique()281 Unique::Unique() {}
282
283 #ifndef ENABLE_ANDROID
Parse()284 std::shared_ptr<TensorOperation> Unique::Parse() { return std::make_shared<UniqueOperation>(); }
285 #else
Parse()286 std::shared_ptr<TensorOperation> Unique::Parse() {
287 MS_LOG(ERROR) << "Unique op is not supported for Android.";
288 return nullptr;
289 }
290 #endif
291 } // namespace transforms
292 } // namespace dataset
293 } // namespace mindspore
294