• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <utility>
19 
20 #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
21 
22 // Kernel data headers (in alphabetical order)
23 #include "minddata/dataset/kernels/data/compose_op.h"
24 #ifndef ENABLE_ANDROID
25 #include "minddata/dataset/kernels/data/concatenate_op.h"
26 #endif
27 #include "minddata/dataset/kernels/data/duplicate_op.h"
28 #ifndef ENABLE_ANDROID
29 #include "minddata/dataset/kernels/data/fill_op.h"
30 #include "minddata/dataset/kernels/data/mask_op.h"
31 #endif
32 #include "minddata/dataset/kernels/data/one_hot_op.h"
33 #ifndef ENABLE_ANDROID
34 #include "minddata/dataset/kernels/data/pad_end_op.h"
35 #endif
36 #include "minddata/dataset/kernels/data/random_apply_op.h"
37 #include "minddata/dataset/kernels/data/random_choice_op.h"
38 #ifndef ENABLE_ANDROID
39 #include "minddata/dataset/kernels/data/slice_op.h"
40 #endif
41 #include "minddata/dataset/kernels/data/type_cast_op.h"
42 
43 #ifndef ENABLE_ANDROID
44 #include "minddata/dataset/kernels/data/unique_op.h"
45 #include "minddata/dataset/kernels/plugin_op.h"
46 #endif
47 
48 #include "minddata/dataset/kernels/ir/validators.h"
49 #ifdef ENABLE_PYTHON
50 #include "minddata/dataset/kernels/py_func_op.h"
51 #endif
52 
53 namespace mindspore {
54 namespace dataset {
55 // Transform operations for data.
56 namespace transforms {
57 /* ####################################### Derived TensorOperation classes ################################# */
58 
59 // (In alphabetical order)
60 
61 // ComposeOperation
ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> & transforms)62 ComposeOperation::ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
63     : transforms_(transforms) {}
64 
ValidateParams()65 Status ComposeOperation::ValidateParams() {
66   RETURN_IF_NOT_OK(ValidateVectorTransforms("Compose", transforms_));
67   return Status::OK();
68 }
69 
Build()70 std::shared_ptr<TensorOp> ComposeOperation::Build() {
71   std::vector<std::shared_ptr<TensorOp>> tensor_ops;
72   (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
73                        [](const auto &op) -> std::shared_ptr<TensorOp> { return op->Build(); });
74   return std::make_shared<ComposeOp>(tensor_ops);
75 }
76 
77 #ifndef ENABLE_ANDROID
78 // ConcatenateOperation
ConcatenateOperation(int8_t axis,const std::shared_ptr<Tensor> & prepend,const std::shared_ptr<Tensor> & append)79 ConcatenateOperation::ConcatenateOperation(int8_t axis, const std::shared_ptr<Tensor> &prepend,
80                                            const std::shared_ptr<Tensor> &append)
81     : axis_(axis), prepend_(prepend), append_(append) {}
82 
ValidateParams()83 Status ConcatenateOperation::ValidateParams() {
84   if (axis_ != 0 && axis_ != -1) {
85     std::string err_msg = "Concatenate: Only 1D concatenation supported.";
86     MS_LOG(ERROR) << err_msg;
87     RETURN_STATUS_SYNTAX_ERROR(err_msg);
88   }
89   if (prepend_) {
90     if (prepend_->shape().Size() != 1) {
91       std::string err_msg = "Concatenate: Can only prepend 1D arrays.";
92       MS_LOG(ERROR) << err_msg;
93       RETURN_STATUS_SYNTAX_ERROR(err_msg);
94     }
95   }
96   if (append_) {
97     if (append_->shape().Size() != 1) {
98       std::string err_msg = "Concatenate: Can only append 1D arrays.";
99       MS_LOG(ERROR) << err_msg;
100       RETURN_STATUS_SYNTAX_ERROR(err_msg);
101     }
102   }
103   return Status::OK();
104 }
105 
Build()106 std::shared_ptr<TensorOp> ConcatenateOperation::Build() {
107   return std::make_shared<ConcatenateOp>(axis_, prepend_, append_);
108 }
109 #endif
110 
111 // DuplicateOperation
ValidateParams()112 Status DuplicateOperation::ValidateParams() { return Status::OK(); }
113 
Build()114 std::shared_ptr<TensorOp> DuplicateOperation::Build() { return std::make_shared<DuplicateOp>(); }
115 
116 #ifndef ENABLE_ANDROID
117 
118 // FillOperation
FillOperation(const std::shared_ptr<Tensor> & fill_value)119 FillOperation::FillOperation(const std::shared_ptr<Tensor> &fill_value) : fill_value_(fill_value) {}
120 
ValidateParams()121 Status FillOperation::ValidateParams() {
122   if (fill_value_->shape() != TensorShape::CreateScalar()) {
123     std::string err_msg = "Fill: fill_value is not a scalar tensor.";
124     MS_LOG(ERROR) << err_msg;
125     RETURN_STATUS_SYNTAX_ERROR(err_msg);
126   }
127 
128   return Status::OK();
129 }
130 
Build()131 std::shared_ptr<TensorOp> FillOperation::Build() { return std::make_shared<FillOp>(fill_value_); }
132 
to_json(nlohmann::json * out_json)133 Status FillOperation::to_json(nlohmann::json *out_json) {
134   RETURN_IF_NOT_OK(fill_value_->to_json(out_json));
135   return Status::OK();
136 }
137 
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)138 Status FillOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
139   std::shared_ptr<Tensor> fill_value;
140   RETURN_IF_NOT_OK(Tensor::from_json(op_params, &fill_value));
141   *operation = std::make_shared<transforms::FillOperation>(fill_value);
142   return Status::OK();
143 }
144 
145 // MaskOperation
MaskOperation(RelationalOp op,const std::shared_ptr<Tensor> & constant,const DataType & dtype)146 MaskOperation::MaskOperation(RelationalOp op, const std::shared_ptr<Tensor> &constant, const DataType &dtype)
147     : op_(op), constant_(constant), dtype_(dtype) {}
148 
ValidateParams()149 Status MaskOperation::ValidateParams() {
150   if (!dtype_.IsBool() && !dtype_.IsFloat() && !dtype_.IsInt()) {
151     std::string err_msg = "Mask: Only supports bool or numeric datatype for generated mask type.";
152     MS_LOG(ERROR) << err_msg;
153     RETURN_STATUS_SYNTAX_ERROR(err_msg);
154   }
155   return Status::OK();
156 }
157 
Build()158 std::shared_ptr<TensorOp> MaskOperation::Build() { return std::make_shared<MaskOp>(op_, constant_, dtype_); }
159 #endif
160 
161 // OneHotOperation
OneHotOperation(int32_t num_classes)162 OneHotOperation::OneHotOperation(int32_t num_classes) : num_classes_(num_classes) {}
163 
ValidateParams()164 Status OneHotOperation::ValidateParams() {
165   if (num_classes_ <= 0) {
166     std::string err_msg = "OneHot: Number of classes must be greater than 0, but got: " + std::to_string(num_classes_);
167     MS_LOG(ERROR) << err_msg;
168     RETURN_STATUS_SYNTAX_ERROR(err_msg);
169   }
170 
171   return Status::OK();
172 }
173 
Build()174 std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
175 
to_json(nlohmann::json * out_json)176 Status OneHotOperation::to_json(nlohmann::json *out_json) {
177   nlohmann::json args;
178   args["num_classes"] = num_classes_;
179   *out_json = args;
180   return Status::OK();
181 }
182 
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)183 Status OneHotOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
184   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("num_classes") != op_params.end(), "Failed tofind num_classes");
185   int32_t num_classes = op_params["num_classes"];
186   *operation = std::make_shared<transforms::OneHotOperation>(num_classes);
187   return Status::OK();
188 }
189 
190 #ifndef ENABLE_ANDROID
191 // PadEndOperation
PadEndOperation(const TensorShape & pad_shape,const std::shared_ptr<Tensor> & pad_value)192 PadEndOperation::PadEndOperation(const TensorShape &pad_shape, const std::shared_ptr<Tensor> &pad_value)
193     : pad_shape_(pad_shape), pad_value_(pad_value) {}
194 
ValidateParams()195 Status PadEndOperation::ValidateParams() { return Status::OK(); }
196 
Build()197 std::shared_ptr<TensorOp> PadEndOperation::Build() { return std::make_shared<PadEndOp>(pad_shape_, pad_value_); }
198 #endif
199 
200 // PreBuiltOperation
PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op)201 PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(std::move(tensor_op)) {
202 #ifdef ENABLE_PYTHON
203   auto pyfunc_tensor_op = std::dynamic_pointer_cast<PyFuncOp>(tensor_op);
204   if (pyfunc_tensor_op && pyfunc_tensor_op->IsRandom()) random_op_ = true;
205 #endif
206 }
207 
ValidateParams()208 Status PreBuiltOperation::ValidateParams() { return Status::OK(); }
209 
Build()210 std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
211 
Name() const212 std::string PreBuiltOperation::Name() const { return op_ ? op_->Name() : kPreBuiltOperation; }
213 
to_json(nlohmann::json * out_json)214 Status PreBuiltOperation::to_json(nlohmann::json *out_json) {
215   RETURN_IF_NOT_OK(op_->to_json(out_json));
216   return Status::OK();
217 }
218 
219 // RandomApplyOperation
RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> & transforms,double prob)220 RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
221     : TensorOperation(true), transforms_(transforms), prob_(prob) {}
222 
ValidateParams()223 Status RandomApplyOperation::ValidateParams() {
224   RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_));
225   RETURN_IF_NOT_OK(ValidateProbability("RandomApply", prob_));
226   return Status::OK();
227 }
228 
Build()229 std::shared_ptr<TensorOp> RandomApplyOperation::Build() {
230   std::vector<std::shared_ptr<TensorOp>> tensor_ops;
231   (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
232                        [](std::shared_ptr<TensorOperation> op) -> std::shared_ptr<TensorOp> { return op->Build(); });
233   return std::make_shared<RandomApplyOp>(tensor_ops, prob_);
234 }
235 
236 // RandomChoiceOperation
RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> & transforms)237 RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
238     : TensorOperation(true), transforms_(transforms) {}
239 
ValidateParams()240 Status RandomChoiceOperation::ValidateParams() {
241   RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_));
242   return Status::OK();
243 }
244 
Build()245 std::shared_ptr<TensorOp> RandomChoiceOperation::Build() {
246   std::vector<std::shared_ptr<TensorOp>> tensor_ops;
247   (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
248                        [](const auto &op) -> std::shared_ptr<TensorOp> { return op->Build(); });
249   return std::make_shared<RandomChoiceOp>(tensor_ops);
250 }
251 
252 #ifndef ENABLE_ANDROID
253 // SliceOperation
SliceOperation(const std::vector<SliceOption> & slice_input)254 SliceOperation::SliceOperation(const std::vector<SliceOption> &slice_input) : slice_input_(slice_input) {}
255 
ValidateParams()256 Status SliceOperation::ValidateParams() { return Status::OK(); }
257 
Build()258 std::shared_ptr<TensorOp> SliceOperation::Build() { return std::make_shared<SliceOp>(slice_input_); }
259 #endif
260 
261 // TypeCastOperation
262 // DataType data_type - required for C++ API
TypeCastOperation(const DataType & data_type)263 TypeCastOperation::TypeCastOperation(const DataType &data_type) : data_type_(data_type) {}
264 
265 // std::string data_type - required for Pybind
TypeCastOperation(const std::string & data_type)266 TypeCastOperation::TypeCastOperation(const std::string &data_type) {
267   // Convert from string to DEType
268   DataType temp_data_type(data_type);
269   data_type_ = temp_data_type;
270 }
271 
ValidateParams()272 Status TypeCastOperation::ValidateParams() {
273   if (data_type_ == DataType::DE_UNKNOWN) {
274     std::string err_msg = "TypeCast: Invalid data type";
275     MS_LOG(ERROR) << err_msg;
276     RETURN_STATUS_SYNTAX_ERROR(err_msg);
277   }
278   return Status::OK();
279 }
280 
Build()281 std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); }
282 
to_json(nlohmann::json * out_json)283 Status TypeCastOperation::to_json(nlohmann::json *out_json) {
284   nlohmann::json args;
285   args["data_type"] = data_type_.ToString();
286   *out_json = args;
287   return Status::OK();
288 }
289 
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)290 Status TypeCastOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
291   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("data_type") != op_params.end(), "Failed tofind data_type");
292   std::string data_type = op_params["data_type"];
293   *operation = std::make_shared<transforms::TypeCastOperation>(data_type);
294   return Status::OK();
295 }
296 
297 #ifndef ENABLE_ANDROID
298 // UniqueOperation
ValidateParams()299 Status UniqueOperation::ValidateParams() { return Status::OK(); }
300 
Build()301 std::shared_ptr<TensorOp> UniqueOperation::Build() { return std::make_shared<UniqueOp>(); }
ValidateParams()302 Status PluginOperation::ValidateParams() {
303   std::string err_msg;
304   err_msg += lib_path_.empty() ? "lib_path is empty, please specify a path to .so file. " : "";
305   err_msg += func_name_.empty() ? "func_name_ is empty, please specify function name to load." : "";
306   if (!err_msg.empty()) {
307     RETURN_STATUS_SYNTAX_ERROR(err_msg);
308   }
309   return Status::OK();
310 }
Build()311 std::shared_ptr<TensorOp> PluginOperation::Build() {
312   return std::make_shared<PluginOp>(lib_path_, func_name_, user_args_);
313 }
314 #endif
315 }  // namespace transforms
316 }  // namespace dataset
317 }  // namespace mindspore
318