1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <utility>
19
20 #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
21
22 // Kernel data headers (in alphabetical order)
23 #include "minddata/dataset/kernels/data/compose_op.h"
24 #ifndef ENABLE_ANDROID
25 #include "minddata/dataset/kernels/data/concatenate_op.h"
26 #endif
27 #include "minddata/dataset/kernels/data/duplicate_op.h"
28 #ifndef ENABLE_ANDROID
29 #include "minddata/dataset/kernels/data/fill_op.h"
30 #include "minddata/dataset/kernels/data/mask_op.h"
31 #endif
32 #include "minddata/dataset/kernels/data/one_hot_op.h"
33 #ifndef ENABLE_ANDROID
34 #include "minddata/dataset/kernels/data/pad_end_op.h"
35 #endif
36 #include "minddata/dataset/kernels/data/random_apply_op.h"
37 #include "minddata/dataset/kernels/data/random_choice_op.h"
38 #ifndef ENABLE_ANDROID
39 #include "minddata/dataset/kernels/data/slice_op.h"
40 #endif
41 #include "minddata/dataset/kernels/data/type_cast_op.h"
42
43 #ifndef ENABLE_ANDROID
44 #include "minddata/dataset/kernels/data/unique_op.h"
45 #include "minddata/dataset/kernels/plugin_op.h"
46 #endif
47
48 #include "minddata/dataset/kernels/ir/validators.h"
49 #ifdef ENABLE_PYTHON
50 #include "minddata/dataset/kernels/py_func_op.h"
51 #endif
52
53 namespace mindspore {
54 namespace dataset {
55 // Transform operations for data.
56 namespace transforms {
57 /* ####################################### Derived TensorOperation classes ################################# */
58
59 // (In alphabetical order)
60
61 // ComposeOperation
ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> & transforms)62 ComposeOperation::ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
63 : transforms_(transforms) {}
64
ValidateParams()65 Status ComposeOperation::ValidateParams() {
66 RETURN_IF_NOT_OK(ValidateVectorTransforms("Compose", transforms_));
67 return Status::OK();
68 }
69
Build()70 std::shared_ptr<TensorOp> ComposeOperation::Build() {
71 std::vector<std::shared_ptr<TensorOp>> tensor_ops;
72 (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
73 [](const auto &op) -> std::shared_ptr<TensorOp> { return op->Build(); });
74 return std::make_shared<ComposeOp>(tensor_ops);
75 }
76
77 #ifndef ENABLE_ANDROID
78 // ConcatenateOperation
ConcatenateOperation(int8_t axis,const std::shared_ptr<Tensor> & prepend,const std::shared_ptr<Tensor> & append)79 ConcatenateOperation::ConcatenateOperation(int8_t axis, const std::shared_ptr<Tensor> &prepend,
80 const std::shared_ptr<Tensor> &append)
81 : axis_(axis), prepend_(prepend), append_(append) {}
82
ValidateParams()83 Status ConcatenateOperation::ValidateParams() {
84 if (axis_ != 0 && axis_ != -1) {
85 std::string err_msg = "Concatenate: Only 1D concatenation supported.";
86 MS_LOG(ERROR) << err_msg;
87 RETURN_STATUS_SYNTAX_ERROR(err_msg);
88 }
89 if (prepend_) {
90 if (prepend_->shape().Size() != 1) {
91 std::string err_msg = "Concatenate: Can only prepend 1D arrays.";
92 MS_LOG(ERROR) << err_msg;
93 RETURN_STATUS_SYNTAX_ERROR(err_msg);
94 }
95 }
96 if (append_) {
97 if (append_->shape().Size() != 1) {
98 std::string err_msg = "Concatenate: Can only append 1D arrays.";
99 MS_LOG(ERROR) << err_msg;
100 RETURN_STATUS_SYNTAX_ERROR(err_msg);
101 }
102 }
103 return Status::OK();
104 }
105
Build()106 std::shared_ptr<TensorOp> ConcatenateOperation::Build() {
107 return std::make_shared<ConcatenateOp>(axis_, prepend_, append_);
108 }
109 #endif
110
111 // DuplicateOperation
ValidateParams()112 Status DuplicateOperation::ValidateParams() { return Status::OK(); }
113
Build()114 std::shared_ptr<TensorOp> DuplicateOperation::Build() { return std::make_shared<DuplicateOp>(); }
115
116 #ifndef ENABLE_ANDROID
117
118 // FillOperation
FillOperation(const std::shared_ptr<Tensor> & fill_value)119 FillOperation::FillOperation(const std::shared_ptr<Tensor> &fill_value) : fill_value_(fill_value) {}
120
ValidateParams()121 Status FillOperation::ValidateParams() {
122 if (fill_value_->shape() != TensorShape::CreateScalar()) {
123 std::string err_msg = "Fill: fill_value is not a scalar tensor.";
124 MS_LOG(ERROR) << err_msg;
125 RETURN_STATUS_SYNTAX_ERROR(err_msg);
126 }
127
128 return Status::OK();
129 }
130
Build()131 std::shared_ptr<TensorOp> FillOperation::Build() { return std::make_shared<FillOp>(fill_value_); }
132
to_json(nlohmann::json * out_json)133 Status FillOperation::to_json(nlohmann::json *out_json) {
134 RETURN_IF_NOT_OK(fill_value_->to_json(out_json));
135 return Status::OK();
136 }
137
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)138 Status FillOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
139 std::shared_ptr<Tensor> fill_value;
140 RETURN_IF_NOT_OK(Tensor::from_json(op_params, &fill_value));
141 *operation = std::make_shared<transforms::FillOperation>(fill_value);
142 return Status::OK();
143 }
144
145 // MaskOperation
MaskOperation(RelationalOp op,const std::shared_ptr<Tensor> & constant,const DataType & dtype)146 MaskOperation::MaskOperation(RelationalOp op, const std::shared_ptr<Tensor> &constant, const DataType &dtype)
147 : op_(op), constant_(constant), dtype_(dtype) {}
148
ValidateParams()149 Status MaskOperation::ValidateParams() {
150 if (!dtype_.IsBool() && !dtype_.IsFloat() && !dtype_.IsInt()) {
151 std::string err_msg = "Mask: Only supports bool or numeric datatype for generated mask type.";
152 MS_LOG(ERROR) << err_msg;
153 RETURN_STATUS_SYNTAX_ERROR(err_msg);
154 }
155 return Status::OK();
156 }
157
Build()158 std::shared_ptr<TensorOp> MaskOperation::Build() { return std::make_shared<MaskOp>(op_, constant_, dtype_); }
159 #endif
160
161 // OneHotOperation
OneHotOperation(int32_t num_classes)162 OneHotOperation::OneHotOperation(int32_t num_classes) : num_classes_(num_classes) {}
163
ValidateParams()164 Status OneHotOperation::ValidateParams() {
165 if (num_classes_ <= 0) {
166 std::string err_msg = "OneHot: Number of classes must be greater than 0, but got: " + std::to_string(num_classes_);
167 MS_LOG(ERROR) << err_msg;
168 RETURN_STATUS_SYNTAX_ERROR(err_msg);
169 }
170
171 return Status::OK();
172 }
173
Build()174 std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
175
to_json(nlohmann::json * out_json)176 Status OneHotOperation::to_json(nlohmann::json *out_json) {
177 nlohmann::json args;
178 args["num_classes"] = num_classes_;
179 *out_json = args;
180 return Status::OK();
181 }
182
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)183 Status OneHotOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
184 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_classes") != op_params.end(), "Failed tofind num_classes");
185 int32_t num_classes = op_params["num_classes"];
186 *operation = std::make_shared<transforms::OneHotOperation>(num_classes);
187 return Status::OK();
188 }
189
190 #ifndef ENABLE_ANDROID
191 // PadEndOperation
PadEndOperation(const TensorShape & pad_shape,const std::shared_ptr<Tensor> & pad_value)192 PadEndOperation::PadEndOperation(const TensorShape &pad_shape, const std::shared_ptr<Tensor> &pad_value)
193 : pad_shape_(pad_shape), pad_value_(pad_value) {}
194
ValidateParams()195 Status PadEndOperation::ValidateParams() { return Status::OK(); }
196
Build()197 std::shared_ptr<TensorOp> PadEndOperation::Build() { return std::make_shared<PadEndOp>(pad_shape_, pad_value_); }
198 #endif
199
200 // PreBuiltOperation
PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op)201 PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(std::move(tensor_op)) {
202 #ifdef ENABLE_PYTHON
203 auto pyfunc_tensor_op = std::dynamic_pointer_cast<PyFuncOp>(tensor_op);
204 if (pyfunc_tensor_op && pyfunc_tensor_op->IsRandom()) random_op_ = true;
205 #endif
206 }
207
ValidateParams()208 Status PreBuiltOperation::ValidateParams() { return Status::OK(); }
209
Build()210 std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
211
Name() const212 std::string PreBuiltOperation::Name() const { return op_ ? op_->Name() : kPreBuiltOperation; }
213
to_json(nlohmann::json * out_json)214 Status PreBuiltOperation::to_json(nlohmann::json *out_json) {
215 RETURN_IF_NOT_OK(op_->to_json(out_json));
216 return Status::OK();
217 }
218
219 // RandomApplyOperation
RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> & transforms,double prob)220 RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
221 : TensorOperation(true), transforms_(transforms), prob_(prob) {}
222
ValidateParams()223 Status RandomApplyOperation::ValidateParams() {
224 RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_));
225 RETURN_IF_NOT_OK(ValidateProbability("RandomApply", prob_));
226 return Status::OK();
227 }
228
Build()229 std::shared_ptr<TensorOp> RandomApplyOperation::Build() {
230 std::vector<std::shared_ptr<TensorOp>> tensor_ops;
231 (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
232 [](std::shared_ptr<TensorOperation> op) -> std::shared_ptr<TensorOp> { return op->Build(); });
233 return std::make_shared<RandomApplyOp>(tensor_ops, prob_);
234 }
235
236 // RandomChoiceOperation
RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> & transforms)237 RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
238 : TensorOperation(true), transforms_(transforms) {}
239
ValidateParams()240 Status RandomChoiceOperation::ValidateParams() {
241 RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_));
242 return Status::OK();
243 }
244
Build()245 std::shared_ptr<TensorOp> RandomChoiceOperation::Build() {
246 std::vector<std::shared_ptr<TensorOp>> tensor_ops;
247 (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
248 [](const auto &op) -> std::shared_ptr<TensorOp> { return op->Build(); });
249 return std::make_shared<RandomChoiceOp>(tensor_ops);
250 }
251
252 #ifndef ENABLE_ANDROID
253 // SliceOperation
SliceOperation(const std::vector<SliceOption> & slice_input)254 SliceOperation::SliceOperation(const std::vector<SliceOption> &slice_input) : slice_input_(slice_input) {}
255
ValidateParams()256 Status SliceOperation::ValidateParams() { return Status::OK(); }
257
Build()258 std::shared_ptr<TensorOp> SliceOperation::Build() { return std::make_shared<SliceOp>(slice_input_); }
259 #endif
260
261 // TypeCastOperation
262 // DataType data_type - required for C++ API
TypeCastOperation(const DataType & data_type)263 TypeCastOperation::TypeCastOperation(const DataType &data_type) : data_type_(data_type) {}
264
265 // std::string data_type - required for Pybind
TypeCastOperation(const std::string & data_type)266 TypeCastOperation::TypeCastOperation(const std::string &data_type) {
267 // Convert from string to DEType
268 DataType temp_data_type(data_type);
269 data_type_ = temp_data_type;
270 }
271
ValidateParams()272 Status TypeCastOperation::ValidateParams() {
273 if (data_type_ == DataType::DE_UNKNOWN) {
274 std::string err_msg = "TypeCast: Invalid data type";
275 MS_LOG(ERROR) << err_msg;
276 RETURN_STATUS_SYNTAX_ERROR(err_msg);
277 }
278 return Status::OK();
279 }
280
Build()281 std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); }
282
to_json(nlohmann::json * out_json)283 Status TypeCastOperation::to_json(nlohmann::json *out_json) {
284 nlohmann::json args;
285 args["data_type"] = data_type_.ToString();
286 *out_json = args;
287 return Status::OK();
288 }
289
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)290 Status TypeCastOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
291 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("data_type") != op_params.end(), "Failed tofind data_type");
292 std::string data_type = op_params["data_type"];
293 *operation = std::make_shared<transforms::TypeCastOperation>(data_type);
294 return Status::OK();
295 }
296
297 #ifndef ENABLE_ANDROID
298 // UniqueOperation
ValidateParams()299 Status UniqueOperation::ValidateParams() { return Status::OK(); }
300
Build()301 std::shared_ptr<TensorOp> UniqueOperation::Build() { return std::make_shared<UniqueOp>(); }
ValidateParams()302 Status PluginOperation::ValidateParams() {
303 std::string err_msg;
304 err_msg += lib_path_.empty() ? "lib_path is empty, please specify a path to .so file. " : "";
305 err_msg += func_name_.empty() ? "func_name_ is empty, please specify function name to load." : "";
306 if (!err_msg.empty()) {
307 RETURN_STATUS_SYNTAX_ERROR(err_msg);
308 }
309 return Status::OK();
310 }
Build()311 std::shared_ptr<TensorOp> PluginOperation::Build() {
312 return std::make_shared<PluginOp>(lib_path_, func_name_, user_args_);
313 }
314 #endif
315 } // namespace transforms
316 } // namespace dataset
317 } // namespace mindspore
318