1 /**
2 * Copyright 2020-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
18
19 #include <algorithm>
20 #include <utility>
21
22 #ifndef ENABLE_ANDROID
23 #include "minddata/dataset/engine/serdes.h"
24 #endif
25
26 // Kernel data headers (in alphabetical order)
27 #include "minddata/dataset/kernels/data/compose_op.h"
28 #ifndef ENABLE_ANDROID
29 #include "minddata/dataset/kernels/data/concatenate_op.h"
30 #endif
31 #include "minddata/dataset/kernels/data/duplicate_op.h"
32 #ifndef ENABLE_ANDROID
33 #include "minddata/dataset/kernels/data/fill_op.h"
34 #include "minddata/dataset/kernels/data/mask_op.h"
35 #endif
36 #include "minddata/dataset/kernels/data/one_hot_op.h"
37 #ifndef ENABLE_ANDROID
38 #include "minddata/dataset/kernels/data/pad_end_op.h"
39 #include "minddata/dataset/kernels/data/parse_example_op.h"
40 #endif
41 #include "minddata/dataset/kernels/data/random_apply_op.h"
42 #include "minddata/dataset/kernels/data/random_choice_op.h"
43 #ifndef ENABLE_ANDROID
44 #include "minddata/dataset/kernels/data/slice_op.h"
45 #endif
46 #include "minddata/dataset/kernels/data/type_cast_op.h"
47 #ifndef ENABLE_ANDROID
48 #include "minddata/dataset/kernels/data/unique_op.h"
49 #endif
50
51 #include "minddata/dataset/kernels/ir/validators.h"
52 #ifndef ENABLE_ANDROID
53 #include "minddata/dataset/kernels/plugin_op.h"
54 #endif
55 #ifdef ENABLE_PYTHON
56 #include "minddata/dataset/kernels/py_func_op.h"
57 #endif
58 #include "minddata/dataset/util/validators.h"
59
60 namespace mindspore {
61 namespace dataset {
62 // Transform operations for data.
63 namespace transforms {
64 /* ####################################### Derived TensorOperation classes ################################# */
65 // (In alphabetical order)
66
67 // ComposeOperation
ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> & transforms)68 ComposeOperation::ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
69 : transforms_(transforms) {}
70
ValidateParams()71 Status ComposeOperation::ValidateParams() {
72 RETURN_IF_NOT_OK(ValidateVectorTransforms("Compose", transforms_));
73 return Status::OK();
74 }
75
Build()76 std::shared_ptr<TensorOp> ComposeOperation::Build() {
77 std::vector<std::shared_ptr<TensorOp>> tensor_ops;
78 (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
79 [](const auto &op) -> std::shared_ptr<TensorOp> { return op->Build(); });
80 return std::make_shared<ComposeOp>(tensor_ops);
81 }
82
to_json(nlohmann::json * out_json)83 Status ComposeOperation::to_json(nlohmann::json *out_json) {
84 RETURN_UNEXPECTED_IF_NULL(out_json);
85 auto transforms = nlohmann::json::array();
86 for (auto &tensor_operation : transforms_) {
87 nlohmann::json tensor_op, args;
88 RETURN_IF_NOT_OK(tensor_operation->to_json(&args));
89 tensor_op["tensor_op_params"] = args;
90 tensor_op["tensor_op_name"] = tensor_operation->Name();
91 transforms.push_back(tensor_op);
92 }
93 (*out_json)["transforms"] = transforms;
94 return Status::OK();
95 }
96
97 #ifndef ENABLE_ANDROID
from_json(const nlohmann::json & op_params,std::shared_ptr<TensorOperation> * operation)98 Status ComposeOperation::from_json(const nlohmann::json &op_params, std::shared_ptr<TensorOperation> *operation) {
99 RETURN_UNEXPECTED_IF_NULL(operation);
100 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "transforms", kComposeOperation));
101 nlohmann::json transforms = op_params["transforms"];
102 std::vector<std::shared_ptr<TensorOperation>> operations;
103 RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(transforms, &operations));
104 *operation = std::make_shared<transforms::ComposeOperation>(operations);
105 return Status::OK();
106 }
107
108 // ConcatenateOperation
ConcatenateOperation(int8_t axis,const std::shared_ptr<Tensor> & prepend,const std::shared_ptr<Tensor> & append)109 ConcatenateOperation::ConcatenateOperation(int8_t axis, const std::shared_ptr<Tensor> &prepend,
110 const std::shared_ptr<Tensor> &append)
111 : axis_(axis), prepend_(prepend), append_(append) {}
112
ValidateParams()113 Status ConcatenateOperation::ValidateParams() {
114 if (axis_ != 0 && axis_ != -1) {
115 std::string err_msg =
116 "Concatenate: Only 1D concatenation supported, input 'axis' should be 0 or -1, but got:" + std::to_string(axis_);
117 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
118 }
119 if (prepend_) {
120 if (prepend_->shape().Size() != 1) {
121 std::string err_msg = "Concatenate: Can only prepend 1D arrays, rank of input 'prepend' should be 1, but got:" +
122 std::to_string(prepend_->shape().Size());
123 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
124 }
125 }
126 if (append_) {
127 if (append_->shape().Size() != 1) {
128 std::string err_msg = "Concatenate: Can only append 1D arrays, rank of input 'append' should be 1, but got:" +
129 std::to_string(append_->shape().Size());
130 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
131 }
132 }
133 return Status::OK();
134 }
135
Build()136 std::shared_ptr<TensorOp> ConcatenateOperation::Build() {
137 return std::make_shared<ConcatenateOp>(axis_, prepend_, append_);
138 }
139
to_json(nlohmann::json * out_json)140 Status ConcatenateOperation::to_json(nlohmann::json *out_json) {
141 RETURN_UNEXPECTED_IF_NULL(out_json);
142 nlohmann::json args;
143 args["axis"] = axis_;
144 nlohmann::json prepend;
145 nlohmann::json append;
146 RETURN_IF_NOT_OK(prepend_->to_json(&prepend));
147 RETURN_IF_NOT_OK(append_->to_json(&append));
148 args["prepend"] = prepend;
149 args["append"] = append;
150 *out_json = args;
151 return Status::OK();
152 }
153
from_json(const nlohmann::json & op_params,std::shared_ptr<TensorOperation> * operation)154 Status ConcatenateOperation::from_json(const nlohmann::json &op_params, std::shared_ptr<TensorOperation> *operation) {
155 RETURN_UNEXPECTED_IF_NULL(operation);
156 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "axis", kConcatenateOperation));
157 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "prepend", kConcatenateOperation));
158 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "append", kConcatenateOperation));
159 int8_t axis = op_params["axis"];
160 std::shared_ptr<Tensor> prepend;
161 std::shared_ptr<Tensor> append;
162 RETURN_IF_NOT_OK(Tensor::from_json(op_params["prepend"], &prepend));
163 RETURN_IF_NOT_OK(Tensor::from_json(op_params["append"], &append));
164 *operation = std::make_shared<transforms::ConcatenateOperation>(axis, prepend, append);
165 return Status::OK();
166 }
167 #endif
168
169 // DuplicateOperation
ValidateParams()170 Status DuplicateOperation::ValidateParams() { return Status::OK(); }
171
Build()172 std::shared_ptr<TensorOp> DuplicateOperation::Build() { return std::make_shared<DuplicateOp>(); }
173
from_json(const nlohmann::json & op_params,std::shared_ptr<TensorOperation> * operation)174 Status DuplicateOperation::from_json(const nlohmann::json &op_params, std::shared_ptr<TensorOperation> *operation) {
175 RETURN_UNEXPECTED_IF_NULL(operation);
176 *operation = std::make_shared<transforms::DuplicateOperation>();
177 return Status::OK();
178 }
179
180 #ifndef ENABLE_ANDROID
181 // FillOperation
FillOperation(const std::shared_ptr<Tensor> & fill_value)182 FillOperation::FillOperation(const std::shared_ptr<Tensor> &fill_value) : fill_value_(fill_value) {}
183
ValidateParams()184 Status FillOperation::ValidateParams() {
185 if (fill_value_->shape() != TensorShape::CreateScalar()) {
186 std::string err_msg = "Fill: fill_value is not a scalar tensor, got shape:" + fill_value_->shape().ToString();
187 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
188 }
189
190 return Status::OK();
191 }
192
Build()193 std::shared_ptr<TensorOp> FillOperation::Build() { return std::make_shared<FillOp>(fill_value_); }
194
to_json(nlohmann::json * out_json)195 Status FillOperation::to_json(nlohmann::json *out_json) {
196 RETURN_IF_NOT_OK(fill_value_->to_json(out_json));
197 return Status::OK();
198 }
199
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)200 Status FillOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
201 RETURN_UNEXPECTED_IF_NULL(operation);
202 std::shared_ptr<Tensor> fill_value;
203 RETURN_IF_NOT_OK(Tensor::from_json(op_params, &fill_value));
204 *operation = std::make_shared<transforms::FillOperation>(fill_value);
205 return Status::OK();
206 }
207
208 // MaskOperation
MaskOperation(RelationalOp op,const std::shared_ptr<Tensor> & constant,const DataType & dtype)209 MaskOperation::MaskOperation(RelationalOp op, const std::shared_ptr<Tensor> &constant, const DataType &dtype)
210 : op_(op), constant_(constant), dtype_(dtype) {}
211
ValidateParams()212 Status MaskOperation::ValidateParams() {
213 if (!dtype_.IsBool() && !dtype_.IsFloat() && !dtype_.IsInt()) {
214 std::string err_msg =
215 "Mask: Only supports bool or numeric datatype for generated mask type, but got:" + dtype_.ToString();
216 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
217 }
218 return Status::OK();
219 }
220
Build()221 std::shared_ptr<TensorOp> MaskOperation::Build() { return std::make_shared<MaskOp>(op_, constant_, dtype_); }
222
to_json(nlohmann::json * out_json)223 Status MaskOperation::to_json(nlohmann::json *out_json) {
224 RETURN_UNEXPECTED_IF_NULL(out_json);
225 nlohmann::json args;
226 args["op"] = op_;
227 nlohmann::json constant;
228 RETURN_IF_NOT_OK(constant_->to_json(&constant));
229 args["constant"] = constant;
230 args["dtype"] = dtype_.value();
231 *out_json = args;
232 return Status::OK();
233 }
234
from_json(const nlohmann::json & op_params,std::shared_ptr<TensorOperation> * operation)235 Status MaskOperation::from_json(const nlohmann::json &op_params, std::shared_ptr<TensorOperation> *operation) {
236 RETURN_UNEXPECTED_IF_NULL(operation);
237 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "op", kMaskOperation));
238 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "constant", kMaskOperation));
239 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "dtype", kMaskOperation));
240 RelationalOp op = op_params["op"];
241 std::shared_ptr<Tensor> constant;
242 RETURN_IF_NOT_OK(Tensor::from_json(op_params["constant"], &constant));
243 auto dtype = DataType(static_cast<DataType::Type>(op_params["dtype"]));
244 *operation = std::make_shared<transforms::MaskOperation>(op, constant, dtype);
245 return Status::OK();
246 }
247 #endif
248
249 // OneHotOperation
OneHotOperation(int32_t num_classes,double smoothing_rate)250 OneHotOperation::OneHotOperation(int32_t num_classes, double smoothing_rate)
251 : num_classes_(num_classes), smoothing_rate_(smoothing_rate) {}
252
ValidateParams()253 Status OneHotOperation::ValidateParams() {
254 if (num_classes_ <= 0) {
255 std::string err_msg = "OneHot: Number of classes must be greater than 0, but got: " + std::to_string(num_classes_);
256 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
257 }
258 if (smoothing_rate_ < 0.0 || smoothing_rate_ > 1.0) {
259 std::string err_msg = "OneHot: Smoothing rate must be between 0 and 1, but got: " + std::to_string(smoothing_rate_);
260 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
261 }
262 return Status::OK();
263 }
264
Build()265 std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_, smoothing_rate_); }
266
to_json(nlohmann::json * out_json)267 Status OneHotOperation::to_json(nlohmann::json *out_json) {
268 RETURN_UNEXPECTED_IF_NULL(out_json);
269 nlohmann::json args;
270 args["num_classes"] = num_classes_;
271 args["smoothing_rate"] = smoothing_rate_;
272
273 *out_json = args;
274 return Status::OK();
275 }
276
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)277 Status OneHotOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
278 RETURN_UNEXPECTED_IF_NULL(operation);
279 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "num_classes", kOneHotOperation));
280 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "smoothing_rate", kOneHotOperation));
281 int32_t num_classes = op_params["num_classes"];
282 double smoothing_rate = op_params["smoothing_rate"];
283 *operation = std::make_shared<transforms::OneHotOperation>(num_classes, smoothing_rate);
284 return Status::OK();
285 }
286
287 #ifndef ENABLE_ANDROID
288 // PadEndOperation
PadEndOperation(const TensorShape & pad_shape,const std::shared_ptr<Tensor> & pad_value)289 PadEndOperation::PadEndOperation(const TensorShape &pad_shape, const std::shared_ptr<Tensor> &pad_value)
290 : pad_shape_(pad_shape), pad_value_(pad_value) {}
291
ValidateParams()292 Status PadEndOperation::ValidateParams() { return Status::OK(); }
293
Build()294 std::shared_ptr<TensorOp> PadEndOperation::Build() { return std::make_shared<PadEndOp>(pad_shape_, pad_value_); }
295
to_json(nlohmann::json * out_json)296 Status PadEndOperation::to_json(nlohmann::json *out_json) {
297 RETURN_UNEXPECTED_IF_NULL(out_json);
298 nlohmann::json args;
299 args["pad_shape"] = pad_shape_.AsVector();
300 nlohmann::json pad_value;
301 RETURN_IF_NOT_OK(pad_value_->to_json(&pad_value));
302 args["pad_value"] = pad_value;
303 *out_json = args;
304 return Status::OK();
305 }
306
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)307 Status PadEndOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
308 RETURN_UNEXPECTED_IF_NULL(operation);
309 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "pad_shape", kPadEndOperation));
310 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "pad_value", kPadEndOperation));
311 std::vector<dsize_t> shape_vector = op_params["pad_shape"];
312 TensorShape pad_shape = TensorShape(shape_vector);
313 std::shared_ptr<Tensor> pad_value;
314 RETURN_IF_NOT_OK(Tensor::from_json(op_params["pad_value"], &pad_value));
315 *operation = std::make_shared<transforms::PadEndOperation>(pad_shape, pad_value);
316 return Status::OK();
317 }
318
319 #if !defined(_WIN32) && !defined(_WIN64)
320 // ParseExampleOperation
ParseExampleOperation(DataSchema schema,std::vector<std::string> column_list,bool parallel_parse)321 ParseExampleOperation::ParseExampleOperation(DataSchema schema, std::vector<std::string> column_list,
322 bool parallel_parse)
323 : schema_(std::move(schema)), column_list_(std::move(column_list)), parallel_parse_(parallel_parse) {}
324
Build()325 std::shared_ptr<TensorOp> ParseExampleOperation::Build() {
326 return std::make_shared<ParseExampleOp>(schema_, column_list_, parallel_parse_);
327 }
328 #endif
329 #endif
330
331 // PreBuiltOperation
PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op)332 PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(std::move(tensor_op)) {
333 #ifdef ENABLE_PYTHON
334 auto pyfunc_tensor_op = std::dynamic_pointer_cast<PyFuncOp>(op_);
335 if (pyfunc_tensor_op && pyfunc_tensor_op->IsRandom()) {
336 random_op_ = true;
337 }
338 #endif
339 }
340
ValidateParams()341 Status PreBuiltOperation::ValidateParams() { return Status::OK(); }
342
Build()343 std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
344
Name() const345 std::string PreBuiltOperation::Name() const { return op_ ? op_->Name() : kPreBuiltOperation; }
346
to_json(nlohmann::json * out_json)347 Status PreBuiltOperation::to_json(nlohmann::json *out_json) {
348 RETURN_IF_NOT_OK(op_->to_json(out_json));
349 return Status::OK();
350 }
351
352 // RandomApplyOperation
RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> & transforms,double prob)353 RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
354 : TensorOperation(true), transforms_(transforms), prob_(prob) {}
355
ValidateParams()356 Status RandomApplyOperation::ValidateParams() {
357 RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_));
358 RETURN_IF_NOT_OK(ValidateProbability("RandomApply", prob_));
359 return Status::OK();
360 }
361
Build()362 std::shared_ptr<TensorOp> RandomApplyOperation::Build() {
363 std::vector<std::shared_ptr<TensorOp>> tensor_ops;
364 (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
365 [](std::shared_ptr<TensorOperation> op) -> std::shared_ptr<TensorOp> { return op->Build(); });
366 return std::make_shared<RandomApplyOp>(tensor_ops, prob_);
367 }
368
to_json(nlohmann::json * out_json)369 Status RandomApplyOperation::to_json(nlohmann::json *out_json) {
370 RETURN_UNEXPECTED_IF_NULL(out_json);
371 auto transforms = nlohmann::json::array();
372 for (auto &tensor_operation : transforms_) {
373 nlohmann::json tensor_op, args;
374 RETURN_IF_NOT_OK(tensor_operation->to_json(&args));
375 tensor_op["tensor_op_params"] = args;
376 tensor_op["tensor_op_name"] = tensor_operation->Name();
377 transforms.push_back(tensor_op);
378 }
379 (*out_json)["transforms"] = transforms;
380 (*out_json)["prob"] = prob_;
381 return Status::OK();
382 }
383
384 #ifndef ENABLE_ANDROID
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)385 Status RandomApplyOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
386 RETURN_UNEXPECTED_IF_NULL(operation);
387 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "transforms", kRandomApplyOperation));
388 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "prob", kRandomApplyOperation));
389 nlohmann::json transforms = op_params["transforms"];
390 std::vector<std::shared_ptr<TensorOperation>> operations;
391 RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(transforms, &operations));
392 double prob = op_params["prob"];
393 *operation = std::make_shared<transforms::RandomApplyOperation>(operations, prob);
394 return Status::OK();
395 }
396 #endif
397
398 // RandomChoiceOperation
RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> & transforms)399 RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
400 : TensorOperation(true), transforms_(transforms) {}
401
ValidateParams()402 Status RandomChoiceOperation::ValidateParams() {
403 RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_));
404 return Status::OK();
405 }
406
Build()407 std::shared_ptr<TensorOp> RandomChoiceOperation::Build() {
408 std::vector<std::shared_ptr<TensorOp>> tensor_ops;
409 (void)std::transform(transforms_.begin(), transforms_.end(), std::back_inserter(tensor_ops),
410 [](const auto &op) -> std::shared_ptr<TensorOp> { return op->Build(); });
411 return std::make_shared<RandomChoiceOp>(tensor_ops);
412 }
413
to_json(nlohmann::json * out_json)414 Status RandomChoiceOperation::to_json(nlohmann::json *out_json) {
415 RETURN_UNEXPECTED_IF_NULL(out_json);
416 auto transforms = nlohmann::json::array();
417 for (auto &tensor_operation : transforms_) {
418 nlohmann::json tensor_op, args;
419 RETURN_IF_NOT_OK(tensor_operation->to_json(&args));
420 tensor_op["tensor_op_params"] = args;
421 tensor_op["tensor_op_name"] = tensor_operation->Name();
422 transforms.push_back(tensor_op);
423 }
424 (*out_json)["transforms"] = transforms;
425 return Status::OK();
426 }
427
428 #ifndef ENABLE_ANDROID
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)429 Status RandomChoiceOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
430 RETURN_UNEXPECTED_IF_NULL(operation);
431 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "transforms", kRandomChoiceOperation));
432 nlohmann::json transforms = op_params["transforms"];
433 std::vector<std::shared_ptr<TensorOperation>> operations;
434 RETURN_IF_NOT_OK(Serdes::ConstructTensorOps(transforms, &operations));
435 *operation = std::make_shared<transforms::RandomChoiceOperation>(operations);
436 return Status::OK();
437 }
438
439 // SliceOperation
SliceOperation(const std::vector<SliceOption> & slice_input)440 SliceOperation::SliceOperation(const std::vector<SliceOption> &slice_input) : slice_input_(slice_input) {}
441
ValidateParams()442 Status SliceOperation::ValidateParams() { return Status::OK(); }
443
Build()444 std::shared_ptr<TensorOp> SliceOperation::Build() { return std::make_shared<SliceOp>(slice_input_); }
445 #endif
446
447 // TypeCastOperation
448 // DataType data_type - required for C++ API
TypeCastOperation(const DataType & data_type)449 TypeCastOperation::TypeCastOperation(const DataType &data_type) : data_type_(data_type) {}
450
451 // std::string data_type - required for Pybind
TypeCastOperation(const std::string & data_type)452 TypeCastOperation::TypeCastOperation(const std::string &data_type) {
453 // Convert from string to DEType
454 DataType temp_data_type(data_type);
455 data_type_ = temp_data_type;
456 }
457
ValidateParams()458 Status TypeCastOperation::ValidateParams() {
459 if (data_type_ == DataType::DE_UNKNOWN) {
460 std::string err_msg = "TypeCast: Invalid data type";
461 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
462 }
463 return Status::OK();
464 }
465
Build()466 std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); }
467
to_json(nlohmann::json * out_json)468 Status TypeCastOperation::to_json(nlohmann::json *out_json) {
469 RETURN_UNEXPECTED_IF_NULL(out_json);
470 nlohmann::json args;
471 args["data_type"] = data_type_.ToString();
472 *out_json = args;
473 return Status::OK();
474 }
475
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)476 Status TypeCastOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
477 RETURN_UNEXPECTED_IF_NULL(operation);
478 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "data_type", kTypeCastOperation));
479 std::string data_type = op_params["data_type"];
480 *operation = std::make_shared<transforms::TypeCastOperation>(data_type);
481 return Status::OK();
482 }
483
484 #ifndef ENABLE_ANDROID
485 // UniqueOperation
ValidateParams()486 Status UniqueOperation::ValidateParams() { return Status::OK(); }
487
Build()488 std::shared_ptr<TensorOp> UniqueOperation::Build() { return std::make_shared<UniqueOp>(); }
489
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)490 Status UniqueOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
491 RETURN_UNEXPECTED_IF_NULL(operation);
492 *operation = std::make_shared<transforms::UniqueOperation>();
493 return Status::OK();
494 }
495
ValidateParams()496 Status PluginOperation::ValidateParams() {
497 std::string err_msg;
498 err_msg += lib_path_.empty() ? "lib_path is empty, please specify a path to .so file. " : "";
499 err_msg += func_name_.empty() ? "func_name_ is empty, please specify function name to load." : "";
500 if (!err_msg.empty()) {
501 LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
502 }
503 return Status::OK();
504 }
Build()505 std::shared_ptr<TensorOp> PluginOperation::Build() {
506 return std::make_shared<PluginOp>(lib_path_, func_name_, user_args_);
507 }
508
to_json(nlohmann::json * out_json)509 Status PluginOperation::to_json(nlohmann::json *out_json) {
510 RETURN_UNEXPECTED_IF_NULL(out_json);
511 nlohmann::json args;
512 args["lib_path"] = lib_path_;
513 args["func_name"] = func_name_;
514 args["user_args"] = user_args_;
515 *out_json = args;
516 return Status::OK();
517 }
518
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)519 Status PluginOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
520 RETURN_UNEXPECTED_IF_NULL(operation);
521 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "lib_path", kPluginOperation));
522 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "func_name", kPluginOperation));
523 RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "user_args", kPluginOperation));
524 std::string lib_path = op_params["lib_path"];
525 std::string func_name = op_params["func_name"];
526 std::string user_args = op_params["user_args"];
527 *operation = std::make_shared<transforms::PluginOperation>(lib_path, func_name, user_args);
528 return Status::OK();
529 }
530 #endif
531 } // namespace transforms
532 } // namespace dataset
533 } // namespace mindspore
534